From 673c598651d1c9858e2fbc57dd9f6398d8d46e21 Mon Sep 17 00:00:00 2001 From: takeyamakenta Date: Fri, 13 Sep 2024 09:21:50 +0900 Subject: [PATCH] openai: add support for sending Structured Output request for chat completions api (#986) openai: add support for structured output request --- llms/openai/internal/openaiclient/chat.go | 22 ++- .../internal/openaiclient/openaiclient.go | 4 + llms/openai/llm.go | 4 +- llms/openai/openaillm.go | 7 + llms/openai/openaillm_option.go | 6 + llms/openai/structured_output_test.go | 164 ++++++++++++++++++ llms/options.go | 2 + 7 files changed, 207 insertions(+), 2 deletions(-) create mode 100644 llms/openai/structured_output_test.go diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 0e614cebf..2bb572a0a 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -108,9 +108,27 @@ type ToolCall struct { Function ToolFunction `json:"function,omitempty"` } +type ResponseFormatJSONSchemaProperty struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Enum []interface{} `json:"enum,omitempty"` + Items *ResponseFormatJSONSchemaProperty `json:"items,omitempty"` + Properties map[string]*ResponseFormatJSONSchemaProperty `json:"properties,omitempty"` + AdditionalProperties bool `json:"additionalProperties"` + Required []string `json:"required,omitempty"` + Ref string `json:"$ref,omitempty"` +} + +type ResponseFormatJSONSchema struct { + Name string `json:"name"` + Strict bool `json:"strict"` + Schema *ResponseFormatJSONSchemaProperty `json:"schema"` +} + // ResponseFormat is the format of the response. type ResponseFormat struct { - Type string `json:"type"` + Type string `json:"type"` + JSONSchema *ResponseFormatJSONSchema `json:"json_schema,omitempty"` } // ChatMessage is a message in a chat request. @@ -323,6 +341,8 @@ type FunctionDefinition struct { Description string `json:"description,omitempty"` // Parameters is a list of parameters for the function. Parameters any `json:"parameters"` + // Strict is a flag to enable structured output mode. + Strict bool `json:"strict,omitempty"` } // FunctionCallBehavior is the behavior to use when calling functions. diff --git a/llms/openai/internal/openaiclient/openaiclient.go b/llms/openai/internal/openaiclient/openaiclient.go index cab5123a9..05c114e84 100644 --- a/llms/openai/internal/openaiclient/openaiclient.go +++ b/llms/openai/internal/openaiclient/openaiclient.go @@ -36,6 +36,8 @@ type Client struct { EmbeddingModel string // required when APIType is APITypeAzure or APITypeAzureAD apiVersion string + + ResponseFormat *ResponseFormat } // Option is an option for the OpenAI client. @@ -49,6 +51,7 @@ type Doer interface { // New returns a new OpenAI client. func New(token string, model string, baseURL string, organization string, apiType APIType, apiVersion string, httpClient Doer, embeddingModel string, + responseFormat *ResponseFormat, opts ...Option, ) (*Client, error) { c := &Client{ @@ -60,6 +63,7 @@ func New(token string, model string, baseURL string, organization string, apiType: apiType, apiVersion: apiVersion, httpClient: httpClient, + ResponseFormat: responseFormat, } for _, opt := range opts { diff --git a/llms/openai/llm.go b/llms/openai/llm.go index 6b94408e8..21a074774 100644 --- a/llms/openai/llm.go +++ b/llms/openai/llm.go @@ -48,7 +48,9 @@ func newClient(opts ...Option) (*options, *openaiclient.Client, error) { } cli, err := openaiclient.New(options.token, options.model, options.baseURL, options.organization, - openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient, options.embeddingModel) + openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient, options.embeddingModel, + options.responseFormat, + ) return options, cli, err } diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 0f1b25a59..78f8334d2 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -124,6 +124,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten Name: fn.Name, Description: fn.Description, Parameters: fn.Parameters, + Strict: fn.Strict, }, }) } @@ -136,6 +137,11 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten req.Tools = append(req.Tools, t) } + // if o.client.ResponseFormat is set, use it for the request + if o.client.ResponseFormat != nil { + req.ResponseFormat = o.client.ResponseFormat + } + result, err := o.client.CreateChat(ctx, req) if err != nil { return nil, err @@ -234,6 +240,7 @@ func toolFromTool(t llms.Tool) (openaiclient.Tool, error) { Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters, + Strict: t.Function.Strict, } default: return openaiclient.Tool{}, fmt.Errorf("tool type %v not supported", t.Type) diff --git a/llms/openai/openaillm_option.go b/llms/openai/openaillm_option.go index ba4771ca2..42bfa1913 100644 --- a/llms/openai/openaillm_option.go +++ b/llms/openai/openaillm_option.go @@ -48,6 +48,12 @@ type Option func(*options) // ResponseFormat is the response format for the OpenAI client. type ResponseFormat = openaiclient.ResponseFormat +// ResponseFormatJSONSchema is the JSON Schema response format in structured output. +type ResponseFormatJSONSchema = openaiclient.ResponseFormatJSONSchema + +// ResponseFormatJSONSchemaProperty is the JSON Schema property in structured output. +type ResponseFormatJSONSchemaProperty = openaiclient.ResponseFormatJSONSchemaProperty + // ResponseFormatJSON is the JSON response format. var ResponseFormatJSON = &ResponseFormat{Type: "json_object"} //nolint:gochecknoglobals diff --git a/llms/openai/structured_output_test.go b/llms/openai/structured_output_test.go new file mode 100644 index 000000000..83b025e84 --- /dev/null +++ b/llms/openai/structured_output_test.go @@ -0,0 +1,164 @@ +package openai + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/openai/internal/openaiclient" +) + +func TestStructuredOutputObjectSchema(t *testing.T) { + t.Parallel() + responseFormat := &ResponseFormat{ + Type: "json_schema", + JSONSchema: &ResponseFormatJSONSchema{ + Name: "math_schema", + Strict: true, + Schema: &ResponseFormatJSONSchemaProperty{ + Type: "object", + Properties: map[string]*ResponseFormatJSONSchemaProperty{ + "final_answer": { + Type: "string", + }, + }, + AdditionalProperties: false, + Required: []string{"final_answer"}, + }, + }, + } + llm := newTestClient( + t, + WithModel("gpt-4o-2024-08-06"), + WithResponseFormat(responseFormat), + ) + + content := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{llms.TextContent{Text: "You are a student taking a math exam."}}, + }, + { + Role: llms.ChatMessageTypeGeneric, + Parts: []llms.ContentPart{llms.TextContent{Text: "Solve 2 + 2"}}, + }, + } + + rsp, err := llm.GenerateContent(context.Background(), content) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "\"final_answer\":", strings.ToLower(c1.Content)) +} + +func TestStructuredOutputObjectAndArraySchema(t *testing.T) { + t.Parallel() + responseFormat := &ResponseFormat{ + Type: "json_schema", + JSONSchema: &ResponseFormatJSONSchema{ + Name: "math_schema", + Strict: true, + Schema: &ResponseFormatJSONSchemaProperty{ + Type: "object", + Properties: map[string]*ResponseFormatJSONSchemaProperty{ + "steps": { + Type: "array", + Items: &ResponseFormatJSONSchemaProperty{ + Type: "string", + }, + }, + "final_answer": { + Type: "string", + }, + }, + AdditionalProperties: false, + Required: []string{"final_answer", "steps"}, + }, + }, + } + llm := newTestClient( + t, + WithModel("gpt-4o-2024-08-06"), + WithResponseFormat(responseFormat), + ) + + content := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{llms.TextContent{Text: "You are a student taking a math exam."}}, + }, + { + Role: llms.ChatMessageTypeGeneric, + Parts: []llms.ContentPart{llms.TextContent{Text: "Solve 2 + 2"}}, + }, + } + + rsp, err := llm.GenerateContent(context.Background(), content) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "\"steps\":", strings.ToLower(c1.Content)) +} + +func TestStructuredOutputFunctionCalling(t *testing.T) { + t.Parallel() + llm := newTestClient( + t, + WithModel("gpt-4o-2024-08-06"), + ) + + toolList := []llms.Tool{ + { + Type: string(openaiclient.ToolTypeFunction), + Function: &llms.FunctionDefinition{ + Name: "search", + Description: "Search by the web search engine", + Parameters: json.RawMessage( + `{ + "type": "object", + "properties" : { + "search_engine" : { + "type" : "string", + "enum" : ["google", "duckduckgo", "bing"] + }, + "search_query" : { + "type" : "string" + } + }, + "required":["search_engine", "search_query"], + "additionalProperties": false + }`), + Strict: true, + }, + }, + } + + content := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{llms.TextContent{Text: "You are a helpful assistant"}}, + }, + { + Role: llms.ChatMessageTypeGeneric, + Parts: []llms.ContentPart{llms.TextContent{Text: "What is the age of Bob Odenkirk, a famous comedy screenwriter and an actor."}}, + }, + } + + rsp, err := llm.GenerateContent( + context.Background(), + content, + llms.WithTools(toolList), + ) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "\"search_engine\":", c1.ToolCalls[0].FunctionCall.Arguments) + assert.Regexp(t, "\"search_query\":", c1.ToolCalls[0].FunctionCall.Arguments) +} diff --git a/llms/options.go b/llms/options.go index 114b66656..b6b595290 100644 --- a/llms/options.go +++ b/llms/options.go @@ -84,6 +84,8 @@ type FunctionDefinition struct { Description string `json:"description"` // Parameters is a list of parameters for the function. Parameters any `json:"parameters,omitempty"` + // Strict is a flag to indicate if the function should be called strictly. Only used for openai llm structured output. + Strict bool `json:"strict,omitempty"` } // ToolChoice is a specific tool to use.