diff --git a/tools/pubmed/internal/client.go b/tools/pubmed/internal/client.go new file mode 100644 index 000000000..a4899515a --- /dev/null +++ b/tools/pubmed/internal/client.go @@ -0,0 +1,203 @@ +package internal + +import ( + "context" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// Client defines an HTTP client for communicating with PubMed. +type Client struct { + MaxResults int + UserAgent string + BaseURL string +} + +// Result defines a search query result type. +type Result struct { + Title string + Authors []string + Abstract string + PMID string + Published string +} + +var ( + ErrNoGoodResult = errors.New("no good search results found") + ErrAPIResponse = errors.New("PubMed API responded with error") +) + +// NewClient initializes a Client with arguments for setting a max +// results per search query and a value for the user agent header. +func NewClient(maxResults int, userAgent string) *Client { + if maxResults == 0 { + maxResults = 1 + } + + return &Client{ + MaxResults: maxResults, + UserAgent: userAgent, + BaseURL: "https://eutils.ncbi.nlm.nih.gov/entrez/eutils", + } +} + +func (client *Client) newRequest(ctx context.Context, queryURL string) (*http.Request, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, queryURL, nil) + if err != nil { + return nil, fmt.Errorf("creating PubMed request: %w", err) + } + + if client.UserAgent != "" { + request.Header.Add("User-Agent", client.UserAgent) + } + + return request, nil +} + +// Search performs a search query and returns +// the result as string and an error if any. +func (client *Client) Search(ctx context.Context, query string) (string, error) { + // First, search for IDs + searchURL := fmt.Sprintf("%s/esearch.fcgi?db=pubmed&term=%s&retmax=%d&usehistory=y", + client.BaseURL, url.QueryEscape(query), client.MaxResults) + + request, err := client.newRequest(ctx, searchURL) + if err != nil { + return "", err + } + + response, err := http.DefaultClient.Do(request) + if err != nil { + return "", fmt.Errorf("get %s error: %w", searchURL, err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return "", ErrAPIResponse + } + + body, err := io.ReadAll(response.Body) + if err != nil { + return "", fmt.Errorf("reading response body: %w", err) + } + + var searchResult struct { + IDList struct { + IDs []string `xml:"Id"` + } `xml:"IdList"` + WebEnv string `xml:"WebEnv"` + QueryKey string `xml:"QueryKey"` + } + + if err := xml.Unmarshal(body, &searchResult); err != nil { + return "", fmt.Errorf("unmarshaling XML: %w", err) + } + + if len(searchResult.IDList.IDs) == 0 { + return "", ErrNoGoodResult + } + + // Now fetch details for these IDs + fetchURL := fmt.Sprintf("%s/efetch.fcgi?db=pubmed&WebEnv=%s&query_key=%s&retmode=xml&rettype=abstract", + client.BaseURL, searchResult.WebEnv, searchResult.QueryKey) + + request, err = client.newRequest(ctx, fetchURL) + if err != nil { + return "", err + } + + response, err = http.DefaultClient.Do(request) + if err != nil { + return "", fmt.Errorf("get %s error: %w", fetchURL, err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return "", ErrAPIResponse + } + + body, err = io.ReadAll(response.Body) + if err != nil { + return "", fmt.Errorf("reading response body: %w", err) + } + + var fetchResult struct { + Articles []struct { + MedlineCitation struct { + Article struct { + ArticleTitle string `xml:"ArticleTitle"` + Abstract struct { + AbstractText string `xml:"AbstractText"` + } `xml:"Abstract"` + AuthorList struct { + Authors []struct { + LastName string `xml:"LastName"` + ForeName string `xml:"ForeName"` + Initials string `xml:"Initials"` + } `xml:"Author"` + } `xml:"AuthorList"` + } `xml:"Article"` + PMID string `xml:"PMID"` + } `xml:"MedlineCitation"` + PubmedData struct { + History struct { + PubMedPubDate []struct { + Year string `xml:"Year"` + Month string `xml:"Month"` + Day string `xml:"Day"` + } `xml:"PubMedPubDate"` + } `xml:"History"` + } `xml:"PubmedData"` + } `xml:"PubmedArticle"` + } + + if err := xml.Unmarshal(body, &fetchResult); err != nil { + return "", fmt.Errorf("unmarshaling XML: %w", err) + } + + results := []Result{} + for _, article := range fetchResult.Articles { + authors := []string{} + for _, author := range article.MedlineCitation.Article.AuthorList.Authors { + authors = append(authors, fmt.Sprintf("%s %s", author.ForeName, author.LastName)) + } + + var pubDate time.Time + for _, date := range article.PubmedData.History.PubMedPubDate { + if pubDate, err = time.Parse("2006-1-2", fmt.Sprintf("%s-%s-%s", date.Year, date.Month, date.Day)); err == nil { + break + } + } + + results = append(results, Result{ + Title: article.MedlineCitation.Article.ArticleTitle, + Authors: authors, + Abstract: article.MedlineCitation.Article.Abstract.AbstractText, + PMID: article.MedlineCitation.PMID, + Published: pubDate.Format("2006-01-02"), + }) + } + + return client.formatResults(results), nil +} + +// formatResults will return a structured string with the results. +func (client *Client) formatResults(results []Result) string { + var formattedResults strings.Builder + + for _, result := range results { + formattedResults.WriteString(fmt.Sprintf("Title: %s\n", result.Title)) + formattedResults.WriteString(fmt.Sprintf("Authors: %s\n", strings.Join(result.Authors, ", "))) + formattedResults.WriteString(fmt.Sprintf("Abstract: %s\n", result.Abstract)) + formattedResults.WriteString(fmt.Sprintf("PMID: %s\n", result.PMID)) + formattedResults.WriteString(fmt.Sprintf("Published: %s\n\n", result.Published)) + } + + return formattedResults.String() +} diff --git a/tools/pubmed/pubmed.go b/tools/pubmed/pubmed.go new file mode 100644 index 000000000..912e7ff25 --- /dev/null +++ b/tools/pubmed/pubmed.go @@ -0,0 +1,66 @@ +package pubmed + +import ( + "context" + "errors" + + "github.com/tmc/langchaingo/callbacks" + "github.com/tmc/langchaingo/tools" + "github.com/tmc/langchaingo/tools/pubmed/internal" +) + +// DefaultUserAgent defines a default value for user-agent header. +const DefaultUserAgent = "github.com/tmc/langchaingo/tools/pubmed" + +// Tool defines a tool implementation for the PubMed Search. +type Tool struct { + CallbacksHandler callbacks.Handler + client *internal.Client +} + +var _ tools.Tool = Tool{} + +// New initializes a new PubMed Search tool with arguments for setting a +// max results per search query and a value for the user agent header. +func New(maxResults int, userAgent string) (*Tool, error) { + return &Tool{ + client: internal.NewClient(maxResults, userAgent), + }, nil +} + +// Name returns a name for the tool. +func (t Tool) Name() string { + return "PubMed Search" +} + +// Description returns a description for the tool. +func (t Tool) Description() string { + return ` + "A wrapper around PubMed Search API." + "Search for biomedical literature from MEDLINE, life science journals, and online books." + "Input should be a search query."` +} + +// Call performs the search and return the result. +func (t Tool) Call(ctx context.Context, input string) (string, error) { + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolStart(ctx, input) + } + + result, err := t.client.Search(ctx, input) + if err != nil { + if errors.Is(err, internal.ErrNoGoodResult) { + return "No good PubMed Search Results were found", nil + } + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolError(ctx, err) + } + return "", err + } + + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolEnd(ctx, result) + } + + return result, nil +} diff --git a/tools/pubmed/pubmed_test.go b/tools/pubmed/pubmed_test.go new file mode 100644 index 000000000..8cb3ac0e6 --- /dev/null +++ b/tools/pubmed/pubmed_test.go @@ -0,0 +1,159 @@ +// File: tools/pubmed/pubmed_test.go + +package pubmed + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/tools/pubmed/internal" +) + +func TestNew(t *testing.T) { + t.Parallel() + + tool, err := New(5, "TestUserAgent") + require.NoError(t, err) + assert.NotNil(t, tool) + assert.Equal(t, "PubMed Search", tool.Name()) +} + +func TestPubMedTool(t *testing.T) { + t.Parallel() + + // Create a mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if the request URL is correct + assert.Contains(t, r.URL.String(), "/eutils/") + + // Serve a mock XML response + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + if r.URL.Path == "/eutils/esearch.fcgi" { + _, _ = w.Write([]byte(` + + + 12345 + + NCID_1_1234567_130.14.22.215_9001_1234567890_123456789_0MetA0_S_MegaStore + 1 + `)) + } else if r.URL.Path == "/eutils/efetch.fcgi" { + _, _ = w.Write([]byte(` + + + + 12345 +
+ Mock PubMed Article + + This is a mock abstract of a PubMed article. + + + + Doe + John + J + + +
+
+ + + + 2023 + 1 + 1 + + + +
+
`)) + } + })) + defer server.Close() + + // Create a custom client that uses the test server URL + customClient := &internal.Client{ + MaxResults: 1, + UserAgent: "TestAgent", + BaseURL: server.URL + "/eutils", + } + + // Create the PubMed tool with the custom client + tool := &Tool{ + client: customClient, + } + + // Test the Call method + result, err := tool.Call(context.Background(), "test query") + + // Assert that there's no error + require.NoError(t, err) + + // Assert that the result contains expected information + assert.Contains(t, result, "Title: Mock PubMed Article") + assert.Contains(t, result, "Authors: John Doe") + assert.Contains(t, result, "Abstract: This is a mock abstract of a PubMed article.") + assert.Contains(t, result, "PMID: 12345") + assert.Contains(t, result, "Published: 2023-01-01") +} + +func TestPubMedToolNoResults(t *testing.T) { + t.Parallel() + + // Create a mock HTTP server that returns no results + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(` + + + `)) + })) + defer server.Close() + + customClient := &internal.Client{ + MaxResults: 1, + UserAgent: "TestAgent", + BaseURL: server.URL + "/eutils", + } + + tool := &Tool{ + client: customClient, + } + + result, err := tool.Call(context.Background(), "no results query") + + require.NoError(t, err) + require.Equal(t, "No good PubMed Search Results were found", result) +} + +func TestPubMedToolAPIError(t *testing.T) { + t.Parallel() + + // Create a mock HTTP server that returns an error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + customClient := &internal.Client{ + MaxResults: 1, + UserAgent: "TestAgent", + BaseURL: server.URL + "/eutils", + } + + tool := &Tool{ + client: customClient, + } + + _, err := tool.Call(context.Background(), "error query") + + require.Error(t, err) + require.Equal(t, internal.ErrAPIResponse, err) +}