Skip to content

Commit

Permalink
fix(mistral): supports the default llms.WithTools implementation (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglarek authored Sep 13, 2024
1 parent 71160f9 commit 2124f7f
Showing 1 changed file with 58 additions and 19 deletions.
77 changes: 58 additions & 19 deletions llms/mistral/mistralmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,28 @@ func mistralChatParamsFromCallOptions(callOpts *llms.CallOptions) sdk.ChatReques
chatOpts.Temperature = callOpts.Temperature
chatOpts.RandomSeed = callOpts.Seed
chatOpts.Tools = make([]sdk.Tool, 0)
for _, function := range callOpts.Functions {
chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{
Type: "function",
Function: sdk.Function{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
})
if len(callOpts.Tools) > 0 {
for _, tool := range callOpts.Tools {
chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{
Type: "function",
Function: sdk.Function{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters,
},
})
}
} else {
for _, function := range callOpts.Functions {
chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{
Type: "function",
Function: sdk.Function{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
})
}
}
return chatOpts
}
Expand Down Expand Up @@ -159,6 +172,16 @@ func generateNonStreamingContent(ctx context.Context, m *Model, callOptions *llm
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
langchainContentResponse.Choices[idx].FuncCall = (*llms.FunctionCall)(&toolCalls[0].Function)
for _, tool := range toolCalls {
langchainContentResponse.Choices[0].ToolCalls = append(langchainContentResponse.Choices[0].ToolCalls, llms.ToolCall{
ID: tool.Id,
Type: string(tool.Type),
FunctionCall: &llms.FunctionCall{
Name: tool.Function.Name,
Arguments: tool.Function.Arguments,
},
})
}
}
}
m.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, langchainContentResponse)
Expand Down Expand Up @@ -192,6 +215,16 @@ func generateStreamingContent(ctx context.Context, m *Model, callOptions *llms.C
langchainContentResponse.Choices[0].StopReason = string(choice.FinishReason)
if len(choice.Delta.ToolCalls) > 0 {
langchainContentResponse.Choices[0].FuncCall = (*llms.FunctionCall)(&choice.Delta.ToolCalls[0].Function)
for _, tool := range choice.Delta.ToolCalls {
langchainContentResponse.Choices[0].ToolCalls = append(langchainContentResponse.Choices[0].ToolCalls, llms.ToolCall{
ID: tool.Id,
Type: string(tool.Type),
FunctionCall: &llms.FunctionCall{
Name: tool.Function.Name,
Arguments: tool.Function.Arguments,
},
})
}
}
}
err := callOptions.StreamingFunc(ctx, []byte(chunkStr))
Expand All @@ -209,19 +242,25 @@ func generateStreamingContent(ctx context.Context, m *Model, callOptions *llms.C
func convertToMistralChatMessages(langchainMessages []llms.MessageContent) ([]sdk.ChatMessage, error) {
messages := make([]sdk.ChatMessage, 0)
for _, msg := range langchainMessages {
msgText := ""
for _, part := range msg.Parts {
textContent, ok := part.(llms.TextContent)
if !ok {
switch p := part.(type) {
case llms.TextContent:
chatMsg := sdk.ChatMessage{Content: p.Text, Role: string(msg.Role)}
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
if chatMsg.Content != "" && chatMsg.Role != "" {
messages = append(messages, chatMsg)
}
case llms.ToolCallResponse:
chatMsg := sdk.ChatMessage{Role: string(msg.Role), Content: p.Content}
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
messages = append(messages, chatMsg)
case llms.ToolCall:
chatMsg := sdk.ChatMessage{Role: string(msg.Role), ToolCalls: []sdk.ToolCall{{Id: p.ID, Type: sdk.ToolTypeFunction, Function: sdk.FunctionCall{Name: p.FunctionCall.Name, Arguments: p.FunctionCall.Arguments}}}}
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
messages = append(messages, chatMsg)
default:
return nil, errors.New("unsupported content type encountered while preparing chat messages to send to mistral platform")
}
msgText += textContent.Text
}
chatMsg := sdk.ChatMessage{Content: msgText, Role: "user"}

setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
if chatMsg.Content != "" && chatMsg.Role != "" {
messages = append(messages, chatMsg)
}
}
return messages, nil
Expand Down

0 comments on commit 2124f7f

Please sign in to comment.