Skip to content

Commit

Permalink
feat: added support for image editing and upscaling for imagen3 (#989)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Oct 3, 2024
1 parent 895249a commit 231591b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 15 deletions.
45 changes: 44 additions & 1 deletion docs/plugins/vertex-ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ The Vertex AI plugin provides interfaces to several AI services:

- [Google generative AI models](https://cloud.google.com/vertex-ai/generative-ai/docs/):
- Gemini text generation
- Imagen2 image generation
- Imagen2 and Imagen3 image generation
- Text embedding generation
- A subset of evaluation metrics through the Vertex AI [Rapid Evaluation API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation):
- [BLEU](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations/evaluateInstances#bleuinput)
Expand Down Expand Up @@ -157,6 +157,49 @@ const embedding = await embed({
});
```

Imagen3 model allows generating images from user prompt:

```js
import { imagen3 } from '@genkit-ai/vertexai';

const response = await generate({
model: imagen3,
output: { format: 'media' },
prompt: 'a banana riding a bicycle',
});

return response.media();
```

and even advanced editing of existing images:

```js
const baseImg = fs.readFileSync('base.png', { encoding: 'base64' });
const maskImg = fs.readFileSync('mask.png', { encoding: 'base64' });

const response = await generate({
model: imagen3,
output: { format: 'media' },
prompt: [
{ media: { url: `data:image/png;base64,${baseImg}` }},
{
media: { url: `data:image/png;base64,${maskImg}` },
metadata: { type: 'mask' },
},
{ text: 'replace the background with foo bar baz' },
],
config: {
editConfig: {
editMode: 'outpainting',
},
},
});

return response.media();
```

Refer to (Imagen model documentation)[https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2] for more detailed options.

#### Anthropic Claude 3 on Vertex AI Model Garden

If you have access to Claude 3 models ([haiku](https://console.cloud.google.com/vertex-ai/publishers/anthropic/model-garden/claude-3-haiku), [sonnet](https://console.cloud.google.com/vertex-ai/publishers/anthropic/model-garden/claude-3-sonnet) or [opus](https://console.cloud.google.com/vertex-ai/publishers/anthropic/model-garden/claude-3-opus)) in Vertex AI Model Garden you can use them with Genkit.
Expand Down
72 changes: 58 additions & 14 deletions js/plugins/vertexai/src/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,40 @@ const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
addWatermark: z.boolean().optional(),
/** Cloud Storage URI to store the generated images. **/
storageUri: z.string().optional(),
});
/** Mode must be set for upscaling requests. */
mode: z.enum(['upscale']).optional(),
/**
* Describes the editing intention for the request.
*
* Refer to https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2 for details.
*/
editConfig: z
.object({
/** Describes the editing intention for the request. */
editMode: z
.enum([
'inpainting-insert',
'inpainting-remove',
'outpainting',
'product-image',
])
.optional(),
/** Prompts the model to generate a mask instead of you needing to provide one. Consequently, when you provide this parameter you can omit a mask object. */
maskMode: z
.object({
maskType: z.enum(['background', 'foreground', 'semantic']),
classes: z.array(z.number()).optional(),
})
.optional(),
maskDilation: z.number().optional(),
guidanceScale: z.number().optional(),
productPosition: z.enum(['reposition', 'fixed']).optional(),
})
.passthrough()
.optional(),
/** Upscale config object. */
upscaleConfig: z.object({ upscaleFactor: z.enum(['x2', 'x4']) }).optional(),
}).passthrough();

export const imagen2 = modelRef({
name: 'vertexai/imagen2',
Expand All @@ -86,7 +119,7 @@ export const imagen3 = modelRef({
label: 'Vertex AI - Imagen3',
versions: ['imagen-3.0-generate-001'],
supports: {
media: false,
media: true,
multiturn: false,
tools: false,
systemRole: false,
Expand Down Expand Up @@ -144,14 +177,7 @@ function toParameters(
): ImagenParameters {
const out = {
sampleCount: request.candidates ?? 1,
aspectRatio: request.config?.aspectRatio,
negativePrompt: request.config?.negativePrompt,
seed: request.config?.seed,
language: request.config?.language,
personGeneration: request.config?.personGeneration,
safetySetting: request.config?.safetySetting,
addWatermark: request.config?.addWatermark,
storageUri: request.config?.storageUri,
...request?.config,
};

for (const k in out) {
Expand All @@ -161,10 +187,19 @@ function toParameters(
return out;
}

function extractPromptImage(request: GenerateRequest): string | undefined {
function extractMaskImage(request: GenerateRequest): string | undefined {
return request.messages
.at(-1)
?.content.find((p) => !!p.media && p.metadata?.type === 'mask')
?.media?.url.split(',')[1];
}

function extractBaseImage(request: GenerateRequest): string | undefined {
return request.messages
.at(-1)
?.content.find((p) => !!p.media)
?.content.find(
(p) => !!p.media && (!p.metadata?.type || p.metadata?.type === 'base')
)
?.media?.url.split(',')[1];
}

Expand All @@ -176,6 +211,7 @@ interface ImagenPrediction {
interface ImagenInstance {
prompt: string;
image?: { bytesBase64Encoded: string };
mask?: { image?: { bytesBase64Encoded: string } };
}

export function imagenModel(
Expand Down Expand Up @@ -222,8 +258,16 @@ export function imagenModel(
const instance: ImagenInstance = {
prompt: extractText(request),
};
if (extractPromptImage(request))
instance.image = { bytesBase64Encoded: extractPromptImage(request)! };
const baseImage = extractBaseImage(request);
if (baseImage) {
instance.image = { bytesBase64Encoded: baseImage };
}
const maskImage = extractMaskImage(request);
if (maskImage) {
instance.mask = {
image: { bytesBase64Encoded: maskImage },
};
}

const req: any = {
instances: [instance],
Expand Down

0 comments on commit 231591b

Please sign in to comment.