diff --git a/.changes/a7f82a33-11f1-4184-97af-ff713e922dfc.json b/.changes/a7f82a33-11f1-4184-97af-ff713e922dfc.json new file mode 100644 index 000000000..8697c4d46 --- /dev/null +++ b/.changes/a7f82a33-11f1-4184-97af-ff713e922dfc.json @@ -0,0 +1,9 @@ +{ + "id": "a7f82a33-11f1-4184-97af-ff713e922dfc", + "type": "bugfix", + "description": "⚠️ **IMPORTANT**: Fix codegen for map shapes which use string enums as map keys. See the [**Map key changes** breaking change announcement](https://github.com/awslabs/aws-sdk-kotlin/discussions/1258) for more details", + "issues": [ + "awslabs/smithy-kotlin#1045" + ], + "requiresMinorVersionBump": true +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt index 5e7911972..25942d3cc 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt @@ -43,6 +43,7 @@ data class KotlinSettings( val sdkId: String, val build: BuildSettings = BuildSettings.Default, val api: ApiSettings = ApiSettings.Default, + val debug: Boolean = false, ) { /** @@ -104,12 +105,14 @@ data class KotlinSettings( val sdkId = config.getStringMemberOrDefault(SDK_ID, serviceId.name) val build = config.getObjectMember(BUILD_SETTINGS) val api = config.getObjectMember(API_SETTINGS) + val debug = config.getBooleanMemberOrDefault("debug", false) return KotlinSettings( serviceId, PackageSettings(packageName, version, desc), sdkId, BuildSettings.fromNode(build), ApiSettings.fromNode(api), + debug, ) } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt index f185a66d6..7b5b871e6 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt @@ -150,6 +150,7 @@ class KotlinDelegator( val needsNewline = writers.containsKey(formattedFilename) val writer = writers.getOrPut(formattedFilename) { val kotlinWriter = KotlinWriter(namespace) + if (settings.debug) kotlinWriter.enableStackTraceComments(true) // Register all integrations [SectionWriterBindings] on the writer. integrations.forEach { integration -> diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt index 10dd958ce..b2848653d 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.kotlin.codegen.core import software.amazon.smithy.codegen.core.* import software.amazon.smithy.kotlin.codegen.KotlinSettings -import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.lang.kotlinReservedWords import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.utils.dq @@ -162,15 +161,18 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli } override fun mapShape(shape: MapShape): Symbol { - val reference = toSymbol(shape.value) - val valueSuffix = if (reference.isNullable) "?" else "" - val valueType = "${reference.name}$valueSuffix" - val fullyQualifiedValueType = "${reference.fullName}$valueSuffix" + val keyReference = toSymbol(shape.key) + val keyType = keyReference.name + val fullyQualifiedKeyType = keyReference.fullName + + val valueReference = toSymbol(shape.value) + val valueSuffix = if (valueReference.isNullable) "?" else "" + val valueType = "${valueReference.name}$valueSuffix" + val fullyQualifiedValueType = "${valueReference.fullName}$valueSuffix" - val keyType = KotlinTypes.String.name - val fullyQualifiedKeyType = KotlinTypes.String.fullName return createSymbolBuilder(shape, "Map<$keyType, $valueType>") - .addReferences(reference) + .addReferences(keyReference) + .addReferences(valueReference) .putProperty(SymbolProperty.FULLY_QUALIFIED_NAME_HINT, "Map<$fullyQualifiedKeyType, $fullyQualifiedValueType>") .putProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION, "mutableMapOf<$keyType, $valueType>") .putProperty(SymbolProperty.IMMUTABLE_COLLECTION_FUNCTION, "mapOf<$keyType, $valueType>") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGenerator.kt index 885febc11..3ac683da8 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGenerator.kt @@ -188,7 +188,16 @@ class ShapeValueGenerator( } is MapShape -> { memberShape = generator.model.expectShape(currShape.value.target) - writer.writeInline("#S to ", keyNode.value) + + val keyTarget = generator.model.expectShape(currShape.key.target) + if (keyTarget.isEnum) { + val keySymbol = generator.symbolProvider.toSymbol(currShape.key) + writer.writeInline("#T.fromValue(#S)", keySymbol, keyNode.value) + } else { + writer.writeInline("#S", keyNode.value) + } + + writer.writeInline(" to ") if (valueNode is NullNode) { writer.write("null") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt index 2f8e8de51..ea91afc9f 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.* @@ -145,9 +146,10 @@ open class DeserializeStructGenerator( .indent() .withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) { write( - "val #L = #T()", + "val #L = #T<#T, #T#L>()", mutableCollectionName, KotlinTypes.Collections.mutableMapOf, + ctx.symbolProvider.toSymbol(targetShape.key), ctx.symbolProvider.toSymbol(targetShape.value), nullabilitySuffix(targetShape.isSparse), ) @@ -168,6 +170,8 @@ open class DeserializeStructGenerator( nestingLevel: Int, parentMemberName: String, ) { + val keyShape = ctx.model.expectShape(mapShape.key.target) + val keySymbol = ctx.symbolProvider.toSymbol(keyShape) val elementShape = ctx.model.expectShape(mapShape.value.target) val isSparse = mapShape.isSparse @@ -187,21 +191,47 @@ open class DeserializeStructGenerator( ShapeType.TIMESTAMP, ShapeType.ENUM, ShapeType.INT_ENUM, - -> renderEntry(elementShape, nestingLevel, isSparse, parentMemberName) + -> renderEntry(keyShape, keySymbol, elementShape, nestingLevel, isSparse, parentMemberName) ShapeType.SET, ShapeType.LIST, - -> renderListEntry(rootMemberShape, elementShape as CollectionShape, nestingLevel, isSparse, parentMemberName) + -> renderListEntry( + rootMemberShape, + keyShape, + keySymbol, + elementShape as CollectionShape, + nestingLevel, + isSparse, + parentMemberName, + ) + + ShapeType.MAP -> renderMapEntry( + rootMemberShape, + keyShape, + keySymbol, + elementShape as MapShape, + nestingLevel, + isSparse, + parentMemberName, + ) - ShapeType.MAP -> renderMapEntry(rootMemberShape, elementShape as MapShape, nestingLevel, isSparse, parentMemberName) ShapeType.UNION, ShapeType.STRUCTURE, - -> renderNestedStructureEntry(elementShape, nestingLevel, isSparse, parentMemberName) + -> renderNestedStructureEntry(keyShape, keySymbol, elementShape, nestingLevel, isSparse, parentMemberName) else -> error("Unhandled type ${elementShape.type}") } } + private fun writeKeyVal(keyShape: Shape, keySymbol: Symbol, keyName: String) { + writer.writeInline("val $keyName = ") + if (keyShape.isEnum) { + writer.write("#T.fromValue(key())", keySymbol) + } else { + writer.write("key()") + } + } + /** * Renders the deserialization of a nested structure contained in a map. Example: * @@ -212,6 +242,8 @@ open class DeserializeStructGenerator( * ``` */ private fun renderNestedStructureEntry( + keyShape: Shape, + keySymbol: Symbol, elementShape: Shape, nestingLevel: Int, isSparse: Boolean, @@ -226,7 +258,7 @@ open class DeserializeStructGenerator( writer.addImport(symbol) } - writer.write("val $keyName = key()") + writeKeyVal(keyShape, keySymbol, keyName) writer.write("val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }") writer.write("$parentMemberName[$keyName] = $valueName") } @@ -247,6 +279,8 @@ open class DeserializeStructGenerator( */ private fun renderMapEntry( rootMemberShape: MemberShape, + keyShape: Shape, + keySymbol: Symbol, mapShape: MapShape, nestingLevel: Int, isSparse: Boolean, @@ -260,14 +294,15 @@ open class DeserializeStructGenerator( val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP) val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName) - writer.write("val $keyName = key()") + writeKeyVal(keyShape, keySymbol, keyName) writer.withBlock("val $valueName =", "") { withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") { withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) { write( - "val #L = #T()", + "val #L = #T<#T, #T#L>()", memberName, KotlinTypes.Collections.mutableMapOf, + keySymbol, ctx.symbolProvider.toSymbol(mapShape.value), nullabilitySuffix(mapShape.isSparse), ) @@ -298,6 +333,8 @@ open class DeserializeStructGenerator( */ private fun renderListEntry( rootMemberShape: MemberShape, + keyShape: Shape, + keySymbol: Symbol, collectionShape: CollectionShape, nestingLevel: Int, isSparse: Boolean, @@ -311,7 +348,7 @@ open class DeserializeStructGenerator( val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION) val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName) - writer.write("val $keyName = key()") + writeKeyVal(keyShape, keySymbol, keyName) writer.withBlock("val $valueName =", "") { withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") { withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) { @@ -340,13 +377,20 @@ open class DeserializeStructGenerator( * map0[k0] = el0 * ``` */ - private fun renderEntry(elementShape: Shape, nestingLevel: Int, isSparse: Boolean, parentMemberName: String) { + private fun renderEntry( + keyShape: Shape, + keySymbol: Symbol, + elementShape: Shape, + nestingLevel: Int, + isSparse: Boolean, + parentMemberName: String, + ) { val deserializerFn = deserializerForShape(elementShape) val keyName = nestingLevel.variableNameFor(NestedIdentifierType.KEY) val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE) val populateNullValuePostfix = if (isSparse) "" else "; continue" - writer.write("val $keyName = key()") + writeKeyVal(keyShape, keySymbol, keyName) writer.write("val $valueName = if (nextHasValue()) { $deserializerFn } else { deserializeNull()$populateNullValuePostfix }") writer.write("$parentMemberName[$keyName] = $valueName") } @@ -476,9 +520,10 @@ open class DeserializeStructGenerator( writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) { write( - "val #L = #T()", + "val #L = #T<#T, #T#L>()", mapName, KotlinTypes.Collections.mutableMapOf, + ctx.symbolProvider.toSymbol(mapShape.key), ctx.symbolProvider.toSymbol(mapShape.value), nullabilitySuffix(mapShape.isSparse), ) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt index ff099b0c9..20be3a6b6 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGenerator.kt @@ -175,6 +175,7 @@ open class SerializeStructGenerator( * Delegates to other functions based on the type of value target of map. */ protected fun delegateMapSerialization(rootMemberShape: MemberShape, mapShape: MapShape, nestingLevel: Int, parentMemberName: String) { + val keyShape = ctx.model.expectShape(mapShape.key.target) val elementShape = ctx.model.expectShape(mapShape.value.target) val isSparse = mapShape.isSparse @@ -192,18 +193,41 @@ open class SerializeStructGenerator( ShapeType.BIG_INTEGER, ShapeType.ENUM, ShapeType.INT_ENUM, - -> renderPrimitiveEntry(elementShape, nestingLevel, parentMemberName) + -> renderPrimitiveEntry(keyShape, elementShape, nestingLevel, parentMemberName) + + ShapeType.BLOB -> renderBlobEntry(keyShape, nestingLevel, parentMemberName) + + ShapeType.TIMESTAMP -> renderTimestampEntry( + keyShape, + mapShape.value, + elementShape, + nestingLevel, + parentMemberName, + ) - ShapeType.BLOB -> renderBlobEntry(nestingLevel, parentMemberName) - ShapeType.TIMESTAMP -> renderTimestampEntry(mapShape.value, elementShape, nestingLevel, parentMemberName) ShapeType.SET, ShapeType.LIST, - -> renderListEntry(rootMemberShape, elementShape as CollectionShape, nestingLevel, isSparse, parentMemberName) + -> renderListEntry( + rootMemberShape, + keyShape, + elementShape as CollectionShape, + nestingLevel, + isSparse, + parentMemberName, + ) + + ShapeType.MAP -> renderMapEntry( + rootMemberShape, + keyShape, + elementShape as MapShape, + nestingLevel, + isSparse, + parentMemberName, + ) - ShapeType.MAP -> renderMapEntry(rootMemberShape, elementShape as MapShape, nestingLevel, isSparse, parentMemberName) ShapeType.UNION, ShapeType.STRUCTURE, - -> renderNestedStructureEntry(elementShape, nestingLevel, parentMemberName, isSparse) + -> renderNestedStructureEntry(keyShape, elementShape, nestingLevel, parentMemberName, isSparse) else -> error("Unhandled type ${elementShape.type}") } @@ -276,6 +300,7 @@ open class SerializeStructGenerator( * ``` */ private fun renderNestedStructureEntry( + keyShape: Shape, structureShape: Shape, nestingLevel: Int, parentMemberName: String, @@ -283,13 +308,14 @@ open class SerializeStructGenerator( ) { val serializerTypeName = ctx.symbolProvider.toSymbol(structureShape).documentSerializerName() val (keyName, valueName) = keyValueNames(nestingLevel) + val keyValue = keyValue(keyShape, keyName) val containerName = if (nestingLevel == 0) "input." else "" val value = "asSdkSerializable($valueName, ::$serializerTypeName)" when (isSparse) { - true -> writer.write("$containerName$parentMemberName.forEach { ($keyName, $valueName) -> if ($valueName != null) entry($keyName, $value) else entry($keyName, null as String?) }") - false -> writer.write("$containerName$parentMemberName.forEach { ($keyName, $valueName) -> entry($keyName, $value) }") + true -> writer.write("$containerName$parentMemberName.forEach { ($keyName, $valueName) -> if ($valueName != null) entry($keyValue, $value) else entry($keyValue, null as String?) }") + false -> writer.write("$containerName$parentMemberName.forEach { ($keyName, $valueName) -> entry($keyValue, $value) }") } } @@ -337,6 +363,7 @@ open class SerializeStructGenerator( */ private fun renderMapEntry( rootMemberShape: MemberShape, + keyShape: Shape, mapShape: MapShape, nestingLevel: Int, isSparse: Boolean, @@ -345,11 +372,12 @@ open class SerializeStructGenerator( val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName()) val containerName = if (nestingLevel == 0) "input." else "" val (keyName, valueName) = keyValueNames(nestingLevel) + val keyValue = keyValue(keyShape, keyName) val parentName = parentName(valueName) writer.withBlock("$containerName$parentMemberName.forEach { ($keyName, $valueName) ->", "}") { - writer.wrapBlockIf(isSparse, "if ($valueName != null) {", "} else entry($keyName, null as String?)") { - writer.withBlock("mapEntry($keyName, $descriptorName) {", "}") { + writer.wrapBlockIf(isSparse, "if ($valueName != null) {", "} else entry($keyValue, null as String?)") { + writer.withBlock("mapEntry($keyValue, $descriptorName) {", "}") { delegateMapSerialization(rootMemberShape, mapShape, nestingLevel + 1, parentName) } } @@ -367,6 +395,7 @@ open class SerializeStructGenerator( */ private fun renderListEntry( rootMemberShape: MemberShape, + keyShape: Shape, elementShape: CollectionShape, nestingLevel: Int, isSparse: Boolean, @@ -376,10 +405,11 @@ open class SerializeStructGenerator( val containerName = if (nestingLevel == 0) "input." else "" val (keyName, valueName) = keyValueNames(nestingLevel) val parentName = parentName(valueName) + val keyValue = keyValue(keyShape, keyName) writer.withBlock("$containerName$parentMemberName.forEach { ($keyName, $valueName) ->", "}") { - writer.wrapBlockIf(isSparse, "if ($valueName != null) {", "} else entry($keyName, null as String?)") { - writer.withBlock("listEntry($keyName, $descriptorName) {", "}") { + writer.wrapBlockIf(isSparse, "if ($valueName != null) {", "} else entry($keyValue, null as String?)") { + writer.withBlock("listEntry($keyValue, $descriptorName) {", "}") { delegateListSerialization(rootMemberShape, elementShape, nestingLevel + 1, parentName) } } @@ -417,12 +447,13 @@ open class SerializeStructGenerator( * c0.forEach { (key1, value1) -> entry(key1, value1) } * ``` */ - private fun renderPrimitiveEntry(elementShape: Shape, nestingLevel: Int, listMemberName: String) { + private fun renderPrimitiveEntry(keyShape: Shape, elementShape: Shape, nestingLevel: Int, listMemberName: String) { val containerName = if (nestingLevel == 0) "input." else "" val enumPostfix = if (elementShape.isEnum) ".value" else "" val (keyName, valueName) = keyValueNames(nestingLevel) + val keyValue = keyValue(keyShape, keyName) - writer.write("$containerName$listMemberName.forEach { ($keyName, $valueName) -> entry($keyName, $valueName$enumPostfix) }") + writer.write("$containerName$listMemberName.forEach { ($keyName, $valueName) -> entry($keyValue, $valueName$enumPostfix) }") } /** @@ -432,12 +463,13 @@ open class SerializeStructGenerator( * input.fooBlobMap.forEach { (key, value) -> entry(key, value.encodeBase64String()) } * ``` */ - private fun renderBlobEntry(nestingLevel: Int, listMemberName: String) { + private fun renderBlobEntry(keyShape: Shape, nestingLevel: Int, listMemberName: String) { val containerName = if (nestingLevel == 0) "input." else "" val (keyName, valueName) = keyValueNames(nestingLevel) + val keyValue = keyValue(keyShape, keyName) writer.write( - "$containerName$listMemberName.forEach { ($keyName, $valueName) -> entry($keyName, $valueName.#T()) }", + "$containerName$listMemberName.forEach { ($keyName, $valueName) -> entry($keyValue, $valueName.#T()) }", RuntimeTypes.Core.Text.Encoding.encodeBase64String, ) } @@ -449,7 +481,13 @@ open class SerializeStructGenerator( * input.fooTimestampMap.forEach { (key, value) -> entry(key, it, TimestampFormat.EPOCH_SECONDS) } * ``` */ - private fun renderTimestampEntry(memberShape: Shape, elementShape: Shape, nestingLevel: Int, listMemberName: String) { + private fun renderTimestampEntry( + keyShape: Shape, + memberShape: Shape, + elementShape: Shape, + nestingLevel: Int, + listMemberName: String, + ) { writer.addImport(RuntimeTypes.Core.TimestampFormat) // favor the member shape if it overrides the value shape trait @@ -466,9 +504,10 @@ open class SerializeStructGenerator( .toRuntimeEnum() val (keyName, valueName) = keyValueNames(nestingLevel) + val keyValue = keyValue(keyShape, keyName) val containerName = if (nestingLevel == 0) "input." else "" - writer.write("$containerName$listMemberName.forEach { ($keyName, $valueName) -> entry($keyName, it, $tsFormat) }") + writer.write("$containerName$listMemberName.forEach { ($keyName, $valueName) -> entry($keyValue, it, $tsFormat) }") } /** @@ -659,6 +698,8 @@ open class SerializeStructGenerator( return keyName to valueName } + private fun keyValue(keyShape: Shape, keyName: String) = keyName + if (keyShape.isEnum) ".value" else "" + /** * Get the name of the `PrimitiveSerializer` function name for the corresponding shape type * @throws CodegenException when no known function name for the given type is known to exist diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt index 4c4e59894..35051bec5 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlParserGenerator.kt @@ -463,7 +463,7 @@ open class XmlParserGenerator( writer: KotlinWriter, ) { val target = ctx.model.expectShape(member.target) - val keySymbol = KotlinTypes.String + val keySymbol = ctx.symbolProvider.toSymbol(target.key) val valueSymbol = ctx.symbolProvider.toSymbol(target.value) writer.addImportReferences(valueSymbol, SymbolReference.ContextOption.USE) val isSparse = target.hasTrait() @@ -512,7 +512,7 @@ open class XmlParserGenerator( map: MapShape, ): Symbol { val shapeName = StringUtils.capitalize(map.id.getName(ctx.service)) - val keySymbol = KotlinTypes.String + val keySymbol = ctx.symbolProvider.toSymbol(map.key) val valueSymbol = ctx.symbolProvider.toSymbol(map.value) val isSparse = map.hasTrait() val serdeCtx = SerdeCtx("reader") @@ -541,14 +541,6 @@ open class XmlParserGenerator( val keyName = map.key.getTrait()?.value ?: map.key.memberName writeInline("#S -> key = ", keyName) deserializeMember(ctx, innerCtx, map.key, this) - // FIXME - We re-use deserializeMember here but key types targeting enums - // have to pull the raw string value back out because of - // https://github.com/awslabs/smithy-kotlin/issues/1045 - val targetValueShape = ctx.model.expectShape(map.key.target) - if (targetValueShape.type == ShapeType.ENUM) { - writer.indent() - .write(".value") - } val valueName = map.value.getTrait()?.value ?: map.value.memberName if (isSparse) { diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/SymbolProviderTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/SymbolProviderTest.kt index 07e248119..0cc4d2cb1 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/SymbolProviderTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/SymbolProviderTest.kt @@ -529,21 +529,58 @@ class SymbolProviderTest { assertEquals("Map", mapSymbol.name) - // collections should contain a reference to the member type - assertEquals("Record", mapSymbol.references[0].symbol.name) + // collections should contain a reference to the member types + val refNames = mapSymbol.references.map { it.symbol.fullName } + assertTrue("kotlin.String" in refNames) + assertTrue("com.test.model.Record" in refNames) val sparseMapSymbol = provider.toSymbol(model.expectShape("${TestModelDefault.NAMESPACE}#MySparseMap")) assertEquals("Map", sparseMapSymbol.name) // collections should contain a reference to the member type - assertEquals("Record", sparseMapSymbol.references[0].symbol.name) + val sparseRefNames = sparseMapSymbol.references.map { it.symbol.fullName } + assertTrue("kotlin.String" in sparseRefNames) + assertTrue("com.test.model.Record" in sparseRefNames) // check the fully qualified name hint is set assertEquals("Map", mapSymbol.fullNameHint) assertEquals("Map", sparseMapSymbol.fullNameHint) } + @Test + fun `creates maps with enum keys`() { + val model = """ + @enum([ + { + value: "FOO", + }, + { + value: "BAR", + }, + ]) + string Type + + structure Record {} + + map MyMap { + key: Type, + value: Record, + } + """.prependNamespaceAndService().toSmithyModel() + + val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model) + + val mapSymbol = provider.toSymbol(model.expectShape("${TestModelDefault.NAMESPACE}#MyMap")) + + assertEquals("Map", mapSymbol.name) + + // collections should contain a reference to the member types + val refNames = mapSymbol.references.map { it.symbol.fullName } + assertTrue("com.test.model.Type" in refNames) + assertTrue("com.test.model.Record" in refNames) + } + @DisplayName("creates bigNumbers") @ParameterizedTest(name = "{index} ==> ''{0}''") @ValueSource(strings = ["BigInteger", "BigDecimal"]) diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGeneratorTest.kt index 8e6a738f1..101c3e950 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ShapeValueGeneratorTest.kt @@ -50,6 +50,52 @@ mapOf( contents.shouldContainOnlyOnceWithDiff(expected) } + @Test + fun `it renders maps with enum keys`() { + val model = """ + @enum([ + { + value: "k1", + }, + { + value: "k2", + }, + { + value: "k3", + }, + ]) + string KeyType + + map MyMap { + key: KeyType, + value: Integer, + } + """.prependNamespaceAndService(namespace = "foo.bar").toSmithyModel() + + val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model, rootNamespace = "foo.bar") + val mapShape = model.expectShape(ShapeId.from("foo.bar#MyMap")) + val writer = KotlinWriter("test") + + val params = Node.objectNodeBuilder() + .withMember("k1", 1) + .withMember("k2", 2) + .withMember("k3", 3) + .build() + + ShapeValueGenerator(model, provider).instantiateShapeInline(writer, mapShape, params) + val contents = writer.toString() + + val expected = """ +mapOf( + KeyType.fromValue("k1") to 1, + KeyType.fromValue("k2") to 2, + KeyType.fromValue("k3") to 3 +) +""" + + contents.shouldContainOnlyOnceWithDiff(expected) + } + @Test fun `it renders lists`() { val model = """ diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt index 0d0349f21..abd959dda 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt @@ -1026,6 +1026,57 @@ class DeserializeStructGeneratorTest { actual.shouldContainOnlyOnceWithDiff(expected) } + @Test + fun `it deserializes a structure containing a map with an enum key`() { + val model = ( + modelPrefix + """ + structure FooResponse { + payload: KeyValuedMap + } + + map KeyValuedMap { + key: KeyType, + value: String + } + + @enum([ + { + value: "FOO", + }, + { + value: "BAR", + }, + ]) + string KeyType + """ + ).toSmithyModel() + + val expected = """ + deserializer.deserializeStruct(OBJ_DESCRIPTOR) { + loop@while (true) { + when (findNextFieldIndex()) { + PAYLOAD_DESCRIPTOR.index -> builder.payload = + deserializer.deserializeMap(PAYLOAD_DESCRIPTOR) { + val map0 = mutableMapOf() + while (hasNextEntry()) { + val k0 = KeyType.fromValue(key()) + val v0 = if (nextHasValue()) { deserializeString() } else { deserializeNull(); continue } + map0[k0] = v0 + } + map0 + } + null -> break@loop + else -> skipValue() + } + } + } + """.trimIndent() + + val actual = codegenDeserializerForShape(model, "com.test#Foo") + + actual.shouldContainOnlyOnceWithDiff(expected) + } + @Test fun `it deserializes a structure containing a map of a union of primitive values`() { val model = ( diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt index 066f2fef6..6c74c5e48 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt @@ -1034,6 +1034,46 @@ class SerializeStructGeneratorTest { actual.shouldContainOnlyOnceWithDiff(expected) } + @Test + fun `it serializes a structure containing a map with enum keys`() { + val model = ( + modelPrefix + """ + structure FooRequest { + payload: KeyValuedMap + } + + map KeyValuedMap { + key: KeyType, + value: String + } + + @enum([ + { + value: "FOO", + }, + { + value: "BAR", + }, + ]) + string KeyType + """ + ).toSmithyModel() + + val expected = """ + serializer.serializeStruct(OBJ_DESCRIPTOR) { + if (input.payload != null) { + mapField(PAYLOAD_DESCRIPTOR) { + input.payload.forEach { (key, value) -> entry(key.value, value) } + } + } + } + """.trimIndent() + + val actual = codegenSerializerForShape(model, "com.test#Foo").stripCodegenPrefix() + + actual.shouldContainOnlyOnceWithDiff(expected) + } + @Test fun `it serializes a structure containing a required map`() { val model = ( diff --git a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt index 530df9186..0cd961424 100644 --- a/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt +++ b/tests/codegen/serde-tests/src/test/kotlin/aws/smithy/kotlin/tests/serde/XmlMapTest.kt @@ -249,9 +249,9 @@ class XmlMapTest : AbstractXmlTest() { // see also https://github.com/awslabs/smithy-kotlin/issues/1045 val expected = StructType { enumKeyMap = mapOf( - FooEnum.Foo.value to 1, - "Bar" to 2, - "Unknown" to 3, + FooEnum.Foo to 1, + FooEnum.Bar to 2, + FooEnum.SdkUnknown("Unknown") to 3, ) } val payload = """ diff --git a/tests/compile/src/test/resources/kitchen-sink-model.smithy b/tests/compile/src/test/resources/kitchen-sink-model.smithy index c1fb3747d..b5643d3fe 100644 --- a/tests/compile/src/test/resources/kitchen-sink-model.smithy +++ b/tests/compile/src/test/resources/kitchen-sink-model.smithy @@ -239,6 +239,11 @@ map StringMap { value: String } +map EnumKeyedStringMap { + key: MyEnum, + value: String +} + // only exists as value of a map through MapInputRequest::structMap structure ReachableOnlyThroughMap { prop1: Integer @@ -280,6 +285,8 @@ structure MapInputRequest { structMap: StructMap, enumMap: EnumMap, blobMap: BlobMap, + stringMap: StringMap, + enumKeyedStringMap: EnumKeyedStringMap, mapOfLists: MapOfLists, nestedMap: NestedMap } @@ -289,6 +296,8 @@ structure MapOutputResponse { structMap: StructMap, enumMap: EnumMap, blobMap: BlobMap, + stringMap: StringMap, + enumKeyedStringMap: EnumKeyedStringMap, nestedMap: NestedMap }