Skip to content

Commit

Permalink
add support for nullable struct members when generating AWS SDKs (#2916)
Browse files Browse the repository at this point in the history
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
smithy-rs#1767 aws-sdk-rust#536

## Description
<!--- Describe your changes in detail -->
This PR adds support for nullability i.e. much less unwraps will be
required when using the AWS SDK. For generic clients, this new behavior
can be enabled in codegen by setting `nullabilityCheckMode: "Client"` in
their codegen config:
```
      "plugins": {
        "rust-client-codegen": {
          "codegen": {
            "includeFluentClient": false,
            "nullabilityCheckMode": "CLIENT_CAREFUL"
          },
     }
```


## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Ran existing tests

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: John DiSanti <[email protected]>
Co-authored-by: Russell Cohen <[email protected]>
  • Loading branch information
3 people authored Sep 21, 2023
1 parent 1771dbd commit 1331dc5
Show file tree
Hide file tree
Showing 49 changed files with 1,110 additions and 181 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ references = ["smithy-rs#2911"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "Velfi"

[[aws-sdk-rust]]
message = "Struct members modeled as required are no longer wrapped in `Option`s [when possible](https://smithy.io/2.0/spec/aggregate-types.html#structure-member-optionality). For upgrade guidance and more info, see [here](https://github.com/awslabs/smithy-rs/discussions/2929)."
references = ["smithy-rs#2916", "aws-sdk-rust#536"]
meta = { "breaking" = true, "tada" = true, "bug" = false }
author = "Velfi"

[[smithy-rs]]
message = """
Support for Smithy IDLv2 nullability is now enabled by default. You can maintain the old behavior by setting `nullabilityCheckMode: "CLIENT_ZERO_VALUE_V1" in your codegen config.
For upgrade guidance and more info, see [here](https://github.com/awslabs/smithy-rs/discussions/2929).
"""
references = ["smithy-rs#2916", "smithy-rs#1767"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client"}
author = "Velfi"

[[aws-sdk-rust]]
message = """
All versions of SigningParams have been updated to contain an [`Identity`](https://docs.rs/aws-smithy-runtime-api/latest/aws_smithy_runtime_api/client/identity/struct.Identity.html)
Expand Down
17 changes: 4 additions & 13 deletions aws/rust-runtime/aws-config/src/sts/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,15 @@ pub(crate) fn into_credentials(
) -> provider::Result {
let sts_credentials = sts_credentials
.ok_or_else(|| CredentialsError::unhandled("STS credentials must be defined"))?;
let expiration = SystemTime::try_from(
sts_credentials
.expiration
.ok_or_else(|| CredentialsError::unhandled("missing expiration"))?,
)
.map_err(|_| {
let expiration = SystemTime::try_from(sts_credentials.expiration).map_err(|_| {
CredentialsError::unhandled(
"credential expiration time cannot be represented by a SystemTime",
)
})?;
Ok(AwsCredentials::new(
sts_credentials
.access_key_id
.ok_or_else(|| CredentialsError::unhandled("access key id missing from result"))?,
sts_credentials
.secret_access_key
.ok_or_else(|| CredentialsError::unhandled("secret access token missing"))?,
sts_credentials.session_token,
sts_credentials.access_key_id,
sts_credentials.secret_access_key,
Some(sts_credentials.session_token),
Some(expiration),
provider_name,
))
Expand Down
52 changes: 29 additions & 23 deletions aws/sdk-adhoc-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,46 @@ dependencies {
implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-model:$smithyVersion")
}

val allCodegenTests = listOf(
CodegenTest(
"com.amazonaws.apigateway#BackplaneControlService",
"apigateway",
imports = listOf("models/apigateway-rules.smithy"),
fun getNullabilityCheckMode(): String = properties.get("nullability.check.mode") ?: "CLIENT_CAREFUL"

fun baseTest(service: String, module: String, imports: List<String> = listOf()): CodegenTest {
return CodegenTest(
service = service,
module = module,
imports = imports,
extraCodegenConfig = """
"includeFluentClient": false,
"nullabilityCheckMode": "${getNullabilityCheckMode()}"
""",
extraConfig = """
,
"codegen": {
"includeFluentClient": false
},
"customizationConfig": {
, "customizationConfig": {
"awsSdk": {
"generateReadme": false
"generateReadme": false,
"requireEndpointResolver": false
}
}
""",
)
}

val allCodegenTests = listOf(
baseTest(
"com.amazonaws.apigateway#BackplaneControlService",
"apigateway",
imports = listOf("models/apigateway-rules.smithy"),
),
CodegenTest(
baseTest(
"com.amazonaws.testservice#TestService",
"endpoint-test-service",
imports = listOf("models/single-static-endpoint.smithy"),
extraConfig = """
,
"codegen": {
"includeFluentClient": false
},
"customizationConfig": {
"awsSdk": {
"generateReadme": false
}
}
""",
),
baseTest(
"com.amazonaws.testservice#RequiredValues",
"required-values",
imports = listOf("models/required-value-test.smithy"),
),
)

Expand Down
28 changes: 28 additions & 0 deletions aws/sdk-adhoc-test/models/required-value-test.smithy
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
$version: "1.0"

namespace com.amazonaws.testservice

use aws.api#service
use aws.protocols#restJson1

@restJson1
@title("Test Service")
@service(sdkId: "Test")
@aws.auth#sigv4(name: "test-service")
service RequiredValues {
operations: [TestOperation]
}

@http(method: "GET", uri: "/")
operation TestOperation {
errors: [Error]
}

@error("client")
structure Error {
@required
requestId: String

@required
message: String
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
Expand All @@ -19,13 +21,15 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplSection
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

Expand Down Expand Up @@ -72,6 +76,10 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
}
}

open fun asMemberShape(container: StructureShape): MemberShape? {
return container.members().firstOrNull { member -> member.memberName.lowercase() == "requestid" }
}

private inner class RequestIdOperationCustomization(private val codegenContext: ClientCodegenContext) :
OperationCustomization() {
override fun section(section: OperationSection): Writable = writable {
Expand All @@ -82,19 +90,22 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
"apply_to_error" to applyToError(codegenContext),
)
}

is OperationSection.MutateOutput -> {
rust(
"output._set_$fieldName(#T::$accessorFunctionName(${section.responseHeadersName}).map(str::to_string));",
accessorTrait(codegenContext),
)
}

is OperationSection.BeforeParseResponse -> {
rustTemplate(
"#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));",
"tracing" to RuntimeType.Tracing,
"trait" to accessorTrait(codegenContext),
)
}

else -> {}
}
}
Expand Down Expand Up @@ -123,8 +134,17 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") {
rustBlock("match self") {
section.allErrors.forEach { error ->
val optional = asMemberShape(error)?.let { member ->
codegenContext.symbolProvider.toSymbol(member).isOptional()
} ?: true
val wrapped = writable {
when (optional) {
false -> rustTemplate("#{Some}(e.$accessorFunctionName())", *preludeScope)
true -> rustTemplate("e.$accessorFunctionName()")
}
}
val sym = codegenContext.symbolProvider.toSymbol(error)
rust("Self::${sym.name}(e) => e.$accessorFunctionName(),")
rust("Self::${sym.name}(e) => #T,", wrapped)
}
rust("Self::Unhandled(e) => e.$accessorFunctionName(),")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package software.amazon.smithy.rustsdk.customize.s3

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rustsdk.BaseRequestIdDecorator
Expand All @@ -17,6 +19,10 @@ class S3ExtendedRequestIdDecorator : BaseRequestIdDecorator() {
override val fieldName: String = "extended_request_id"
override val accessorFunctionName: String = "extended_request_id"

override fun asMemberShape(container: StructureShape): MemberShape? {
return null
}

private val requestIdModule: RuntimeType =
RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("s3_request_id"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ class TimestreamDecorator : ClientCodegenDecorator {
client.describe_endpoints().send().await.map_err(|e| {
#{ResolveEndpointError}::from_source("failed to call describe_endpoints", e)
})?;
let endpoint = describe_endpoints.endpoints().unwrap().get(0).unwrap();
let endpoint = describe_endpoints.endpoints().get(0).unwrap();
let expiry = client.config().time_source().expect("checked when ep discovery was enabled").now()
+ #{Duration}::from_secs(endpoint.cache_period_in_minutes() as u64 * 60);
Ok((
#{Endpoint}::builder()
.url(format!("https://{}", endpoint.address().unwrap()))
.url(format!("https://{}", endpoint.address()))
.build(),
expiry,
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@
package software.amazon.smithy.rustsdk.customize.ec2

import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.CsvSource
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.lookup

internal class EC2MakePrimitivesOptionalTest {
@Test
fun `primitive shapes are boxed`() {
@ParameterizedTest
@CsvSource(
"CLIENT",
"CLIENT_CAREFUL",
"CLIENT_ZERO_VALUE_V1",
"CLIENT_ZERO_VALUE_V1_NO_INPUT",
)
fun `primitive shapes are boxed`(nullabilityCheckMode: NullableIndex.CheckMode) {
val baseModel = """
namespace test
structure Primitives {
Expand All @@ -36,7 +43,7 @@ internal class EC2MakePrimitivesOptionalTest {
val nullableIndex = NullableIndex(model)
val struct = model.lookup<StructureShape>("test#Primitives")
struct.members().forEach {
nullableIndex.isMemberNullable(it, NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1) shouldBe true
nullableIndex.isMemberNullable(it, nullabilityCheckMode) shouldBe true
}
}
}
4 changes: 3 additions & 1 deletion aws/sdk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ val crateVersioner by lazy { aws.sdk.CrateVersioner.defaultFor(rootProject, prop

fun getRustMSRV(): String = properties.get("rust.msrv") ?: throw Exception("Rust MSRV missing")
fun getPreviousReleaseVersionManifestPath(): String? = properties.get("aws.sdk.previous.release.versions.manifest")
fun getNullabilityCheckMode(): String = properties.get("nullability.check.mode") ?: "CLIENT_CAREFUL"

fun loadServiceMembership(): Membership {
val membershipOverride = properties.get("aws.services")?.let { parseMembership(it) }
Expand Down Expand Up @@ -103,7 +104,8 @@ fun generateSmithyBuild(services: AwsServices): String {
"renameErrors": false,
"debugMode": $debugMode,
"eventStreamAllowList": [$eventStreamAllowListMembers],
"enableUserConfigurableRuntimePlugins": false
"enableUserConfigurableRuntimePlugins": false,
"nullabilityCheckMode": "${getNullabilityCheckMode()}"
},
"service": "${service.service}",
"module": "$moduleName",
Expand Down
2 changes: 1 addition & 1 deletion aws/sdk/integration-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ members = [
"s3",
"s3control",
"sts",
"transcribestreaming",
"timestreamquery",
"transcribestreaming",
"webassembly",
]
15 changes: 10 additions & 5 deletions aws/sdk/integration-tests/dynamodb/tests/movies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,36 @@ async fn create_table(client: &Client, table_name: &str) {
KeySchemaElement::builder()
.attribute_name("year")
.key_type(KeyType::Hash)
.build(),
.build()
.unwrap(),
)
.key_schema(
KeySchemaElement::builder()
.attribute_name("title")
.key_type(KeyType::Range)
.build(),
.build()
.unwrap(),
)
.attribute_definitions(
AttributeDefinition::builder()
.attribute_name("year")
.attribute_type(ScalarAttributeType::N)
.build(),
.build()
.unwrap(),
)
.attribute_definitions(
AttributeDefinition::builder()
.attribute_name("title")
.attribute_type(ScalarAttributeType::S)
.build(),
.build()
.unwrap(),
)
.provisioned_throughput(
ProvisionedThroughput::builder()
.read_capacity_units(10)
.write_capacity_units(10)
.build(),
.build()
.unwrap(),
)
.send()
.await
Expand Down
Loading

0 comments on commit 1331dc5

Please sign in to comment.