Skip to content

Commit

Permalink
ollama: support tools
Browse files Browse the repository at this point in the history
  • Loading branch information
treywelsh committed Sep 14, 2024
1 parent 1975058 commit b2954d3
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 57 deletions.
39 changes: 35 additions & 4 deletions llms/ollama/internal/ollamaclient/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,38 @@ type GenerateRequest struct {

type ImageData []byte

type ToolCall struct {
Function ToolCallFunction `json:"function"`
}

type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}

type ToolCallFunctionArguments struct {
Content string
}

func (a *ToolCallFunctionArguments) UnmarshalJSON(b []byte) error {
a.Content = string(b)
return nil
}

type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}

type ChatRequest struct {
Model string `json:"model"`
Messages []*Message `json:"messages"`
Stream bool `json:"stream,omitempty"`
Stream bool `json:"stream"`
Format string `json:"format"`
KeepAlive string `json:"keep_alive,omitempty"`
Tools []Tool `json:"tools,omitempty"`

Options Options `json:"options"`
}
Expand Down Expand Up @@ -168,3 +188,14 @@ type Options struct {
TopP float32 `json:"top_p,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
}

type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}

type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
184 changes: 131 additions & 53 deletions llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ollama
import (
"context"
"errors"
"fmt"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
Expand Down Expand Up @@ -63,52 +64,33 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten

// Our input is a sequence of MessageContent, each of which potentially has
// a sequence of Part that could be text, images etc.
// We have to convert it to a format Ollama undestands: ChatRequest, which
// We have to convert it to a format Ollama understands: ChatRequest, which
// has a sequence of Message, each of which has a role and content - single
// text + potential images.
chatMsgs := make([]*ollamaclient.Message, 0, len(messages))
for _, mc := range messages {
msg := &ollamaclient.Message{Role: typeToRole(mc.Role)}

// Look at all the parts in mc; expect to find a single Text part and
// any number of binary parts.
var text string
foundText := false
var images []ollamaclient.ImageData

for _, p := range mc.Parts {
switch pt := p.(type) {
case llms.TextContent:
if foundText {
return nil, errors.New("expecting a single Text content")
}
foundText = true
text = pt.Text
case llms.BinaryContent:
images = append(images, ollamaclient.ImageData(pt.Data))
default:
return nil, errors.New("only support Text and BinaryContent parts right now")
}
}

msg.Content = text
msg.Images = images
chatMsgs = append(chatMsgs, msg)
chatMsgs, err := makeOllamaMessages(messages)
if err != nil {
return nil, err
}

format := o.options.format
if opts.JSONMode {
format = "json"
}

tools := o.options.tools
if len(opts.Tools) > 0 {
tools = makeOllamaTools(opts.Tools)
}

// Get our ollamaOptions from llms.CallOptions
ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts)
ollamaOptions := makeOllamaOptions(o.options.ollamaOptions, opts)
req := &ollamaclient.ChatRequest{
Model: model,
Format: format,
Messages: chatMsgs,
Options: ollamaOptions,
Stream: opts.StreamingFunc != nil,
Stream: opts.StreamingFunc != nil && len(opts.Tools) == 0,
Tools: tools,
}

keepAlive := o.options.keepAlive
Expand All @@ -129,24 +111,28 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
if response.Message != nil {
streamedResponse += response.Message.Content
}

if !req.Stream || response.Done {
resp = response
resp.Message = &ollamaclient.Message{
Role: "assistant",
Content: streamedResponse,
Role: "assistant",
Content: streamedResponse,
ToolCalls: response.Message.ToolCalls,
}
}
return nil
}

err := o.client.GenerateChat(ctx, req, fn)
err = o.client.GenerateChat(ctx, req, fn)
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

toolCalls := makeLLMSToolCall(resp.Message.ToolCalls)

choices := []*llms.ContentChoice{
{
Content: resp.Message.Content,
Expand All @@ -155,6 +141,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
"PromptTokens": resp.PromptEvalCount,
"TotalTokens": resp.EvalCount + resp.PromptEvalCount,
},
ToolCalls: toolCalls,
},
}

Expand Down Expand Up @@ -198,25 +185,8 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo
return embeddings, nil
}

func typeToRole(typ llms.ChatMessageType) string {
switch typ {
case llms.ChatMessageTypeSystem:
return "system"
case llms.ChatMessageTypeAI:
return "assistant"
case llms.ChatMessageTypeHuman:
fallthrough
case llms.ChatMessageTypeGeneric:
return "user"
case llms.ChatMessageTypeFunction:
return "function"
case llms.ChatMessageTypeTool:
return "tool"
}
return ""
}

func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.CallOptions) ollamaclient.Options {
// makeOllamaOptions make ollamaclient.Options from llms.CallOptions.
func makeOllamaOptions(ollamaOptions ollamaclient.Options, opts llms.CallOptions) ollamaclient.Options {
// Load back CallOptions as ollamaOptions
ollamaOptions.NumPredict = opts.MaxTokens
ollamaOptions.Temperature = float32(opts.Temperature)
Expand All @@ -230,3 +200,111 @@ func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.

return ollamaOptions
}

// makeOllamaTools make ollamaclient.Tool from llms.Tool.
func makeOllamaTools(tools []llms.Tool) []ollamaclient.Tool {
ollamaTools := make([]ollamaclient.Tool, 0, len(tools))
for _, tool := range tools {
functionDef := ollamaclient.ToolFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters,
}
ollamaTools = append(ollamaTools, ollamaclient.Tool{
Type: tool.Type,
Function: functionDef,
})
}
return ollamaTools
}

// makeOllamaMessages make ollamaclient.Message from message llms.MessageContent.
func makeOllamaMessages(messages []llms.MessageContent) ([]*ollamaclient.Message, error) {
chatMsgs := make([]*ollamaclient.Message, 0, len(messages))
for _, mc := range messages {
msg := &ollamaclient.Message{}
switch mc.Role {
case llms.ChatMessageTypeSystem:
msg.Role = "system"
case llms.ChatMessageTypeAI:
msg.Role = "assistant"
case llms.ChatMessageTypeHuman:
fallthrough
case llms.ChatMessageTypeGeneric:
msg.Role = "user"
case llms.ChatMessageTypeFunction:
msg.Role = "function"
case llms.ChatMessageTypeTool:
msg.Role = "tool"

if len(mc.Parts) != 1 {
return nil, fmt.Errorf("expected exactly one part for role %v, got %v", mc.Role, len(mc.Parts))
}
switch p := mc.Parts[0].(type) {
case llms.ToolCallResponse:
msg.Content = p.Content
default:
return nil, fmt.Errorf("expected part of type ToolCallResponse for role %v, got %T", mc.Role, mc.Parts[0])
}
}

text, images, tools, err := makeOllamaContent(mc)
if err != nil {
return nil, err
}

msg.Content = text
msg.Images = images
msg.ToolCalls = tools
chatMsgs = append(chatMsgs, msg)
}

return chatMsgs, nil
}

// makeOllamaContent make ollamaclient Content, ImageData and ToolCall from llms.MessageContent.
func makeOllamaContent(mc llms.MessageContent) (string, []ollamaclient.ImageData, []ollamaclient.ToolCall, error) {
// Look at all the parts in mc; expect to find a single Text part and
// any number of binary parts.
var text string
foundText := false
var images []ollamaclient.ImageData
var tools []ollamaclient.ToolCall
for _, p := range mc.Parts {
switch pt := p.(type) {
case llms.TextContent:
if foundText {
return "", nil, nil, errors.New("expecting a single Text content")
}
foundText = true
text = pt.Text
case llms.BinaryContent:
images = append(images, ollamaclient.ImageData(pt.Data))
case llms.ToolCall:
tools = append(tools, ollamaclient.ToolCall{
Function: ollamaclient.ToolCallFunction{
Name: pt.FunctionCall.Name,
Arguments: ollamaclient.ToolCallFunctionArguments{
Content: pt.FunctionCall.Arguments,
},
},
})
}
}
return text, images, tools, nil
}

// makeLLMSToolCall make llms.ToolCall from ollamaclient.ToolCall.
func makeLLMSToolCall(toolCalls []ollamaclient.ToolCall) []llms.ToolCall {
llmsToolCalls := make([]llms.ToolCall, 0, len(toolCalls))
for _, tool := range toolCalls {
llmsToolCalls = append(llmsToolCalls, llms.ToolCall{
Type: "function",
FunctionCall: &llms.FunctionCall{
Name: tool.Function.Name,
Arguments: tool.Function.Arguments.Content,
},
})
}
return llmsToolCalls
}
8 changes: 8 additions & 0 deletions llms/ollama/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type options struct {
system string
format string
keepAlive string
tools []ollamaclient.Tool
}

type Option func(*options)
Expand Down Expand Up @@ -264,3 +265,10 @@ func WithPredictPenalizeNewline(val bool) Option {
opts.ollamaOptions.PenalizeNewline = val
}
}

// WithTools Provide tools description to the model. The model should support it.
func WithTools(val []ollamaclient.Tool) Option {
return func(opts *options) {
opts.tools = val
}
}

0 comments on commit b2954d3

Please sign in to comment.