Skip to content

Commit

Permalink
tools: add pubmed tool
Browse files Browse the repository at this point in the history
  • Loading branch information
longkeyy committed Oct 4, 2024
1 parent 60fa95d commit 6bad095
Show file tree
Hide file tree
Showing 3 changed files with 428 additions and 0 deletions.
203 changes: 203 additions & 0 deletions tools/pubmed/internal/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package internal

import (
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)

// Client defines an HTTP client for communicating with PubMed.
type Client struct {
MaxResults int
UserAgent string
BaseURL string
}

// Result defines a search query result type.
type Result struct {
Title string
Authors []string
Abstract string
PMID string
Published string
}

var (
ErrNoGoodResult = errors.New("no good search results found")
ErrAPIResponse = errors.New("PubMed API responded with error")
)

// NewClient initializes a Client with arguments for setting a max
// results per search query and a value for the user agent header.
func NewClient(maxResults int, userAgent string) *Client {
if maxResults == 0 {
maxResults = 1
}

return &Client{
MaxResults: maxResults,
UserAgent: userAgent,
BaseURL: "https://eutils.ncbi.nlm.nih.gov/entrez/eutils",
}
}

func (client *Client) newRequest(ctx context.Context, queryURL string) (*http.Request, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, queryURL, nil)
if err != nil {
return nil, fmt.Errorf("creating PubMed request: %w", err)
}

if client.UserAgent != "" {
request.Header.Add("User-Agent", client.UserAgent)
}

return request, nil
}

// Search performs a search query and returns
// the result as string and an error if any.
func (client *Client) Search(ctx context.Context, query string) (string, error) {
// First, search for IDs
searchURL := fmt.Sprintf("%s/esearch.fcgi?db=pubmed&term=%s&retmax=%d&usehistory=y",
client.BaseURL, url.QueryEscape(query), client.MaxResults)

request, err := client.newRequest(ctx, searchURL)
if err != nil {
return "", err
}

response, err := http.DefaultClient.Do(request)
if err != nil {
return "", fmt.Errorf("get %s error: %w", searchURL, err)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return "", ErrAPIResponse
}

body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
}

var searchResult struct {
IDList struct {
IDs []string `xml:"Id"`
} `xml:"IdList"`
WebEnv string `xml:"WebEnv"`
QueryKey string `xml:"QueryKey"`
}

if err := xml.Unmarshal(body, &searchResult); err != nil {
return "", fmt.Errorf("unmarshaling XML: %w", err)
}

if len(searchResult.IDList.IDs) == 0 {
return "", ErrNoGoodResult
}

// Now fetch details for these IDs
fetchURL := fmt.Sprintf("%s/efetch.fcgi?db=pubmed&WebEnv=%s&query_key=%s&retmode=xml&rettype=abstract",
client.BaseURL, searchResult.WebEnv, searchResult.QueryKey)

request, err = client.newRequest(ctx, fetchURL)
if err != nil {
return "", err
}

response, err = http.DefaultClient.Do(request)
if err != nil {
return "", fmt.Errorf("get %s error: %w", fetchURL, err)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return "", ErrAPIResponse
}

body, err = io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("reading response body: %w", err)
}

var fetchResult struct {
Articles []struct {
MedlineCitation struct {
Article struct {
ArticleTitle string `xml:"ArticleTitle"`
Abstract struct {
AbstractText string `xml:"AbstractText"`
} `xml:"Abstract"`
AuthorList struct {
Authors []struct {
LastName string `xml:"LastName"`
ForeName string `xml:"ForeName"`
Initials string `xml:"Initials"`
} `xml:"Author"`
} `xml:"AuthorList"`
} `xml:"Article"`
PMID string `xml:"PMID"`
} `xml:"MedlineCitation"`
PubmedData struct {
History struct {
PubMedPubDate []struct {
Year string `xml:"Year"`
Month string `xml:"Month"`
Day string `xml:"Day"`
} `xml:"PubMedPubDate"`
} `xml:"History"`
} `xml:"PubmedData"`
} `xml:"PubmedArticle"`
}

if err := xml.Unmarshal(body, &fetchResult); err != nil {
return "", fmt.Errorf("unmarshaling XML: %w", err)
}

results := []Result{}
for _, article := range fetchResult.Articles {
authors := []string{}
for _, author := range article.MedlineCitation.Article.AuthorList.Authors {
authors = append(authors, fmt.Sprintf("%s %s", author.ForeName, author.LastName))
}

var pubDate time.Time
for _, date := range article.PubmedData.History.PubMedPubDate {
if pubDate, err = time.Parse("2006-1-2", fmt.Sprintf("%s-%s-%s", date.Year, date.Month, date.Day)); err == nil {
break
}
}

results = append(results, Result{
Title: article.MedlineCitation.Article.ArticleTitle,
Authors: authors,
Abstract: article.MedlineCitation.Article.Abstract.AbstractText,
PMID: article.MedlineCitation.PMID,
Published: pubDate.Format("2006-01-02"),
})
}

return client.formatResults(results), nil
}

// formatResults will return a structured string with the results.
func (client *Client) formatResults(results []Result) string {
var formattedResults strings.Builder

for _, result := range results {
formattedResults.WriteString(fmt.Sprintf("Title: %s\n", result.Title))
formattedResults.WriteString(fmt.Sprintf("Authors: %s\n", strings.Join(result.Authors, ", ")))
formattedResults.WriteString(fmt.Sprintf("Abstract: %s\n", result.Abstract))
formattedResults.WriteString(fmt.Sprintf("PMID: %s\n", result.PMID))
formattedResults.WriteString(fmt.Sprintf("Published: %s\n\n", result.Published))
}

return formattedResults.String()
}
66 changes: 66 additions & 0 deletions tools/pubmed/pubmed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package pubmed

import (
"context"
"errors"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/tools"
"github.com/tmc/langchaingo/tools/pubmed/internal"
)

// DefaultUserAgent defines a default value for user-agent header.
const DefaultUserAgent = "github.com/tmc/langchaingo/tools/pubmed"

// Tool defines a tool implementation for the PubMed Search.
type Tool struct {
CallbacksHandler callbacks.Handler
client *internal.Client
}

var _ tools.Tool = Tool{}

// New initializes a new PubMed Search tool with arguments for setting a
// max results per search query and a value for the user agent header.
func New(maxResults int, userAgent string) (*Tool, error) {
return &Tool{
client: internal.NewClient(maxResults, userAgent),
}, nil
}

// Name returns a name for the tool.
func (t Tool) Name() string {
return "PubMed Search"
}

// Description returns a description for the tool.
func (t Tool) Description() string {
return `
"A wrapper around PubMed Search API."
"Search for biomedical literature from MEDLINE, life science journals, and online books."
"Input should be a search query."`
}

// Call performs the search and return the result.
func (t Tool) Call(ctx context.Context, input string) (string, error) {
if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolStart(ctx, input)
}

result, err := t.client.Search(ctx, input)
if err != nil {
if errors.Is(err, internal.ErrNoGoodResult) {
return "No good PubMed Search Results were found", nil
}
if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolError(ctx, err)
}
return "", err
}

if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolEnd(ctx, result)
}

return result, nil
}
Loading

0 comments on commit 6bad095

Please sign in to comment.