Skip to content

Commit

Permalink
Add frequencePenalty and presencePenalty parameters support for gener… (
Browse files Browse the repository at this point in the history
#264)

* Add frequencePenalty and presencePenalty parameters support for generate content.

* Change frequence to frequency

* Update format

* Update parameter description

* Update the parameter doc.
  • Loading branch information
junyanxu authored Sep 25, 2024
1 parent 85d1eb1 commit dda0b5c
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .changeset/cyan-pants-move.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

Add `frequencyPenalty` and `presencePenalty` parameters support for `generateContent()`
6 changes: 6 additions & 0 deletions common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

// @public
export interface BaseParams {
frequencyPenalty?: number;
// (undocumented)
generationConfig?: GenerationConfig;
presencePenalty?: number;
// (undocumented)
safetySettings?: SafetySetting[];
}
Expand Down Expand Up @@ -458,13 +460,17 @@ export class GenerativeModel {
cachedContent: CachedContent;
countTokens(request: CountTokensRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<CountTokensResponse>;
embedContent(request: EmbedContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<EmbedContentResponse>;
// (undocumented)
frequencyPenalty?: number;
generateContent(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentStreamResult>;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
model: string;
// (undocumented)
presencePenalty?: number;
// (undocumented)
safetySettings: SafetySetting[];
startChat(startChatParams?: StartChatParams): ChatSession;
// (undocumented)
Expand Down
13 changes: 13 additions & 0 deletions docs/reference/main/generative-ai.baseparams.frequencypenalty.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<!-- Do not edit this file. It is automatically generated by API Documenter. -->

[Home](./index.md) &gt; [@google/generative-ai](./generative-ai.md) &gt; [BaseParams](./generative-ai.baseparams.md) &gt; [frequencyPenalty](./generative-ai.baseparams.frequencypenalty.md)

## BaseParams.frequencyPenalty property

Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been seen in the respponse so far.

**Signature:**

```typescript
frequencyPenalty?: number;
```
2 changes: 2 additions & 0 deletions docs/reference/main/generative-ai.baseparams.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export interface BaseParams

| Property | Modifiers | Type | Description |
| --- | --- | --- | --- |
| [frequencyPenalty?](./generative-ai.baseparams.frequencypenalty.md) | | number | _(Optional)_ Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been seen in the respponse so far. |
| [generationConfig?](./generative-ai.baseparams.generationconfig.md) | | [GenerationConfig](./generative-ai.generationconfig.md) | _(Optional)_ |
| [presencePenalty?](./generative-ai.baseparams.presencepenalty.md) | | number | _(Optional)_ Presence penalty applied to the next token's logprobs if the token has already been seen in the response. |
| [safetySettings?](./generative-ai.baseparams.safetysettings.md) | | [SafetySetting](./generative-ai.safetysetting.md)<!-- -->\[\] | _(Optional)_ |

13 changes: 13 additions & 0 deletions docs/reference/main/generative-ai.baseparams.presencepenalty.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<!-- Do not edit this file. It is automatically generated by API Documenter. -->

[Home](./index.md) &gt; [@google/generative-ai](./generative-ai.md) &gt; [BaseParams](./generative-ai.baseparams.md) &gt; [presencePenalty](./generative-ai.baseparams.presencepenalty.md)

## BaseParams.presencePenalty property

Presence penalty applied to the next token's logprobs if the token has already been seen in the response.

**Signature:**

```typescript
presencePenalty?: number;
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<!-- Do not edit this file. It is automatically generated by API Documenter. -->

[Home](./index.md) &gt; [@google/generative-ai](./generative-ai.md) &gt; [GenerativeModel](./generative-ai.generativemodel.md) &gt; [frequencyPenalty](./generative-ai.generativemodel.frequencypenalty.md)

## GenerativeModel.frequencyPenalty property

**Signature:**

```typescript
frequencyPenalty?: number;
```
2 changes: 2 additions & 0 deletions docs/reference/main/generative-ai.generativemodel.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ export declare class GenerativeModel
| --- | --- | --- | --- |
| [apiKey](./generative-ai.generativemodel.apikey.md) | | string | |
| [cachedContent](./generative-ai.generativemodel.cachedcontent.md) | | [CachedContent](./generative-ai.cachedcontent.md) | |
| [frequencyPenalty?](./generative-ai.generativemodel.frequencypenalty.md) | | number | _(Optional)_ |
| [generationConfig](./generative-ai.generativemodel.generationconfig.md) | | [GenerationConfig](./generative-ai.generationconfig.md) | |
| [model](./generative-ai.generativemodel.model.md) | | string | |
| [presencePenalty?](./generative-ai.generativemodel.presencepenalty.md) | | number | _(Optional)_ |
| [safetySettings](./generative-ai.generativemodel.safetysettings.md) | | [SafetySetting](./generative-ai.safetysetting.md)<!-- -->\[\] | |
| [systemInstruction?](./generative-ai.generativemodel.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
| [toolConfig?](./generative-ai.generativemodel.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<!-- Do not edit this file. It is automatically generated by API Documenter. -->

[Home](./index.md) &gt; [@google/generative-ai](./generative-ai.md) &gt; [GenerativeModel](./generative-ai.generativemodel.md) &gt; [presencePenalty](./generative-ai.generativemodel.presencepenalty.md)

## GenerativeModel.presencePenalty property

**Signature:**

```typescript
presencePenalty?: number;
```
2 changes: 2 additions & 0 deletions src/methods/generate-content.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ const fakeRequestParams: GenerateContentRequest = {
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
presencePenalty: 0.5,
frequencyPenalty: 0.1,
};

describe("generateContent()", () => {
Expand Down
14 changes: 12 additions & 2 deletions src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ describe("GenerativeModel", () => {
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
presencePenalty: 0.6,
frequencyPenalty: 0.5,
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.NONE },
Expand All @@ -92,6 +94,8 @@ describe("GenerativeModel", () => {
genModel.generationConfig?.responseSchema.properties.testField.type,
).to.equal(SchemaType.STRING);
expect(genModel.safetySettings?.length).to.equal(1);
expect(genModel.presencePenalty).to.equal(0.6);
expect(genModel.frequencyPenalty).to.equal(0.5);
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
Expand All @@ -116,7 +120,9 @@ describe("GenerativeModel", () => {
value.includes("be friendly") &&
value.includes("temperature") &&
value.includes("testField") &&
value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) &&
value.includes("presencePenalty") &&
value.includes("frequencyPenalty")
);
}),
match((value) => {
Expand Down Expand Up @@ -210,6 +216,8 @@ describe("GenerativeModel", () => {
threshold: HarmBlockThreshold.BLOCK_NONE,
},
],
presencePenalty: 0.6,
frequencyPenalty: 0.5,
contents: [{ role: "user", parts: [{ text: "hello" }] }],
tools: [{ functionDeclarations: [{ name: "otherfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } },
Expand All @@ -228,7 +236,9 @@ describe("GenerativeModel", () => {
value.includes("topK") &&
value.includes("newTestField") &&
!value.includes("testField") &&
value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT)
value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT) &&
value.includes("presencePenalty") &&
value.includes("frequencyPenalty")
);
}),
{},
Expand Down
10 changes: 10 additions & 0 deletions src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ export class GenerativeModel {
toolConfig?: ToolConfig;
systemInstruction?: Content;
cachedContent: CachedContent;
presencePenalty?: number;
frequencyPenalty?: number;

constructor(
public apiKey: string,
Expand All @@ -84,6 +86,8 @@ export class GenerativeModel {
modelParams.systemInstruction,
);
this.cachedContent = modelParams.cachedContent;
this.presencePenalty = modelParams.presencePenalty;
this.frequencyPenalty = modelParams.frequencyPenalty;
}

/**
Expand Down Expand Up @@ -113,6 +117,8 @@ export class GenerativeModel {
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
cachedContent: this.cachedContent?.name,
presencePenalty: this.presencePenalty,
frequencyPenalty: this.frequencyPenalty,
...formattedParams,
},
generativeModelRequestOptions,
Expand Down Expand Up @@ -148,6 +154,8 @@ export class GenerativeModel {
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
cachedContent: this.cachedContent?.name,
presencePenalty: this.presencePenalty,
frequencyPenalty: this.frequencyPenalty,
...formattedParams,
},
generativeModelRequestOptions,
Expand Down Expand Up @@ -194,6 +202,8 @@ export class GenerativeModel {
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
cachedContent: this.cachedContent,
presencePenalty: this.presencePenalty,
frequencyPenalty: this.frequencyPenalty,
});
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
Expand Down
10 changes: 10 additions & 0 deletions types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ import {
export interface BaseParams {
safetySettings?: SafetySetting[];
generationConfig?: GenerationConfig;
/**
* Presence penalty applied to the next token's logprobs if the token has
* already been seen in the response.
*/
presencePenalty?: number
/**
* Frequency penalty applied to the next token's logprobs, multiplied by the
* number of times each token has been seen in the respponse so far.
*/
frequencyPenalty?: number
}

/**
Expand Down

0 comments on commit dda0b5c

Please sign in to comment.