Skip to content

Commit

Permalink
Merge pull request #23 from jeasonstudio/feat-improve-structured-ouput
Browse files Browse the repository at this point in the history
feat: improve structured ouput
  • Loading branch information
jeasonstudio authored Aug 8, 2024
2 parents 151dd4f + d1f4727 commit 9c78c8a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .changeset/silver-chefs-add.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chrome-ai": patch
---

feat: improve structured output
13 changes: 10 additions & 3 deletions src/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
} from '@ai-sdk/provider';
import { ChromeAISession, ChromeAISessionOptions } from './global';
import createDebug from 'debug';
import { StreamAI } from './stream-ai';
import { objectStartSequence, objectStopSequence, StreamAI } from './stream-ai';

const debug = createDebug('chromeai');

Expand Down Expand Up @@ -164,8 +164,15 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {

const session = await this.getSession();
const message = this.formatMessages(options);
const text = await session.prompt(message);
let text = await session.prompt(message);

if (options.mode.type === 'object-json') {
text = text.replace(new RegExp('^' + objectStartSequence, 'ig'), '');
text = text.replace(new RegExp(objectStopSequence + '$', 'ig'), '');
}

debug('generate result:', text);

return {
text,
finishReason: 'stop',
Expand Down Expand Up @@ -193,7 +200,7 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
const session = await this.getSession();
const message = this.formatMessages(options);
const promptStream = session.promptStreaming(message);
const transformStream = new StreamAI(options.abortSignal);
const transformStream = new StreamAI(options);
const stream = promptStream.pipeThrough(transformStream);

return {
Expand Down
46 changes: 43 additions & 3 deletions src/stream-ai.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import { describe, it, expect } from 'vitest';
import { StreamAI } from './stream-ai';
import type { LanguageModelV1CallOptions } from '@ai-sdk/provider';

describe('stream-ai', () => {
describe('stream-ai', async () => {
const defaultOptions: LanguageModelV1CallOptions = {
prompt: [],
mode: { type: 'regular' },
inputFormat: 'messages',
};
it('should correctly transform', async () => {
const transformStream = new StreamAI();
const transformStream = new StreamAI(defaultOptions);

const writer = transformStream.writable.getWriter();
writer.write('hello');
Expand All @@ -26,7 +32,10 @@ describe('stream-ai', () => {

it('should abort when signal', async () => {
const controller = new AbortController();
const transformStream = new StreamAI(controller.signal);
const transformStream = new StreamAI({
...defaultOptions,
abortSignal: controller.signal,
});

const writer = transformStream.writable.getWriter();
const reader = transformStream.readable.getReader();
Expand All @@ -41,4 +50,35 @@ describe('stream-ai', () => {
controller.abort();
expect(await reader.read()).toMatchObject({ done: true });
});

it('should transform when object-json', async () => {
const transformStream = new StreamAI({
...defaultOptions,
mode: { type: 'object-json', schema: {} },
});

const writer = transformStream.writable.getWriter();
const reader = transformStream.readable.getReader();

for (const chunk of [
' ```',
' ```json\n',
' ```json\n{}',
' ```json\n{}\n```',
]) {
writer.write(chunk);
}
writer.close();

let output = '';
while (true) {
const item = await reader.read();
if (item.done || item.value?.type === 'finish') break;
if (item.value?.type === 'text-delta') {
output += item.value.textDelta;
}
}

expect(output).toBe('{}');
});
});
47 changes: 35 additions & 12 deletions src/stream-ai.ts
Original file line number Diff line number Diff line change
@@ -1,37 +1,60 @@
import { LanguageModelV1StreamPart } from '@ai-sdk/provider';
import {
LanguageModelV1CallOptions,
LanguageModelV1StreamPart,
} from '@ai-sdk/provider';
import createDebug from 'debug';

const debug = createDebug('chromeai');

export const objectStartSequence = ' ```json\n';
export const objectStopSequence = '\n```';

export class StreamAI extends TransformStream<
string,
LanguageModelV1StreamPart
> {
public constructor(abortSignal?: AbortSignal) {
let textTemp = '';
public constructor(options: LanguageModelV1CallOptions) {
let buffer = '';
let transforming = false;

const reset = () => {
buffer = '';
transforming = false;
};

super({
start: (controller) => {
textTemp = '';
if (!abortSignal) return;
abortSignal.addEventListener('abort', () => {
reset();
if (!options.abortSignal) return;
options.abortSignal.addEventListener('abort', () => {
debug('streamText terminate by abortSignal');
controller.terminate();
textTemp = '';
});
},
transform: (chunk, controller) => {
const textDelta = chunk.replace(textTemp, '');
textTemp += textDelta;
controller.enqueue({ type: 'text-delta', textDelta });
if (options.mode.type === 'object-json') {
transforming =
chunk.startsWith(objectStartSequence) &&
!chunk.endsWith(objectStopSequence);
chunk = chunk.replace(
new RegExp('^' + objectStartSequence, 'ig'),
''
);
chunk = chunk.replace(new RegExp(objectStopSequence + '$', 'ig'), '');
} else {
transforming = true;
}
const textDelta = chunk.replace(buffer, ''); // See: https://github.com/jeasonstudio/chrome-ai/issues/11
if (transforming) controller.enqueue({ type: 'text-delta', textDelta });
buffer = chunk;
},
flush: (controller) => {
controller.enqueue({
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 0, promptTokens: 0 },
});
debug('stream result:', textTemp);
textTemp = '';
debug('stream result:', buffer);
},
});
}
Expand Down

0 comments on commit 9c78c8a

Please sign in to comment.