From 65fc3d351b9128010733a644d4a141f100e44b35 Mon Sep 17 00:00:00 2001 From: Adam Bergman Date: Thu, 23 May 2024 14:30:10 +0200 Subject: [PATCH] Support open api v3.0 --- apps/example/src/main.spec.ts | 30 ++--- .../zod-nestjs/src/lib/patch-nest-swagger.ts | 13 +- .../zod-openapi/src/lib/zod-openapi.spec.ts | 30 +++++ packages/zod-openapi/src/lib/zod-openapi.ts | 114 ++++++++++++------ 4 files changed, 130 insertions(+), 57 deletions(-) diff --git a/apps/example/src/main.spec.ts b/apps/example/src/main.spec.ts index 38fcf69..b51047f 100644 --- a/apps/example/src/main.spec.ts +++ b/apps/example/src/main.spec.ts @@ -59,7 +59,7 @@ describe('Cats', () => { // console.log(inspect(body, false, 10, true)); // }); - it(`Swagger Test`, () => { + it(`Swagger Test`, async () => { return request(app.getHttpServer()) .get('/api-json') .set('Accept', 'application/json') @@ -163,11 +163,11 @@ describe('Cats', () => { components: { schemas: { GetCatsDto: { - type: ['object'], + type: 'object', properties: { cats: { - type: ['array'], - items: { type: ['string'] }, + type: 'array', + items: { type: 'string' }, description: 'List of cats', }, }, @@ -175,30 +175,30 @@ describe('Cats', () => { title: 'Get Cat Response', }, CatDto: { - type: ['object'], + type: 'object', properties: { - name: { type: ['string'] }, - age: { type: ['number'] }, - breed: { type: ['string'] }, + name: { type: 'string' }, + age: { type: 'number' }, + breed: { type: 'string' }, }, required: ['name', 'age', 'breed'], title: 'Cat', description: 'A cat', }, CreateCatResponseDto: { - type: ['object'], + type: 'object', properties: { - success: { type: ['boolean'] }, - message: { type: ['string'] }, - name: { type: ['string'] }, + success: { type: 'boolean' }, + message: { type: 'string' }, + name: { type: 'string' }, }, required: ['success', 'message', 'name'], }, UpdateCatDto: { - type: ['object'], + type: 'object', properties: { - age: { type: ['number'] }, - breed: { type: ['string'] }, + age: { type: 'number' }, + breed: { type: 'string' }, }, required: ['age', 'breed'], }, diff --git a/packages/zod-nestjs/src/lib/patch-nest-swagger.ts b/packages/zod-nestjs/src/lib/patch-nest-swagger.ts index be285ad..3b5a205 100644 --- a/packages/zod-nestjs/src/lib/patch-nest-swagger.ts +++ b/packages/zod-nestjs/src/lib/patch-nest-swagger.ts @@ -12,11 +12,15 @@ interface Type extends Function { new (...args: any[]): T; } +type SchemaObjectFactoryModule = + typeof import('@nestjs/swagger/dist/services/schema-object-factory'); + export const patchNestjsSwagger = ( - schemaObjectFactoryModule = require('@nestjs/swagger/dist/services/schema-object-factory') + schemaObjectFactoryModule: SchemaObjectFactoryModule | undefined = undefined, + openApiVersion: '3.0' | '3.1' = '3.0' ): void => { - // eslint-disable-next-line @typescript-eslint/no-var-requires,@typescript-eslint/naming-convention - const { SchemaObjectFactory } = schemaObjectFactoryModule; + const { SchemaObjectFactory } = (schemaObjectFactoryModule ?? + require('@nestjs/swagger/dist/services/schema-object-factory')) as SchemaObjectFactoryModule; const orgExploreModelSchema = SchemaObjectFactory.prototype.exploreModelSchema; @@ -29,6 +33,7 @@ export const patchNestjsSwagger = ( // schemas: Record, // schemaRefsStack: string[] = [] ) { + // @ts-expect-error Reported as private, but since we are patching, we will be able to reach it if (this.isLazyTypeFunc(type)) { // eslint-disable-next-line @typescript-eslint/ban-types type = (type as Function)(); @@ -38,7 +43,7 @@ export const patchNestjsSwagger = ( return orgExploreModelSchema.call(this, type, schemas, schemaRefsStack); } - schemas[type.name] = generateSchema(type.zodSchema); + schemas[type.name] = generateSchema(type.zodSchema, false, openApiVersion); return type.name; }; diff --git a/packages/zod-openapi/src/lib/zod-openapi.spec.ts b/packages/zod-openapi/src/lib/zod-openapi.spec.ts index b3bc608..e831c43 100644 --- a/packages/zod-openapi/src/lib/zod-openapi.spec.ts +++ b/packages/zod-openapi/src/lib/zod-openapi.spec.ts @@ -37,6 +37,36 @@ describe('zodOpenapi', () => { }); }); + it('should support basic primitives for OpenAPI v3.0', () => { + const zodSchema = extendApi( + z.object({ + aString: z.string().describe('A test string').optional(), + aNumber: z.number().optional(), + aBigInt: z.bigint(), + aBoolean: z.boolean(), + aDate: z.date(), + }), + { + description: `Primitives also testing overwriting of "required"`, + required: ['aNumber'], // All schema settings "merge" + } + ); + const apiSchema = generateSchema(zodSchema, false, '3.0'); + + expect(apiSchema).toEqual({ + type: 'object', + properties: { + aString: { description: 'A test string', type: 'string' }, + aNumber: { type: 'number' }, + aBigInt: { type: 'integer', format: 'int64' }, + aBoolean: { type: 'boolean' }, + aDate: { type: 'string', format: 'date-time' }, + }, + required: ['aBigInt', 'aBoolean', 'aDate', 'aNumber'], + description: 'Primitives also testing overwriting of "required"', + }); + }); + it('should support empty types', () => { const zodSchema = extendApi( z.object({ diff --git a/packages/zod-openapi/src/lib/zod-openapi.ts b/packages/zod-openapi/src/lib/zod-openapi.ts index e3ab4fb..1269a15 100644 --- a/packages/zod-openapi/src/lib/zod-openapi.ts +++ b/packages/zod-openapi/src/lib/zod-openapi.ts @@ -12,11 +12,14 @@ interface OpenApiZodAnyObject extends AnyZodObject { metaOpenApi?: AnatineSchemaObject | AnatineSchemaObject[]; } +type OpenAPIVersion = '3.0' | '3.1'; + interface ParsingArgs { zodRef: T; schemas: AnatineSchemaObject[]; useOutput?: boolean; hideDefinitions?: string[]; + openApiVersion: OpenAPIVersion; } export function extendApi( @@ -37,13 +40,14 @@ function iterateZodObject({ zodRef, useOutput, hideDefinitions, + openApiVersion, }: ParsingArgs) { const reduced = Object.keys(zodRef.shape) .filter((key) => hideDefinitions?.includes(key) === false) .reduce( (carry, key) => ({ ...carry, - [key]: generateSchema(zodRef.shape[key], useOutput), + [key]: generateSchema(zodRef.shape[key], useOutput, openApiVersion), }), {} as Record ); @@ -51,12 +55,17 @@ function iterateZodObject({ return reduced; } +function typeFormat(type: T, openApiVersion: OpenAPIVersion) { + return openApiVersion === '3.0' ? type : [type]; +} + function parseTransformation({ zodRef, schemas, useOutput, + openApiVersion, }: ParsingArgs | z.ZodEffects>): SchemaObject { - const input = generateSchema(zodRef._def.schema, useOutput); + const input = generateSchema(zodRef._def.schema, useOutput, openApiVersion); let output = 'undefined'; if (useOutput && zodRef._def.effect) { @@ -64,9 +73,7 @@ function parseTransformation({ zodRef._def.effect.type === 'transform' ? zodRef._def.effect : null; if (effect && 'transform' in effect) { try { - // todo: this doesn't deal with nullable types very well - // @ts-expect-error because we try/catch for a missing type - const type = input.type[0]; + const type = Array.isArray(input.type) ? input.type[0] : input.type; output = typeof effect.transform( ['integer', 'number'].includes(`${type}`) ? 0 @@ -88,13 +95,14 @@ function parseTransformation({ } } } + const outputType = output as 'number' | 'string' | 'boolean' | 'null' return merge( { ...(zodRef.description ? { description: zodRef.description } : {}), ...input, ...(['number', 'string', 'boolean', 'null'].includes(output) ? { - type: [output as 'number' | 'string' | 'boolean' | 'null'], + type: typeFormat(outputType, openApiVersion), } : {}), }, @@ -105,9 +113,10 @@ function parseTransformation({ function parseString({ zodRef, schemas, + openApiVersion, }: ParsingArgs): SchemaObject { const baseSchema: SchemaObject = { - type: ['string'], + type: typeFormat('string', openApiVersion), }; const { checks = [] } = zodRef._def; checks.forEach((item) => { @@ -152,9 +161,10 @@ function parseString({ function parseNumber({ zodRef, schemas, + openApiVersion, }: ParsingArgs): SchemaObject { const baseSchema: SchemaObject = { - type: ['number'], + type: typeFormat('number', openApiVersion), }; const { checks = [] } = zodRef._def; checks.forEach((item) => { @@ -168,7 +178,7 @@ function parseNumber({ else baseSchema.exclusiveMinimum = item.value; break; case 'int': - baseSchema.type = ['integer']; + baseSchema.type = typeFormat('integer', openApiVersion); break; case 'multipleOf': baseSchema.multipleOf = item.value; @@ -202,6 +212,7 @@ function parseObject({ schemas, useOutput, hideDefinitions, + openApiVersion, }: ParsingArgs< z.ZodObject >): SchemaObject { @@ -214,7 +225,7 @@ function parseObject({ zodRef._def.catchall?._def.typeName === 'ZodNever' ) ) - additionalProperties = generateSchema(zodRef._def.catchall, useOutput); + additionalProperties = generateSchema(zodRef._def.catchall, useOutput, openApiVersion); else if (zodRef._def.unknownKeys === 'passthrough') additionalProperties = true; else if (zodRef._def.unknownKeys === 'strict') additionalProperties = false; @@ -241,12 +252,13 @@ function parseObject({ return merge( { - type: ['object' as SchemaObjectType], + type: typeFormat('object', openApiVersion), properties: iterateZodObject({ zodRef: zodRef as OpenApiZodAnyObject, schemas, useOutput, hideDefinitions: getExcludedDefinitionsFromSchema(schemas), + openApiVersion, }), ...required, ...additionalProperties, @@ -261,14 +273,15 @@ function parseRecord({ zodRef, schemas, useOutput, + openApiVersion, }: ParsingArgs): SchemaObject { return merge( { - type: ['object' as SchemaObjectType], + type: typeFormat('object', openApiVersion), additionalProperties: zodRef._def.valueType instanceof z.ZodUnknown ? {} - : generateSchema(zodRef._def.valueType, useOutput), + : generateSchema(zodRef._def.valueType, useOutput, openApiVersion), }, zodRef.description ? { description: zodRef.description } : {}, ...schemas @@ -278,9 +291,13 @@ function parseRecord({ function parseBigInt({ zodRef, schemas, + openApiVersion, }: ParsingArgs): SchemaObject { return merge( - { type: ['integer' as SchemaObjectType], format: 'int64' }, + { + type: typeFormat('integer', openApiVersion), + format: 'int64' + }, zodRef.description ? { description: zodRef.description } : {}, ...schemas ); @@ -289,25 +306,29 @@ function parseBigInt({ function parseBoolean({ zodRef, schemas, + openApiVersion, }: ParsingArgs): SchemaObject { return merge( - { type: ['boolean' as SchemaObjectType] }, + { type: typeFormat('boolean', openApiVersion) }, zodRef.description ? { description: zodRef.description } : {}, ...schemas ); } -function parseDate({ zodRef, schemas }: ParsingArgs): SchemaObject { +function parseDate({ zodRef, schemas, openApiVersion }: ParsingArgs): SchemaObject { return merge( - { type: ['string' as SchemaObjectType], format: 'date-time' }, + { + type: typeFormat('string', openApiVersion), + format: 'date-time' + }, zodRef.description ? { description: zodRef.description } : {}, ...schemas ); } -function parseNull({ zodRef, schemas }: ParsingArgs): SchemaObject { +function parseNull({ zodRef, schemas, openApiVersion }: ParsingArgs): SchemaObject { return merge( - { + openApiVersion === '3.0' ? { type: 'null' as SchemaObjectType } : { type: ['string', 'null'] as SchemaObjectType[], enum: ['null'], }, @@ -320,9 +341,10 @@ function parseOptional({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { return merge( - generateSchema(zodRef.unwrap(), useOutput), + generateSchema(zodRef.unwrap(), useOutput, openApiVersion), zodRef.description ? { description: zodRef.description } : {}, ...schemas ); @@ -332,11 +354,12 @@ function parseNullable({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { - const schema = generateSchema(zodRef.unwrap(), useOutput); + const schema = generateSchema(zodRef.unwrap(), useOutput, openApiVersion); return merge( schema, - { type: ['null'] as SchemaObjectType[] }, + { type: typeFormat('null', openApiVersion) }, zodRef.description ? { description: zodRef.description } : {}, ...schemas ); @@ -346,11 +369,12 @@ function parseDefault({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { return merge( { default: zodRef._def.defaultValue(), - ...generateSchema(zodRef._def.innerType, useOutput), + ...generateSchema(zodRef._def.innerType, useOutput, openApiVersion), }, zodRef.description ? { description: zodRef.description } : {}, ...schemas @@ -361,6 +385,7 @@ function parseArray({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { const constraints: SchemaObject = {}; if (zodRef._def.exactLength != null) { @@ -375,8 +400,8 @@ function parseArray({ return merge( { - type: ['array' as SchemaObjectType], - items: generateSchema(zodRef.element, useOutput), + type: typeFormat('array', openApiVersion), + items: generateSchema(zodRef.element, useOutput, openApiVersion), ...constraints, }, zodRef.description ? { description: zodRef.description } : {}, @@ -387,10 +412,12 @@ function parseArray({ function parseLiteral({ schemas, zodRef, + openApiVersion, }: ParsingArgs>): SchemaObject { + const type = typeof zodRef._def.value as 'string' | 'number' | 'boolean' return merge( { - type: [typeof zodRef._def.value as 'string' | 'number' | 'boolean'], + type: typeFormat(type, openApiVersion), enum: [zodRef._def.value], }, zodRef.description ? { description: zodRef.description } : {}, @@ -401,10 +428,12 @@ function parseLiteral({ function parseEnum({ schemas, zodRef, + openApiVersion, }: ParsingArgs | z.ZodNativeEnum>): SchemaObject { + const type = typeof Object.values(zodRef._def.values)[0] as 'string' | 'number' return merge( { - type: [typeof Object.values(zodRef._def.values)[0] as 'string' | 'number'], + type: typeFormat(type, openApiVersion), enum: Object.values(zodRef._def.values), }, zodRef.description ? { description: zodRef.description } : {}, @@ -416,12 +445,13 @@ function parseIntersection({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { return merge( { allOf: [ - generateSchema(zodRef._def.left, useOutput), - generateSchema(zodRef._def.right, useOutput), + generateSchema(zodRef._def.left, useOutput, openApiVersion), + generateSchema(zodRef._def.right, useOutput, openApiVersion), ], }, zodRef.description ? { description: zodRef.description } : {}, @@ -433,6 +463,7 @@ function parseUnion({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { const contents = zodRef._def.options; if ( @@ -454,7 +485,7 @@ function parseUnion({ if (type) { return merge( { - type: [type as 'string' | 'number' | 'boolean'], + type: typeFormat(type as SchemaObjectType, openApiVersion), enum: literals.map((literal) => literal._def.value), }, zodRef.description ? { description: zodRef.description } : {}, @@ -465,7 +496,7 @@ function parseUnion({ return merge( { - oneOf: contents.map((schema) => generateSchema(schema, useOutput)), + oneOf: contents.map((schema) => generateSchema(schema, useOutput, openApiVersion)), }, zodRef.description ? { description: zodRef.description } : {}, ...schemas @@ -476,6 +507,7 @@ function parseDiscriminatedUnion({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs< z.ZodDiscriminatedUnion[]> >): SchemaObject { @@ -496,7 +528,7 @@ function parseDiscriminatedUnion({ z.ZodDiscriminatedUnionOption[] > )._def.options.values() - ).map((schema) => generateSchema(schema, useOutput)), + ).map((schema) => generateSchema(schema, useOutput, openApiVersion)), }, zodRef.description ? { description: zodRef.description } : {}, ...schemas @@ -517,8 +549,10 @@ function parseNever({ function parseBranded({ schemas, zodRef, + useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { - return merge(generateSchema(zodRef._def.type), ...schemas); + return merge(generateSchema(zodRef._def.type, useOutput, openApiVersion), ...schemas); } function catchAllParser({ @@ -535,9 +569,10 @@ function parsePipeline({ schemas, zodRef, useOutput, + openApiVersion, }: ParsingArgs>): SchemaObject { return merge( - generateSchema(useOutput ? zodRef._def.out : zodRef._def.in, useOutput), + generateSchema(useOutput ? zodRef._def.out : zodRef._def.in, useOutput, openApiVersion), ...schemas, ); } @@ -546,9 +581,10 @@ function parseReadonly({ zodRef, useOutput, schemas, + openApiVersion, }: ParsingArgs>): SchemaObject { return merge( - generateSchema(zodRef._def.innerType, useOutput), + generateSchema(zodRef._def.innerType, useOutput, openApiVersion), zodRef.description ? { description: zodRef.description } : {}, ...schemas ); @@ -595,7 +631,8 @@ type WorkerKeys = keyof typeof workerMap; export function generateSchema( zodRef: OpenApiZodAny, - useOutput?: boolean + useOutput = false, + openApiVersion: OpenAPIVersion = '3.1', ): SchemaObject { const { metaOpenApi = {} } = zodRef; const schemas = [ @@ -608,12 +645,13 @@ export function generateSchema( zodRef: zodRef as never, schemas, useOutput, + openApiVersion, }); } - return catchAllParser({ zodRef, schemas }); + return catchAllParser({ zodRef, schemas, openApiVersion }); } catch (err) { console.error(err); - return catchAllParser({ zodRef, schemas }); + return catchAllParser({ zodRef, schemas, openApiVersion }); } }