Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/awslabs/smithy-kotlin into …
Browse files Browse the repository at this point in the history
…multi-auth
  • Loading branch information
0marperez committed Mar 6, 2024
2 parents efe0a79 + 2dd9abc commit 0afa170
Show file tree
Hide file tree
Showing 96 changed files with 3,059 additions and 4,084 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Changelog

## [1.0.16] - 02/28/2024

### Features
* Add support for S3 Express One Zone

### Fixes
* [#1220](https://github.com/awslabs/aws-sdk-kotlin/issues/1220) Refactor XML deserialization to handle flat collections

### Miscellaneous
* Refactor exception codegen to delegate message field to exception base class

## [1.0.15] - 02/19/2024

### Features
Expand Down
3 changes: 2 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ apiValidation {
"channel-benchmarks",
"http-benchmarks",
"serde-benchmarks",
"serde-benchmarks-codegen",
"serde-codegen-support",
"serde-tests",
"nullability-tests",
"paginator-tests",
"waiter-tests",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,7 @@ fun <T : AbstractCodeWriter<T>> T.callIf(test: Boolean, runnable: Runnable): T {
}
return this
}

/** Escape the [expressionStart] character to avoid problems during formatting */
fun <T : AbstractCodeWriter<T>> T.escape(text: String): String =
text.replace("$expressionStart", "$expressionStart$expressionStart")
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ object RuntimeTypes {
val Attributes = symbol("Attributes")
val attributesOf = symbol("attributesOf")
val AttributeKey = symbol("AttributeKey")
val createOrAppend = symbol("createOrAppend")
val get = symbol("get")
val mutableMultiMapOf = symbol("mutableMultiMapOf")
val putIfAbsent = symbol("putIfAbsent")
Expand Down Expand Up @@ -231,6 +232,7 @@ object RuntimeTypes {
val SerialKind = symbol("SerialKind")
val SerializationException = symbol("SerializationException")
val DeserializationException = symbol("DeserializationException")
val getOrDeserializeErr = symbol("getOrDeserializeErr")

val serializeStruct = symbol("serializeStruct")
val serializeList = symbol("serializeList")
Expand All @@ -242,6 +244,18 @@ object RuntimeTypes {
val asSdkSerializable = symbol("asSdkSerializable")
val field = symbol("field")

val parse = symbol("parse")
val parseInt = symbol("parseInt")
val parseShort = symbol("parseShort")
val parseLong = symbol("parseLong")
val parseFloat = symbol("parseFloat")
val parseDouble = symbol("parseDouble")
val parseByte = symbol("parseByte")
val parseBoolean = symbol("parseBoolean")
val parseTimestamp = symbol("parseTimestamp")
val parseBigInteger = symbol("parseBigInteger")
val parseBigDecimal = symbol("parseBigDecimal")

object SerdeJson : RuntimeTypePackage(KotlinDependency.SERDE_JSON) {
val JsonSerialName = symbol("JsonSerialName")
val JsonSerializer = symbol("JsonSerializer")
Expand All @@ -261,8 +275,13 @@ object RuntimeTypes {
val XmlMapName = symbol("XmlMapName")
val XmlError = symbol("XmlError")
val XmlSerializer = symbol("XmlSerializer")
val XmlDeserializer = symbol("XmlDeserializer")
val XmlUnwrappedOutput = symbol("XmlUnwrappedOutput")

val XmlTagReader = symbol("XmlTagReader")
val xmlStreamReader = symbol("xmlStreamReader")
val xmlRootTagReader = symbol("xmlTagReader")
val data = symbol("data")
val tryData = symbol("tryData")
}

object SerdeFormUrl : RuntimeTypePackage(KotlinDependency.SERDE_FORM_URL) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ object KotlinTypes {
val List: Symbol = stdlibSymbol("List")
val listOf: Symbol = stdlibSymbol("listOf")
val MutableList: Symbol = stdlibSymbol("MutableList")
val MutableMap: Symbol = stdlibSymbol("MutableMap")
val Map: Symbol = stdlibSymbol("Map")
val mutableListOf: Symbol = stdlibSymbol("mutableListOf")
val mutableMapOf: Symbol = stdlibSymbol("mutableMapOf")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,29 @@ class StructureGenerator(
// generate the immutable properties that are set from a builder
sortedMembers.forEach {
val (memberName, memberSymbol) = memberNameSymbolIndex[it]!!
// Throwable.message is handled special and passed as a constructor parameter to the parent exception base class
if (shape.isError && memberName == "message") {
val targetShape = model.expectShape(it.target)
if (!targetShape.isStringShape) {
throw CodegenException("message is a reserved name for exception types and cannot be used for any other property")
}
return@forEach
}
writer.renderMemberDocumentation(model, it)
writer.renderAnnotations(it)
renderImmutableProperty(it, memberName, memberSymbol)
renderImmutableProperty(memberName, memberSymbol)
}
}

private fun renderImmutableProperty(memberShape: MemberShape, memberName: String, memberSymbol: Symbol) {
// override Throwable's message property
val prefix = if (shape.isError && memberName == "message") {
val targetShape = model.expectShape(memberShape.target)
if (!targetShape.isStringShape) {
throw CodegenException("message is a reserved name for exception types and cannot be used for any other property")
}
"override"
} else {
"public"
}

private fun renderImmutableProperty(memberName: String, memberSymbol: Symbol) {
if (memberSymbol.isRequiredWithNoDefault) {
writer.write(
"""#1L val #2L: #3F = requireNotNull(builder.#2L) { "A non-null value must be provided for #2L" }""",
prefix,
"""public val #1L: #2F = requireNotNull(builder.#1L) { "A non-null value must be provided for #1L" }""",
memberName,
memberSymbol,
)
} else {
writer.write("#1L val #2L: #3F = builder.#2L", prefix, memberName, memberSymbol)
writer.write("public val #1L: #2F = builder.#1L", memberName, memberSymbol)
}
}

Expand Down Expand Up @@ -316,10 +312,16 @@ class StructureGenerator(
val exceptionBaseClass = ExceptionBaseClassGenerator.baseExceptionSymbol(ctx.settings)
writer.addImport(exceptionBaseClass)

val superParam = shape.members().find {
symbolProvider.toMemberName(it) == "message"
}?.let { "builder.message" } ?: ""

writer.openBlock(
"#L class #T private constructor(builder: Builder) : ${exceptionBaseClass.name}() {",
"#L class #T private constructor(builder: Builder) : #L(#L) {",
ctx.settings.api.visibility,
symbol,
exceptionBaseClass.name,
superParam,
)
.write("")
.call { renderImmutableProperties() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.auth

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.core.clientName
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.knowledge.AuthIndex
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,36 @@ private object Sigv4EndpointCustomization : EndpointCustomization {
// SigV4a requires SigV4 so SigV4 integration renders SigV4a auth scheme.
// See comment in example model: https://smithy.io/2.0/aws/aws-auth.html?highlight=sigv4#aws-auth-sigv4a-trait
private fun renderAuthSchemes(writer: KotlinWriter, authSchemes: Expression, expressionRenderer: ExpressionRenderer) {
writer.writeInline("#T to ", RuntimeTypes.SmithyClient.Endpoints.SigningContextAttributeKey)
writer.withBlock("listOf(", ")") {
authSchemes.toNode().expectArrayNode().forEach {
val scheme = it.expectObjectNode()
val schemeName = scheme.expectStringMember("name").value

val authFactoryFn = when (schemeName) {
"sigv4" -> RuntimeTypes.Auth.HttpAuthAws.sigV4
"sigv4a" -> RuntimeTypes.Auth.HttpAuthAws.sigV4A
else -> return@forEach
}

withBlock("#T(", "),", authFactoryFn) {
// we delegate back to the expression visitor for each of these fields because it's possible to
// encounter template strings throughout

writeInline("serviceName = ")
renderOrElse(expressionRenderer, scheme.getStringMember("signingName"), "null")

writeInline("disableDoubleUriEncode = ")
renderOrElse(expressionRenderer, scheme.getBooleanMember("disableDoubleEncoding"), "false")

renderFieldsForScheme(writer, scheme, expressionRenderer)
val schemes = authSchemes.toNode().expectArrayNode().filter {
val name = it.expectObjectNode().expectStringMember("name").value
name == "sigv4" || name == "sigv4a"
}.takeIf { it.isNotEmpty() }

schemes?.let {
writer.writeInline("#T to ", RuntimeTypes.SmithyClient.Endpoints.SigningContextAttributeKey)
writer.withBlock("listOf(", ")") {
schemes.forEach {
val scheme = it.expectObjectNode()
val schemeName = scheme.expectStringMember("name").value

val authFactoryFn = when (schemeName) {
"sigv4" -> RuntimeTypes.Auth.HttpAuthAws.sigV4
"sigv4a" -> RuntimeTypes.Auth.HttpAuthAws.sigV4A
else -> return@forEach
}

withBlock("#T(", "),", authFactoryFn) {
// we delegate back to the expression visitor for each of these fields because it's possible to
// encounter template strings throughout

writeInline("serviceName = ")
renderOrElse(expressionRenderer, scheme.getStringMember("signingName"), "null")

writeInline("disableDoubleUriEncode = ")
renderOrElse(expressionRenderer, scheme.getBooleanMember("disableDoubleEncoding"), "false")

renderFieldsForScheme(writer, scheme, expressionRenderer)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class DefaultEndpointProviderGenerator(

private val propertyRenderers = endpointCustomizations
.map { it.propertyRenderers }
.fold(mutableMapOf<String, EndpointPropertyRenderer>()) { acc, propRenderers ->
acc.putAll(propRenderers)
.fold(mutableMapOf<String, MutableList<EndpointPropertyRenderer>>()) { acc, propRenderers ->
propRenderers.forEach { (key, propRenderer) ->
acc[key] = acc.getOrDefault(key, mutableListOf()).also { it.add(propRenderer) }
}
acc
}

Expand Down Expand Up @@ -190,7 +192,9 @@ class DefaultEndpointProviderGenerator(

// caller has a chance to generate their own value for a recognized property
if (kStr in propertyRenderers) {
propertyRenderers[kStr]!!(writer, v, this@DefaultEndpointProviderGenerator)
propertyRenderers[kStr]!!.forEach { renderer ->
renderer(writer, v, this@DefaultEndpointProviderGenerator)
}
return@forEach
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class DefaultEndpointProviderTestGenerator(
private val endpointCustomizations = ctx.integrations.mapNotNull { it.customizeEndpointResolution(ctx) }
private val propertyRenderers = endpointCustomizations
.map { it.propertyRenderers }
.fold(mutableMapOf<String, EndpointPropertyRenderer>()) { acc, propRenderers ->
acc.putAll(propRenderers)
.fold(mutableMapOf<String, MutableList<EndpointPropertyRenderer>>()) { acc, propRenderers ->
propRenderers.forEach { (key, propRenderer) ->
acc[key] = acc.getOrDefault(key, mutableListOf()).also { it.add(propRenderer) }
}
acc
}

Expand Down Expand Up @@ -131,7 +133,9 @@ class DefaultEndpointProviderTestGenerator(
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Collections.attributesOf) {
endpoint.properties.entries.forEach { (k, v) ->
if (k in propertyRenderers) {
propertyRenderers[k]!!(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator)
propertyRenderers[k]!!.forEach { renderer ->
renderer(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator)
}
return@forEach
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import software.amazon.smithy.kotlin.codegen.lang.toEscapedLiteral
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.deserializerName
import software.amazon.smithy.kotlin.codegen.rendering.serde.formatInstant
import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstant
import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstantExpr
import software.amazon.smithy.kotlin.codegen.rendering.serde.serializerName
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.Model
Expand Down Expand Up @@ -813,14 +813,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
HttpBinding.Location.HEADER,
defaultTimestampFormat,
)
writer
.addImport(RuntimeTypes.Core.Instant)
.write(
"builder.#L = response.headers[#S]?.let { #L }",
memberName,
headerName,
parseInstant("it", tsFormat),
)
writer.write(
"builder.#L = response.headers[#S]?.let { #L }",
memberName,
headerName,
writer.parseInstantExpr("it", tsFormat),
)
}
is ListShape -> {
// member > boolean, number, string, or timestamp
Expand Down Expand Up @@ -849,8 +847,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
if (tsFormat == TimestampFormatTrait.Format.HTTP_DATE) {
splitFn = "splitHttpDateHeaderListValues"
}
writer.addImport(RuntimeTypes.Core.Instant)
parseInstant("it", tsFormat)
writer.parseInstantExpr("it", tsFormat)
}
is StringShape -> {
when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,12 @@ open class HttpProtocolClientGenerator(
return
}

val requestAlgorithmMember = ctx.model.getShape(input.get()).getOrNull()
?.members()
?.firstOrNull { it.memberName == httpChecksumTrait?.requestAlgorithmMember?.getOrNull() }

if (hasTrait<HttpChecksumRequiredTrait>() || httpChecksumTrait?.isRequestChecksumRequired == true) {
val interceptorSymbol = RuntimeTypes.HttpClient.Interceptors.Md5ChecksumInterceptor
val inputSymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(inputShape))

requestAlgorithmMember?.let {
writer.withBlock("op.interceptors.add(#T<#T> { ", "})", interceptorSymbol, inputSymbol) {
writer.write("it.#L?.value == null", requestAlgorithmMember.defaultName())
}
} ?: writer.write("op.interceptors.add(#T<#T>())", interceptorSymbol, inputSymbol)
writer.withBlock("op.interceptors.add(#T<#T> {", "})", interceptorSymbol, inputSymbol) {
writer.write("op.context.getOrNull(#T.ChecksumAlgorithm) == null", RuntimeTypes.HttpClient.Operation.HttpOperationContext)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,5 +605,3 @@ open class DeserializeStructGenerator(
}
}
}

private fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else ""
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolReference
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.SymbolRenderer
import software.amazon.smithy.kotlin.codegen.core.defaultName
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.core.mangledSuffix
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.model.Model
Expand Down Expand Up @@ -216,11 +215,31 @@ fun formatInstant(paramName: String, tsFmt: TimestampFormatTrait.Format, forceSt
* @param paramName The name of the local identifier to convert to an `Instant`
* @param tsFmt The timestamp format [paramName] is expected to be converted from
*/
fun parseInstant(paramName: String, tsFmt: TimestampFormatTrait.Format): String = when (tsFmt) {
TimestampFormatTrait.Format.EPOCH_SECONDS -> "Instant.fromEpochSeconds($paramName)"
TimestampFormatTrait.Format.DATE_TIME -> "Instant.fromIso8601($paramName)"
TimestampFormatTrait.Format.HTTP_DATE -> "Instant.fromRfc5322($paramName)"
else -> throw CodegenException("unknown timestamp format: $tsFmt")
fun KotlinWriter.parseInstantExpr(paramName: String, tsFmt: TimestampFormatTrait.Format): String {
val fn = when (tsFmt) {
TimestampFormatTrait.Format.EPOCH_SECONDS -> "fromEpochSeconds"
TimestampFormatTrait.Format.DATE_TIME -> "fromIso8601"
TimestampFormatTrait.Format.HTTP_DATE -> "fromRfc5322"
else -> throw CodegenException("unknown timestamp format: $tsFmt")
}
return format("#T.#L(#L)", RuntimeTypes.Core.Instant, fn, paramName)
}

fun TimestampFormatTrait.Format.toRuntimeEnum(): String = when (this) {
TimestampFormatTrait.Format.EPOCH_SECONDS -> "TimestampFormat.EPOCH_SECONDS"
TimestampFormatTrait.Format.DATE_TIME -> "TimestampFormat.ISO_8601"
TimestampFormatTrait.Format.HTTP_DATE -> "TimestampFormat.RFC_5322"
else -> throw CodegenException("unknown timestamp format: $this")
}

fun TimestampFormatTrait.Format.toRuntimeEnum(writer: KotlinWriter): String {
val enum = when (this) {
TimestampFormatTrait.Format.EPOCH_SECONDS -> "EPOCH_SECONDS"
TimestampFormatTrait.Format.DATE_TIME -> "ISO_8601"
TimestampFormatTrait.Format.HTTP_DATE -> "RFC_5322"
else -> throw CodegenException("unknown timestamp format: $this")
}
return writer.format("#T.#L", RuntimeTypes.Core.TimestampFormat, enum)
}

/**
Expand Down Expand Up @@ -289,3 +308,5 @@ internal fun Shape.childShape(model: Model): Shape? = when (this) {
is MapShape -> model.expectShape(this.value.target)
else -> null
}

internal fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else ""
Loading

0 comments on commit 0afa170

Please sign in to comment.