Unverified Commit 1331dc54 authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

add support for nullable struct members when generating AWS SDKs (#2916)



## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
smithy-rs#1767 aws-sdk-rust#536

## Description
<!--- Describe your changes in detail -->
This PR adds support for nullability i.e. much less unwraps will be
required when using the AWS SDK. For generic clients, this new behavior
can be enabled in codegen by setting `nullabilityCheckMode: "Client"` in
their codegen config:
```
      "plugins": {
        "rust-client-codegen": {
          "codegen": {
            "includeFluentClient": false,
            "nullabilityCheckMode": "CLIENT_CAREFUL"
          },
     }
```


## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Ran existing tests

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>
Co-authored-by: default avatarRussell Cohen <rcoh@amazon.com>
parent 1771dbdc
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -34,6 +34,21 @@ references = ["smithy-rs#2911"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "Velfi"

[[aws-sdk-rust]]
message = "Struct members modeled as required are no longer wrapped in `Option`s [when possible](https://smithy.io/2.0/spec/aggregate-types.html#structure-member-optionality). For upgrade guidance and more info, see [here](https://github.com/awslabs/smithy-rs/discussions/2929)."
references = ["smithy-rs#2916", "aws-sdk-rust#536"]
meta = { "breaking" = true, "tada" = true, "bug" = false }
author = "Velfi"

[[smithy-rs]]
message = """
Support for Smithy IDLv2 nullability is now enabled by default. You can maintain the old behavior by setting `nullabilityCheckMode: "CLIENT_ZERO_VALUE_V1" in your codegen config.
For upgrade guidance and more info, see [here](https://github.com/awslabs/smithy-rs/discussions/2929).
"""
references = ["smithy-rs#2916", "smithy-rs#1767"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client"}
author = "Velfi"

[[aws-sdk-rust]]
message = """
All versions of SigningParams have been updated to contain an [`Identity`](https://docs.rs/aws-smithy-runtime-api/latest/aws_smithy_runtime_api/client/identity/struct.Identity.html)
+4 −13
Original line number Diff line number Diff line
@@ -17,24 +17,15 @@ pub(crate) fn into_credentials(
) -> provider::Result {
    let sts_credentials = sts_credentials
        .ok_or_else(|| CredentialsError::unhandled("STS credentials must be defined"))?;
    let expiration = SystemTime::try_from(
        sts_credentials
            .expiration
            .ok_or_else(|| CredentialsError::unhandled("missing expiration"))?,
    )
    .map_err(|_| {
    let expiration = SystemTime::try_from(sts_credentials.expiration).map_err(|_| {
        CredentialsError::unhandled(
            "credential expiration time cannot be represented by a SystemTime",
        )
    })?;
    Ok(AwsCredentials::new(
        sts_credentials
            .access_key_id
            .ok_or_else(|| CredentialsError::unhandled("access key id missing from result"))?,
        sts_credentials
            .secret_access_key
            .ok_or_else(|| CredentialsError::unhandled("secret access token missing"))?,
        sts_credentials.session_token,
        sts_credentials.access_key_id,
        sts_credentials.secret_access_key,
        Some(sts_credentials.session_token),
        Some(expiration),
        provider_name,
    ))
+29 −23
Original line number Diff line number Diff line
@@ -35,40 +35,46 @@ dependencies {
    implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
    implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-model:$smithyVersion")
}

val allCodegenTests = listOf(
    CodegenTest(
        "com.amazonaws.apigateway#BackplaneControlService",
        "apigateway",
        imports = listOf("models/apigateway-rules.smithy"),
fun getNullabilityCheckMode(): String = properties.get("nullability.check.mode") ?: "CLIENT_CAREFUL"

fun baseTest(service: String, module: String, imports: List<String> = listOf()): CodegenTest {
    return CodegenTest(
        service = service,
        module = module,
        imports = imports,
        extraCodegenConfig = """
            "includeFluentClient": false,
            "nullabilityCheckMode": "${getNullabilityCheckMode()}"
        """,
        extraConfig = """
            ,
            "codegen": {
                "includeFluentClient": false
            },
            "customizationConfig": {
            , "customizationConfig": {
                "awsSdk": {
                    "generateReadme": false
                    "generateReadme": false,
                    "requireEndpointResolver": false
                }
            }
        """,
    )
}

val allCodegenTests = listOf(
    baseTest(
        "com.amazonaws.apigateway#BackplaneControlService",
        "apigateway",
        imports = listOf("models/apigateway-rules.smithy"),
    ),
    CodegenTest(
    baseTest(
        "com.amazonaws.testservice#TestService",
        "endpoint-test-service",
        imports = listOf("models/single-static-endpoint.smithy"),
        extraConfig = """
            ,
            "codegen": {
                "includeFluentClient": false
            },
            "customizationConfig": {
                "awsSdk": {
                    "generateReadme": false
                }
            }
        """,
    ),
    baseTest(
        "com.amazonaws.testservice#RequiredValues",
        "required-values",
        imports = listOf("models/required-value-test.smithy"),
    ),
)

+28 −0
Original line number Diff line number Diff line
$version: "1.0"

namespace com.amazonaws.testservice

use aws.api#service
use aws.protocols#restJson1

@restJson1
@title("Test Service")
@service(sdkId: "Test")
@aws.auth#sigv4(name: "test-service")
service RequiredValues {
    operations: [TestOperation]
}

@http(method: "GET", uri: "/")
operation TestOperation {
    errors: [Error]
}

@error("client")
structure Error {
    @required
    requestId: String

    @required
    message: String
}
+21 −1
Original line number Diff line number Diff line
@@ -5,7 +5,9 @@

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
@@ -19,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderSection
@@ -26,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCusto
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplSection
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

@@ -72,6 +76,10 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
        }
    }

    open fun asMemberShape(container: StructureShape): MemberShape? {
        return container.members().firstOrNull { member -> member.memberName.lowercase() == "requestid" }
    }

    private inner class RequestIdOperationCustomization(private val codegenContext: ClientCodegenContext) :
        OperationCustomization() {
        override fun section(section: OperationSection): Writable = writable {
@@ -82,12 +90,14 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
                        "apply_to_error" to applyToError(codegenContext),
                    )
                }

                is OperationSection.MutateOutput -> {
                    rust(
                        "output._set_$fieldName(#T::$accessorFunctionName(${section.responseHeadersName}).map(str::to_string));",
                        accessorTrait(codegenContext),
                    )
                }

                is OperationSection.BeforeParseResponse -> {
                    rustTemplate(
                        "#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));",
@@ -95,6 +105,7 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
                        "trait" to accessorTrait(codegenContext),
                    )
                }

                else -> {}
            }
        }
@@ -123,8 +134,17 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
                        rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") {
                            rustBlock("match self") {
                                section.allErrors.forEach { error ->
                                    val optional = asMemberShape(error)?.let { member ->
                                        codegenContext.symbolProvider.toSymbol(member).isOptional()
                                    } ?: true
                                    val wrapped = writable {
                                        when (optional) {
                                            false -> rustTemplate("#{Some}(e.$accessorFunctionName())", *preludeScope)
                                            true -> rustTemplate("e.$accessorFunctionName()")
                                        }
                                    }
                                    val sym = codegenContext.symbolProvider.toSymbol(error)
                                    rust("Self::${sym.name}(e) => e.$accessorFunctionName(),")
                                    rust("Self::${sym.name}(e) => #T,", wrapped)
                                }
                                rust("Self::Unhandled(e) => e.$accessorFunctionName(),")
                            }
Loading