Skip to content

Commit

Permalink
examples: update ollama-functions-examples with ollama tool support
Browse files Browse the repository at this point in the history
  • Loading branch information
treywelsh committed Sep 14, 2024
1 parent b2954d3 commit 2c0b8b2
Showing 1 changed file with 102 additions and 141 deletions.
243 changes: 102 additions & 141 deletions examples/ollama-functions-example/ollama_functions_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,196 +3,157 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"slices"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
)

var flagVerbose = flag.Bool("v", false, "verbose mode")

func main() {
flag.Parse()
// allow specifying your own model via OLLAMA_TEST_MODEL
// (same as the Ollama unit tests).
model := "llama3"
model := "llama3.1"
if v := os.Getenv("OLLAMA_TEST_MODEL"); v != "" {
model = v
}

llm, err := ollama.New(
ollama.WithModel(model),
ollama.WithFormat("json"),
)
llm, err := ollama.New(ollama.WithModel(model))
if err != nil {
log.Fatal(err)
}

var msgs []llms.MessageContent

// system message defines the available tools.
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem, systemMessage()))
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "What's the weather like in Beijing?"))

ctx := context.Background()
messageHistory := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeHuman, "What's the weather like in Beijing and Shenzhen?"),
}

for retries := 3; retries > 0; retries = retries - 1 {
resp, err := llm.GenerateContent(ctx, msgs)
if err != nil {
log.Fatal(err)
}

choice1 := resp.Choices[0]
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeAI, choice1.Content))
fmt.Println("Querying for weather in Beijing and Shenzhen.")
resp, err := llm.GenerateContent(ctx, messageHistory, llms.WithTools(availableTools))
if err != nil {
log.Fatal(err)
}
messageHistory = updateMessageHistory(messageHistory, resp)

if c := unmarshalCall(choice1.Content); c != nil {
log.Printf("Call: %v", c.Tool)
if *flagVerbose {
log.Printf("Call: %v (raw: %v)", c.Tool, choice1.Content)
}
msg, cont := dispatchCall(c)
if !cont {
break
}
msgs = append(msgs, msg)
} else {
// Ollama doesn't always respond with a function call, let it try again.
log.Printf("Not a call: %v", choice1.Content)
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "Sorry, I don't understand. Please try again."))
}
// Execute tool calls requested by the model
messageHistory = executeToolCalls(messageHistory, resp)
messageHistory = append(messageHistory, llms.TextParts(llms.ChatMessageTypeHuman, "Can you compare the two?"))

if retries == 0 {
log.Fatal("retries exhausted")
}
// Send query to the model again, this time with a history containing its
// request to invoke a tool and our response to the tool call.
fmt.Println("Querying with tool response...")
resp, err = llm.GenerateContent(ctx, messageHistory)
if err != nil {
log.Fatal(err)
}
}

type Call struct {
Tool string `json:"tool"`
Input map[string]any `json:"tool_input"`
fmt.Println(showResponse(resp))
}

func unmarshalCall(input string) *Call {
var c Call
if err := json.Unmarshal([]byte(input), &c); err == nil && c.Tool != "" {
return &c
// updateMessageHistory updates the message history with the assistant's
// response and requested tool calls.
func updateMessageHistory(messageHistory []llms.MessageContent, resp *llms.ContentResponse) []llms.MessageContent {
respchoice := resp.Choices[0]

assistantResponse := llms.TextParts(llms.ChatMessageTypeAI, respchoice.Content)
for _, tc := range respchoice.ToolCalls {
assistantResponse.Parts = append(assistantResponse.Parts, tc)
}
return nil
return append(messageHistory, assistantResponse)
}

func dispatchCall(c *Call) (llms.MessageContent, bool) {
// ollama doesn't always respond with a *valid* function call. As we're using prompt
// engineering to inject the tools, it may hallucinate.
if !validTool(c.Tool) {
log.Printf("invalid function call: %#v, prompting model to try again", c)
return llms.TextParts(llms.ChatMessageTypeHuman,
"Tool does not exist, please try again."), true
}
// executeToolCalls executes the tool calls in the response and returns the
// updated message history.
func executeToolCalls(messageHistory []llms.MessageContent, resp *llms.ContentResponse) []llms.MessageContent {
fmt.Println("Executing", len(resp.Choices[0].ToolCalls), "tool calls")
for _, toolCall := range resp.Choices[0].ToolCalls {
switch toolCall.FunctionCall.Name {
case "getCurrentWeather":
var args struct {
Location string `json:"location"`
Unit string `json:"unit"`
}
if err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &args); err != nil {
log.Fatal(err)
}

// we could make this more dynamic, by parsing the function schema.
switch c.Tool {
case "getCurrentWeather":
loc, ok := c.Input["location"].(string)
if !ok {
log.Fatal("invalid input")
}
unit, ok := c.Input["unit"].(string)
if !ok {
log.Fatal("invalid input")
}
response, err := getCurrentWeather(args.Location, args.Unit)
if err != nil {
log.Fatal(err)
}

weather, err := getCurrentWeather(loc, unit)
if err != nil {
log.Fatal(err)
}
return llms.TextParts(llms.ChatMessageTypeHuman, weather), true
case "finalResponse":
resp, ok := c.Input["response"].(string)
if !ok {
log.Fatal("invalid input")
weatherCallResponse := llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.ToolCallResponse{
Name: toolCall.FunctionCall.Name,
Content: response,
},
},
}
messageHistory = append(messageHistory, weatherCallResponse)
default:
log.Fatalf("Unsupported tool: %s", toolCall.FunctionCall.Name)
}

log.Printf("Final response: %v", resp)

return llms.MessageContent{}, false
default:
// we already checked above if we had a valid tool.
panic("unreachable")
}
}

func validTool(name string) bool {
var valid []string
for _, v := range functions {
valid = append(valid, v.Name)
}
return slices.Contains(valid, name)
return messageHistory
}

func systemMessage() string {
bs, err := json.Marshal(functions)
if err != nil {
log.Fatal(err)
}

return fmt.Sprintf(`You have access to the following tools:
%s
To use a tool, respond with a JSON object with the following structure:
{
"tool": <name of the called tool>,
"tool_input": <parameters for the tool matching the above JSON schema>
}
`, string(bs))
type Weather struct {
Location string `json:"location"`
Forecast string `json:"forecast"`
}

func getCurrentWeather(location string, unit string) (string, error) {
weatherInfo := map[string]any{
"location": location,
"temperature": "6",
"unit": unit,
"forecast": []string{"sunny", "windy"},
}
if unit == "fahrenheit" {
weatherInfo["temperature"] = 43

var weatherInfo Weather
switch location {
case "Shenzhen":
weatherInfo = Weather{Location: location, Forecast: "74 and cloudy"}
case "Beijing":
weatherInfo = Weather{Location: location, Forecast: "80 and rainy"}
}

b, err := json.Marshal(weatherInfo)
if err != nil {
return "", err
}

return string(b), nil
}

var functions = []llms.FunctionDefinition{
{
Name: "getCurrentWeather",
Description: "Get the current weather in a given location",
Parameters: json.RawMessage(`{
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
},
"required": ["location", "unit"]
}`),
},
// availableTools simulates the tools/functions we're making available for
// the model.
var availableTools = []llms.Tool{
{
// I found that providing a tool for Ollama to give the final response significantly
// increases the chances of success.
Name: "finalResponse",
Description: "Provide the final response to the user query",
Parameters: json.RawMessage(`{
"type": "object",
"properties": {
"response": {"type": "string", "description": "The final response to the user query"}
},
"required": ["response"]
}`),
Type: "function",
Function: &llms.FunctionDefinition{
Name: "getCurrentWeather",
Description: "Get the current weather in a given location",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{
"type": "string",
"description": "The city, e.g. San Francisco",
},
"unit": map[string]any{
"type": "string",
"enum": []string{"fahrenheit", "celsius"},
},
},
"required": []string{"location", "unit"},
},
},
},
}

func showResponse(resp *llms.ContentResponse) string {
b, err := json.MarshalIndent(resp, "", " ")
if err != nil {
log.Fatal(err)
}
return string(b)
}

0 comments on commit 2c0b8b2

Please sign in to comment.