Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1033 #1034

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 75 additions & 37 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,35 +272,52 @@
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for human message", ErrInvalidContentType)
}

func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
if toolCall, ok := msg.Parts[0].(llms.ToolCall); ok {
var inputStruct map[string]interface{}
err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &inputStruct)
if err != nil {
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err)
}
toolUse := anthropicclient.ToolUseContent{
Type: "tool_use",
ID: toolCall.ID,
Name: toolCall.FunctionCall.Name,
Input: inputStruct,
}
func getTextPart(part llms.TextContent) *anthropicclient.TextContent {
return &anthropicclient.TextContent{
Type: "text",
Text: part.Text,
}
}

return anthropicclient.ChatMessage{
Role: RoleAssistant,
Content: []anthropicclient.Content{toolUse},
}, nil
func getToolPart(part llms.ToolCall) (*anthropicclient.ToolUseContent, error) {
var inputStruct map[string]interface{}
err := json.Unmarshal([]byte(part.FunctionCall.Arguments), &inputStruct)
if err != nil {
return nil, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err)
}
if textContent, ok := msg.Parts[0].(llms.TextContent); ok {
return anthropicclient.ChatMessage{
Role: RoleAssistant,
Content: []anthropicclient.Content{&anthropicclient.TextContent{
Type: "text",
Text: textContent.Text,
}},
}, nil
return &anthropicclient.ToolUseContent{
Type: "tool_use",
ID: part.ID,
Name: part.FunctionCall.Name,
Input: inputStruct,
}, nil
}

func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
cm := anthropicclient.ChatMessage{
Role: RoleAssistant,
Content: make([]anthropicclient.Content, 0),
}
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message", ErrInvalidContentType)

contentArr := make([]anthropicclient.Content, 0)

for _, part := range msg.Parts {
switch part.(type) {
case llms.TextContent:
contentArr = append(contentArr, getTextPart(part.(llms.TextContent)))
case llms.ToolCall:

Check failure on line 308 in llms/anthropic/anthropicllm.go

View workflow job for this annotation

GitHub Actions / Lint

typeSwitchVar: 2 cases can benefit from type switch with assignment (gocritic)
tp, err := getToolPart(part.(llms.ToolCall))
if err != nil {

Check failure on line 310 in llms/anthropic/anthropicllm.go

View workflow job for this annotation

GitHub Actions / Lint

type assertion must be checked (forcetypeassert)
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", err, part)
}
contentArr = append(contentArr, tp)
default:
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", ErrInvalidContentType, part)
}
}

cm.Content = contentArr
return cm, nil
}

type ToolResult struct {
Expand All @@ -309,18 +326,39 @@
Content string `json:"content"`
}

func getToolResponse(part llms.ToolCallResponse) (*anthropicclient.ToolResultContent, error) {
return &anthropicclient.ToolResultContent{
Type: "tool_result",
ToolUseID: part.ToolCallID,

Check failure on line 332 in llms/anthropic/anthropicllm.go

View workflow job for this annotation

GitHub Actions / Lint

getToolResponse - result 1 (error) is always nil (unparam)
Content: part.Content,
}, nil

}

func handleToolMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
if toolCallResponse, ok := msg.Parts[0].(llms.ToolCallResponse); ok {
toolContent := anthropicclient.ToolResultContent{
Type: "tool_result",
ToolUseID: toolCallResponse.ToolCallID,
Content: toolCallResponse.Content,
}
cm := anthropicclient.ChatMessage{

Check failure on line 339 in llms/anthropic/anthropicllm.go

View workflow job for this annotation

GitHub Actions / Lint

unnecessary trailing newline (whitespace)
Role: RoleUser,
Content: make([]anthropicclient.Content, 0),
}

return anthropicclient.ChatMessage{
Role: RoleUser,
Content: []anthropicclient.Content{toolContent},
}, nil
contentArr := make([]anthropicclient.Content, 0)

for _, part := range msg.Parts {
switch part.(type) {
case llms.TextContent:
contentArr = append(contentArr, getTextPart(part.(llms.TextContent)))
case llms.ToolCallResponse:

Check failure on line 350 in llms/anthropic/anthropicllm.go

View workflow job for this annotation

GitHub Actions / Lint

typeSwitchVar: 2 cases can benefit from type switch with assignment (gocritic)
tp, err := getToolResponse(part.(llms.ToolCallResponse))
if err != nil {

Check failure on line 352 in llms/anthropic/anthropicllm.go

View workflow job for this annotation

GitHub Actions / Lint

type assertion must be checked (forcetypeassert)
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool part response message %T", err, part)
}
contentArr = append(contentArr, tp)
default:
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", ErrInvalidContentType, part)
}
}
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool message", ErrInvalidContentType)

cm.Content = contentArr

return cm, nil
}
108 changes: 98 additions & 10 deletions llms/anthropic/internal/anthropicclient/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,44 @@
return tc.Type
}

type PartialJSONContent struct {
Type string `json:"type"`
PartialJSON string `json:"partial_json"`
}

func (tc PartialJSONContent) GetType() string {
return tc.Type
}

type ToolUseContent struct {
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
rawStreamInput string
}

func (tuc *ToolUseContent) AppendStreamChunk(chunk string) {
tuc.rawStreamInput += chunk
}

func (tuc *ToolUseContent) GetStreamInput() string {
return tuc.rawStreamInput
}

func (tuc *ToolUseContent) DecodeStream() error {
if tuc.rawStreamInput == "" {
return nil
}

m := make(map[string]interface{})
err := json.Unmarshal([]byte(tuc.rawStreamInput), &m)
if err != nil {
return err
}

tuc.Input = m
return nil
}

func (tuc ToolUseContent) GetType() string {
Expand Down Expand Up @@ -248,7 +281,7 @@
return event, err
}

func processStreamEvent(ctx context.Context, event map[string]interface{}, payload *messagePayload, response MessageResponsePayload, eventChan chan<- MessageEvent) (MessageResponsePayload, error) {

Check failure on line 284 in llms/anthropic/internal/anthropicclient/messages.go

View workflow job for this annotation

GitHub Actions / Lint

calculated cyclomatic complexity for function processStreamEvent is 15, max is 12 (cyclop)
eventType, ok := event["type"].(string)
if !ok {
return response, ErrInvalidEventType
Expand All @@ -261,7 +294,20 @@
case "content_block_delta":
return handleContentBlockDeltaEvent(ctx, event, response, payload)
case "content_block_stop":
// Nothing to do here
for _, content := range response.Content {
if content == nil {
continue
}
tuc, ok := content.(*ToolUseContent)
if !ok {
continue
}

err := tuc.DecodeStream()
if err != nil {
return response, fmt.Errorf("error decoding stream tool data: %w", err)
}
}
case "message_delta":
return handleMessageDeltaEvent(event, response)
case "message_stop":
Expand Down Expand Up @@ -307,15 +353,38 @@
index := int(indexValue)

var eventType string
if cb, ok := event["content_block"].(map[string]any); ok {
cb, ok := event["content_block"].(map[string]any)
if ok {
typ, _ := cb["type"].(string)
eventType = typ
}

if len(response.Content) <= index {
response.Content = append(response.Content, &TextContent{
Type: eventType,
})
switch eventType {
case "text":
response.Content = append(response.Content, &TextContent{
Type: eventType,
})
case "tool_use":
toolID, ok := cb["id"].(string)
if !ok {
return response, fmt.Errorf("missing tool id field in content block [start]")
}

toolName, ok := cb["name"].(string)
if !ok {
return response, fmt.Errorf("missing name field in content block [start]")
}

response.Content = append(response.Content, &ToolUseContent{
Type: eventType,
Input: make(map[string]interface{}),
ID: toolID,
Name: toolName,
})
default:
return response, fmt.Errorf("unknown content block type: %s", eventType)
}
}
return response, nil
}
Expand Down Expand Up @@ -351,7 +420,26 @@
textContent.Text += text
}

if payload.StreamingFunc != nil {
streamOutput := true
if deltaType == "input_json_delta" {
streamOutput = false
partial, ok := delta["partial_json"].(string)
if !ok {
return response, fmt.Errorf("partial_json field missing")
}
if len(response.Content) <= index {
return response, ErrContentIndexOutOfRange
}
tuc, ok := response.Content[index].(*ToolUseContent)
if !ok {
asJson, _ := json.MarshalIndent(response, "", " ")
return response, fmt.Errorf("failed to cast index %v to ToolUseContent: \n%s", index, string(asJson))
}

Check failure on line 437 in llms/anthropic/internal/anthropicclient/messages.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: var asJson should be asJSON (revive)

tuc.AppendStreamChunk(partial)
}

if payload.StreamingFunc != nil && streamOutput {
text, ok := delta["text"].(string)
if !ok {
return response, ErrInvalidDeltaTextField
Expand Down
4 changes: 3 additions & 1 deletion llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,12 +460,14 @@ func combineStreamingChatResponse(
chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall)
}

outputToTextStream := true
if len(choice.Delta.ToolCalls) > 0 {
chunk, response.Choices[0].Message.ToolCalls = updateToolCalls(response.Choices[0].Message.ToolCalls,
choice.Delta.ToolCalls)
outputToTextStream = false
}

if payload.StreamingFunc != nil {
if payload.StreamingFunc != nil && outputToTextStream {
err := payload.StreamingFunc(ctx, chunk)
if err != nil {
return nil, fmt.Errorf("streaming func returned an error: %w", err)
Expand Down
26 changes: 13 additions & 13 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,21 @@
case llms.ChatMessageTypeFunction:
msg.Role = RoleFunction
case llms.ChatMessageTypeTool:
msg.Role = RoleTool
// Here we extract tool calls from the message and populate the ToolCalls field.

// parse mc.Parts (which should have one entry of type ToolCallResponse) and populate msg.Content and msg.ToolCallID
if len(mc.Parts) != 1 {
return nil, fmt.Errorf("expected exactly one part for role %v, got %v", mc.Role, len(mc.Parts))
}
switch p := mc.Parts[0].(type) {
case llms.ToolCallResponse:
msg.ToolCallID = p.ToolCallID
msg.Content = p.Content
default:
return nil, fmt.Errorf("expected part of type ToolCallResponse for role %v, got %T", mc.Role, mc.Parts[0])
for _, p := range mc.Parts {
switch p.(type) {

Check failure on line 74 in llms/openai/openaillm.go

View workflow job for this annotation

GitHub Actions / Lint

typeSwitchVar: 1 case can benefit from type switch with assignment (gocritic)
case llms.ToolCallResponse:
tr := p.(llms.ToolCallResponse)

Check failure on line 76 in llms/openai/openaillm.go

View workflow job for this annotation

GitHub Actions / Lint

type assertion must be checked (forcetypeassert)
rep := &ChatMessage{}
rep.Role = RoleTool
rep.ToolCallID = tr.ToolCallID
rep.Content = tr.Content
chatMsgs = append(chatMsgs, rep)
default:
return nil, fmt.Errorf("expected part of type ToolCallResponse for role %v, got %T", mc.Role, mc.Parts[0])
}
}

continue
default:
return nil, fmt.Errorf("role %v not supported", mc.Role)
}
Expand Down