diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index 2d236ebe5..4494ce295 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -272,35 +272,52 @@ func handleHumanMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, e return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for human message", ErrInvalidContentType) } -func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) { - if toolCall, ok := msg.Parts[0].(llms.ToolCall); ok { - var inputStruct map[string]interface{} - err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &inputStruct) - if err != nil { - return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err) - } - toolUse := anthropicclient.ToolUseContent{ - Type: "tool_use", - ID: toolCall.ID, - Name: toolCall.FunctionCall.Name, - Input: inputStruct, - } +func getTextPart(part llms.TextContent) *anthropicclient.TextContent { + return &anthropicclient.TextContent{ + Type: "text", + Text: part.Text, + } +} - return anthropicclient.ChatMessage{ - Role: RoleAssistant, - Content: []anthropicclient.Content{toolUse}, - }, nil +func getToolPart(part llms.ToolCall) (*anthropicclient.ToolUseContent, error) { + var inputStruct map[string]interface{} + err := json.Unmarshal([]byte(part.FunctionCall.Arguments), &inputStruct) + if err != nil { + return nil, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err) } - if textContent, ok := msg.Parts[0].(llms.TextContent); ok { - return anthropicclient.ChatMessage{ - Role: RoleAssistant, - Content: []anthropicclient.Content{&anthropicclient.TextContent{ - Type: "text", - Text: textContent.Text, - }}, - }, nil + return &anthropicclient.ToolUseContent{ + Type: "tool_use", + ID: part.ID, + Name: part.FunctionCall.Name, + Input: inputStruct, + }, nil +} + +func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) { + cm := anthropicclient.ChatMessage{ + Role: RoleAssistant, + Content: make([]anthropicclient.Content, 0), } - return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message", ErrInvalidContentType) + + contentArr := make([]anthropicclient.Content, 0) + + for _, part := range msg.Parts { + switch part.(type) { + case llms.TextContent: + contentArr = append(contentArr, getTextPart(part.(llms.TextContent))) + case llms.ToolCall: + tp, err := getToolPart(part.(llms.ToolCall)) + if err != nil { + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", err, part) + } + contentArr = append(contentArr, tp) + default: + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", ErrInvalidContentType, part) + } + } + + cm.Content = contentArr + return cm, nil } type ToolResult struct { @@ -309,18 +326,39 @@ type ToolResult struct { Content string `json:"content"` } +func getToolResponse(part llms.ToolCallResponse) (*anthropicclient.ToolResultContent, error) { + return &anthropicclient.ToolResultContent{ + Type: "tool_result", + ToolUseID: part.ToolCallID, + Content: part.Content, + }, nil + +} + func handleToolMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) { - if toolCallResponse, ok := msg.Parts[0].(llms.ToolCallResponse); ok { - toolContent := anthropicclient.ToolResultContent{ - Type: "tool_result", - ToolUseID: toolCallResponse.ToolCallID, - Content: toolCallResponse.Content, - } + cm := anthropicclient.ChatMessage{ + Role: RoleUser, + Content: make([]anthropicclient.Content, 0), + } - return anthropicclient.ChatMessage{ - Role: RoleUser, - Content: []anthropicclient.Content{toolContent}, - }, nil + contentArr := make([]anthropicclient.Content, 0) + + for _, part := range msg.Parts { + switch part.(type) { + case llms.TextContent: + contentArr = append(contentArr, getTextPart(part.(llms.TextContent))) + case llms.ToolCallResponse: + tp, err := getToolResponse(part.(llms.ToolCallResponse)) + if err != nil { + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool part response message %T", err, part) + } + contentArr = append(contentArr, tp) + default: + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", ErrInvalidContentType, part) + } } - return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool message", ErrInvalidContentType) + + cm.Content = contentArr + + return cm, nil } diff --git a/llms/anthropic/internal/anthropicclient/messages.go b/llms/anthropic/internal/anthropicclient/messages.go index 10a0ca400..9f7a5619a 100644 --- a/llms/anthropic/internal/anthropicclient/messages.go +++ b/llms/anthropic/internal/anthropicclient/messages.go @@ -66,11 +66,44 @@ func (tc TextContent) GetType() string { return tc.Type } +type PartialJSONContent struct { + Type string `json:"type"` + PartialJSON string `json:"partial_json"` +} + +func (tc PartialJSONContent) GetType() string { + return tc.Type +} + type ToolUseContent struct { - Type string `json:"type"` - ID string `json:"id"` - Name string `json:"name"` - Input map[string]interface{} `json:"input"` + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` + rawStreamInput string +} + +func (tuc *ToolUseContent) AppendStreamChunk(chunk string) { + tuc.rawStreamInput += chunk +} + +func (tuc *ToolUseContent) GetStreamInput() string { + return tuc.rawStreamInput +} + +func (tuc *ToolUseContent) DecodeStream() error { + if tuc.rawStreamInput == "" { + return nil + } + + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(tuc.rawStreamInput), &m) + if err != nil { + return err + } + + tuc.Input = m + return nil } func (tuc ToolUseContent) GetType() string { @@ -261,7 +294,20 @@ func processStreamEvent(ctx context.Context, event map[string]interface{}, paylo case "content_block_delta": return handleContentBlockDeltaEvent(ctx, event, response, payload) case "content_block_stop": - // Nothing to do here + for _, content := range response.Content { + if content == nil { + continue + } + tuc, ok := content.(*ToolUseContent) + if !ok { + continue + } + + err := tuc.DecodeStream() + if err != nil { + return response, fmt.Errorf("error decoding stream tool data: %w", err) + } + } case "message_delta": return handleMessageDeltaEvent(event, response) case "message_stop": @@ -307,15 +353,38 @@ func handleContentBlockStartEvent(event map[string]interface{}, response Message index := int(indexValue) var eventType string - if cb, ok := event["content_block"].(map[string]any); ok { + cb, ok := event["content_block"].(map[string]any) + if ok { typ, _ := cb["type"].(string) eventType = typ } if len(response.Content) <= index { - response.Content = append(response.Content, &TextContent{ - Type: eventType, - }) + switch eventType { + case "text": + response.Content = append(response.Content, &TextContent{ + Type: eventType, + }) + case "tool_use": + toolID, ok := cb["id"].(string) + if !ok { + return response, fmt.Errorf("missing tool id field in content block [start]") + } + + toolName, ok := cb["name"].(string) + if !ok { + return response, fmt.Errorf("missing name field in content block [start]") + } + + response.Content = append(response.Content, &ToolUseContent{ + Type: eventType, + Input: make(map[string]interface{}), + ID: toolID, + Name: toolName, + }) + default: + return response, fmt.Errorf("unknown content block type: %s", eventType) + } } return response, nil } @@ -351,7 +420,26 @@ func handleContentBlockDeltaEvent(ctx context.Context, event map[string]interfac textContent.Text += text } - if payload.StreamingFunc != nil { + streamOutput := true + if deltaType == "input_json_delta" { + streamOutput = false + partial, ok := delta["partial_json"].(string) + if !ok { + return response, fmt.Errorf("partial_json field missing") + } + if len(response.Content) <= index { + return response, ErrContentIndexOutOfRange + } + tuc, ok := response.Content[index].(*ToolUseContent) + if !ok { + asJson, _ := json.MarshalIndent(response, "", " ") + return response, fmt.Errorf("failed to cast index %v to ToolUseContent: \n%s", index, string(asJson)) + } + + tuc.AppendStreamChunk(partial) + } + + if payload.StreamingFunc != nil && streamOutput { text, ok := delta["text"].(string) if !ok { return response, ErrInvalidDeltaTextField diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 407dad6dd..ee7f73b4c 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -460,12 +460,14 @@ func combineStreamingChatResponse( chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall) } + outputToTextStream := true if len(choice.Delta.ToolCalls) > 0 { chunk, response.Choices[0].Message.ToolCalls = updateToolCalls(response.Choices[0].Message.ToolCalls, choice.Delta.ToolCalls) + outputToTextStream = false } - if payload.StreamingFunc != nil { + if payload.StreamingFunc != nil && outputToTextStream { err := payload.StreamingFunc(ctx, chunk) if err != nil { return nil, fmt.Errorf("streaming func returned an error: %w", err) diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 699c0d304..378a278c2 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -69,21 +69,21 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten case llms.ChatMessageTypeFunction: msg.Role = RoleFunction case llms.ChatMessageTypeTool: - msg.Role = RoleTool // Here we extract tool calls from the message and populate the ToolCalls field. - - // parse mc.Parts (which should have one entry of type ToolCallResponse) and populate msg.Content and msg.ToolCallID - if len(mc.Parts) != 1 { - return nil, fmt.Errorf("expected exactly one part for role %v, got %v", mc.Role, len(mc.Parts)) - } - switch p := mc.Parts[0].(type) { - case llms.ToolCallResponse: - msg.ToolCallID = p.ToolCallID - msg.Content = p.Content - default: - return nil, fmt.Errorf("expected part of type ToolCallResponse for role %v, got %T", mc.Role, mc.Parts[0]) + for _, p := range mc.Parts { + switch p.(type) { + case llms.ToolCallResponse: + tr := p.(llms.ToolCallResponse) + rep := &ChatMessage{} + rep.Role = RoleTool + rep.ToolCallID = tr.ToolCallID + rep.Content = tr.Content + chatMsgs = append(chatMsgs, rep) + default: + return nil, fmt.Errorf("expected part of type ToolCallResponse for role %v, got %T", mc.Role, mc.Parts[0]) + } } - + continue default: return nil, fmt.Errorf("role %v not supported", mc.Role) }