diff --git a/llms/mistral/mistralmodel.go b/llms/mistral/mistralmodel.go index 42fbf1bf0..23021682f 100644 --- a/llms/mistral/mistralmodel.go +++ b/llms/mistral/mistralmodel.go @@ -117,15 +117,28 @@ func mistralChatParamsFromCallOptions(callOpts *llms.CallOptions) sdk.ChatReques chatOpts.Temperature = callOpts.Temperature chatOpts.RandomSeed = callOpts.Seed chatOpts.Tools = make([]sdk.Tool, 0) - for _, function := range callOpts.Functions { - chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{ - Type: "function", - Function: sdk.Function{ - Name: function.Name, - Description: function.Description, - Parameters: function.Parameters, - }, - }) + if len(callOpts.Tools) > 0 { + for _, tool := range callOpts.Tools { + chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{ + Type: "function", + Function: sdk.Function{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: tool.Function.Parameters, + }, + }) + } + } else { + for _, function := range callOpts.Functions { + chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{ + Type: "function", + Function: sdk.Function{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + }, + }) + } } return chatOpts } @@ -159,6 +172,16 @@ func generateNonStreamingContent(ctx context.Context, m *Model, callOptions *llm toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { langchainContentResponse.Choices[idx].FuncCall = (*llms.FunctionCall)(&toolCalls[0].Function) + for _, tool := range toolCalls { + langchainContentResponse.Choices[0].ToolCalls = append(langchainContentResponse.Choices[0].ToolCalls, llms.ToolCall{ + ID: tool.Id, + Type: string(tool.Type), + FunctionCall: &llms.FunctionCall{ + Name: tool.Function.Name, + Arguments: tool.Function.Arguments, + }, + }) + } } } m.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, langchainContentResponse) @@ -192,6 +215,16 @@ func generateStreamingContent(ctx context.Context, m *Model, callOptions *llms.C langchainContentResponse.Choices[0].StopReason = string(choice.FinishReason) if len(choice.Delta.ToolCalls) > 0 { langchainContentResponse.Choices[0].FuncCall = (*llms.FunctionCall)(&choice.Delta.ToolCalls[0].Function) + for _, tool := range choice.Delta.ToolCalls { + langchainContentResponse.Choices[0].ToolCalls = append(langchainContentResponse.Choices[0].ToolCalls, llms.ToolCall{ + ID: tool.Id, + Type: string(tool.Type), + FunctionCall: &llms.FunctionCall{ + Name: tool.Function.Name, + Arguments: tool.Function.Arguments, + }, + }) + } } } err := callOptions.StreamingFunc(ctx, []byte(chunkStr)) @@ -209,19 +242,25 @@ func generateStreamingContent(ctx context.Context, m *Model, callOptions *llms.C func convertToMistralChatMessages(langchainMessages []llms.MessageContent) ([]sdk.ChatMessage, error) { messages := make([]sdk.ChatMessage, 0) for _, msg := range langchainMessages { - msgText := "" for _, part := range msg.Parts { - textContent, ok := part.(llms.TextContent) - if !ok { + switch p := part.(type) { + case llms.TextContent: + chatMsg := sdk.ChatMessage{Content: p.Text, Role: string(msg.Role)} + setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601 + if chatMsg.Content != "" && chatMsg.Role != "" { + messages = append(messages, chatMsg) + } + case llms.ToolCallResponse: + chatMsg := sdk.ChatMessage{Role: string(msg.Role), Content: p.Content} + setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601 + messages = append(messages, chatMsg) + case llms.ToolCall: + chatMsg := sdk.ChatMessage{Role: string(msg.Role), ToolCalls: []sdk.ToolCall{{Id: p.ID, Type: sdk.ToolTypeFunction, Function: sdk.FunctionCall{Name: p.FunctionCall.Name, Arguments: p.FunctionCall.Arguments}}}} + setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601 + messages = append(messages, chatMsg) + default: return nil, errors.New("unsupported content type encountered while preparing chat messages to send to mistral platform") } - msgText += textContent.Text - } - chatMsg := sdk.ChatMessage{Content: msgText, Role: "user"} - - setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601 - if chatMsg.Content != "" && chatMsg.Role != "" { - messages = append(messages, chatMsg) } } return messages, nil