Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking changes to allow custom generation options and multi-completions #28

Merged
merged 8 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const { ChatSession, CompletionService } = require('langxlang')

```js
const service = new CompletionService({ openai: [key], gemini: [key] })
const response = await service.requestCompletion(
const [response] = await service.requestCompletion(
'gemini-1.0-pro', // Model name
'', // System prompt (optional)
'Tell me about yourself' // User prompt
Expand Down
25 changes: 13 additions & 12 deletions src/ChatSession.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ class ChatSession {
// This calls a function and adds the reponse to the context so the model can be called again
async _callFunction (functionName, payload, metadata) {
if (this.modelAuthor === 'googleaistudio') {
let content
if (metadata.text) content = metadata.text + '\n'
content = content.trim()
const arStr = Object.keys(payload).length ? JSON.stringify(payload) : ''
content += `\n<FUNCTION_CALL>${functionName}(${arStr})</FUNCTION_CALL>`
this.messages.push({ role: 'assistant', content })
if (metadata.content) {
let content = ''
content = metadata.content + '\n'
content = content.trim()
const arStr = Object.keys(payload).length ? JSON.stringify(payload) : ''
content += `\n<FUNCTION_CALL>${functionName}(${arStr})</FUNCTION_CALL>`
this.messages.push({ role: 'assistant', content })
}
const result = await this._callFunctionWithArgs(functionName, payload)
this.messages.push({ role: 'function', name: functionName, content: JSON.stringify(result) })
} else if (this.modelFamily === 'openai') {
Expand Down Expand Up @@ -126,11 +128,10 @@ class ChatSession {

async _submitRequest (chunkCb) {
debug('Sending to', this.model, this.messages)
const response = await this.service.requestStreamingChat(this.model, {
const [response] = await this.service.requestChatCompletion(this.model, {
maxTokens: this.maxTokens,
messages: this.messages,
functions: this.functionsPayload
// stream: !!chunkCb
}, chunkCb)
debug('Streaming response', JSON.stringify(response))
if (response.type === 'function') {
Expand All @@ -141,12 +142,12 @@ class ChatSession {
// we need to call the function with the payload and then send the result back to the model
for (const index in response.fnCalls) {
const call = response.fnCalls[index]
const args = typeof call.args === 'string' ? JSON.parse(call.args) : call.args
const args = (typeof call.args === 'string' && call.args.length) ? JSON.parse(call.args) : call.args
await this._callFunction(call.name, args ?? {}, response)
}
return this._submitRequest(chunkCb)
} else if (response.type === 'text') {
this.messages.push({ role: 'assistant', content: response.completeMessage })
this.messages.push({ role: 'assistant', content: response.content })
}
return response
}
Expand Down Expand Up @@ -174,9 +175,9 @@ class ChatSession {
break
}
}
response.completeMessage = message.guidanceText + response.completeMessage
response.content = message.guidanceText + response.content
}
return { text: response.completeMessage, calledFunctions: this._calledFunctionsForRound }
return { text: response.content, calledFunctions: this._calledFunctionsForRound }
}
}

Expand Down
117 changes: 54 additions & 63 deletions src/CompletionService.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class CompletionService {
this.palm2ApiKey = keys.palm2 || process.env.PALM2_API_KEY
this.geminiApiKey = keys.gemini || process.env.GEMINI_API_KEY
this.openaiApiKey = keys.openai || process.env.OPENAI_API_KEY

this.defaultGenerationOptions = options.generationOptions
}

async listModels () {
Expand All @@ -31,17 +33,22 @@ class CompletionService {
return { openai: openaiModels, google: geminiModels }
}

async _requestCompletionOpenAI (model, system, user) {
async _requestCompletionOpenAI (model, system, user, { maxTokens, temperature, topP }, chunkCb) {
if (!this.openaiApiKey) throw new Error('OpenAI API key not set')
const guidance = system?.guidanceText || user?.guidanceText || ''
const result = await openai.generateCompletion(model, system.basePrompt || system, user.basePrompt || user, {
const response = await openai.generateCompletion(model, system.basePrompt || system, user.basePrompt || user, {
apiKey: this.openaiApiKey,
guidanceMessage: guidance
guidanceMessage: guidance,
generationConfig: {
max_tokens: maxTokens,
temperature,
top_p: topP
}
})
return { text: guidance + result.message.content }
return response.choices.map((choice) => ({ text: guidance + choice.content }))
}

async _requestCompletionGemini (model, system, user, chunkCb) {
async _requestCompletionGemini (model, system, user, { maxTokens, temperature, topP, topK }, chunkCb) {
if (!this.geminiApiKey) throw new Error('Gemini API key not set')
const guidance = system?.guidanceText || user?.guidanceText || ''
// April 2024 - Only Gemini 1.5 supports instructions
Expand All @@ -50,8 +57,11 @@ class CompletionService {
system = ''
user = mergedPrompt
}
const result = await gemini.generateCompletion(model, system, user, { apiKey: this.geminiApiKey }, chunkCb)
return { text: guidance + result.text() }
const result = await gemini.generateCompletion(model, system, user, {
apiKey: this.geminiApiKey,
generationConfig: { maxOutputTokens: maxTokens, temperature, topP, topK }
}, chunkCb)
return [{ text: guidance + result.text() }]
}

async requestCompletion (model, system, user, chunkCb, options = {}) {
Expand All @@ -73,11 +83,17 @@ class CompletionService {
}
return response
}

const genOpts = {
...this.defaultGenerationOptions,
maxTokens: options.maxTokens,
temperature: options.temperature,
topP: options.topP,
topK: options.topK
}
const { family } = getModelInfo(model)
switch (family) {
case 'openai': return saveIfCaching(await this._requestCompletionOpenAI(model, system, user))
case 'gemini': return saveIfCaching(await this._requestCompletionGemini(model, system, user, chunkCb))
case 'openai': return saveIfCaching(await this._requestCompletionOpenAI(model, system, user, genOpts))
case 'gemini': return saveIfCaching(await this._requestCompletionGemini(model, system, user, genOpts, chunkCb))
case 'palm2': {
if (!this.palm2ApiKey) throw new Error('PaLM2 API key not set')
const result = await palm2.requestPalmCompletion(system + '\n' + user, this.palm2ApiKey, model)
Expand All @@ -88,62 +104,36 @@ class CompletionService {
}
}

async _requestStreamingChatOpenAI (model, messages, maxTokens, functions, chunkCb) {
async _requestStreamingChatOpenAI (model, messages, { maxTokens, temperature, topP }, functions, chunkCb) {
if (!this.openaiApiKey) throw new Error('OpenAI API key not set')
let completeMessage = ''
let finishReason
const fnCalls = {}
await openai.getStreamingCompletion(this.openaiApiKey, {
const response = await openai.generateChatCompletionIn(
model,
max_tokens: maxTokens,
messages: messages.map((entry) => {
messages.map((entry) => {
const msg = structuredClone(entry)
if (msg.role === 'model') msg.role = 'assistant'
if (msg.role === 'guidance') msg.role = 'assistant'
return msg
}),
stream: true,
tools: functions || undefined,
tool_choice: functions ? 'auto' : undefined
}, (chunk) => {
if (!chunk) {
chunkCb?.({ done: true, delta: '' })
return
}
const choice = chunk.choices[0]
if (choice.finish_reason) {
finishReason = choice.finish_reason
}
if (choice.message) {
completeMessage += choice.message.content
} else if (choice.delta) {
const delta = choice.delta
if (delta.tool_calls) {
for (const call of delta.tool_calls) {
fnCalls[call.index] ??= {
id: call.id,
name: '',
args: ''
}
const entry = fnCalls[call.index]
if (call.function.name) {
entry.name = call.function.name
}
if (call.function.arguments) {
entry.args += call.function.arguments
}
}
} else if (delta.content) {
completeMessage += delta.content
chunkCb?.(choice.delta)
}
} else throw new Error('Unknown chunk type')
{
apiKey: this.openaiApiKey,
functions,
generationConfig: { max_tokens: maxTokens, temperature, top_p: topP }
},
chunkCb
)
return response.choices.map((choice) => {
const choiceType = {
stop: 'text',
length: 'text',
function_call: 'function',
content_filter: 'safety', // an error would be thrown before this
tool_calls: 'function'
}[choice.finishReason] ?? 'unknown'
return { type: choiceType, isTruncated: choice.finishReason === 'length', ...choice }
})
const type = finishReason === 'tool_calls' ? 'function' : 'text'
return { type, completeMessage, fnCalls }
}

async _requestStreamingChatGemini (model, messages, maxTokens, functions, chunkCb) {
async _requestStreamingChatGemini (model, messages, { maxTokens, temperature, topP, topK }, functions, chunkCb) {
if (!this.geminiApiKey) throw new Error('Gemini API key not set')
const geminiMessages = messages.map((msg) => {
const m = structuredClone(msg)
Expand All @@ -158,13 +148,14 @@ class CompletionService {
})
const response = await gemini.generateChatCompletionEx(model, geminiMessages, {
apiKey: this.geminiApiKey,
functions
functions,
generationConfig: { maxOutputTokens: maxTokens, temperature, topP, topK }
}, chunkCb)
if (response.text()) {
const answer = response.text()
chunkCb?.({ done: true, delta: '' })
const result = { type: 'text', completeMessage: answer }
return result
const result = { type: 'text', content: answer }
return [result]
} else if (response.functionCalls()) {
const calls = response.functionCalls()
const fnCalls = {}
Expand All @@ -177,17 +168,17 @@ class CompletionService {
}
}
const result = { type: 'function', fnCalls }
return result
return [result]
} else {
throw new Error('Unknown response from Gemini')
}
}

async requestStreamingChat (model, { messages, maxTokens, functions }, chunkCb) {
async requestChatCompletion (model, { messages, maxTokens, functions }, chunkCb) {
const { family } = getModelInfo(model)
switch (family) {
case 'openai': return this._requestStreamingChatOpenAI(model, messages, maxTokens, functions, chunkCb)
case 'gemini': return this._requestStreamingChatGemini(model, messages, maxTokens, functions, chunkCb)
case 'openai': return this._requestStreamingChatOpenAI(model, messages, { ...this.defaultGenerationOptions, maxTokens }, functions, chunkCb)
case 'gemini': return this._requestStreamingChatGemini(model, messages, { ...this.defaultGenerationOptions, maxTokens }, functions, chunkCb)
default:
throw new Error(`Model '${model}' not supported for streaming chat, available models: ${knownModels.join(', ')}`)
}
Expand Down
3 changes: 2 additions & 1 deletion src/Flow.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class Flow {
if (runFollowUp && runFollowUp.pastResponses[inputHash]) {
resp = structuredClone(runFollowUp.pastResponses[inputHash])
} else {
resp = await this.service.requestCompletion(model, systemPrompt, userPrompt, this.chunkCb, this.generationOpts)
const rs = await this.service.requestCompletion(model, systemPrompt, userPrompt, this.chunkCb, this.generationOpts)
resp = rs[0]
}
resp.inputHash = inputHash
resp.name = details.name
Expand Down
9 changes: 5 additions & 4 deletions src/GoogleAIStudioCompletionService.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GoogleAIStudioCompletionService {
if (cachedResponse) {
chunkCb?.({ done: false, delta: cachedResponse.text })
chunkCb?.({ done: true, delta: '' })
return cachedResponse
return [cachedResponse]
}
}

Expand Down Expand Up @@ -83,15 +83,16 @@ class GoogleAIStudioCompletionService {
}
}
chunkCb?.({ done: true, delta: '\n' })
return saveIfCaching({ text: guidance + combinedResult })
return [saveIfCaching({ text: guidance + combinedResult })]
}

async requestStreamingChat (model, { messages, maxTokens, functions }, chunkCb) {
async requestChatCompletion (model, { messages, maxTokens, functions }, chunkCb) {
if (!supportedModels.includes(model)) {
throw new Error(`Model ${model} is not supported`)
}
const result = await this._studio.requestChatCompletion(model, messages, chunkCb, { maxTokens, functions })
return { ...result, completeMessage: result.text }
chunkCb?.({ done: true, delta: '\n' })
return [result]
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/SafetyError.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class SafetyError extends Error {
constructor (message) {
super(message)
this.name = 'SafetyError'
}
}

module.exports = SafetyError
Loading
Loading