Skip to content

Commit

Permalink
openai: add support for sending Structured Output request for chat co…
Browse files Browse the repository at this point in the history
…mpletions api (#986)

openai: add support for structured output request
  • Loading branch information
takeyamakenta authored Sep 13, 2024
1 parent 39a5ac9 commit 673c598
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 2 deletions.
22 changes: 21 additions & 1 deletion llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,27 @@ type ToolCall struct {
Function ToolFunction `json:"function,omitempty"`
}

type ResponseFormatJSONSchemaProperty struct {
Type string `json:"type"`
Description string `json:"description,omitempty"`
Enum []interface{} `json:"enum,omitempty"`
Items *ResponseFormatJSONSchemaProperty `json:"items,omitempty"`
Properties map[string]*ResponseFormatJSONSchemaProperty `json:"properties,omitempty"`
AdditionalProperties bool `json:"additionalProperties"`
Required []string `json:"required,omitempty"`
Ref string `json:"$ref,omitempty"`
}

type ResponseFormatJSONSchema struct {
Name string `json:"name"`
Strict bool `json:"strict"`
Schema *ResponseFormatJSONSchemaProperty `json:"schema"`
}

// ResponseFormat is the format of the response.
type ResponseFormat struct {
Type string `json:"type"`
Type string `json:"type"`
JSONSchema *ResponseFormatJSONSchema `json:"json_schema,omitempty"`
}

// ChatMessage is a message in a chat request.
Expand Down Expand Up @@ -323,6 +341,8 @@ type FunctionDefinition struct {
Description string `json:"description,omitempty"`
// Parameters is a list of parameters for the function.
Parameters any `json:"parameters"`
// Strict is a flag to enable structured output mode.
Strict bool `json:"strict,omitempty"`
}

// FunctionCallBehavior is the behavior to use when calling functions.
Expand Down
4 changes: 4 additions & 0 deletions llms/openai/internal/openaiclient/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type Client struct {
EmbeddingModel string
// required when APIType is APITypeAzure or APITypeAzureAD
apiVersion string

ResponseFormat *ResponseFormat
}

// Option is an option for the OpenAI client.
Expand All @@ -49,6 +51,7 @@ type Doer interface {
// New returns a new OpenAI client.
func New(token string, model string, baseURL string, organization string,
apiType APIType, apiVersion string, httpClient Doer, embeddingModel string,
responseFormat *ResponseFormat,
opts ...Option,
) (*Client, error) {
c := &Client{
Expand All @@ -60,6 +63,7 @@ func New(token string, model string, baseURL string, organization string,
apiType: apiType,
apiVersion: apiVersion,
httpClient: httpClient,
ResponseFormat: responseFormat,
}

for _, opt := range opts {
Expand Down
4 changes: 3 additions & 1 deletion llms/openai/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ func newClient(opts ...Option) (*options, *openaiclient.Client, error) {
}

cli, err := openaiclient.New(options.token, options.model, options.baseURL, options.organization,
openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient, options.embeddingModel)
openaiclient.APIType(options.apiType), options.apiVersion, options.httpClient, options.embeddingModel,
options.responseFormat,
)
return options, cli, err
}

Expand Down
7 changes: 7 additions & 0 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
Name: fn.Name,
Description: fn.Description,
Parameters: fn.Parameters,
Strict: fn.Strict,
},
})
}
Expand All @@ -136,6 +137,11 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
req.Tools = append(req.Tools, t)
}

// if o.client.ResponseFormat is set, use it for the request
if o.client.ResponseFormat != nil {
req.ResponseFormat = o.client.ResponseFormat
}

result, err := o.client.CreateChat(ctx, req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -234,6 +240,7 @@ func toolFromTool(t llms.Tool) (openaiclient.Tool, error) {
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
Strict: t.Function.Strict,
}
default:
return openaiclient.Tool{}, fmt.Errorf("tool type %v not supported", t.Type)
Expand Down
6 changes: 6 additions & 0 deletions llms/openai/openaillm_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ type Option func(*options)
// ResponseFormat is the response format for the OpenAI client.
type ResponseFormat = openaiclient.ResponseFormat

// ResponseFormatJSONSchema is the JSON Schema response format in structured output.
type ResponseFormatJSONSchema = openaiclient.ResponseFormatJSONSchema

// ResponseFormatJSONSchemaProperty is the JSON Schema property in structured output.
type ResponseFormatJSONSchemaProperty = openaiclient.ResponseFormatJSONSchemaProperty

// ResponseFormatJSON is the JSON response format.
var ResponseFormatJSON = &ResponseFormat{Type: "json_object"} //nolint:gochecknoglobals

Expand Down
164 changes: 164 additions & 0 deletions llms/openai/structured_output_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package openai

import (
"context"
"encoding/json"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai/internal/openaiclient"
)

func TestStructuredOutputObjectSchema(t *testing.T) {
t.Parallel()
responseFormat := &ResponseFormat{
Type: "json_schema",
JSONSchema: &ResponseFormatJSONSchema{
Name: "math_schema",
Strict: true,
Schema: &ResponseFormatJSONSchemaProperty{
Type: "object",
Properties: map[string]*ResponseFormatJSONSchemaProperty{
"final_answer": {
Type: "string",
},
},
AdditionalProperties: false,
Required: []string{"final_answer"},
},
},
}
llm := newTestClient(
t,
WithModel("gpt-4o-2024-08-06"),
WithResponseFormat(responseFormat),
)

content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{llms.TextContent{Text: "You are a student taking a math exam."}},
},
{
Role: llms.ChatMessageTypeGeneric,
Parts: []llms.ContentPart{llms.TextContent{Text: "Solve 2 + 2"}},
},
}

rsp, err := llm.GenerateContent(context.Background(), content)
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "\"final_answer\":", strings.ToLower(c1.Content))
}

func TestStructuredOutputObjectAndArraySchema(t *testing.T) {
t.Parallel()
responseFormat := &ResponseFormat{
Type: "json_schema",
JSONSchema: &ResponseFormatJSONSchema{
Name: "math_schema",
Strict: true,
Schema: &ResponseFormatJSONSchemaProperty{
Type: "object",
Properties: map[string]*ResponseFormatJSONSchemaProperty{
"steps": {
Type: "array",
Items: &ResponseFormatJSONSchemaProperty{
Type: "string",
},
},
"final_answer": {
Type: "string",
},
},
AdditionalProperties: false,
Required: []string{"final_answer", "steps"},
},
},
}
llm := newTestClient(
t,
WithModel("gpt-4o-2024-08-06"),
WithResponseFormat(responseFormat),
)

content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{llms.TextContent{Text: "You are a student taking a math exam."}},
},
{
Role: llms.ChatMessageTypeGeneric,
Parts: []llms.ContentPart{llms.TextContent{Text: "Solve 2 + 2"}},
},
}

rsp, err := llm.GenerateContent(context.Background(), content)
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "\"steps\":", strings.ToLower(c1.Content))
}

func TestStructuredOutputFunctionCalling(t *testing.T) {
t.Parallel()
llm := newTestClient(
t,
WithModel("gpt-4o-2024-08-06"),
)

toolList := []llms.Tool{
{
Type: string(openaiclient.ToolTypeFunction),
Function: &llms.FunctionDefinition{
Name: "search",
Description: "Search by the web search engine",
Parameters: json.RawMessage(
`{
"type": "object",
"properties" : {
"search_engine" : {
"type" : "string",
"enum" : ["google", "duckduckgo", "bing"]
},
"search_query" : {
"type" : "string"
}
},
"required":["search_engine", "search_query"],
"additionalProperties": false
}`),
Strict: true,
},
},
}

content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{llms.TextContent{Text: "You are a helpful assistant"}},
},
{
Role: llms.ChatMessageTypeGeneric,
Parts: []llms.ContentPart{llms.TextContent{Text: "What is the age of Bob Odenkirk, a famous comedy screenwriter and an actor."}},
},
}

rsp, err := llm.GenerateContent(
context.Background(),
content,
llms.WithTools(toolList),
)
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "\"search_engine\":", c1.ToolCalls[0].FunctionCall.Arguments)
assert.Regexp(t, "\"search_query\":", c1.ToolCalls[0].FunctionCall.Arguments)
}
2 changes: 2 additions & 0 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ type FunctionDefinition struct {
Description string `json:"description"`
// Parameters is a list of parameters for the function.
Parameters any `json:"parameters,omitempty"`
// Strict is a flag to indicate if the function should be called strictly. Only used for openai llm structured output.
Strict bool `json:"strict,omitempty"`
}

// ToolChoice is a specific tool to use.
Expand Down

0 comments on commit 673c598

Please sign in to comment.