From b43905eabc86ee6fb81ed8f2465c245d0116dd3b Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 15 Nov 2022 15:06:23 +0100 Subject: [PATCH] Builders of builders (#1342) This patchset, affectionately called "Builders of builders", lays the groundwork for fully implementing [Constraint traits] in the server SDK generator. [The RFC] illustrates what the end goal looks like, and is recommended prerrequisite reading to understanding this cover letter. This commit makes the sever deserializers work with _unconstrained_ types during request parsing, and only after the entire request is parsed are constraints enforced. Values for a constrained shape are stored in the correspondingly unconstrained shape, and right before the operation input is built, the values are constrained via a `TryFrom for ConstrainedShape` implementation that all unconstrained types enjoy. The service owner only interacts with constrained types, the unconstrained ones are `pub(crate)` and for use by the framework only. In the case of structure shapes, the corresponding unconstrained shape is their builders. This is what gives this commit its title: during request deserialization, arbitrarily nested structures are parsed into _builders that hold builders_. Builders keep track of whether their members are constrained or not by storing its members in a `MaybeConstrained` [Cow](https://doc.rust-lang.org/std/borrow/enum.Cow.html)-like `enum` type: ```rust pub(crate) trait Constrained { type Unconstrained; } #[derive(Debug, Clone)] pub(crate) enum MaybeConstrained { Constrained(T), Unconstrained(T::Unconstrained), } ``` Consult the documentation for the generator in `ServerBuilderGenerator.kt` for more implementation details and for the differences with the builder types the server has been using, generated by `BuilderGenerator.kt`, which after this commit are exclusively used by clients. Other shape types, when they are constrained, get generated with their correspondingly unconstrained counterparts. Their Rust types are essentially wrapper newtypes, and similarly enjoy `TryFrom` converters to constrain them. See the documentation in `UnconstrainedShapeSymbolProvider.kt` for details and an example. When constraints are not met, the converters raise _constraint violations_. These are currently `enum`s holding the _first_ encountered violation. When a shape is _transitively but not directly_ constrained, newtype wrappers are also generated to hold the nested constrained values. To illustrate their need, consider for example a list of `@length` strings. Upon request parsing, the server deserializers need a way to hold a vector of unconstrained regular `String`s, and a vector of the constrained newtyped `LengthString`s. The former requirement is already satisfied by the generated unconstrained types, but for the latter we need to generate an intermediate constrained `ListUnconstrained(Vec)` newtype that will eventually be unwrapped into the `Vec` the user is handed. This is the purpose of the `PubCrate*` generators: consult the documentation in `PubCrateConstrainedShapeSymbolProvider.kt`, `PubCrateConstrainedCollectionGenerator.kt`, and `PubCrateConstrainedMapGenerator.kt` for more details. As their name implies, all of these types are `pub(crate)`, and the user never interacts with them. For users that would not like their application code to make use of constrained newtypes for their modeled constrained shapes, a `codegenConfig` setting `publicConstrainedTypes` has been added. They opt out of these by setting it to `false`, and use the inner types directly: the framework will still enforce constraints upon request deserialization, but once execution enters an application handler, the user is on their own to honor (or not) the modeled constraints. No user interest has been expressed for this feature, but I expect we will see demand for it. Moreover, it's a good stepping stone for users that want their services to honor constraints, but are not ready to migrate their application code to constrained newtypes. As for how it's implemented, several parts of the codebase inspect the setting and toggle or tweak generators based on its value. Perhaps the only detail worth mentioning in this commit message is that the structure shape builder types are generated by a much simpler and entirely different generator, in `ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt`. Note that this builder _does not_ enforce constraints, except for `required` and `enum`, which are always (and already) baked into the type system. When `publicConstrainedTypes` is disabled, this is the builder that end users interact with, while the one that enforces all constraints, `ServerBuilderGenerator`, is now generated as `pub(crate)` and left for exclusive use by the deserializers. See the relevant documentation for the details and differences among the builder types. As proof that these foundations are sound, this commit also implements the `length` constraint trait on Smithy map and string shapes. Likewise, the `required` and `enum` traits, which were already baked in the generated types as non-`Option`al and `enum` Rust types, respectively, are now also treated like the rest of constraint traits upon request deserialization. See the documentation in `ConstrainedMapGenerator.kt` and `ConstrainedStringGenerator.kt` for details. The rest of the constraint traits and target shapes are left as an exercise to the reader, but hopefully the reader has been convinced that all of them can be enforced within this framework, paving the way for straightforward implementations. The diff is already large as it is. Any reamining work is being tracked in #1401; this and other issues are referenced in the code as TODOs. So as to not give users the impression that the server SDK plugin _fully_ honors constraints as per the Smithy specification, a validator in `ValidateUnsupportedConstraintsAreNotUsed.kt` has been added. This traverses the model and detects yet-unsupported parts of the spec, aborting code generation and printing informative warnings referencing the relevant tracking issues. This is a regression in that models that used constraint traits previously built fine (even though the constraint traits were silently not being honored), and now they will break. To unblock generation of these models, this commit adds another `codegenConfig` setting, `ignoreUnsupportedConstraints`, that users can opt into. Closes #1714. Testing ------- Several Kotlin unit test classes exercising the finer details of the added generators and symbol providers have been added. However, the best way to test is to generate server SDKs from models making use of constraint traits. The biggest assurances come from the newly added `constraints.smithy` model, an "academic" service that _heavily_ exercises constraint traits. It's a `restJson1` service that also tests binding of constrained shapes to different parts of the HTTP message. Deeply nested hierarchies and recursive shapes are also featured. ```sh ./gradlew -P modules='constraints' codegen-server-test:build ``` This model is _additionally_ generated in CI with the `publicConstrainedTypes` setting disabled: ```sh ./gradlew -P modules='constraints_without_public_constrained_types' codegen-server-test:build `````` Similarly, models using currently unsupported constraints are now being generated with the `ignoreUnsupportedConstraints` setting enabled. See `codegen-server-test/build.gradle.kts` for more details. [Constraint traits]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html [The RFC]: https://github.com/awslabs/smithy-rs/pull/1199 --- CHANGELOG.next.toml | 125 ++++ .../smithy/generators/ClientInstantiator.kt | 14 + .../protocols/HttpBoundProtocolGenerator.kt | 4 +- .../common-test-models/constraints.smithy | 453 +++++++++++++++ codegen-core/common-test-models/misc.smithy | 5 +- .../naming-obstacle-course-ops.smithy | 6 +- .../common-test-models/pokemon-common.smithy | 4 +- .../common-test-models/pokemon.smithy | 8 +- .../rest-json-extras.smithy | 10 +- codegen-core/common-test-models/simple.smithy | 4 +- .../codegen/core/rustlang/CargoDependency.kt | 3 + .../rust/codegen/core/rustlang/RustType.kt | 8 + .../codegen/core/smithy/CodegenDelegator.kt | 19 + .../rust/codegen/core/smithy/RuntimeType.kt | 5 + .../rust/codegen/core/smithy/SymbolVisitor.kt | 117 ++-- .../smithy/generators/BuilderGenerator.kt | 30 +- .../core/smithy/generators/EnumGenerator.kt | 4 +- .../core/smithy/generators/Instantiator.kt | 28 +- .../smithy/generators/StructureGenerator.kt | 21 +- .../smithy/generators/error/ErrorGenerator.kt | 25 +- .../generators/http/HttpBindingGenerator.kt | 126 ++-- .../http/RequestBindingGenerator.kt | 11 +- .../http/ResponseBindingGenerator.kt | 12 +- .../protocol/MakeOperationGenerator.kt | 2 + .../codegen/core/smithy/protocols/AwsJson.kt | 11 +- .../codegen/core/smithy/protocols/AwsQuery.kt | 10 +- .../codegen/core/smithy/protocols/Ec2Query.kt | 10 +- .../codegen/core/smithy/protocols/RestJson.kt | 9 +- .../codegen/core/smithy/protocols/RestXml.kt | 7 +- .../parse/AwsQueryParserGenerator.kt | 4 + .../parse/Ec2QueryParserGenerator.kt | 4 + .../parse/EventStreamUnmarshallerGenerator.kt | 98 ++-- .../protocols/parse/JsonParserGenerator.kt | 171 ++++-- .../protocols/parse/RestXmlParserGenerator.kt | 4 + .../parse/XmlBindingTraitParserGenerator.kt | 13 +- .../serialize/JsonSerializerGenerator.kt | 43 +- .../XmlBindingTraitSerializerGenerator.kt | 13 +- .../smithy/generators/InstantiatorTest.kt | 82 ++- .../parse/AwsQueryParserGeneratorTest.kt | 7 +- .../parse/Ec2QueryParserGeneratorTest.kt | 7 +- .../parse/JsonParserGeneratorTest.kt | 6 + .../XmlBindingTraitParserGeneratorTest.kt | 2 + codegen-server-test/build.gradle.kts | 16 +- .../smithy/PythonCodegenServerPlugin.kt | 10 +- .../smithy/PythonServerCodegenVisitor.kt | 63 +- .../generators/PythonServerEnumGenerator.kt | 19 +- .../smithy/ConstrainedShapeSymbolProvider.kt | 109 ++++ .../ConstraintViolationSymbolProvider.kt | 123 ++++ .../rust/codegen/server/smithy/Constraints.kt | 133 +++++ .../LengthTraitValidationErrorMessage.kt | 21 + .../PubCrateConstrainedShapeSymbolProvider.kt | 124 ++++ ...bCrateConstraintViolationSymbolProvider.kt | 37 ++ .../server/smithy/RustCodegenServerPlugin.kt | 21 +- .../server/smithy/ServerCodegenContext.kt | 4 + .../server/smithy/ServerCodegenVisitor.kt | 249 +++++++- .../server/smithy/ServerRuntimeType.kt | 3 - .../server/smithy/ServerRustSettings.kt | 29 +- .../server/smithy/ServerSymbolProviders.kt | 65 +++ .../UnconstrainedShapeSymbolProvider.kt | 166 ++++++ .../smithy/ValidateUnsupportedConstraints.kt | 248 ++++++++ ...BeforeIteratingOverMapJsonCustomization.kt | 38 ++ .../generators/ConstrainedMapGenerator.kt | 160 ++++++ .../ConstrainedMapGeneratorCommon.kt | 22 + .../ConstrainedShapeGeneratorCommon.kt | 24 + .../generators/ConstrainedStringGenerator.kt | 183 ++++++ .../ConstrainedTraitForEnumGenerator.kt | 51 ++ .../MapConstraintViolationGenerator.kt | 121 ++++ .../PubCrateConstrainedCollectionGenerator.kt | 148 +++++ .../PubCrateConstrainedMapGenerator.kt | 142 +++++ .../ServerBuilderConstraintViolations.kt | 218 +++++++ .../generators/ServerBuilderGenerator.kt | 542 ++++++++++++++++++ ...rGeneratorWithoutPublicConstrainedTypes.kt | 238 ++++++++ .../smithy/generators/ServerBuilderSymbol.kt | 35 ++ .../smithy/generators/ServerEnumGenerator.kt | 108 ++-- .../smithy/generators/ServerInstantiator.kt | 23 + .../generators/ServerOperationGenerator.kt | 3 +- .../ServerStructureConstrainedTraitImpl.kt | 32 ++ .../UnconstrainedCollectionGenerator.kt | 139 +++++ .../generators/UnconstrainedMapGenerator.kt | 207 +++++++ .../generators/UnconstrainedUnionGenerator.kt | 248 ++++++++ .../http/ServerRequestBindingGenerator.kt | 53 +- .../http/ServerResponseBindingGenerator.kt | 52 +- .../generators/protocol/ServerProtocol.kt | 96 +++- .../protocol/ServerProtocolTestGenerator.kt | 30 +- .../server/smithy/protocols/ServerAwsJson.kt | 16 +- .../ServerHttpBoundProtocolGenerator.kt | 219 ++++--- ...erRestJsonFactory.kt => ServerRestJson.kt} | 17 + .../smithy/testutil/ServerTestHelpers.kt | 74 ++- ...hapeReachableFromOperationInputTagTrait.kt | 42 ++ ...ToConstrainedOperationInputsInAllowList.kt | 74 +++ .../RemoveEbsModelValidationException.kt | 38 ++ ...ShapesReachableFromOperationInputTagger.kt | 72 +++ .../ConstrainedShapeSymbolProviderTest.kt | 98 ++++ .../codegen/server/smithy/ConstraintsTest.kt | 135 +++++ ...CrateConstrainedShapeSymbolProviderTest.kt | 113 ++++ .../UnconstrainedShapeSymbolProviderTest.kt | 103 ++++ ...ateUnsupportedConstraintsAreNotUsedTest.kt | 254 ++++++++ .../generators/ConstrainedMapGeneratorTest.kt | 158 +++++ .../ConstrainedStringGeneratorTest.kt | 179 ++++++ .../ServerCombinedErrorGeneratorTest.kt | 15 +- .../generators/ServerEnumGeneratorTest.kt | 28 +- .../generators/ServerInstantiatorTest.kt | 9 +- .../UnconstrainedCollectionGeneratorTest.kt | 124 ++++ .../UnconstrainedMapGeneratorTest.kt | 164 ++++++ .../UnconstrainedUnionGeneratorTest.kt | 102 ++++ .../smithy/protocols/EventStreamTestTools.kt | 6 +- .../EventStreamUnmarshallerGeneratorTest.kt | 11 +- .../aws-smithy-http-server/Cargo.toml | 3 +- .../examples/pokemon-service/src/lib.rs | 17 +- .../aws-smithy-http-server/src/rejection.rs | 16 +- .../src/runtime_error.rs | 70 ++- rust-runtime/inlineable/src/constrained.rs | 15 + rust-runtime/inlineable/src/lib.rs | 2 + 113 files changed, 7438 insertions(+), 634 deletions(-) create mode 100644 codegen-core/common-test-models/constraints.smithy create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt rename codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/{ServerRestJsonFactory.kt => ServerRestJson.kt} (63%) create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt create mode 100644 codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt create mode 100644 rust-runtime/inlineable/src/constrained.rs diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index cf1587192..f97ce2166 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -214,3 +214,128 @@ message = "Several breaking changes have been made to errors. See [the upgrade g references = ["smithy-rs#1926", "smithy-rs#1819"] meta = { "breaking" = true, "tada" = false, "bug" = false } author = "jdisanti" + +[[smithy-rs]] +message = """ +[Constraint traits](https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html) in server SDKs are beginning to be supported. The following are now supported: + +* The `length` trait on `string` shapes. +* The `length` trait on `map` shapes. + +Upon receiving a request that violates the modeled constraints, the server SDK will reject it with a message indicating why. + +Unsupported (constraint trait, target shape) combinations will now fail at code generation time, whereas previously they were just ignored. This is a breaking change to raise awareness in service owners of their server SDKs behaving differently than what was modeled. To continue generating a server SDK with unsupported constraint traits, set `codegenConfig.ignoreUnsupportedConstraints` to `true` in your `smithy-build.json`. +""" +references = ["smithy-rs#1199", "smithy-rs#1342", "smithy-rs#1401"] +meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "server" } +author = "david-perez" + +[[smithy-rs]] +message = """ +Server SDKs now generate "constrained types" for constrained shapes. Constrained types are [newtypes](https://rust-unofficial.github.io/patterns/patterns/behavioural/newtype.html) that encapsulate the modeled constraints. They constitute a [widespread pattern to guarantee domain invariants](https://www.lpalmieri.com/posts/2020-12-11-zero-to-production-6-domain-modelling/) and promote correctness in your business logic. So, for example, the model: + +```smithy +@length(min: 1, max: 69) +string NiceString +``` + +will now render a `struct NiceString(String)`. Instantiating a `NiceString` is a fallible operation: + +```rust +let data: String = ... ; +let nice_string = NiceString::try_from(data).expect("data is not nice"); +``` + +A failed attempt to instantiate a constrained type will yield a `ConstraintViolation` error type you may want to handle. This type's API is subject to change. + +Constrained types _guarantee_, by virtue of the type system, that your service's operation outputs adhere to the modeled constraints. To learn more about the motivation for constrained types and how they work, see [the RFC](https://github.com/awslabs/smithy-rs/pull/1199). + +If you'd like to opt-out of generating constrained types, you can set `codegenConfig.publicConstrainedTypes` to `false`. Note that if you do, the generated server SDK will still honor your operation input's modeled constraints upon receiving a request, but will not help you in writing business logic code that adheres to the constraints, and _will not prevent you from returning responses containing operation outputs that violate said constraints_. +""" +references = ["smithy-rs#1342", "smithy-rs#1119"] +meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "server" } +author = "david-perez" + +[[smithy-rs]] +message = """ +Structure builders in server SDKs have undergone significant changes. + +The API surface has been reduced. It is now simpler and closely follows what you would get when using the [`derive_builder`](https://docs.rs/derive_builder/latest/derive_builder/) crate: + +1. Builders no longer have `set_*` methods taking in `Option`. You must use the unprefixed method, named exactly after the structure's field name, and taking in a value _whose type matches exactly that of the structure's field_. +2. Builders no longer have convenience methods to pass in an element for a field whose type is a vector or a map. You must pass in the entire contents of the collection up front. +3. Builders no longer implement [`PartialEq`](https://doc.rust-lang.org/std/cmp/trait.PartialEq.html). + +Bug fixes: + +4. Builders now always fail to build if a value for a `required` member is not provided. Previously, builders were falling back to a default value (e.g. `""` for `String`s) for some shapes. This was a bug. + +Additions: + +5. A structure `Structure` with builder `Builder` now implements `TryFrom for Structure` or `From for Structure`, depending on whether the structure [is constrained](https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html) or not, respectively. + +To illustrate how to migrate to the new API, consider the example model below. + +```smithy +structure Pokemon { + @required + name: String, + @required + description: String, + @required + evolvesTo: PokemonList +} + +list PokemonList { + member: Pokemon +} +``` + +In the Rust code below, note the references calling out the changes described in the numbered list above. + +Before: + +```rust +let eevee_builder = Pokemon::builder() + // (1) `set_description` takes in `Some`. + .set_description(Some("Su código genético es muy inestable. Puede evolucionar en diversas razas de Pokémon.".to_owned())) + // (2) Convenience method to add one element to the `evolvesTo` list. + .evolves_to(vaporeon) + .evolves_to(jolteon) + .evolves_to(flareon); + +// (3) Builder types can be compared. +assert_ne!(eevee_builder, Pokemon::builder()); + +// (4) Builds fine even though we didn't provide a value for `name`, which is `required`! +let _eevee = eevee_builder.build(); +``` + +After: + +```rust +let eevee_builder = Pokemon::builder() + // (1) `set_description` no longer exists. Use `description`, which directly takes in `String`. + .description("Su código genético es muy inestable. Puede evolucionar en diversas razas de Pokémon.".to_owned()) + // (2) Convenience methods removed; provide the entire collection up front. + .evolves_to(vec![vaporeon, jolteon, flareon]); + +// (3) Binary operation `==` cannot be applied to `pokemon::Builder`. +// assert_ne!(eevee_builder, Pokemon::builder()); + +// (4) `required` member `name` was not set. +// (5) Builder type can be fallibly converted to the structure using `TryFrom` or `TryInto`. +let _error = Pokemon::try_from(eevee_builder).expect_err("name was not provided"); +``` +""" +references = ["smithy-rs#1714", "smithy-rs#1342"] +meta = { "breaking" = true, "tada" = true, "bug" = true, "target" = "server" } +author = "david-perez" + +[[smithy-rs]] +message = """ +Server SDKs now correctly reject operation inputs that don't set values for `required` structure members. Previously, in some scenarios, server SDKs would accept the request and set a default value for the member (e.g. `""` for a `String`), even when the member shape did not have [Smithy IDL v2's `default` trait](https://awslabs.github.io/smithy/2.0/spec/type-refinement-traits.html#smithy-api-default-trait) attached. The `default` trait is [still unsupported](https://github.com/awslabs/smithy-rs/issues/1860). +""" +references = ["smithy-rs#1714", "smithy-rs#1342", "smithy-rs#1860"] +meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "server" } +author = "david-perez" diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt index 5e47701f5..b74079cc2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiator.kt @@ -6,20 +6,34 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable { rust("#T::from($data)", enumSymbol) } +class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { + override fun hasFallibleBuilder(shape: StructureShape): Boolean = + BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider) + + override fun setterName(memberShape: MemberShape): String = memberShape.setterName() + + override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true +} + fun clientInstantiator(codegenContext: CodegenContext) = Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, + ClientBuilderKindBehavior(codegenContext), ::enumFromStringFn, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 912a0f668..0f1eac1d6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -26,7 +26,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.ResponseBindingGenerator @@ -332,7 +332,7 @@ class HttpBoundProtocolTraitImplGenerator( } } - val err = if (StructureGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { + val err = if (BuilderGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { ".map_err(${format(errorSymbol)}::unhandled)?" } else "" diff --git a/codegen-core/common-test-models/constraints.smithy b/codegen-core/common-test-models/constraints.smithy new file mode 100644 index 000000000..d43ea8b7b --- /dev/null +++ b/codegen-core/common-test-models/constraints.smithy @@ -0,0 +1,453 @@ +$version: "1.0" + +namespace com.amazonaws.constraints + +use aws.protocols#restJson1 +use smithy.framework#ValidationException + +/// A service to test aspects of code generation where shapes have constraint traits. +@restJson1 +@title("ConstraintsService") +service ConstraintsService { + operations: [ + // TODO Rename as {Verb}[{Qualifier}]{Noun}: https://github.com/awslabs/smithy-rs/pull/1342#discussion_r980936650 + ConstrainedShapesOperation, + ConstrainedHttpBoundShapesOperation, + ConstrainedRecursiveShapesOperation, + // `httpQueryParams` and `httpPrefixHeaders` are structurually + // exclusive, so we need one operation per target shape type + // combination. + QueryParamsTargetingLengthMapOperation, + QueryParamsTargetingMapOfLengthStringOperation, + QueryParamsTargetingMapOfEnumStringOperation, + QueryParamsTargetingMapOfListOfLengthStringOperation, + QueryParamsTargetingMapOfSetOfLengthStringOperation, + QueryParamsTargetingMapOfListOfEnumStringOperation, + HttpPrefixHeadersTargetingLengthMapOperation, + // TODO(https://github.com/awslabs/smithy-rs/issues/1431) + // HttpPrefixHeadersTargetingMapOfEnumStringOperation, + + NonStreamingBlobOperation, + + StreamingBlobOperation, + EventStreamsOperation, + ], +} + +@http(uri: "/constrained-shapes-operation", method: "POST") +operation ConstrainedShapesOperation { + input: ConstrainedShapesOperationInputOutput, + output: ConstrainedShapesOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/constrained-http-bound-shapes-operation/{lengthStringLabel}/{enumStringLabel}", method: "POST") +operation ConstrainedHttpBoundShapesOperation { + input: ConstrainedHttpBoundShapesOperationInputOutput, + output: ConstrainedHttpBoundShapesOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/constrained-recursive-shapes-operation", method: "POST") +operation ConstrainedRecursiveShapesOperation { + input: ConstrainedRecursiveShapesOperationInputOutput, + output: ConstrainedRecursiveShapesOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-length-map", method: "POST") +operation QueryParamsTargetingLengthMapOperation { + input: QueryParamsTargetingLengthMapOperationInputOutput, + output: QueryParamsTargetingLengthMapOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-length-string-operation", method: "POST") +operation QueryParamsTargetingMapOfLengthStringOperation { + input: QueryParamsTargetingMapOfLengthStringOperationInputOutput, + output: QueryParamsTargetingMapOfLengthStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-enum-string-operation", method: "POST") +operation QueryParamsTargetingMapOfEnumStringOperation { + input: QueryParamsTargetingMapOfEnumStringOperationInputOutput, + output: QueryParamsTargetingMapOfEnumStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-list-of-length-string-operation", method: "POST") +operation QueryParamsTargetingMapOfListOfLengthStringOperation { + input: QueryParamsTargetingMapOfListOfLengthStringOperationInputOutput, + output: QueryParamsTargetingMapOfListOfLengthStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-set-of-length-string-operation", method: "POST") +operation QueryParamsTargetingMapOfSetOfLengthStringOperation { + input: QueryParamsTargetingMapOfSetOfLengthStringOperationInputOutput, + output: QueryParamsTargetingMapOfSetOfLengthStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/query-params-targeting-map-of-list-of-enum-string-operation", method: "POST") +operation QueryParamsTargetingMapOfListOfEnumStringOperation { + input: QueryParamsTargetingMapOfListOfEnumStringOperationInputOutput, + output: QueryParamsTargetingMapOfListOfEnumStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/http-prefix-headers-targeting-length-map-operation", method: "POST") +operation HttpPrefixHeadersTargetingLengthMapOperation { + input: HttpPrefixHeadersTargetingLengthMapOperationInputOutput, + output: HttpPrefixHeadersTargetingLengthMapOperationInputOutput, + errors: [ValidationException], +} + +@http(uri: "/http-prefix-headers-targeting-map-of-enum-string-operation", method: "POST") +operation HttpPrefixHeadersTargetingMapOfEnumStringOperation { + input: HttpPrefixHeadersTargetingMapOfEnumStringOperationInputOutput, + output: HttpPrefixHeadersTargetingMapOfEnumStringOperationInputOutput, + errors: [ValidationException] +} + +@http(uri: "/non-streaming-blob-operation", method: "POST") +operation NonStreamingBlobOperation { + input: NonStreamingBlobOperationInputOutput, + output: NonStreamingBlobOperationInputOutput, +} + +@http(uri: "/streaming-blob-operation", method: "POST") +operation StreamingBlobOperation { + input: StreamingBlobOperationInputOutput, + output: StreamingBlobOperationInputOutput, +} + +@http(uri: "/event-streams-operation", method: "POST") +operation EventStreamsOperation { + input: EventStreamsOperationInputOutput, + output: EventStreamsOperationInputOutput, +} + +structure ConstrainedShapesOperationInputOutput { + @required + conA: ConA, +} + +structure ConstrainedHttpBoundShapesOperationInputOutput { + @required + @httpLabel + lengthStringLabel: LengthString, + + @required + @httpLabel + enumStringLabel: EnumString, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1394) `@required` not working + // @required + @httpPrefixHeaders("X-Prefix-Headers-") + lengthStringHeaderMap: MapOfLengthString, + + @httpHeader("X-Length") + lengthStringHeader: LengthString, + + // @httpHeader("X-Length-MediaType") + // lengthStringHeaderWithMediaType: MediaTypeLengthString, + + @httpHeader("X-Length-Set") + lengthStringSetHeader: SetOfLengthString, + + @httpHeader("X-Length-List") + lengthStringListHeader: ListOfLengthString, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1431) + // @httpHeader("X-Enum") + //enumStringHeader: EnumString, + + // @httpHeader("X-Enum-List") + // enumStringListHeader: ListOfEnumString, + + @httpQuery("lengthString") + lengthStringQuery: LengthString, + + @httpQuery("enumString") + enumStringQuery: EnumString, + + @httpQuery("lengthStringList") + lengthStringListQuery: ListOfLengthString, + + @httpQuery("lengthStringSet") + lengthStringSetQuery: SetOfLengthString, + + @httpQuery("enumStringList") + enumStringListQuery: ListOfEnumString, +} + +structure HttpPrefixHeadersTargetingLengthMapOperationInputOutput { + @httpPrefixHeaders("X-Prefix-Headers-LengthMap-") + lengthMap: ConBMap, +} + +structure HttpPrefixHeadersTargetingMapOfEnumStringOperationInputOutput { + @httpPrefixHeaders("X-Prefix-Headers-MapOfEnumString-") + mapOfEnumString: MapOfEnumString, +} + +structure QueryParamsTargetingLengthMapOperationInputOutput { + @httpQueryParams + lengthMap: ConBMap +} + +structure QueryParamsTargetingMapOfLengthStringOperationInputOutput { + @httpQueryParams + mapOfLengthString: MapOfLengthString +} + +structure QueryParamsTargetingMapOfEnumStringOperationInputOutput { + @httpQueryParams + mapOfEnumString: MapOfEnumString +} + +structure QueryParamsTargetingMapOfListOfLengthStringOperationInputOutput { + @httpQueryParams + mapOfListOfLengthString: MapOfListOfLengthString +} + +structure QueryParamsTargetingMapOfSetOfLengthStringOperationInputOutput { + @httpQueryParams + mapOfSetOfLengthString: MapOfSetOfLengthString +} + +structure QueryParamsTargetingMapOfListOfEnumStringOperationInputOutput { + @httpQueryParams + mapOfListOfEnumString: MapOfListOfEnumString +} + +structure NonStreamingBlobOperationInputOutput { + @httpPayload + nonStreamingBlob: NonStreamingBlob, +} + +structure StreamingBlobOperationInputOutput { + @httpPayload + streamingBlob: StreamingBlob, +} + +structure EventStreamsOperationInputOutput { + @httpPayload + events: Event, +} + +@streaming +union Event { + regularMessage: EventStreamRegularMessage, + errorMessage: EventStreamErrorMessage, +} + +structure EventStreamRegularMessage { + messageContent: String + // TODO(https://github.com/awslabs/smithy/issues/1388): Can't add a constraint trait here until the semantics are clarified. + // messageContent: LengthString +} + +@error("server") +structure EventStreamErrorMessage { + messageContent: String + // TODO(https://github.com/awslabs/smithy/issues/1388): Can't add a constraint trait here until the semantics are clarified. + // messageContent: LengthString +} + +// TODO(https://github.com/awslabs/smithy/issues/1389): Can't add a constraint trait here until the semantics are clarified. +@streaming +blob StreamingBlob + +blob NonStreamingBlob + +structure ConA { + @required + conB: ConB, + + optConB: ConB, + + lengthString: LengthString, + minLengthString: MinLengthString, + maxLengthString: MaxLengthString, + fixedLengthString: FixedLengthString, + + conBList: ConBList, + conBList2: ConBList2, + + conBSet: ConBSet, + + conBMap: ConBMap, + + mapOfMapOfListOfListOfConB: MapOfMapOfListOfListOfConB, + + constrainedUnion: ConstrainedUnion, + enumString: EnumString, + + listOfLengthString: ListOfLengthString, + setOfLengthString: SetOfLengthString, + mapOfLengthString: MapOfLengthString, + + nonStreamingBlob: NonStreamingBlob +} + +map MapOfLengthString { + key: LengthString, + value: LengthString, +} + +map MapOfEnumString { + key: EnumString, + value: EnumString, +} + +map MapOfListOfLengthString { + key: LengthString, + value: ListOfLengthString, +} + +map MapOfListOfEnumString { + key: EnumString, + value: ListOfEnumString, +} + +map MapOfSetOfLengthString { + key: LengthString, + value: SetOfLengthString, +} + +@length(min: 2, max: 8) +list LengthListOfLengthString { + member: LengthString +} + +@length(min: 2, max: 69) +string LengthString + +@length(min: 2) +string MinLengthString + +@length(min: 69) +string MaxLengthString + +@length(min: 69, max: 69) +string FixedLengthString + +@mediaType("video/quicktime") +@length(min: 1, max: 69) +string MediaTypeLengthString + +/// A union with constrained members. +union ConstrainedUnion { + enumString: EnumString, + lengthString: LengthString, + + constrainedStructure: ConB, + conBList: ConBList, + conBSet: ConBSet, + conBMap: ConBMap, +} + +@enum([ + { + value: "t2.nano", + name: "T2_NANO", + }, + { + value: "t2.micro", + name: "T2_MICRO", + }, + { + value: "m256.mega", + name: "M256_MEGA", + } +]) +string EnumString + +set SetOfLengthString { + member: LengthString +} + +list ListOfLengthString { + member: LengthString +} + +list ListOfEnumString { + member: EnumString +} + +structure ConB { + @required + nice: String, + @required + int: Integer, + + optNice: String, + optInt: Integer +} + +structure ConstrainedRecursiveShapesOperationInputOutput { + nested: RecursiveShapesInputOutputNested1, + + @required + recursiveList: RecursiveList +} + +structure RecursiveShapesInputOutputNested1 { + @required + recursiveMember: RecursiveShapesInputOutputNested2 +} + +structure RecursiveShapesInputOutputNested2 { + recursiveMember: RecursiveShapesInputOutputNested1, +} + +list RecursiveList { + member: RecursiveShapesInputOutputNested1 +} + +list ConBList { + member: NestedList +} + +list ConBList2 { + member: ConB +} + +list NestedList { + member: ConB +} + +set ConBSet { + member: NestedSet +} + +set NestedSet { + member: String +} + +@length(min: 1, max: 69) +map ConBMap { + key: String, + value: LengthString +} + +@error("client") +structure ErrorWithLengthStringMessage { + // TODO Doesn't work yet because constrained string types don't implement + // `AsRef`. + // @required + // message: LengthString +} + +map MapOfMapOfListOfListOfConB { + key: String, + value: MapOfListOfListOfConB +} + +map MapOfListOfListOfConB { + key: String, + value: ConBList +} diff --git a/codegen-core/common-test-models/misc.smithy b/codegen-core/common-test-models/misc.smithy index a98c0fa21..69bdcc2cd 100644 --- a/codegen-core/common-test-models/misc.smithy +++ b/codegen-core/common-test-models/misc.smithy @@ -5,6 +5,7 @@ namespace aws.protocoltests.misc use aws.protocols#restJson1 use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.framework#ValidationException /// A service to test miscellaneous aspects of code generation where protocol /// selection is not relevant. If you want to test something protocol-specific, @@ -54,10 +55,11 @@ map MapA { /// This operation tests that (de)serializing required values from a nested /// shape works correctly. -@http(uri: "/innerRequiredShapeOperation", method: "POST") +@http(uri: "/requiredInnerShapeOperation", method: "POST") operation RequiredInnerShapeOperation { input: RequiredInnerShapeOperationInputOutput, output: RequiredInnerShapeOperationInputOutput, + errors: [ValidationException], } structure RequiredInnerShapeOperationInputOutput { @@ -236,6 +238,7 @@ operation AcceptHeaderStarService {} operation RequiredHeaderCollectionOperation { input: RequiredHeaderCollectionOperationInputOutput, output: RequiredHeaderCollectionOperationInputOutput, + errors: [ValidationException] } structure RequiredHeaderCollectionOperationInputOutput { diff --git a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy index 087d99b75..f54b27e76 100644 --- a/codegen-core/common-test-models/naming-obstacle-course-ops.smithy +++ b/codegen-core/common-test-models/naming-obstacle-course-ops.smithy @@ -5,6 +5,7 @@ use smithy.test#httpRequestTests use smithy.test#httpResponseTests use aws.protocols#awsJson1_1 use aws.api#service +use smithy.framework#ValidationException /// Confounds model generation machinery with lots of problematic names @awsJson1_1 @@ -41,17 +42,20 @@ service Config { } ]) operation ReservedWordsAsMembers { - input: ReservedWords + input: ReservedWords, + errors: [ValidationException], } // tests that module names are properly escaped operation Match { input: ReservedWords + errors: [ValidationException], } // Should generate a PascalCased `RpcEchoInput` struct. operation RPCEcho { input: ReservedWords + errors: [ValidationException], } structure ReservedWords { diff --git a/codegen-core/common-test-models/pokemon-common.smithy b/codegen-core/common-test-models/pokemon-common.smithy index 3198cb8c7..d213a16b1 100644 --- a/codegen-core/common-test-models/pokemon-common.smithy +++ b/codegen-core/common-test-models/pokemon-common.smithy @@ -2,6 +2,8 @@ $version: "1.0" namespace com.aws.example +use smithy.framework#ValidationException + /// A Pokémon species forms the basis for at least one Pokémon. @title("Pokémon Species") resource PokemonSpecies { @@ -17,7 +19,7 @@ resource PokemonSpecies { operation GetPokemonSpecies { input: GetPokemonSpeciesInput, output: GetPokemonSpeciesOutput, - errors: [ResourceNotFoundException], + errors: [ResourceNotFoundException, ValidationException], } @input diff --git a/codegen-core/common-test-models/pokemon.smithy b/codegen-core/common-test-models/pokemon.smithy index e955cdd21..d42185e31 100644 --- a/codegen-core/common-test-models/pokemon.smithy +++ b/codegen-core/common-test-models/pokemon.smithy @@ -3,6 +3,7 @@ $version: "1.0" namespace com.aws.example.rust use aws.protocols#restJson1 +use smithy.framework#ValidationException use com.aws.example#PokemonSpecies use com.aws.example#GetServerStatistics use com.aws.example#DoNothing @@ -31,13 +32,13 @@ resource Storage { read: GetStorage, } -/// Retrieve information about your Pokedex. +/// Retrieve information about your Pokédex. @readonly @http(uri: "/pokedex/{user}", method: "GET") operation GetStorage { input: GetStorageInput, output: GetStorageOutput, - errors: [ResourceNotFoundException, NotAuthorized], + errors: [ResourceNotFoundException, NotAuthorized, ValidationException], } /// Not authorized to access Pokémon storage. @@ -74,7 +75,7 @@ structure GetStorageOutput { operation CapturePokemon { input: CapturePokemonEventsInput, output: CapturePokemonEventsOutput, - errors: [UnsupportedRegionError, ThrottlingError] + errors: [UnsupportedRegionError, ThrottlingError, ValidationException] } @input @@ -140,7 +141,6 @@ structure InvalidPokeballError { } @error("server") structure MasterBallUnsuccessful { - @required message: String, } diff --git a/codegen-core/common-test-models/rest-json-extras.smithy b/codegen-core/common-test-models/rest-json-extras.smithy index 10224b0c7..65b7fcc8f 100644 --- a/codegen-core/common-test-models/rest-json-extras.smithy +++ b/codegen-core/common-test-models/rest-json-extras.smithy @@ -6,6 +6,7 @@ use aws.protocols#restJson1 use aws.api#service use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.framework#ValidationException apply QueryPrecedence @httpRequestTests([ { @@ -120,6 +121,7 @@ structure StringPayloadInput { documentation: "Primitive ints should not be serialized when they are unset", uri: "/primitive-document", method: "POST", + appliesTo: "client", body: "{}", headers: { "Content-Type": "application/json" }, params: {}, @@ -152,7 +154,8 @@ structure PrimitiveIntDocument { ]) @http(uri: "/primitive", method: "POST") operation PrimitiveIntHeader { - output: PrimitiveIntHeaderInput + output: PrimitiveIntHeaderInput, + errors: [ValidationException], } integer PrimitiveInt @@ -174,7 +177,8 @@ structure PrimitiveIntHeaderInput { } ]) operation EnumQuery { - input: EnumQueryInput + input: EnumQueryInput, + errors: [ValidationException], } structure EnumQueryInput { @@ -226,6 +230,7 @@ structure MapWithEnumKeyInputOutput { operation MapWithEnumKeyOp { input: MapWithEnumKeyInputOutput, output: MapWithEnumKeyInputOutput, + errors: [ValidationException], } @@ -265,6 +270,7 @@ structure EscapedStringValuesInputOutput { operation EscapedStringValues { input: EscapedStringValuesInputOutput, output: EscapedStringValuesInputOutput, + errors: [ValidationException], } list NonSparseList { diff --git a/codegen-core/common-test-models/simple.smithy b/codegen-core/common-test-models/simple.smithy index 6e094abe1..c685c02f6 100644 --- a/codegen-core/common-test-models/simple.smithy +++ b/codegen-core/common-test-models/simple.smithy @@ -5,6 +5,7 @@ namespace com.amazonaws.simple use aws.protocols#restJson1 use smithy.test#httpRequestTests use smithy.test#httpResponseTests +use smithy.framework#ValidationException @restJson1 @title("SimpleService") @@ -74,7 +75,7 @@ resource Service { operation RegisterService { input: RegisterServiceInputRequest, output: RegisterServiceOutputResponse, - errors: [ResourceAlreadyExists] + errors: [ResourceAlreadyExists, ValidationException] } @documentation("Service register input structure") @@ -116,6 +117,7 @@ structure HealthcheckOutputResponse { operation StoreServiceBlob { input: StoreServiceBlobInput, output: StoreServiceBlobOutput + errors: [ValidationException] } @documentation("Store a blob for a service id input structure") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index bd231f809..970c1a63e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -112,6 +112,9 @@ class InlineDependency( fun unwrappedXmlErrors(runtimeConfig: RuntimeConfig): InlineDependency = forRustFile("rest_xml_unwrapped_errors", CargoDependency.smithyXml(runtimeConfig)) + + fun constrained(): InlineDependency = + forRustFile("constrained") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 6ff484bba..120798258 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -148,6 +148,12 @@ sealed class RustType { } } + data class MaybeConstrained(override val member: RustType) : RustType(), Container { + val runtimeType: RuntimeType = RuntimeType.MaybeConstrained() + override val name = runtimeType.name!! + override val namespace = runtimeType.namespace + } + data class Box(override val member: RustType) : RustType(), Container { override val name = "Box" override val namespace = "std::boxed" @@ -237,6 +243,7 @@ fun RustType.render(fullyQualified: Boolean = true): String { is RustType.Box -> "${this.name}<${this.member.render(fullyQualified)}>" is RustType.Dyn -> "${this.name} ${this.member.render(fullyQualified)}" is RustType.Opaque -> this.name + is RustType.MaybeConstrained -> "${this.name}<${this.member.render(fullyQualified)}>" } return "$namespace$base" } @@ -380,6 +387,7 @@ sealed class Attribute { companion object { val AllowDeadCode = Custom("allow(dead_code)") val AllowDeprecated = Custom("allow(deprecated)") + val AllowUnused = Custom("allow(unused)") val AllowUnusedMut = Custom("allow(unused_mut)") val DocHidden = Custom("doc(hidden)") val DocInline = Custom("doc(inline)") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt index bbd77b7c7..f1739324a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt @@ -171,6 +171,25 @@ open class RustCrate( } } +val ErrorsModule = RustModule.public("error", documentation = "All error types that operations can return.") +val OperationsModule = RustModule.public("operation", documentation = "All operations that this crate can perform.") +val ModelsModule = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs.") +val InputsModule = RustModule.public("input", documentation = "Input structures for operations.") +val OutputsModule = RustModule.public("output", documentation = "Output structures for operations.") +val ConfigModule = RustModule.public("config", documentation = "Client configuration.") + +/** + * Allowlist of modules that will be exposed publicly in generated crates + */ +val DefaultPublicModules = setOf( + ErrorsModule, + OperationsModule, + ModelsModule, + InputsModule, + OutputsModule, + ConfigModule, +).associateBy { it.name } + /** * Finalize all the writers by: * - inlining inline dependencies that have been used diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index d79059065..4fc233df3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -196,7 +196,9 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n val Debug = stdfmt.member("Debug") val Default: RuntimeType = RuntimeType("Default", dependency = null, namespace = "std::default") val Display = stdfmt.member("Display") + val Eq = std.member("cmp::Eq") val From = RuntimeType("From", dependency = null, namespace = "std::convert") + val Hash = std.member("hash::Hash") val TryFrom = RuntimeType("TryFrom", dependency = null, namespace = "std::convert") val PartialEq = std.member("cmp::PartialEq") val StdError = RuntimeType("Error", dependency = null, namespace = "std::error") @@ -256,6 +258,9 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n func, CargoDependency.SmithyProtocolTestHelpers(runtimeConfig), "aws_smithy_protocol_test", ) + fun ConstrainedTrait() = RuntimeType("Constrained", InlineDependency.constrained(), namespace = "crate::constrained") + fun MaybeConstrained() = RuntimeType("MaybeConstrained", InlineDependency.constrained(), namespace = "crate::constrained") + val http = CargoDependency.Http.asType() fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index d72bafb5b..238a1177b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode import software.amazon.smithy.model.shapes.BigDecimalShape import software.amazon.smithy.model.shapes.BigIntegerShape import software.amazon.smithy.model.shapes.BlobShape @@ -77,29 +78,65 @@ data class SymbolLocation(val namespace: String) { val filename = "$namespace.rs" } -val Models = SymbolLocation("model") -val Errors = SymbolLocation("error") -val Operations = SymbolLocation("operation") -val Inputs = SymbolLocation("input") -val Outputs = SymbolLocation("output") +val Models = SymbolLocation(ModelsModule.name) +val Errors = SymbolLocation(ErrorsModule.name) +val Operations = SymbolLocation(OperationsModule.name) +val Serializers = SymbolLocation("serializer") +val Inputs = SymbolLocation(InputsModule.name) +val Outputs = SymbolLocation(OutputsModule.name) +val Unconstrained = SymbolLocation("unconstrained") +val Constrained = SymbolLocation("constrained") /** * Make the Rust type of a symbol optional (hold `Option`) * * This is idempotent and will have no change if the type is already optional. */ -fun Symbol.makeOptional(): Symbol { - return if (isOptional()) { +fun Symbol.makeOptional(): Symbol = + if (isOptional()) { this } else { val rustType = RustType.Option(this.rustType()) - Symbol.builder().rustType(rustType) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Make the Rust type of a symbol boxed (hold `Box`). + * + * This is idempotent and will have no change if the type is already boxed. + */ +fun Symbol.makeRustBoxed(): Symbol = + if (isRustBoxed()) { + this + } else { + val rustType = RustType.Box(this.rustType()) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Make the Rust type of a symbol wrapped in `MaybeConstrained`. (hold `MaybeConstrained`). + * + * This is idempotent and will have no change if the type is already `MaybeConstrained`. + */ +fun Symbol.makeMaybeConstrained(): Symbol = + if (this.rustType() is RustType.MaybeConstrained) { + this + } else { + val rustType = RustType.MaybeConstrained(this.rustType()) + Symbol.builder() .rustType(rustType) .addReference(this) .name(rustType.name) .build() } -} /** * Map the [RustType] of a symbol with [f]. @@ -208,9 +245,6 @@ open class SymbolVisitor( return RuntimeType.Blob(config.runtimeConfig).toSymbol() } - private fun handleOptionality(symbol: Symbol, member: MemberShape): Symbol = - symbol.letIf(nullableIndex.isMemberNullable(member, config.nullabilityCheckMode)) { symbol.makeOptional() } - /** * Produce `Box` when the shape has the `RustBoxTrait` */ @@ -227,7 +261,7 @@ open class SymbolVisitor( } private fun simpleShape(shape: SimpleShape): Symbol { - return symbolBuilder(SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build() + return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build() } override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape) @@ -240,7 +274,7 @@ open class SymbolVisitor( override fun stringShape(shape: StringShape): Symbol { return if (shape.hasTrait()) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - symbolBuilder(rustType).locatedIn(Models).build() + symbolBuilder(shape, rustType).locatedIn(Models).build() } else { simpleShape(shape) } @@ -248,16 +282,16 @@ open class SymbolVisitor( override fun listShape(shape: ListShape): Symbol { val inner = this.toSymbol(shape.member) - return symbolBuilder(RustType.Vec(inner.rustType())).addReference(inner).build() + return symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() } override fun setShape(shape: SetShape): Symbol { val inner = this.toSymbol(shape.member) val builder = if (model.expectShape(shape.member.target).isStringShape) { - symbolBuilder(RustType.HashSet(inner.rustType())) + symbolBuilder(shape, RustType.HashSet(inner.rustType())) } else { // only strings get put into actual sets because floats are unhashable - symbolBuilder(RustType.Vec(inner.rustType())) + symbolBuilder(shape, RustType.Vec(inner.rustType())) } return builder.addReference(inner).build() } @@ -267,7 +301,7 @@ open class SymbolVisitor( require(target.isStringShape) { "unexpected key shape: ${shape.key}: $target [keys must be strings]" } val key = this.toSymbol(shape.key) val value = this.toSymbol(shape.value) - return symbolBuilder(RustType.HashMap(key.rustType(), value.rustType())).addReference(key) + return symbolBuilder(shape, RustType.HashMap(key.rustType(), value.rustType())).addReference(key) .addReference(value).build() } @@ -285,6 +319,7 @@ open class SymbolVisitor( override fun operationShape(shape: OperationShape): Symbol { return symbolBuilder( + shape, RustType.Opaque( shape.contextName(serviceShape) .replaceFirstChar { it.uppercase() }, @@ -309,7 +344,7 @@ open class SymbolVisitor( val name = shape.contextName(serviceShape).toPascalCase().letIf(isError && config.renameExceptions) { it.replace("Exception", "Error") } - val builder = symbolBuilder(RustType.Opaque(name)) + val builder = symbolBuilder(shape, RustType.Opaque(name)) return when { isError -> builder.locatedIn(Errors) isInput -> builder.locatedIn(Inputs) @@ -320,29 +355,50 @@ open class SymbolVisitor( override fun unionShape(shape: UnionShape): Symbol { val name = shape.contextName(serviceShape).toPascalCase() - val builder = symbolBuilder(RustType.Opaque(name)).locatedIn(Models) + val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(Models) return builder.build() } override fun memberShape(shape: MemberShape): Symbol { val target = model.expectShape(shape.target) - // Handle boxing first so we end up with Option>, not Box> - return handleOptionality(handleRustBoxing(toSymbol(target), shape), shape) + // Handle boxing first so we end up with Option>, not Box>. + return handleOptionality( + handleRustBoxing(toSymbol(target), shape), + shape, + nullableIndex, + config.nullabilityCheckMode, + ) } override fun timestampShape(shape: TimestampShape?): Symbol { return RuntimeType.DateTime(config.runtimeConfig).toSymbol() } +} - private fun symbolBuilder(rustType: RustType): Symbol.Builder { - return Symbol.builder().rustType(rustType).name(rustType.name) - // Every symbol that actually gets defined somewhere should set a definition file - // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation - .definitionFile("thisisabug.rs") - } +/** + * Boxes and returns [symbol], the symbol for the target of the member shape [shape], if [shape] is annotated with + * [RustBoxTrait]; otherwise returns [symbol] unchanged. + * + * See `RecursiveShapeBoxer.kt` for the model transformation pass that annotates model shapes with [RustBoxTrait]. + */ +fun handleRustBoxing(symbol: Symbol, shape: MemberShape): Symbol = + if (shape.hasTrait()) { + symbol.makeRustBoxed() + } else symbol + +fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder { + val builder = Symbol.builder().putProperty(SHAPE_KEY, shape) + return builder.rustType(rustType) + .name(rustType.name) + // Every symbol that actually gets defined somewhere should set a definition file + // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation + .definitionFile("thisisabug.rs") } +fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol = + symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } + // TODO(chore): Move this to a useful place private const val RUST_TYPE_KEY = "rusttype" private const val SHAPE_KEY = "shape" @@ -388,11 +444,6 @@ fun Symbol.isOptional(): Boolean = when (this.rustType()) { else -> false } -/** - * Get the referenced symbol for T if [this] is an Option, [this] otherwise - */ -fun Symbol.extractSymbolFromOption(): Symbol = this.mapRustType { it.stripOuter() } - fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter() is RustType.Box // Symbols should _always_ be created with a Rust type & shape attached diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 9ae2da62a..0dec0b1a6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.StructureShape @@ -28,14 +29,24 @@ import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.defaultValue import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +// TODO(https://github.com/awslabs/smithy-rs/issues/1401) This builder generator is only used by the client. +// Move this entire file, and its tests, to `codegen-client`. + +fun builderSymbolFn(symbolProvider: RustSymbolProvider): (StructureShape) -> Symbol = { structureShape -> + structureShape.builderSymbol(symbolProvider) +} + fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): Symbol { val structureSymbol = symbolProvider.toSymbol(this) val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) @@ -65,6 +76,23 @@ class BuilderGenerator( private val symbolProvider: RustSymbolProvider, private val shape: StructureShape, ) { + companion object { + /** + * Returns whether a structure shape, whose builder has been generated with [BuilderGenerator], requires a + * fallible builder to be constructed. + */ + fun hasFallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = + // All operation inputs should have fallible builders in case a new required field is added in the future. + structureShape.hasTrait() || + structureShape + .members() + .map { symbolProvider.toSymbol(it) }.any { + // If any members are not optional && we can't use a default, we need to + // generate a fallible builder. + !it.isOptional() && !it.canUseDefault() + } + } + private val runtimeConfig = symbolProvider.config().runtimeConfig private val members: List = shape.allMembers.values.toList() private val structureSymbol = symbolProvider.toSymbol(shape) @@ -79,7 +107,7 @@ class BuilderGenerator( } private fun renderBuildFn(implBlockWriter: RustWriter) { - val fallibleBuilder = StructureGenerator.hasFallibleBuilder(shape, symbolProvider) + val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider) val outputSymbol = symbolProvider.toSymbol(shape) val returnType = when (fallibleBuilder) { true -> "Result<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index e735ec5e1..eed5b5462 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -85,8 +85,8 @@ open class EnumGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, private val writer: RustWriter, - private val shape: StringShape, - private val enumTrait: EnumTrait, + protected val shape: StringShape, + protected val enumTrait: EnumTrait, ) { protected val symbol: Symbol = symbolProvider.toSymbol(shape) protected val enumName: String = symbol.name diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index a58c0c65b..7973af760 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -62,12 +62,14 @@ open class Instantiator( private val symbolProvider: RustSymbolProvider, private val model: Model, private val runtimeConfig: RuntimeConfig, + /** Behavior of the builder type used for structure shapes. */ + private val builderKindBehavior: BuilderKindBehavior, /** * A function that given a symbol for an enum shape and a string, returns a writable to instantiate the enum with * the string value. **/ private val enumFromStringFn: (Symbol, String) -> Writable, - /** Fill out required fields with a default value **/ + /** Fill out required fields with a default value. **/ private val defaultsForRequiredFields: Boolean = false, ) { data class Ctx( @@ -76,6 +78,20 @@ open class Instantiator( val lowercaseMapKeys: Boolean = false, ) + /** + * Client and server structures have different builder types. `Instantiator` needs to know how the builder + * type behaves to generate code for it. + */ + interface BuilderKindBehavior { + fun hasFallibleBuilder(shape: StructureShape): Boolean + + // Client structure builders have two kinds of setters: one that always takes in `Option`, and one that takes + // in the structure field's type. The latter's method name is the field's name, whereas the former is prefixed + // with `set_`. Client instantiators call the `set_*` builder setters. + fun setterName(memberShape: MemberShape): String + fun doesSetterTakeInOption(memberShape: MemberShape): Boolean + } + fun render(writer: RustWriter, shape: Shape, data: Node, ctx: Ctx = Ctx()) { when (shape) { // Compound Shapes @@ -165,7 +181,9 @@ open class Instantiator( writer.conditionalBlock( "Some(", ")", - conditional = model.expectShape(memberShape.container) is StructureShape || symbol.isOptional(), + // The conditions are not commutative: note client builders always take in `Option`. + conditional = symbol.isOptional() || + (model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption(memberShape)), ) { writer.conditionalBlock( "Box::new(", @@ -277,7 +295,8 @@ open class Instantiator( */ private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) { fun renderMemberHelper(memberShape: MemberShape, value: Node) { - writer.withBlock(".${memberShape.setterName()}(", ")") { + val setterName = builderKindBehavior.setterName(memberShape) + writer.withBlock(".$setterName(", ")") { renderMember(this, memberShape, value, ctx) } } @@ -297,8 +316,9 @@ open class Instantiator( val memberShape = shape.expectMember(key.value) renderMemberHelper(memberShape, value) } + writer.rust(".build()") - if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { + if (builderKindBehavior.hasFallibleBuilder(shape)) { writer.rust(".unwrap()") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt index cbd0a1395..3972307ac 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt @@ -28,16 +28,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorGenerator -import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom import software.amazon.smithy.rust.codegen.core.smithy.rustType -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait -import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary fun RustWriter.implBlock(structureShape: Shape, symbolProvider: SymbolProvider, block: Writable) { @@ -68,20 +64,6 @@ open class StructureGenerator( } } - companion object { - /** Returns whether a structure shape requires a fallible builder to be generated. */ - fun hasFallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = - // All operation inputs should have fallible builders in case a new required field is added in the future. - structureShape.hasTrait() || - structureShape - .allMembers - .values.map { symbolProvider.toSymbol(it) }.any { - // If any members are not optional && we can't use a default, we need to - // generate a fallible builder - !it.isOptional() && !it.canUseDefault() - } - } - /** * Search for lifetimes used by the members of the struct and generate a declaration. * e.g. `<'a, 'b>` @@ -100,7 +82,8 @@ open class StructureGenerator( } else "" } - /** Render a custom debug implementation + /** + * Render a custom debug implementation * When [SensitiveTrait] support is required, render a custom debug implementation to redact sensitive data */ private fun renderDebugImpl() { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt index 07f9315b8..4303fa1a6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt @@ -9,10 +9,14 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.RetryableTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.asDeref +import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig @@ -20,6 +24,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.StdError import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.REDACTION import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.errorMessageMember @@ -82,10 +88,23 @@ class ErrorGenerator( } } if (messageShape != null) { - val (returnType, message) = if (symbolProvider.toSymbol(messageShape).isOptional()) { - "Option<&str>" to "self.${symbolProvider.toMemberName(messageShape)}.as_deref()" + val messageSymbol = symbolProvider.toSymbol(messageShape).mapRustType { t -> t.asDeref() } + val messageType = messageSymbol.rustType() + val memberName = symbolProvider.toMemberName(messageShape) + val (returnType, message) = if (messageType.stripOuter() is RustType.Opaque) { + // The string shape has a constraint trait that makes its symbol be a wrapper tuple struct. + if (messageSymbol.isOptional()) { + "Option<&${messageType.stripOuter().render()}>" to + "self.$memberName.as_ref()" + } else { + "&${messageType.render()}" to "&self.$memberName" + } } else { - "&str" to "self.${symbolProvider.toMemberName(messageShape)}.as_ref()" + if (messageSymbol.isOptional()) { + messageType.render() to "self.$memberName.as_deref()" + } else { + messageType.render() to "self.$memberName.as_ref()" + } } rust( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index 2a54f5674..97a703ab9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -6,15 +6,19 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape @@ -38,6 +42,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType @@ -68,6 +74,18 @@ enum class HttpMessageType { REQUEST, RESPONSE } +/** + * Class describing an HTTP binding (de)serialization section that can be used in a customization. + */ +sealed class HttpBindingSection(name: String) : Section(name) { + data class BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(val variableName: String, val shape: MapShape) : + HttpBindingSection("BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders") + data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) : + HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders") +} + +typealias HttpBindingCustomization = NamedSectionGenerator + /** * This class generates Rust functions that (de)serialize data from/to an HTTP message. * They are useful for *both*: @@ -88,12 +106,15 @@ enum class HttpMessageType { */ class HttpBindingGenerator( private val protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, + private val symbolProvider: SymbolProvider, private val operationShape: OperationShape, + /** Function that maps a StructureShape into its builder symbol */ + private val builderSymbol: (StructureShape) -> Symbol, + private val customizations: List = listOf(), ) { private val runtimeConfig = codegenContext.runtimeConfig - private val symbolProvider = codegenContext.symbolProvider - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val model = codegenContext.model private val service = codegenContext.serviceShape private val index = HttpBindingIndex.of(model) @@ -120,7 +141,7 @@ class HttpBindingGenerator( val fnName = "deser_header_${fnName(operationShape, binding)}" return RuntimeType.forInlineFun(fnName, httpSerdeModule) { rustBlock( - "pub fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", + "pub(crate) fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", RuntimeType.http, outputT, headerUtil, @@ -134,7 +155,6 @@ class HttpBindingGenerator( fun generateDeserializePrefixHeaderFn(binding: HttpBindingDescriptor): RuntimeType { check(binding.location == HttpBinding.Location.PREFIX_HEADERS) val outputSymbol = symbolProvider.toSymbol(binding.member) - check(outputSymbol.rustType().stripOuter() is RustType.HashMap) { outputSymbol.rustType() } val target = model.expectShape(binding.member.target) check(target is MapShape) val fnName = "deser_prefix_header_${fnName(operationShape, binding)}" @@ -151,7 +171,7 @@ class HttpBindingGenerator( val returnTypeSymbol = outputSymbol.mapRustType { it.asOptional() } return RuntimeType.forInlineFun(fnName, httpSerdeModule) { rustBlock( - "pub fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", + "pub(crate) fn $fnName(header_map: &#T::HeaderMap) -> std::result::Result<#T, #T::ParseError>", RuntimeType.http, returnTypeSymbol, headerUtil, @@ -162,13 +182,19 @@ class HttpBindingGenerator( let out: std::result::Result<_, _> = headers.map(|(key, header_name)| { let values = header_map.get_all(header_name); #T(values.iter()).map(|v| (key.to_string(), v.expect( - "we have checked there is at least one value for this header name; please file a bug report under https://github.com/awslabs/smithy-rs/issues - "))) + "we have checked there is at least one value for this header name; please file a bug report under https://github.com/awslabs/smithy-rs/issues" + ))) }).collect(); - out.map(Some) """, headerUtil, inner, ) + + for (customization in customizations) { + customization.section( + HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(binding.member), + )(this) + } + rust("out.map(Some)") } } } @@ -221,12 +247,10 @@ class HttpBindingGenerator( private fun RustWriter.bindEventStreamOutput(operationShape: OperationShape, targetShape: UnionShape) { val unmarshallerConstructorFn = EventStreamUnmarshallerGenerator( protocol, - model, - runtimeConfig, - symbolProvider, + codegenContext, operationShape, targetShape, - target, + builderSymbol, ).render() rustTemplate( """ @@ -280,7 +304,7 @@ class HttpBindingGenerator( } } if (targetShape.hasTrait()) { - if (target == CodegenTarget.SERVER) { + if (codegenTarget == CodegenTarget.SERVER) { rust( "Ok(#T::try_from(body_str)?)", symbolProvider.toSymbol(targetShape), @@ -312,19 +336,20 @@ class HttpBindingGenerator( * Parse a value from a header. * This function produces an expression which produces the precise type required by the target shape. */ - private fun RustWriter.deserializeFromHeader(targetType: Shape, memberShape: MemberShape) { - val rustType = symbolProvider.toSymbol(targetType).rustType().stripOuter() + private fun RustWriter.deserializeFromHeader(targetShape: Shape, memberShape: MemberShape) { + val rustType = symbolProvider.toSymbol(targetShape).rustType().stripOuter() // Normally, we go through a flow that looks for `,`s but that's wrong if the output // is just a single string (which might include `,`s.). // MediaType doesn't include `,` since it's base64, send that through the normal path - if (targetType is StringShape && !targetType.hasTrait()) { + if (targetShape is StringShape && !targetShape.hasTrait()) { rust("#T::one_or_none(headers)", headerUtil) return } - val (coreType, coreShape) = if (targetType is CollectionShape) { - rustType.stripOuter() to model.expectShape(targetType.member.target) + val (coreType, coreShape) = if (targetShape is CollectionShape) { + val coreShape = model.expectShape(targetShape.member.target) + symbolProvider.toSymbol(coreShape).rustType() to coreShape } else { - rustType to targetType + rustType to targetShape } val parsedValue = safeName() if (coreType == dateTime) { @@ -336,18 +361,18 @@ class HttpBindingGenerator( ) val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) rust( - "let $parsedValue: Vec<${coreType.render(true)}> = #T::many_dates(headers, #T)?;", + "let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?;", headerUtil, timestampFormatType, ) } else if (coreShape.isPrimitive()) { rust( - "let $parsedValue = #T::read_many_primitive::<${coreType.render(fullyQualified = true)}>(headers)?;", + "let $parsedValue = #T::read_many_primitive::<${coreType.render()}>(headers)?;", headerUtil, ) } else { rust( - "let $parsedValue: Vec<${coreType.render(fullyQualified = true)}> = #T::read_many_from_str(headers)?;", + "let $parsedValue: Vec<${coreType.render()}> = #T::read_many_from_str(headers)?;", headerUtil, ) if (coreShape.hasTrait()) { @@ -386,17 +411,36 @@ class HttpBindingGenerator( }) """, ) - else -> rustTemplate( - """ - if $parsedValue.len() > 1 { - Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len()))) + else -> { + if (targetShape is ListShape) { + // This is a constrained list shape and we must therefore be generating a server SDK. + check(codegenTarget == CodegenTarget.SERVER) + check(rustType is RustType.Opaque) + rust( + """ + Ok(if !$parsedValue.is_empty() { + Some(#T($parsedValue)) + } else { + None + }) + """, + symbolProvider.toSymbol(targetShape), + ) } else { - let mut $parsedValue = $parsedValue; - Ok($parsedValue.pop()) + check(targetShape is SimpleShape) + rustTemplate( + """ + if $parsedValue.len() > 1 { + Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len()))) + } else { + let mut $parsedValue = $parsedValue; + Ok($parsedValue.pop()) + } + """, + "header_util" to headerUtil, + ) } - """, - "header_util" to headerUtil, - ) + } } } @@ -475,16 +519,20 @@ class HttpBindingGenerator( val targetShape = model.expectShape(memberShape.target) val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> - val isListHeader = targetShape is CollectionShape listForEach(targetShape, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) if (innerMemberType.isPrimitive()) { val encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") rust("let mut encoder = #T::from(${autoDeref(innerField)});", encoder) } - val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField, isListHeader) + val formatted = headerFmtFun( + this, + innerMemberType, + memberShape, + innerField, + isListHeader = targetShape is CollectionShape, + ) val safeName = safeName("formatted") write("let $safeName = $formatted;") rustBlock("if !$safeName.is_empty()") { @@ -519,6 +567,11 @@ class HttpBindingGenerator( val valueTargetShape = model.expectShape(targetShape.value.target) ifSet(targetShape, memberSymbol, "&input.$memberName") { field -> + for (customization in customizations) { + customization.section( + HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders(field, targetShape), + )(this) + } rustTemplate( """ for (k, v) in $field { @@ -539,6 +592,7 @@ class HttpBindingGenerator( })?; builder = builder.header(header_name, header_value); } + """, "build_error" to runtimeConfig.operationBuildError(), ) @@ -564,7 +618,7 @@ class HttpBindingGenerator( val func = writer.format(RuntimeType.Base64Encode(runtimeConfig)) "$func(&$targetName)" } else { - quoteValue("AsRef::::as_ref($targetName)") + quoteValue("$targetName.as_str()") } } target.isTimestampShape -> { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt index dafaeea25..c4e75d297 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/RequestBindingGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.pattern.SmithyPattern @@ -12,6 +13,7 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency @@ -24,6 +26,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -45,6 +48,8 @@ fun SmithyPattern.rustFormatString(prefix: String, separator: String): String { return base.dq() } +// TODO(https://github.com/awslabs/smithy-rs/issues/1901) Move to `codegen-client` and update docs. +// `MakeOperationGenerator` needs to be moved to `codegen-client` first, which is not easy. /** * Generates methods to serialize and deserialize requests based on the HTTP trait. Specifically: * 1. `fn update_http_request(builder: http::request::Builder) -> Builder` @@ -62,7 +67,9 @@ class RequestBindingGenerator( private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val httpTrait = protocol.httpBindingResolver.httpTrait(operationShape) - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(symbolProvider) + private val httpBindingGenerator = + HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol) private val index = HttpBindingIndex.of(model) private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") @@ -99,7 +106,7 @@ class RequestBindingGenerator( rust( """ let builder = #{T}(input, builder)?; - """.trimIndent(), + """, addHeadersFn, ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt index 1de4cd289..fde74bbd7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/ResponseBindingGenerator.kt @@ -5,19 +5,27 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +// TODO(https://github.com/awslabs/smithy-rs/issues/1901) Move to `codegen-client` and update docs. +// `MakeOperationGenerator` needs to be moved to `codegen-client` first, which is not easy. class ResponseBindingGenerator( protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, operationShape: OperationShape, ) { - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) + + private val httpBindingGenerator = + HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType = httpBindingGenerator.generateDeserializeHeaderFn(binding) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt index c77ef20c4..94d4cd28f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/MakeOperationGenerator.kt @@ -33,6 +33,8 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.letIf +// TODO(https://github.com/awslabs/smithy-rs/issues/1901): Move to `codegen-client`. + /** Generates the `make_operation` function on input structs */ open class MakeOperationGenerator( protected val codegenContext: CodegenContext, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 5869e77f8..d527c0bc7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -129,8 +130,14 @@ open class AwsJson( override fun additionalRequestHeaders(operationShape: OperationShape): List> = listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}") - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - JsonParserGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + return JsonParserGenerator( + codegenContext, + httpBindingResolver, + ::awsJsonFieldName, + builderSymbolFn(codegenContext.symbolProvider), + ) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsJsonSerializerGenerator(codegenContext, httpBindingResolver) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt index d2bd4eb9f..5bd7c2ab6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt @@ -6,9 +6,11 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.aws.traits.protocols.AwsQueryErrorTrait +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait @@ -19,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.AwsQueryParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.AwsQuerySerializerGenerator @@ -55,8 +58,11 @@ class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - AwsQueryParserGenerator(codegenContext, awsQueryErrors) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return AwsQueryParserGenerator(codegenContext, awsQueryErrors, ::builderSymbol) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsQuerySerializerGenerator(codegenContext) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt index 3f9dca4ca..5f5ab09ef 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt @@ -5,8 +5,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency @@ -16,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.Ec2QueryParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Ec2QuerySerializerGenerator @@ -46,8 +49,11 @@ class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - Ec2QueryParserGenerator(codegenContext, ec2QueryErrors) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return Ec2QueryParserGenerator(codegenContext, ec2QueryErrors, ::builderSymbol) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = Ec2QuerySerializerGenerator(codegenContext) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 7a25fabfd..31c5ddae5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape @@ -20,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator @@ -85,8 +87,11 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { override fun additionalErrorResponseHeaders(errorShape: StructureShape): List> = listOf("x-amzn-errortype" to errorShape.id.name) - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName, ::builderSymbol) + } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt index 268abc0d6..3108b98dd 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt @@ -6,7 +6,9 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -15,6 +17,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.RestXmlParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator @@ -45,7 +48,9 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol { TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return RestXmlParserGenerator(codegenContext, restXmlErrors) + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(codegenContext.symbolProvider) + return RestXmlParserGenerator(codegenContext, restXmlErrors, ::builderSymbol) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt index cb0569c22..a53bc2ab1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -27,10 +29,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType class AwsQueryParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, + builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, + builderSymbol, ) { context, inner -> val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt index c33e25737..f59f2df55 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -25,10 +27,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType class Ec2QueryParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, + builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, + builderSymbol, ) { context, inner -> val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index c476fcdb9..056be65fc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape @@ -26,18 +25,19 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.asType +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticEventStreamUnionTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors @@ -48,19 +48,22 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase class EventStreamUnmarshallerGenerator( private val protocol: Protocol, - private val model: Model, - runtimeConfig: RuntimeConfig, - private val symbolProvider: RustSymbolProvider, + codegenContext: CodegenContext, private val operationShape: OperationShape, private val unionShape: UnionShape, - private val target: CodegenTarget, + /** Function that maps a StructureShape into its builder symbol */ + private val builderSymbol: (StructureShape) -> Symbol, ) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val codegenTarget = codegenContext.target + private val runtimeConfig = codegenContext.runtimeConfig private val unionSymbol = symbolProvider.toSymbol(unionShape) private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) - private val errorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + private val errorSymbol = if (codegenTarget == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream").toSymbol() } else { - unionShape.eventStreamErrorSymbol(model, symbolProvider, target).toSymbol() + unionShape.eventStreamErrorSymbol(model, symbolProvider, codegenTarget).toSymbol() } private val eventStreamSerdeModule = RustModule.private("event_stream_serde") private val codegenScope = arrayOf( @@ -149,7 +152,7 @@ class EventStreamUnmarshallerGenerator( } } rustBlock("_unknown_variant => ") { - when (target.renderUnknownVariant()) { + when (codegenTarget.renderUnknownVariant()) { true -> rustTemplate( "Ok(#{UnmarshalledMessage}::Event(#{Output}::${UnionGenerator.UnknownVariantName}))", "Output" to unionSymbol, @@ -191,7 +194,7 @@ class EventStreamUnmarshallerGenerator( ) } else -> { - rust("let mut builder = #T::builder();", symbolProvider.toSymbol(unionStruct)) + rust("let mut builder = #T::default();", builderSymbol(unionStruct)) val payloadMember = unionStruct.members().firstOrNull { it.hasTrait() } if (payloadMember != null) { renderUnmarshallEventPayload(payloadMember) @@ -225,18 +228,19 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshallEventHeader(member: MemberShape) { - val memberName = symbolProvider.toMemberName(member) - withBlock("builder = builder.$memberName(", ");") { - when (val target = model.expectShape(member.target)) { - is BooleanShape -> rustTemplate("#{expect_fns}::expect_bool(header)?", *codegenScope) - is ByteShape -> rustTemplate("#{expect_fns}::expect_byte(header)?", *codegenScope) - is ShortShape -> rustTemplate("#{expect_fns}::expect_int16(header)?", *codegenScope) - is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope) - is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope) - is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope) - is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope) - is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope) - else -> throw IllegalStateException("unsupported event stream header shape type: $target") + withBlock("builder = builder.${member.setterName()}(", ");") { + conditionalBlock("Some(", ")", member.isOptional) { + when (val target = model.expectShape(member.target)) { + is BooleanShape -> rustTemplate("#{expect_fns}::expect_bool(header)?", *codegenScope) + is ByteShape -> rustTemplate("#{expect_fns}::expect_byte(header)?", *codegenScope) + is ShortShape -> rustTemplate("#{expect_fns}::expect_int16(header)?", *codegenScope) + is IntegerShape -> rustTemplate("#{expect_fns}::expect_int32(header)?", *codegenScope) + is LongShape -> rustTemplate("#{expect_fns}::expect_int64(header)?", *codegenScope) + is BlobShape -> rustTemplate("#{expect_fns}::expect_byte_array(header)?", *codegenScope) + is StringShape -> rustTemplate("#{expect_fns}::expect_string(header)?", *codegenScope) + is TimestampShape -> rustTemplate("#{expect_fns}::expect_timestamp(header)?", *codegenScope) + else -> throw IllegalStateException("unsupported event stream header shape type: $target") + } } } } @@ -259,31 +263,33 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } - val memberName = symbolProvider.toMemberName(member) - withBlock("builder = builder.$memberName(", ");") { - when (target) { - is BlobShape -> { - rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope) - } - is StringShape -> { - rustTemplate( - """ - std::str::from_utf8(message.payload()) - .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? - """, - *codegenScope, - ) - } - is UnionShape, is StructureShape -> { - renderParseProtocolPayload(member) + withBlock("builder = builder.${member.setterName()}(", ");") { + conditionalBlock("Some(", ")", member.isOptional) { + when (target) { + is BlobShape -> { + rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope) + } + is StringShape -> { + rustTemplate( + """ + std::str::from_utf8(message.payload()) + .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? + .to_owned() + """, + *codegenScope, + ) + } + is UnionShape, is StructureShape -> { + renderParseProtocolPayload(member) + } } } } } private fun RustWriter.renderParseProtocolPayload(member: MemberShape) { - val parser = protocol.structuredDataParser(operationShape).payloadParser(member) val memberName = symbolProvider.toMemberName(member) + val parser = protocol.structuredDataParser(operationShape).payloadParser(member) rustTemplate( """ #{parser}(&message.payload()[..]) @@ -297,7 +303,7 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshallError() { - when (target) { + when (codegenTarget) { CodegenTarget.CLIENT -> { rustTemplate( """ @@ -326,12 +332,12 @@ class EventStreamUnmarshallerGenerator( rustBlock("${member.memberName.dq()} $matchOperator ") { // TODO(EventStream): Errors on the operation can be disjoint with errors in the union, // so we need to generate a new top-level Error type for each event stream union. - when (target) { + when (codegenTarget) { CodegenTarget.CLIENT -> { val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) if (parser != null) { - rust("let mut builder = #T::builder();", symbolProvider.toSymbol(target)) + rust("let mut builder = #T::default();", builderSymbol(target)) rustTemplate( """ builder = #{parser}(&message.payload()[..], builder) @@ -354,7 +360,7 @@ class EventStreamUnmarshallerGenerator( val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) val mut = if (parser != null) { " mut" } else { "" } - rust("let$mut builder = #T::builder();", symbolProvider.toSymbol(target)) + rust("let$mut builder = #T::default();", builderSymbol(target)) if (parser != null) { rustTemplate( """ @@ -387,7 +393,7 @@ class EventStreamUnmarshallerGenerator( rust("}") } } - when (target) { + when (codegenTarget) { CodegenTarget.CLIENT -> { rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 43c18287a..a69a8e965 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape @@ -13,6 +14,7 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape @@ -37,12 +39,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedSectionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation @@ -54,16 +57,45 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.utils.StringUtils +/** + * Class describing a JSON parser section that can be used in a customization. + */ +sealed class JsonParserSection(name: String) : Section(name) { + data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember") +} + +/** + * Customization for the JSON parser. + */ +typealias JsonParserCustomization = NamedSectionGenerator + +data class ReturnSymbolToParse(val symbol: Symbol, val isUnconstrained: Boolean) + class JsonParserGenerator( - private val codegenContext: CodegenContext, + codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, /** Function that maps a MemberShape into a JSON field name */ private val jsonName: (MemberShape) -> String, + /** Function that maps a StructureShape into its builder symbol */ + private val builderSymbol: (StructureShape) -> Symbol, + /** + * Whether we should parse a value for a shape into its associated unconstrained type. For example, when the shape + * is a `StructureShape`, we should construct and return a builder instead of building into the final `struct` the + * user gets. This is only relevant for the server, that parses the incoming request and only after enforces + * constraint traits. + * + * The function returns a data class that signals the return symbol that should be parsed, and whether it's + * unconstrained or not. + */ + private val returnSymbolToParse: (Shape) -> ReturnSymbolToParse = { shape -> + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + }, + private val customizations: List = listOf(), ) : StructuredDataParserGenerator { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() private val jsonDeserModule = RustModule.private("json_deser") private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig) @@ -94,14 +126,14 @@ class JsonParserGenerator( */ private fun structureParser( fnName: String, - structureShape: StructureShape, + builderSymbol: Symbol, includedMembers: List, ): RuntimeType { return RuntimeType.forInlineFun(fnName, jsonDeserModule) { val unusedMut = if (includedMembers.isEmpty()) "##[allow(unused_mut)] " else "" rustBlockTemplate( - "pub fn $fnName(value: &[u8], ${unusedMut}mut builder: #{Builder}) -> Result<#{Builder}, #{Error}>", - "Builder" to structureShape.builderSymbol(symbolProvider), + "pub(crate) fn $fnName(value: &[u8], ${unusedMut}mut builder: #{Builder}) -> Result<#{Builder}, #{Error}>", + "Builder" to builderSymbol, *codegenScope, ) { rustTemplate( @@ -159,7 +191,7 @@ class JsonParserGenerator( } val outputShape = operationShape.outputShape(model) val fnName = symbolProvider.deserializeFunctionName(operationShape) - return structureParser(fnName, outputShape, httpDocumentMembers) + return structureParser(fnName, builderSymbol(outputShape), httpDocumentMembers) } override fun errorParser(errorShape: StructureShape): RuntimeType? { @@ -167,13 +199,13 @@ class JsonParserGenerator( return null } val fnName = symbolProvider.deserializeFunctionName(errorShape) + "_json_err" - return structureParser(fnName, errorShape, errorShape.members().toList()) + return structureParser(fnName, builderSymbol(errorShape), errorShape.members().toList()) } private fun orEmptyJson(): RuntimeType = RuntimeType.forInlineFun("or_empty_doc", jsonDeserModule) { rust( """ - pub fn or_empty_doc(data: &[u8]) -> &[u8] { + pub(crate) fn or_empty_doc(data: &[u8]) -> &[u8] { if data.is_empty() { b"{}" } else { @@ -191,7 +223,7 @@ class JsonParserGenerator( } val inputShape = operationShape.inputShape(model) val fnName = symbolProvider.deserializeFunctionName(operationShape) - return structureParser(fnName, inputShape, includedMembers) + return structureParser(fnName, builderSymbol(inputShape), includedMembers) } private fun RustWriter.expectEndOfTokenStream() { @@ -208,8 +240,29 @@ class JsonParserGenerator( rustBlock("match key.to_unescaped()?.as_ref()") { for (member in members) { rustBlock("${jsonName(member).dq()} =>") { - withBlock("builder = builder.${member.setterName()}(", ");") { - deserializeMember(member) + when (codegenTarget) { + CodegenTarget.CLIENT -> { + withBlock("builder = builder.${member.setterName()}(", ");") { + deserializeMember(member) + } + } + CodegenTarget.SERVER -> { + if (symbolProvider.toSymbol(member).isOptional()) { + withBlock("builder = builder.${member.setterName()}(", ");") { + deserializeMember(member) + } + } else { + rust("if let Some(v) = ") + deserializeMember(member) + rust( + """ + { + builder = builder.${member.setterName()}(v); + } + """, + ) + } + } } } } @@ -234,6 +287,9 @@ class JsonParserGenerator( } val symbol = symbolProvider.toSymbol(memberShape) if (symbol.isRustBoxed()) { + for (customization in customizations) { + customization.section(JsonParserSection.BeforeBoxingDeserializedMember(memberShape))(this) + } rust(".map(Box::new)") } } @@ -250,15 +306,8 @@ class JsonParserGenerator( withBlock("$escapedStrName.to_unescaped().map(|u|", ")") { when (target.hasTrait()) { true -> { - if (convertsToEnumInServer(target)) { - rustTemplate( - """ - #{EnumSymbol}::try_from(u.as_ref()) - .map_err(|e| #{Error}::custom(format!("unknown variant {}", e))) - """, - "EnumSymbol" to symbolProvider.toSymbol(target), - *codegenScope, - ) + if (returnSymbolToParse(target).isUnconstrained) { + rust("u.into_owned()") } else { rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target)) } @@ -268,12 +317,8 @@ class JsonParserGenerator( } } - private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait() - private fun RustWriter.deserializeString(target: StringShape) { - // Additional `.transpose()?` because we can't use `?` inside the closures that parsed the string. - val additionalTranspose = ".transpose()?".repeat(if (convertsToEnumInServer(target)) 2 else 1) - withBlockTemplate("#{expect_string_or_null}(tokens.next())?.map(|s|", ")$additionalTranspose", *codegenScope) { + withBlockTemplate("#{expect_string_or_null}(tokens.next())?.map(|s|", ").transpose()?", *codegenScope) { deserializeStringInner(target, "s") } } @@ -287,9 +332,10 @@ class JsonParserGenerator( rustTemplate( """ #{expect_number_or_null}(tokens.next())? - .map(|v| v.try_into()) + .map(#{NumberType}::try_from) .transpose()? """, + "NumberType" to symbolProvider.toSymbol(target), *codegenScope, ) } @@ -311,16 +357,17 @@ class JsonParserGenerator( private fun RustWriter.deserializeCollection(shape: CollectionShape) { val fnName = symbolProvider.deserializeFunctionName(shape) val isSparse = shape.hasTrait() + val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) val parser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { // Allow non-snake-case since some SDK models have lists with names prefixed with `__listOf__`, // which become `__list_of__`, and the Rust compiler warning doesn't like multiple adjacent underscores. rustBlockTemplate( """ - ##[allow(clippy::type_complexity, non_snake_case)] - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + ##[allow(non_snake_case)] + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, - "Shape" to symbolProvider.toSymbol(shape), + "ReturnType" to returnSymbol, *codegenScope, ) { startArrayOrNull { @@ -346,7 +393,11 @@ class JsonParserGenerator( } } } - rust("Ok(Some(items))") + if (returnUnconstrainedType) { + rust("Ok(Some(#{T}(items)))", returnSymbol) + } else { + rust("Ok(Some(items))") + } } } } @@ -357,16 +408,17 @@ class JsonParserGenerator( val keyTarget = model.expectShape(shape.key.target) as StringShape val fnName = symbolProvider.deserializeFunctionName(shape) val isSparse = shape.hasTrait() + val returnSymbolToParse = returnSymbolToParse(shape) val parser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { // Allow non-snake-case since some SDK models have maps with names prefixed with `__mapOf__`, // which become `__map_of__`, and the Rust compiler warning doesn't like multiple adjacent underscores. rustBlockTemplate( """ - ##[allow(clippy::type_complexity, non_snake_case)] - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + ##[allow(non_snake_case)] + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, - "Shape" to symbolProvider.toSymbol(shape), + "ReturnType" to returnSymbolToParse.symbol, *codegenScope, ) { startObjectOrNull { @@ -378,9 +430,6 @@ class JsonParserGenerator( withBlock("let value =", ";") { deserializeMember(shape.value) } - if (convertsToEnumInServer(keyTarget)) { - rust("let key = key?;") - } if (isSparse) { rust("map.insert(key, value);") } else { @@ -389,7 +438,11 @@ class JsonParserGenerator( } } } - rust("Ok(Some(map))") + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(Some(#{T}(map)))", returnSymbolToParse.symbol) + } else { + rust("Ok(Some(map))") + } } } } @@ -398,29 +451,25 @@ class JsonParserGenerator( private fun RustWriter.deserializeStruct(shape: StructureShape) { val fnName = symbolProvider.deserializeFunctionName(shape) - val symbol = symbolProvider.toSymbol(shape) + val returnSymbolToParse = returnSymbolToParse(shape) val nestedParser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { rustBlockTemplate( """ - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, - "Shape" to symbol, + "ReturnType" to returnSymbolToParse.symbol, *codegenScope, ) { startObjectOrNull { Attribute.AllowUnusedMut.render(this) - rustTemplate("let mut builder = #{Shape}::builder();", *codegenScope, "Shape" to symbol) + rustTemplate("let mut builder = #{Builder}::default();", *codegenScope, "Builder" to builderSymbol(shape)) deserializeStructInner(shape.members()) - withBlock("Ok(Some(builder.build()", "))") { - if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { - rustTemplate( - """.map_err(|err| #{Error}::new( - #{ErrorReason}::Custom(format!("{}", err).into()), None) - )?""", - *codegenScope, - ) - } + // Only call `build()` if the builder is not fallible. Otherwise, return the builder. + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(Some(builder))") + } else { + rust("Ok(Some(builder.build()))") } } } @@ -430,15 +479,15 @@ class JsonParserGenerator( private fun RustWriter.deserializeUnion(shape: UnionShape) { val fnName = symbolProvider.deserializeFunctionName(shape) - val symbol = symbolProvider.toSymbol(shape) + val returnSymbolToParse = returnSymbolToParse(shape) val nestedParser = RuntimeType.forInlineFun(fnName, jsonDeserModule) { rustBlockTemplate( """ - pub fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> where I: Iterator, #{Error}>> """, *codegenScope, - "Shape" to symbol, + "Shape" to returnSymbolToParse.symbol, ) { rust("let mut variant = None;") rustBlock("match tokens.next().transpose()?") { @@ -462,14 +511,14 @@ class JsonParserGenerator( for (member in shape.members()) { val variantName = symbolProvider.toMemberName(member) rustBlock("${jsonName(member).dq()} =>") { - withBlock("Some(#T::$variantName(", "))", symbol) { + withBlock("Some(#T::$variantName(", "))", returnSymbolToParse.symbol) { deserializeMember(member) unwrapOrDefaultOrError(member) } } } - when (target.renderUnknownVariant()) { - // in client mode, resolve an unknown union variant to the unknown variant + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. true -> rustTemplate( """ _ => { @@ -477,9 +526,11 @@ class JsonParserGenerator( Some(#{Union}::${UnionGenerator.UnknownVariantName}) } """, - "Union" to symbol, *codegenScope, + "Union" to returnSymbolToParse.symbol, + *codegenScope, ) - // in server mode, use strict parsing + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 false -> rustTemplate( """variant => return Err(#{Error}::custom(format!("unexpected union variant: {}", variant)))""", *codegenScope, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt index ed41cfd85..156d025b9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -17,10 +19,12 @@ import software.amazon.smithy.rust.codegen.core.util.orNull class RestXmlParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, + builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, + builderSymbol, ) { context, inner -> val shapeName = context.outputShapeName // Get the non-synthetic version of the outputShape and check to see if it has the `AllowInvalidXmlRoot` trait diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 841975c22..37b0d1e61 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.aws.traits.customizations.S3UnwrappedXmlOutputTrait import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -41,9 +42,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -71,6 +71,7 @@ data class OperationWrapperContext( class XmlBindingTraitParserGenerator( codegenContext: CodegenContext, private val xmlErrors: RuntimeType, + private val builderSymbol: (shape: StructureShape) -> Symbol, private val writeOperationWrapper: RustWriter.(OperationWrapperContext, OperationInnerWriteable) -> Unit, ) : StructuredDataParserGenerator { @@ -187,7 +188,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - outputShape.builderSymbol(symbolProvider), + builderSymbol(outputShape), xmlError, ) { rustTemplate( @@ -220,7 +221,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - errorShape.builderSymbol(symbolProvider), + builderSymbol(errorShape), xmlError, ) { val members = errorShape.errorXmlMembers() @@ -254,7 +255,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - inputShape.builderSymbol(symbolProvider), + builderSymbol(inputShape), xmlError, ) { rustTemplate( @@ -476,7 +477,7 @@ class XmlBindingTraitParserGenerator( rust("let _ = decoder;") } withBlock("Ok(builder.build()", ")") { - if (StructureGenerator.hasFallibleBuilder(shape, symbolProvider)) { + if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { // NOTE:(rcoh) This branch is unreachable given the current nullability rules. // Only synthetic inputs can have fallible builders, but synthetic inputs can never be parsed // (because they're inputs, only outputs will be parsed!) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 415a0a8de..4ea25612b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.TimestampFormatTrait.Format.EPOCH_SECONDS import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -48,35 +47,37 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait -import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape /** - * Class describing a JSON section that can be used in a customization. + * Class describing a JSON serializer section that can be used in a customization. */ -sealed class JsonSection(name: String) : Section(name) { +sealed class JsonSerializerSection(name: String) : Section(name) { /** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */ - data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSection("ServerError") + data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("ServerError") + + /** Mutate a map prior to it being serialized. **/ + data class BeforeIteratingOverMap(val shape: MapShape, val valueExpression: ValueExpression) : JsonSerializerSection("BeforeIteratingOverMap") /** Mutate the input object prior to finalization. */ - data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSection("InputStruct") + data class InputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("InputStruct") /** Mutate the output object prior to finalization. */ - data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSection("OutputStruct") + data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("OutputStruct") } /** - * JSON customization. + * Customization for the JSON serializer. */ -typealias JsonCustomization = NamedSectionGenerator +typealias JsonSerializerCustomization = NamedSectionGenerator class JsonSerializerGenerator( codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, /** Function that maps a MemberShape into a JSON field name */ private val jsonName: (MemberShape) -> String, - private val customizations: List = listOf(), + private val customizations: List = listOf(), ) : StructuredDataSerializerGenerator { private data class Context( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ @@ -154,7 +155,7 @@ class JsonSerializerGenerator( private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val runtimeConfig = codegenContext.runtimeConfig private val smithyTypes = CargoDependency.SmithyTypes(runtimeConfig).asType() private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() @@ -180,7 +181,7 @@ class JsonSerializerGenerator( fnName: String, structureShape: StructureShape, includedMembers: List, - makeSection: (StructureShape, String) -> JsonSection, + makeSection: (StructureShape, String) -> JsonSerializerSection, ): RuntimeType { return RuntimeType.forInlineFun(fnName, operationSerModule) { rustBlockTemplate( @@ -251,7 +252,7 @@ class JsonSerializerGenerator( rust("let mut out = String::new();") rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) serializeStructure(StructContext("object", "input", inputShape), httpDocumentMembers) - customizations.forEach { it.section(JsonSection.InputStruct(inputShape, "object"))(this) } + customizations.forEach { it.section(JsonSerializerSection.InputStruct(inputShape, "object"))(this) } rust("object.finish();") rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) } @@ -293,7 +294,7 @@ class JsonSerializerGenerator( val outputShape = operationShape.outputShape(model) val fnName = symbolProvider.serializeFunctionName(outputShape) - return serverSerializer(fnName, outputShape, httpDocumentMembers, JsonSection::OutputStruct) + return serverSerializer(fnName, outputShape, httpDocumentMembers, JsonSerializerSection::OutputStruct) } override fun serverErrorSerializer(shape: ShapeId): RuntimeType { @@ -302,7 +303,7 @@ class JsonSerializerGenerator( httpBindingResolver.errorResponseBindings(shape).filter { it.location == HttpLocation.DOCUMENT } .map { it.member } val fnName = symbolProvider.serializeFunctionName(errorShape) - return serverSerializer(fnName, errorShape, includedMembers, JsonSection::ServerError) + return serverSerializer(fnName, errorShape, includedMembers, JsonSerializerSection::ServerError) } private fun RustWriter.serializeStructure( @@ -358,6 +359,7 @@ class JsonSerializerGenerator( private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { val writer = context.writerExpression val value = context.valueExpression + when (target) { is StringShape -> rust("$writer.string(${value.name}.as_str());") is BooleanShape -> rust("$writer.boolean(${value.asValue()});") @@ -430,12 +432,11 @@ class JsonSerializerGenerator( private fun RustWriter.serializeMap(context: Context) { val keyName = safeName("key") val valueName = safeName("value") + for (customization in customizations) { + customization.section(JsonSerializerSection.BeforeIteratingOverMap(context.shape, context.valueExpression))(this) + } rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { - val keyTarget = model.expectShape(context.shape.key.target) - val keyExpression = when (keyTarget.hasTrait()) { - true -> "$keyName.as_str()" - else -> keyName - } + val keyExpression = "$keyName.as_str()" serializeMember(MemberContext.mapMember(context, keyExpression, valueName)) } } @@ -456,7 +457,7 @@ class JsonSerializerGenerator( serializeMember(MemberContext.unionMember(context, "inner", member, jsonName)) } } - if (target.renderUnknownVariant()) { + if (codegenTarget.renderUnknownVariant()) { rustTemplate( "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", "Union" to unionSymbol, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt index 2ca848e07..cf8a77a76 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt @@ -63,7 +63,7 @@ class XmlBindingTraitSerializerGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val model = codegenContext.model private val smithyXml = CargoDependency.smithyXml(runtimeConfig).asType() - private val target = codegenContext.target + private val codegenTarget = codegenContext.target private val codegenScope = arrayOf( "XmlWriter" to smithyXml.member("encode::XmlWriter"), @@ -291,7 +291,14 @@ class XmlBindingTraitSerializerGenerator( private fun RustWriter.serializeRawMember(member: MemberShape, input: String) { when (model.expectShape(member.target)) { is StringShape -> { - rust("$input.as_str()") + // The `input` expression always evaluates to a reference type at this point, but if it does so because + // it's preceded by the `&` operator, calling `as_str()` on it will upset Clippy. + val dereferenced = if (input.startsWith("&")) { + autoDeref(input) + } else { + input + } + rust("$dereferenced.as_str()") } is BooleanShape, is NumberShape -> { rust( @@ -399,7 +406,7 @@ class XmlBindingTraitSerializerGenerator( } } - if (target.renderUnknownVariant()) { + if (codegenTarget.renderUnknownVariant()) { rustTemplate( "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", "Union" to unionSymbol, diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index 9dc1e0592..abe650017 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.StringNode import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape @@ -17,13 +18,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder -import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup @@ -83,15 +84,28 @@ class InstantiatorTest { } """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } - private val symbolProvider = testSymbolProvider(model) - private val runtimeConfig = TestRuntimeConfig + private val codegenContext = testCodegenContext(model) + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + // This is the exact same behavior of the client. + private class BuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { + override fun hasFallibleBuilder(shape: StructureShape) = + BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider) + + override fun setterName(memberShape: MemberShape) = memberShape.setterName() + + override fun doesSetterTakeInOption(memberShape: MemberShape) = true + } + + // This can be empty since the actual behavior is tested in `ClientInstantiatorTest` and `ServerInstantiatorTest`. private fun enumFromStringFn(symbol: Symbol, data: String) = writable { } @Test fun `generate unions`() { val union = model.lookup("com.test#MyUnion") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse("""{ "stringVariant": "ok!" }""") val project = TestWorkspace.testProject() @@ -110,7 +124,8 @@ class InstantiatorTest { @Test fun `generate struct builders`() { val structure = model.lookup("com.test#MyStruct") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse("""{ "bar": 10, "foo": "hello" }""") val project = TestWorkspace.testProject() @@ -134,7 +149,8 @@ class InstantiatorTest { @Test fun `generate builders for boxed structs`() { val structure = model.lookup("com.test#WithBox") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse( """ { @@ -172,7 +188,8 @@ class InstantiatorTest { @Test fun `generate lists`() { val data = Node.parse("""["bar", "foo"]""") - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = + Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { @@ -180,16 +197,21 @@ class InstantiatorTest { withBlock("let result = ", ";") { sut.render(this, model.lookup("com.test#MyList"), data) } - rust("""assert_eq!(result, vec!["bar".to_owned(), "foo".to_owned()]);""") } + project.compileAndTest() } - project.compileAndTest() } @Test fun `generate sparse lists`() { val data = Node.parse(""" [ "bar", "foo", null ] """) - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ::enumFromStringFn, + ) val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { @@ -207,14 +229,20 @@ class InstantiatorTest { fun `generate maps of maps`() { val data = Node.parse( """ - { - "k1": { "map": {} }, - "k2": { "map": { "k3": {} } }, - "k3": { } - } - """, + { + "k1": { "map": {} }, + "k2": { "map": { "k3": {} } }, + "k3": { } + } + """, + ) + val sut = Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ::enumFromStringFn, ) - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) val inner = model.lookup("com.test#Inner") val project = TestWorkspace.testProject() @@ -226,11 +254,11 @@ class InstantiatorTest { } rust( """ - assert_eq!(result.len(), 3); - assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); - assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); - assert_eq!(result.get("k3").unwrap().map, None); - """, + assert_eq!(result.len(), 3); + assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); + assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); + assert_eq!(result.get("k3").unwrap().map, None); + """, ) } } @@ -241,7 +269,13 @@ class InstantiatorTest { fun `blob inputs are binary data`() { // "Parameter values that contain binary data MUST be defined using values // that can be represented in plain text (for example, use "foo" and not "Zm9vCg==")." - val sut = Instantiator(symbolProvider, model, runtimeConfig, ::enumFromStringFn) + val sut = Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ::enumFromStringFn, + ) val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt index cee3fb147..b90543bea 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -45,7 +46,11 @@ class AwsQueryParserGeneratorTest { val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = AwsQueryParserGenerator(codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig)) + val parserGenerator = AwsQueryParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + builderSymbolFn(symbolProvider), + ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt index b8eb1e77b..7b835d822 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -45,7 +46,11 @@ class Ec2QueryParserGeneratorTest { val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = Ec2QueryParserGenerator(codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig)) + val parserGenerator = Ec2QueryParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + builderSymbolFn(symbolProvider), + ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index 049e456d8..5cd898a29 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -6,12 +6,14 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import org.junit.jupiter.api.Test +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName @@ -115,10 +117,14 @@ class JsonParserGeneratorTest { val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider + fun builderSymbol(shape: StructureShape): Symbol = + shape.builderSymbol(symbolProvider) + val parserGenerator = JsonParserGenerator( codegenContext, HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), ::restJsonFieldName, + ::builderSymbol, ) val operationGenerator = parserGenerator.operationParser(model.lookup("test#Op")) val payloadGenerator = parserGenerator.payloadParser(model.lookup("test#OpOutput\$top")) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index c4932fe71..52bb4e1e7 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -96,6 +97,7 @@ internal class XmlBindingTraitParserGeneratorTest { val parserGenerator = XmlBindingTraitParserGenerator( codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + builderSymbolFn(symbolProvider), ) { _, inner -> inner("decoder") } val operationParser = parserGenerator.operationParser(model.lookup("test#Op"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 3d01b403a..e63401daf 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -33,6 +33,7 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") + implementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> @@ -40,13 +41,24 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> CodegenTest("crate#Config", "naming_test_ops", imports = listOf("$commonModels/naming-obstacle-course-ops.smithy")), CodegenTest("naming_obs_structs#NamingObstacleCourseStructs", "naming_test_structs", imports = listOf("$commonModels/naming-obstacle-course-structs.smithy")), CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")), + CodegenTest( + "com.amazonaws.constraints#ConstraintsService", "constraints_without_public_constrained_types", + imports = listOf("$commonModels/constraints.smithy"), + extraConfig = """, "codegen": { "publicConstrainedTypes": false } """, + ), + CodegenTest("com.amazonaws.constraints#ConstraintsService", "constraints", imports = listOf("$commonModels/constraints.smithy")), CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"), CodegenTest("aws.protocoltests.restjson#RestJsonExtras", "rest_json_extras", imports = listOf("$commonModels/rest-json-extras.smithy")), - CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"), + CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation", + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + ), CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"), CodegenTest("aws.protocoltests.misc#MiscService", "misc", imports = listOf("$commonModels/misc.smithy")), - CodegenTest("com.amazonaws.ebs#Ebs", "ebs", imports = listOf("$commonModels/ebs.json")), + CodegenTest("com.amazonaws.ebs#Ebs", "ebs", + imports = listOf("$commonModels/ebs.json"), + extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """, + ), CodegenTest("com.amazonaws.s3#AmazonS3", "s3"), CodegenTest("com.aws.example.rust#PokemonService", "pokemon-service-server-sdk", imports = listOf("$commonModels/pokemon.smithy", "$commonModels/pokemon-common.smithy")), ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt index a2fa85b72..218eb8af1 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS +import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator @@ -65,14 +66,17 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, + constrainedTypes: Boolean = true, ) = // Rename a set of symbols that do not implement `PyClass` and have been wrapped in // `aws_smithy_http_server_python::types`. PythonServerSymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + // Generate public constrained types for directly constrained shapes. + // In the Python server project, this is only done to generate constrained types for simple shapes (e.g. + // a `string` shape with the `length` trait), but these always remain `pub(crate)`. + .let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it } // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { - EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) - } + .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } // Streaming shapes need different derives (e.g. they cannot derive Eq) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 32a6fe388..9c38f6118 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy import software.amazon.smithy.build.PluginContext import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.StringShape @@ -15,18 +16,17 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock -import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator import software.amazon.smithy.rust.codegen.server.smithy.DefaultServerPublicModules import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor +import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader @@ -61,14 +61,44 @@ class PythonServerCodegenVisitor( ) .protocolFor(context.model, service) protocolGeneratorFactory = generator + model = codegenDecorator.transformModel(service, baseModel) - symbolProvider = PythonCodegenServerPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) - // Override `codegenContext` which carries the symbolProvider. - codegenContext = ServerCodegenContext(model, symbolProvider, service, protocol, settings) + // `publicConstrainedTypes` must always be `false` for the Python server, since Python generates its own + // wrapper newtypes. + settings = settings.copy(codegenConfig = settings.codegenConfig.copy(publicConstrainedTypes = false)) + + fun baseSymbolProviderFactory( + model: Model, + serviceShape: ServiceShape, + symbolVisitorConfig: SymbolVisitorConfig, + publicConstrainedTypes: Boolean, + ) = PythonCodegenServerPlugin.baseSymbolProvider(model, serviceShape, symbolVisitorConfig, publicConstrainedTypes) + + val serverSymbolProviders = ServerSymbolProviders.from( + model, + service, + symbolVisitorConfig, + settings.codegenConfig.publicConstrainedTypes, + ::baseSymbolProviderFactory, + ) + + // Override `codegenContext` which carries the various symbol providers. + codegenContext = + ServerCodegenContext( + model, + serverSymbolProviders.symbolProvider, + service, + protocol, + settings, + serverSymbolProviders.unconstrainedShapeSymbolProvider, + serverSymbolProviders.constrainedShapeSymbolProvider, + serverSymbolProviders.constraintViolationSymbolProvider, + serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, + ) // Override `rustCrate` which carries the symbolProvider. - rustCrate = RustCrate(context.fileManifest, symbolProvider, DefaultServerPublicModules, settings.codegenConfig) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, DefaultServerPublicModules, settings.codegenConfig) // Override `protocolGenerator` which carries the symbolProvider. protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -88,13 +118,9 @@ class PythonServerCodegenVisitor( rustCrate.useShapeWriter(shape) { // Use Python specific structure generator that adds the #[pyclass] attribute // and #[pymethods] implementation. - PythonServerStructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER) - val builderGenerator = - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) - builderGenerator.render(this) - implBlock(shape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) - } + PythonServerStructureGenerator(model, codegenContext.symbolProvider, this, shape).render(CodegenTarget.SERVER) + + renderStructureShapeBuilder(shape, this) } } @@ -104,12 +130,9 @@ class PythonServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - logger.info("[rust-server-codegen] Generating an enum $shape") - shape.getTrait()?.also { enum -> - rustCrate.useShapeWriter(shape) { - PythonServerEnumGenerator(model, symbolProvider, this, shape, enum, codegenContext.runtimeConfig).render() - } - } + fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = + PythonServerEnumGenerator(codegenContext, writer, shape) + stringShape(shape, ::pythonServerEnumGeneratorFactory) } /** diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt index b804d12bb..cad7bad67 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt @@ -5,9 +5,7 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -16,10 +14,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator /** @@ -28,13 +25,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGe * some utility functions like `__str__()` and `__repr__()`. */ class PythonServerEnumGenerator( - model: Model, - symbolProvider: RustSymbolProvider, + codegenContext: ServerCodegenContext, private val writer: RustWriter, - private val shape: StringShape, - enumTrait: EnumTrait, - runtimeConfig: RuntimeConfig, -) : ServerEnumGenerator(model, symbolProvider, writer, shape, enumTrait, runtimeConfig) { + shape: StringShape, +) : ServerEnumGenerator(codegenContext, writer, shape) { private val pyo3Symbols = listOf(PythonServerCargoDependency.PyO3.asType()) @@ -48,11 +42,6 @@ class PythonServerEnumGenerator( Attribute.Custom("pyo3::pyclass", symbols = pyo3Symbols).render(writer) } - override fun renderFromForStr() { - renderPyClass() - super.renderFromForStr() - } - private fun renderPyO3Methods() { Attribute.Custom("pyo3::pymethods", symbols = pyo3Symbols).render(writer) writer.rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt new file mode 100644 index 000000000..92ff5faf7 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt @@ -0,0 +1,109 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Models +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.contextName +import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality +import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.locatedIn +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toPascalCase + +/** + * The [ConstrainedShapeSymbolProvider] returns, for a given _directly_ + * constrained shape, a symbol whose Rust type can hold the constrained values. + * + * For all shapes with supported traits directly attached to them, this type is + * a [RustType.Opaque] wrapper tuple newtype holding the inner constrained + * type. + * + * The symbols this symbol provider returns are always public and exposed to + * the end user. + * + * This symbol provider is meant to be used "deep" within the wrapped symbol + * providers chain, just above the core base symbol provider, `SymbolVisitor`. + * + * If the shape is _transitively but not directly_ constrained, use + * [PubCrateConstrainedShapeSymbolProvider] instead, which returns symbols + * whose associated types are `pub(crate)` and thus not exposed to the end + * user. + */ +class ConstrainedShapeSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, +) : WrappingSymbolProvider(base) { + private val nullableIndex = NullableIndex.of(model) + + private fun publicConstrainedSymbolForMapShape(shape: Shape): Symbol { + check(shape is MapShape) + + val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) + return symbolBuilder(shape, rustType).locatedIn(Models).build() + } + + override fun toSymbol(shape: Shape): Symbol { + return when (shape) { + is MemberShape -> { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Member shapes can have constraint traits + // (constraint trait precedence). + val target = model.expectShape(shape.target) + val targetSymbol = this.toSymbol(target) + // Handle boxing first so we end up with `Option>`, not `Box>`. + handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + } + is MapShape -> { + if (shape.isDirectlyConstrained(base)) { + check(shape.hasTrait()) { "Only the `length` constraint trait can be applied to maps" } + publicConstrainedSymbolForMapShape(shape) + } else { + val keySymbol = this.toSymbol(shape.key) + val valueSymbol = this.toSymbol(shape.value) + symbolBuilder(shape, RustType.HashMap(keySymbol.rustType(), valueSymbol.rustType())) + .addReference(keySymbol) + .addReference(valueSymbol) + .build() + } + } + is CollectionShape -> { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Both arms return the same because we haven't + // implemented any constraint trait on collection shapes yet. + if (shape.isDirectlyConstrained(base)) { + val inner = this.toSymbol(shape.member) + symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() + } else { + val inner = this.toSymbol(shape.member) + symbolBuilder(shape, RustType.Vec(inner.rustType())).addReference(inner).build() + } + } + is StringShape -> { + if (shape.isDirectlyConstrained(base)) { + val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) + symbolBuilder(shape, rustType).locatedIn(Models).build() + } else { + base.toSymbol(shape) + } + } + else -> base.toSymbol(shape) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt new file mode 100644 index 000000000..119db9645 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt @@ -0,0 +1,123 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Models +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.contextName +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol + +/** + * The [ConstraintViolationSymbolProvider] returns, for a given constrained + * shape, a symbol whose Rust type can hold information about constraint + * violations that may occur when building the shape from unconstrained values. + * + * So, for example, given the model: + * + * ```smithy + * @pattern("\\w+") + * @length(min: 1, max: 69) + * string NiceString + * + * structure Structure { + * @required + * niceString: NiceString + * } + * ``` + * + * A `NiceString` built from an arbitrary Rust `String` may give rise to at + * most two constraint trait violations: one for `pattern`, one for `length`. + * Similarly, the shape `Structure` can fail to be built when a value for + * `niceString` is not provided. + * + * Said type is always called `ConstraintViolation`, and resides in a bespoke + * module inside the same module as the _public_ constrained type the user is + * exposed to. When the user is _not_ exposed to the constrained type, the + * constraint violation type's module is a child of the `model` module. + * + * It is the responsibility of the caller to ensure that the shape is + * constrained (either directly or transitively) before using this symbol + * provider. This symbol provider intentionally crashes if the shape is not + * constrained. + */ +class ConstraintViolationSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, + private val publicConstrainedTypes: Boolean, +) : WrappingSymbolProvider(base) { + private val constraintViolationName = "ConstraintViolation" + + private fun constraintViolationSymbolForCollectionOrMapOrUnionShape(shape: Shape): Symbol { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + + val symbol = base.toSymbol(shape) + val constraintViolationNamespace = + "${symbol.namespace.let { it.ifEmpty { "crate::${Models.namespace}" } }}::${ + RustReservedWords.escapeIfNeeded( + shape.contextName(serviceShape).toSnakeCase(), + ) + }" + val rustType = RustType.Opaque(constraintViolationName, constraintViolationNamespace) + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(symbol.definitionFile) + .build() + } + + override fun toSymbol(shape: Shape): Symbol { + check(shape.canReachConstrainedShape(model, base)) + + return when (shape) { + is MapShape, is CollectionShape, is UnionShape -> { + constraintViolationSymbolForCollectionOrMapOrUnionShape(shape) + } + is StructureShape -> { + val builderSymbol = shape.serverBuilderSymbol(base, pubCrate = !publicConstrainedTypes) + + val namespace = builderSymbol.namespace + val rustType = RustType.Opaque(constraintViolationName, namespace) + Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(builderSymbol.definitionFile) + .build() + } + is StringShape -> { + val namespace = "crate::${Models.namespace}::${ + RustReservedWords.escapeIfNeeded( + shape.contextName(serviceShape).toSnakeCase(), + ) + }" + val rustType = RustType.Opaque(constraintViolationName, namespace) + Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(Models.filename) + .build() + } + else -> TODO("Constraint traits on other shapes not implemented yet: $shape") + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt new file mode 100644 index 000000000..82102be18 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -0,0 +1,133 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.UniqueItemsTrait +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.hasTrait + +/** + * This file contains utilities to work with constrained shapes. + */ + +/** + * Whether the shape has any trait that could cause a request to be rejected with a constraint violation, _whether + * we support it or not_. + */ +fun Shape.hasConstraintTrait() = + hasTrait() || + hasTrait() || + hasTrait() || + hasTrait() || + hasTrait() || + hasTrait() + +/** + * We say a shape is _directly_ constrained if: + * + * - it has a constraint trait, or; + * - in the case of it being an aggregate shape, one of its member shapes has a constraint trait. + * + * Note that an aggregate shape whose member shapes do not have constraint traits but that has a member whose target is + * a constrained shape is _not_ directly constrained. + * + * At the moment only a subset of constraint traits are implemented on a subset of shapes; that's why we match against + * a subset of shapes in each arm, and check for a subset of constraint traits attached to the shape in the arm's + * (with these subsets being smaller than what [the spec] accounts for). + * + * [the spec]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html + */ +fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when (this) { + is StructureShape -> { + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): + // The only reason why the functions in this file have + // to take in a `SymbolProvider` is because non-`required` blob streaming members are interpreted as + // `required`, so we can't use `member.isOptional` here. + this.members().map { symbolProvider.toSymbol(it) }.any { !it.isOptional() } + } + is MapShape -> this.hasTrait() + is StringShape -> this.hasTrait() || this.hasTrait() + else -> false +} + +fun MemberShape.hasConstraintTraitOrTargetHasConstraintTrait(model: Model, symbolProvider: SymbolProvider): Boolean = + this.isDirectlyConstrained(symbolProvider) || (model.expectShape(this.target).isDirectlyConstrained(symbolProvider)) + +fun Shape.isTransitivelyButNotDirectlyConstrained(model: Model, symbolProvider: SymbolProvider): Boolean = + !this.isDirectlyConstrained(symbolProvider) && this.canReachConstrainedShape(model, symbolProvider) + +fun Shape.canReachConstrainedShape(model: Model, symbolProvider: SymbolProvider): Boolean = + if (this is MemberShape) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not implemented + // yet. Also, note that a walker over a member shape can, perhaps counterintuitively, reach the _containing_ shape, + // so we can't simply delegate to the `else` branch when we implement them. + this.targetCanReachConstrainedShape(model, symbolProvider) + } else { + Walker(model).walkShapes(this).toSet().any { it.isDirectlyConstrained(symbolProvider) } + } + +fun MemberShape.targetCanReachConstrainedShape(model: Model, symbolProvider: SymbolProvider): Boolean = + model.expectShape(this.target).canReachConstrainedShape(model, symbolProvider) + +fun Shape.hasPublicConstrainedWrapperTupleType(model: Model, publicConstrainedTypes: Boolean): Boolean = when (this) { + is MapShape -> publicConstrainedTypes && this.hasTrait() + is StringShape -> !this.hasTrait() && (publicConstrainedTypes && this.hasTrait()) + is MemberShape -> model.expectShape(this.target).hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) + else -> false +} + +fun Shape.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model: Model): Boolean = + hasPublicConstrainedWrapperTupleType(model, true) + +/** + * Helper function to determine whether a shape will map to a _public_ constrained wrapper tuple type. + * + * This function is used in core code generators, so it takes in a [CodegenContext] that is downcast + * to [ServerCodegenContext] when generating servers. + */ +fun workingWithPublicConstrainedWrapperTupleType(shape: Shape, model: Model, publicConstrainedTypes: Boolean): Boolean = + shape.hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) + +/** + * Returns whether a shape's type _name_ contains a non-public type when `publicConstrainedTypes` is `false`. + * + * For example, a `Vec` contains a non-public type, because `crate::model::LengthString` + * is `pub(crate)` when `publicConstrainedTypes` is `false` + * + * Note that a structure shape's type _definition_ may contain non-public types, but its _name_ is always public. + * + * Note how we short-circuit on `publicConstrainedTypes = true`, but we still require it to be passed in instead of laying + * the responsibility on the caller, for API safety usage. + */ +fun Shape.typeNameContainsNonPublicType( + model: Model, + symbolProvider: SymbolProvider, + publicConstrainedTypes: Boolean, +): Boolean = !publicConstrainedTypes && when (this) { + is SimpleShape -> wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model) + is MemberShape -> model.expectShape(this.target).typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + is CollectionShape -> this.canReachConstrainedShape(model, symbolProvider) + is MapShape -> this.canReachConstrainedShape(model, symbolProvider) + is StructureShape, is UnionShape -> false + else -> UNREACHABLE("the above arms should be exhaustive, but we received shape: $this") +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt new file mode 100644 index 000000000..b15b2dc8f --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt @@ -0,0 +1,21 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.traits.LengthTrait + +fun LengthTrait.validationErrorMessage(): String { + val beginning = "Value with length {} at '{}' failed to satisfy constraint: Member must have length " + val ending = if (this.min.isPresent && this.max.isPresent) { + "between ${this.min.get()} and ${this.max.get()}, inclusive" + } else if (this.min.isPresent) ( + "greater than or equal to ${this.min.get()}" + ) else { + check(this.max.isPresent) + "less than or equal to ${this.max.get()}" + } + return "$beginning$ending" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt new file mode 100644 index 000000000..e63e18c7a --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.SimpleShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Constrained +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality +import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * The [PubCrateConstrainedShapeSymbolProvider] returns, for a given + * _transitively but not directly_ constrained shape, a symbol whose Rust type + * can hold the constrained values. + * + * For collection and map shapes, this type is a [RustType.Opaque] wrapper + * tuple newtype holding a container over the inner constrained type. For + * member shapes, it's whatever their target shape resolves to. + * + * The class name is prefixed with `PubCrate` because the symbols it returns + * have associated types that are generated as `pub(crate)`. See the + * `PubCrate*Generator` classes to see how these types are generated. + * + * It is important that this symbol provider does _not_ wrap + * [ConstrainedShapeSymbolProvider], since otherwise it will eventually + * delegate to it and generate a symbol with a `pub` type. + * + * Note simple shapes cannot be transitively and not directly constrained at + * the same time, so this symbol provider is only implemented for aggregate shapes. + * The symbol provider will intentionally crash in such a case to avoid the caller + * incorrectly using it. + * + * Note also that for the purposes of this symbol provider, a member shape is + * transitively but not directly constrained only in the case where it itself + * is not directly constrained and its target also is not directly constrained. + * + * If the shape is _directly_ constrained, use [ConstrainedShapeSymbolProvider] + * instead. + */ +class PubCrateConstrainedShapeSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, +) : WrappingSymbolProvider(base) { + private val nullableIndex = NullableIndex.of(model) + + private fun constrainedSymbolForCollectionOrMapShape(shape: Shape): Symbol { + check(shape is CollectionShape || shape is MapShape) + + val name = constrainedTypeNameForCollectionOrMapShape(shape, serviceShape) + val namespace = "crate::${Constrained.namespace}::${RustReservedWords.escapeIfNeeded(name.toSnakeCase())}" + val rustType = RustType.Opaque(name, namespace) + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(Constrained.filename) + .build() + } + + private fun errorMessage(shape: Shape) = + "This symbol provider was called with $shape. However, it can only be called with a shape that is transitively constrained." + + override fun toSymbol(shape: Shape): Symbol { + require(shape.isTransitivelyButNotDirectlyConstrained(model, base)) { errorMessage(shape) } + + return when (shape) { + is CollectionShape, is MapShape -> { + constrainedSymbolForCollectionOrMapShape(shape) + } + is MemberShape -> { + require(model.expectShape(shape.container).isStructureShape) { + "This arm is only exercised by `ServerBuilderGenerator`" + } + require(!shape.hasConstraintTraitOrTargetHasConstraintTrait(model, base)) { errorMessage(shape) } + + val targetShape = model.expectShape(shape.target) + + if (targetShape is SimpleShape) { + base.toSymbol(shape) + } else { + val targetSymbol = this.toSymbol(targetShape) + // Handle boxing first so we end up with `Option>`, not `Box>`. + handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + } + } + is StructureShape, is UnionShape -> { + // Structure shapes and union shapes always generate a [RustType.Opaque] constrained type. + base.toSymbol(shape) + } + else -> { + check(shape is SimpleShape) + // The rest of the shape types are simple shapes, which are impossible to be transitively but not + // directly constrained; directly constrained shapes generate public constrained types. + PANIC(errorMessage(shape)) + } + } + } +} + +fun constrainedTypeNameForCollectionOrMapShape(shape: Shape, serviceShape: ServiceShape): String { + check(shape is CollectionShape || shape is MapShape) + return "${shape.id.getName(serviceShape).toPascalCase()}Constrained" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt new file mode 100644 index 000000000..05a8d635a --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstraintViolationSymbolProvider.kt @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.rustType + +/** + * This is only used when `publicConstrainedTypes` is `false`. + * + * This must wrap [ConstraintViolationSymbolProvider]. + */ +class PubCrateConstraintViolationSymbolProvider( + private val base: ConstraintViolationSymbolProvider, +) : WrappingSymbolProvider(base) { + override fun toSymbol(shape: Shape): Symbol { + val baseSymbol = base.toSymbol(shape) + // If the shape is a structure shape, the module where its builder is hosted when `publicConstrainedTypes` is + // `false` is already suffixed with `_internal`. + if (shape is StructureShape) { + return baseSymbol + } + val baseRustType = baseSymbol.rustType() + val newNamespace = baseSymbol.namespace + "_internal" + return baseSymbol.toBuilder() + .rustType(RustType.Opaque(baseRustType.name, newNamespace)) + .namespace(newNamespace, baseSymbol.namespaceDelimiter) + .build() + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt index ac63fd9ee..5bbc07343 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt @@ -24,10 +24,12 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser import java.util.logging.Level import java.util.logging.Logger -/** Rust Codegen Plugin - * This is the entrypoint for code generation, triggered by the smithy-build plugin. - * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which - * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. +/** + * Rust Codegen Plugin + * + * This is the entrypoint for code generation, triggered by the smithy-build plugin. + * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which + * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. */ class RustCodegenServerPlugin : SmithyBuildPlugin { private val logger = Logger.getLogger(javaClass.name) @@ -51,8 +53,8 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { } companion object { - /** SymbolProvider - * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider + /** + * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider. * * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered * with other symbol providers, documented inline, to handle the full scope of Smithy types. @@ -61,12 +63,13 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, + constrainedTypes: Boolean = true, ) = SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + // Generate public constrained types for directly constrained shapes. + .let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it } // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { - EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) - } + .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt index 0cc39ac64..a0ad38f04 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt @@ -27,6 +27,10 @@ data class ServerCodegenContext( override val serviceShape: ServiceShape, override val protocol: ShapeId, override val settings: ServerRustSettings, + val unconstrainedShapeSymbolProvider: UnconstrainedShapeSymbolProvider, + val constrainedShapeSymbolProvider: RustSymbolProvider, + val constraintViolationSymbolProvider: ConstraintViolationSymbolProvider, + val pubCrateConstrainedShapeSymbolProvider: PubCrateConstrainedShapeSymbolProvider, ) : CodegenContext( model, symbolProvider, serviceShape, protocol, settings, CodegenTarget.SERVER, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index de10207c5..c0a68a26b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -6,25 +6,33 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeVisitor import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.Constrained import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.Unconstrained import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock @@ -33,13 +41,29 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamN import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.util.CommandFailed -import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.runCommand +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedMapGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedStringGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedTraitForEnumGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.MapConstraintViolationGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.PubCrateConstrainedCollectionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.PubCrateConstrainedMapGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerStructureConstrainedTraitImpl +import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedCollectionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList +import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger val DefaultServerPublicModules = setOf( @@ -47,7 +71,6 @@ val DefaultServerPublicModules = setOf( RustModule.Model, RustModule.Input, RustModule.Output, - RustModule.Config, ).associateBy { it.name } /** @@ -60,15 +83,18 @@ open class ServerCodegenVisitor( ) : ShapeVisitor.Default() { protected val logger = Logger.getLogger(javaClass.name) - protected val settings = ServerRustSettings.from(context.model, context.settings) + protected var settings = ServerRustSettings.from(context.model, context.settings) - protected var symbolProvider: RustSymbolProvider protected var rustCrate: RustCrate private val fileManifest = context.fileManifest protected var model: Model protected var codegenContext: ServerCodegenContext protected var protocolGeneratorFactory: ProtocolGeneratorFactory protected var protocolGenerator: ServerProtocolGenerator + private val unconstrainedModule = + RustModule.private(Unconstrained.namespace, "Unconstrained types for constrained shapes.") + private val constrainedModule = + RustModule.private(Constrained.namespace, "Constrained types for constrained shapes.") init { val symbolVisitorConfig = @@ -77,6 +103,7 @@ open class ServerCodegenVisitor( renameExceptions = false, nullabilityCheckMode = NullableIndex.CheckMode.SERVER, ) + val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) val (protocol, generator) = @@ -88,18 +115,30 @@ open class ServerCodegenVisitor( ) .protocolFor(context.model, service) protocolGeneratorFactory = generator + model = codegenDecorator.transformModel(service, baseModel) - symbolProvider = RustCodegenServerPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) + + val serverSymbolProviders = ServerSymbolProviders.from( + model, + service, + symbolVisitorConfig, + settings.codegenConfig.publicConstrainedTypes, + RustCodegenServerPlugin::baseSymbolProvider, + ) codegenContext = ServerCodegenContext( model, - symbolProvider, + serverSymbolProviders.symbolProvider, service, protocol, settings, + serverSymbolProviders.unconstrainedShapeSymbolProvider, + serverSymbolProviders.constrainedShapeSymbolProvider, + serverSymbolProviders.constraintViolationSymbolProvider, + serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, ) - rustCrate = RustCrate(context.fileManifest, symbolProvider, DefaultServerPublicModules, settings.codegenConfig) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, DefaultServerPublicModules, settings.codegenConfig) protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -117,6 +156,13 @@ open class ServerCodegenVisitor( .let(RecursiveShapeBoxer::transform) // Normalize operations by adding synthetic input and output shapes to every operation .let(OperationNormalizer::transform) + // Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException` + .let(RemoveEbsModelValidationException::transform) + // Attach the `smithy.framework#ValidationException` error to operations whose inputs are constrained, + // if they belong to a service in an allowlist + .let(AttachValidationExceptionToConstrainedOperationInputsInAllowList::transform) + // Tag aggregate shapes reachable from operation input + .let(ShapesReachableFromOperationInputTagger::transform) // Normalize event stream operations .let(EventStreamNormalizer::transform) @@ -139,9 +185,26 @@ open class ServerCodegenVisitor( */ fun execute() { val service = settings.getService(model) - logger.info( + logger.warning( "[rust-server-codegen] Generating Rust server for service $service, protocol ${codegenContext.protocol}", ) + + for (validationResult in listOf( + validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model, + service, + ), + validateUnsupportedConstraints(model, service, codegenContext.settings.codegenConfig), + )) { + for (logMessage in validationResult.messages) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1756): These are getting duplicated. + logger.log(logMessage.level, logMessage.message) + } + if (validationResult.shouldAbort) { + throw CodegenException("Unsupported constraints feature used; see error messages above for resolution") + } + } + val serviceShapes = Walker(model).walkShapes(service) serviceShapes.forEach { it.accept(this) } codegenDecorator.extras(codegenContext, rustCrate) @@ -159,7 +222,7 @@ open class ServerCodegenVisitor( timeout = settings.codegenConfig.formatTimeoutSeconds.toLong(), ) } catch (err: CommandFailed) { - logger.warning( + logger.info( "[rust-server-codegen] Failed to run cargo fmt: [${service.id}]\n${err.output}", ) } @@ -180,12 +243,110 @@ open class ServerCodegenVisitor( override fun structureShape(shape: StructureShape) { logger.info("[rust-server-codegen] Generating a structure $shape") rustCrate.useShapeWriter(shape) { - StructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER) - val builderGenerator = - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) - builderGenerator.render(this) - this.implBlock(shape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) + StructureGenerator(model, codegenContext.symbolProvider, this, shape).render(CodegenTarget.SERVER) + + renderStructureShapeBuilder(shape, this) + } + } + + protected fun renderStructureShapeBuilder( + shape: StructureShape, + writer: RustWriter, + ) { + if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) { + val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape) + serverBuilderGenerator.render(writer) + + if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { + writer.implBlock(shape, codegenContext.symbolProvider) { + serverBuilderGenerator.renderConvenienceMethod(this) + } + } + } + + if (shape.isReachableFromOperationInput()) { + ServerStructureConstrainedTraitImpl( + codegenContext.symbolProvider, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + shape, + writer, + ).render() + } + + if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) { + val serverBuilderGeneratorWithoutPublicConstrainedTypes = + ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape) + serverBuilderGeneratorWithoutPublicConstrainedTypes.render(writer) + + writer.implBlock(shape, codegenContext.symbolProvider) { + serverBuilderGeneratorWithoutPublicConstrainedTypes.renderConvenienceMethod(this) + } + } + } + + override fun listShape(shape: ListShape) = collectionShape(shape) + override fun setShape(shape: SetShape) = collectionShape(shape) + + private fun collectionShape(shape: CollectionShape) { + if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) + ) { + logger.info("[rust-server-codegen] Generating an unconstrained type for collection shape $shape") + rustCrate.withModule(unconstrainedModule) unconstrainedModuleWriter@{ + rustCrate.withModule(ModelsModule) modelsModuleWriter@{ + UnconstrainedCollectionGenerator( + codegenContext, + this@unconstrainedModuleWriter, + this@modelsModuleWriter, + shape, + ).render() + } + } + + logger.info("[rust-server-codegen] Generating a constrained type for collection shape $shape") + rustCrate.withModule(constrainedModule) { + PubCrateConstrainedCollectionGenerator(codegenContext, this, shape).render() + } + } + } + + override fun mapShape(shape: MapShape) { + val renderUnconstrainedMap = + shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) + if (renderUnconstrainedMap) { + logger.info("[rust-server-codegen] Generating an unconstrained type for map $shape") + rustCrate.withModule(unconstrainedModule) { + UnconstrainedMapGenerator(codegenContext, this, shape).render() + } + + if (!shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + logger.info("[rust-server-codegen] Generating a constrained type for map $shape") + rustCrate.withModule(constrainedModule) { + PubCrateConstrainedMapGenerator(codegenContext, this, shape).render() + } + } + } + + val isDirectlyConstrained = shape.isDirectlyConstrained(codegenContext.symbolProvider) + if (isDirectlyConstrained) { + rustCrate.withModule(ModelsModule) { + ConstrainedMapGenerator( + codegenContext, + this, + shape, + if (renderUnconstrainedMap) codegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape) else null, + ).render() + } + } + + if (isDirectlyConstrained || renderUnconstrainedMap) { + rustCrate.withModule(ModelsModule) { + MapConstraintViolationGenerator(codegenContext, this, shape).render() } } } @@ -196,10 +357,36 @@ open class ServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - logger.info("[rust-server-codegen] Generating an enum $shape") - shape.getTrait()?.also { enum -> + fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = + ServerEnumGenerator(codegenContext, writer, shape) + stringShape(shape, ::serverEnumGeneratorFactory) + } + + protected fun stringShape( + shape: StringShape, + enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) -> ServerEnumGenerator, + ) { + if (shape.hasTrait()) { + logger.info("[rust-server-codegen] Generating an enum $shape") rustCrate.useShapeWriter(shape) { - ServerEnumGenerator(model, symbolProvider, this, shape, enum, codegenContext.runtimeConfig).render() + enumShapeGeneratorFactory(codegenContext, this, shape).render() + ConstrainedTraitForEnumGenerator(model, codegenContext.symbolProvider, this, shape).render() + } + } + + if (shape.hasTrait() && shape.hasTrait()) { + logger.warning( + """ + String shape $shape has an `enum` trait and the `length` trait. This is valid according to the Smithy + IDL v1 spec, but it's unclear what the semantics are. In any case, the Smithy core libraries should enforce the + constraints (which it currently does not), not each code generator. + See https://github.com/awslabs/smithy/issues/1121f for more information. + """.trimIndent().replace("\n", " "), + ) + } else if (!shape.hasTrait() && shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + logger.info("[rust-server-codegen] Generating a constrained string $shape") + rustCrate.withModule(ModelsModule) { + ConstrainedStringGenerator(codegenContext, this, shape).render() } } } @@ -212,9 +399,27 @@ open class ServerCodegenVisitor( * This function _does not_ generate any serializers. */ override fun unionShape(shape: UnionShape) { - logger.info("[rust-server-codegen] Generating an union $shape") + logger.info("[rust-server-codegen] Generating an union shape $shape") rustCrate.useShapeWriter(shape) { - UnionGenerator(model, symbolProvider, this, shape, renderUnknownVariant = false).render() + UnionGenerator(model, codegenContext.symbolProvider, this, shape, renderUnknownVariant = false).render() + } + + if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) + ) { + logger.info("[rust-server-codegen] Generating an unconstrained type for union shape $shape") + rustCrate.withModule(unconstrainedModule) unconstrainedModuleWriter@{ + rustCrate.withModule(ModelsModule) modelsModuleWriter@{ + UnconstrainedUnionGenerator( + codegenContext, + this@unconstrainedModuleWriter, + this@modelsModuleWriter, + shape, + ).render() + } + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index 4b7f18ba8..d1d74d80b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -24,9 +24,6 @@ object ServerRuntimeType { fun Router(runtimeConfig: RuntimeConfig) = RuntimeType("Router", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing") - fun RequestSpecModule(runtimeConfig: RuntimeConfig) = - RuntimeType("request_spec", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing") - fun OperationHandler(runtimeConfig: RuntimeConfig) = forInlineDependency(ServerInlineDependency.serverOperationHandler(runtimeConfig)) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt index d9a838960..dbfc8356a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt @@ -74,20 +74,35 @@ data class ServerRustSettings( } } +/** + * [publicConstrainedTypes]: Generate constrained wrapper newtypes for constrained shapes + * [ignoreUnsupportedConstraints]: Generate model even though unsupported constraints are present + */ data class ServerCodegenConfig( override val formatTimeoutSeconds: Int = defaultFormatTimeoutSeconds, override val debugMode: Boolean = defaultDebugMode, + val publicConstrainedTypes: Boolean = defaultPublicConstrainedTypes, + val ignoreUnsupportedConstraints: Boolean = defaultIgnoreUnsupportedConstraints, ) : CoreCodegenConfig( formatTimeoutSeconds, debugMode, ) { companion object { - // Note `node` is unused, because at the moment `ServerCodegenConfig` has the same properties as - // `CodegenConfig`. In the future, the server will have server-specific codegen options just like the client - // does. + private const val defaultPublicConstrainedTypes = true + private const val defaultIgnoreUnsupportedConstraints = false + fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = - ServerCodegenConfig( - formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, - debugMode = coreCodegenConfig.debugMode, - ) + if (node.isPresent) { + ServerCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + publicConstrainedTypes = node.get().getBooleanMemberOrDefault("publicConstrainedTypes", defaultPublicConstrainedTypes), + ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", defaultIgnoreUnsupportedConstraints), + ) + } else { + ServerCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + ) + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt new file mode 100644 index 000000000..e2b77c90f --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig + +/** + * Just a handy class to centralize initialization all the symbol providers required by the server code generators, to + * make the init blocks of the codegen visitors ([ServerCodegenVisitor] and [PythonServerCodegenVisitor]), and the + * unit test setup code, shorter and DRYer. + */ +class ServerSymbolProviders private constructor( + val symbolProvider: RustSymbolProvider, + val unconstrainedShapeSymbolProvider: UnconstrainedShapeSymbolProvider, + val constrainedShapeSymbolProvider: RustSymbolProvider, + val constraintViolationSymbolProvider: ConstraintViolationSymbolProvider, + val pubCrateConstrainedShapeSymbolProvider: PubCrateConstrainedShapeSymbolProvider, +) { + companion object { + fun from( + model: Model, + service: ServiceShape, + symbolVisitorConfig: SymbolVisitorConfig, + publicConstrainedTypes: Boolean, + baseSymbolProviderFactory: (model: Model, service: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, publicConstrainedTypes: Boolean) -> RustSymbolProvider, + ): ServerSymbolProviders { + val baseSymbolProvider = baseSymbolProviderFactory(model, service, symbolVisitorConfig, publicConstrainedTypes) + return ServerSymbolProviders( + symbolProvider = baseSymbolProvider, + constrainedShapeSymbolProvider = baseSymbolProviderFactory( + model, + service, + symbolVisitorConfig, + true, + ), + unconstrainedShapeSymbolProvider = UnconstrainedShapeSymbolProvider( + baseSymbolProviderFactory( + model, + service, + symbolVisitorConfig, + false, + ), + model, service, publicConstrainedTypes, + ), + pubCrateConstrainedShapeSymbolProvider = PubCrateConstrainedShapeSymbolProvider( + baseSymbolProvider, + model, + service, + ), + constraintViolationSymbolProvider = ConstraintViolationSymbolProvider( + baseSymbolProvider, + model, + service, + publicConstrainedTypes, + ), + ) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt new file mode 100644 index 000000000..9fa2182e6 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt @@ -0,0 +1,166 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.Default +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.Unconstrained +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality +import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.setDefault +import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol + +/** + * The [UnconstrainedShapeSymbolProvider] returns, _for a given constrained + * shape_, a symbol whose Rust type can hold the corresponding unconstrained + * values. + * + * For collection and map shapes, this type is a [RustType.Opaque] wrapper + * tuple newtype holding a container over the inner unconstrained type. For + * structure shapes, it's their builder type. For union shapes, it's an enum + * whose variants are the corresponding unconstrained variants. For simple + * shapes, it's whatever the regular base symbol provider returns. + * + * So, for example, given the following model: + * + * ```smithy + * list ListA { + * member: ListB + * } + * + * list ListB { + * member: Structure + * } + * + * structure Structure { + * @required + * string: String + * } + * ``` + * + * `ListB` is not _directly_ constrained, but it is constrained, because it + * holds `Structure`s, that are constrained. So the corresponding unconstrained + * symbol has Rust type `struct + * ListBUnconstrained(std::vec::Vec)`. + * Likewise, `ListA` is also constrained. Its unconstrained symbol has Rust + * type `struct ListAUnconstrained(std::vec::Vec)`. + * + * For an _unconstrained_ shape and for simple shapes, this symbol provider + * delegates to the base symbol provider. It is therefore important that this + * symbol provider _not_ wrap [PublicConstrainedShapeSymbolProvider] (from the + * `codegen-server` subproject), because that symbol provider will return a + * constrained type for shapes that have constraint traits attached. + */ +class UnconstrainedShapeSymbolProvider( + private val base: RustSymbolProvider, + private val model: Model, + private val serviceShape: ServiceShape, + private val publicConstrainedTypes: Boolean, +) : WrappingSymbolProvider(base) { + private val nullableIndex = NullableIndex.of(model) + + private fun unconstrainedSymbolForCollectionOrMapOrUnionShape(shape: Shape): Symbol { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + + val name = unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape, serviceShape) + val namespace = "crate::${Unconstrained.namespace}::${RustReservedWords.escapeIfNeeded(name.toSnakeCase())}" + val rustType = RustType.Opaque(name, namespace) + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(Unconstrained.filename) + .build() + } + + override fun toSymbol(shape: Shape): Symbol = + when (shape) { + is CollectionShape -> { + if (shape.canReachConstrainedShape(model, base)) { + unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) + } else { + base.toSymbol(shape) + } + } + is MapShape -> { + if (shape.canReachConstrainedShape(model, base)) { + unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) + } else { + base.toSymbol(shape) + } + } + is StructureShape -> { + if (shape.canReachConstrainedShape(model, base)) { + shape.serverBuilderSymbol(base, !publicConstrainedTypes) + } else { + base.toSymbol(shape) + } + } + is UnionShape -> { + if (shape.canReachConstrainedShape(model, base)) { + unconstrainedSymbolForCollectionOrMapOrUnionShape(shape) + } else { + base.toSymbol(shape) + } + } + is MemberShape -> { + // There are only two cases where we use this symbol provider on a member shape. + // + // 1. When generating deserializers for HTTP-bound member shapes. See, for example: + // * how [HttpBindingGenerator] generates deserializers for a member shape with the `httpPrefixHeaders` + // trait targeting a map shape of string keys and values; or + // * how [ServerHttpBoundProtocolGenerator] deserializes for a member shape with the `httpQuery` + // trait targeting a collection shape that can reach a constrained shape. + // + // 2. When generating members for unconstrained unions. See [UnconstrainedUnionGenerator]. + if (shape.targetCanReachConstrainedShape(model, base)) { + val targetShape = model.expectShape(shape.target) + val targetSymbol = this.toSymbol(targetShape) + // Handle boxing first so we end up with `Option>`, not `Box>`. + handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + } else { + base.toSymbol(shape) + } + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not + // implemented yet. + } + is StringShape -> { + if (shape.canReachConstrainedShape(model, base)) { + symbolBuilder(shape, RustType.String).setDefault(Default.RustDefault).build() + } else { + base.toSymbol(shape) + } + } + else -> base.toSymbol(shape) + } +} + +/** + * Unconstrained type names are always suffixed with `Unconstrained` for clarity, even though we could dispense with it + * given that they all live inside the `unconstrained` module, so they don't collide with the constrained types. + */ +fun unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape: Shape, serviceShape: ServiceShape): String { + check(shape is CollectionShape || shape is MapShape || shape is UnionShape) + return "${shape.id.getName(serviceShape).toPascalCase()}Unconstrained" +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt new file mode 100644 index 000000000..d487689b2 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -0,0 +1,248 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.SetShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.StreamingTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.model.traits.UniqueItemsTrait +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.orNull +import java.util.logging.Level + +private sealed class UnsupportedConstraintMessageKind { + private val constraintTraitsUberIssue = "https://github.com/awslabs/smithy-rs/issues/1401" + + fun intoLogMessage(ignoreUnsupportedConstraints: Boolean): LogMessage { + fun buildMessage(intro: String, willSupport: Boolean, trackingIssue: String) = + """ + $intro + This is not supported in the smithy-rs server SDK. + ${ if (willSupport) "It will be supported in the future." else "" } + See the tracking issue ($trackingIssue). + If you want to go ahead and generate the server SDK ignoring unsupported constraint traits, set the key `ignoreUnsupportedConstraintTraits` + inside the `runtimeConfig.codegenConfig` JSON object in your `smithy-build.json` to `true`. + """.trimIndent().replace("\n", " ") + + fun buildMessageShapeHasUnsupportedConstraintTrait(shape: Shape, constraintTrait: Trait, trackingIssue: String) = + buildMessage( + "The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached.", + willSupport = true, + trackingIssue, + ) + + val level = if (ignoreUnsupportedConstraints) Level.WARNING else Level.SEVERE + + return when (this) { + is UnsupportedConstraintOnMemberShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, constraintTrait, constraintTraitsUberIssue), + ) + is UnsupportedConstraintOnShapeReachableViaAnEventStream -> LogMessage( + level, + buildMessage( + """ + The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached. + This shape is also part of an event stream; it is unclear what the semantics for constrained shapes in event streams are. + """.trimIndent().replace("\n", " "), + willSupport = false, + "https://github.com/awslabs/smithy/issues/1388", + ), + ) + is UnsupportedLengthTraitOnStreamingBlobShape -> LogMessage( + level, + buildMessage( + """ + The ${shape.type} shape `${shape.id}` has both the `${lengthTrait.toShapeId()}` and `${streamingTrait.toShapeId()}` constraint traits attached. + It is unclear what the semantics for streaming blob shapes are. + """.trimIndent().replace("\n", " "), + willSupport = false, + "https://github.com/awslabs/smithy/issues/1389", + ), + ) + is UnsupportedLengthTraitOnCollectionOrOnBlobShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, lengthTrait, constraintTraitsUberIssue), + ) + is UnsupportedPatternTraitOnStringShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, patternTrait, constraintTraitsUberIssue), + ) + is UnsupportedRangeTraitOnShape -> LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, rangeTrait, constraintTraitsUberIssue), + ) + } + } +} +private data class OperationWithConstrainedInputWithoutValidationException(val shape: OperationShape) +private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() +private data class UnsupportedConstraintOnShapeReachableViaAnEventStream(val shape: Shape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() +private data class UnsupportedLengthTraitOnStreamingBlobShape(val shape: BlobShape, val lengthTrait: LengthTrait, val streamingTrait: StreamingTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedLengthTraitOnCollectionOrOnBlobShape(val shape: Shape, val lengthTrait: LengthTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedPatternTraitOnStringShape(val shape: Shape, val patternTrait: PatternTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait: RangeTrait) : UnsupportedConstraintMessageKind() + +data class LogMessage(val level: Level, val message: String) +data class ValidationResult(val shouldAbort: Boolean, val messages: List) + +private val allConstraintTraits = setOf( + LengthTrait::class.java, + PatternTrait::class.java, + RangeTrait::class.java, + UniqueItemsTrait::class.java, + EnumTrait::class.java, + RequiredTrait::class.java, +) +private val unsupportedConstraintsOnMemberShapes = allConstraintTraits - RequiredTrait::class.java + +fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model: Model, service: ServiceShape): ValidationResult { + // Traverse the model and error out if an operation uses constrained input, but it does not have + // `ValidationException` attached in `errors`. https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): This check will go away once we add support for + // `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. + val walker = Walker(model) + val operationsWithConstrainedInputWithoutValidationExceptionSet = walker.walkShapes(service) + .filterIsInstance() + .asSequence() + .filter { operationShape -> + // Walk the shapes reachable via this operation input. + walker.walkShapes(operationShape.inputShape(model)) + .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } + } + .filter { !it.errors.contains(ShapeId.from("smithy.framework#ValidationException")) } + .map { OperationWithConstrainedInputWithoutValidationException(it) } + .toSet() + + val messages = + operationsWithConstrainedInputWithoutValidationExceptionSet.map { + LogMessage( + Level.SEVERE, + """ + Operation ${it.shape.id} takes in input that is constrained + (https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html), and as such can fail with a validation + exception. You must model this behavior in the operation shape in your model file. + """.trimIndent().replace("\n", "") + + """ + + ```smithy + use smithy.framework#ValidationException + + operation ${it.shape.id.name} { + ... + errors: [..., ValidationException] // <-- Add this. + } + ``` + """.trimIndent(), + ) + } + + return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) +} + +fun validateUnsupportedConstraints(model: Model, service: ServiceShape, codegenConfig: ServerCodegenConfig): ValidationResult { + // Traverse the model and error out if: + val walker = Walker(model) + + // 1. Constraint traits on member shapes are used. [Constraint trait precedence] has not been implemented yet. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + // [Constraint trait precedence]: https://awslabs.github.io/smithy/2.0/spec/model.html#applying-traits + val unsupportedConstraintOnMemberShapeSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filterMapShapesToTraits(unsupportedConstraintsOnMemberShapes) + .map { (shape, trait) -> UnsupportedConstraintOnMemberShape(shape as MemberShape, trait) } + .toSet() + + // 2. Constraint traits on streaming blob shapes are used. Their semantics are unclear. + // TODO(https://github.com/awslabs/smithy/issues/1389) + val unsupportedLengthTraitOnStreamingBlobShapeSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filter { it.hasTrait() && it.hasTrait() } + .map { UnsupportedLengthTraitOnStreamingBlobShape(it, it.expectTrait(), it.expectTrait()) } + .toSet() + + // 3. Constraint traits in event streams are used. Their semantics are unclear. + // TODO(https://github.com/awslabs/smithy/issues/1388) + val unsupportedConstraintOnShapeReachableViaAnEventStreamSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filter { it.hasTrait() } + .flatMap { walker.walkShapes(it) } + .filterMapShapesToTraits(allConstraintTraits) + .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } + .toSet() + + // 4. Length trait on collection shapes or on blob shapes is used. It has not been implemented yet for these target types. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + val unsupportedLengthTraitOnCollectionOrOnBlobShapeSet = walker + .walkShapes(service) + .asSequence() + .filter { it is CollectionShape || it is BlobShape } + .filter { it.hasTrait() } + .map { UnsupportedLengthTraitOnCollectionOrOnBlobShape(it, it.expectTrait()) } + .toSet() + + // 5. Pattern trait on string shapes is used. It has not been implemented yet. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + val unsupportedPatternTraitOnStringShapeSet = walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filterMapShapesToTraits(setOf(PatternTrait::class.java)) + .map { (shape, patternTrait) -> UnsupportedPatternTraitOnStringShape(shape, patternTrait as PatternTrait) } + .toSet() + + // 6. Range trait on any shape is used. It has not been implemented yet. + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) + val unsupportedRangeTraitOnShapeSet = walker + .walkShapes(service) + .asSequence() + .filterMapShapesToTraits(setOf(RangeTrait::class.java)) + .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } + .toSet() + + val messages = + unsupportedConstraintOnMemberShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedLengthTraitOnStreamingBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedConstraintOnShapeReachableViaAnEventStreamSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedLengthTraitOnCollectionOrOnBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedPatternTraitOnStringShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + unsupportedRangeTraitOnShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) +} + +/** + * Returns a sequence over pairs `(shape, trait)`. + * The returned sequence contains one pair per shape in the input iterable that has attached a trait contained in [traits]. + */ +private fun Sequence.filterMapShapesToTraits(traits: Set>): Sequence> = + this.map { shape -> shape to traits.mapNotNull { shape.getTrait(it).orNull() } } + .flatMap { (shape, traits) -> traits.map { shape to it } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt new file mode 100644 index 000000000..820f5bc8b --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapJsonCustomization.kt @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType + +/** + * A customization to, just before we iterate over a _constrained_ map shape in a JSON serializer, unwrap the wrapper + * newtype and take a shared reference to the actual `std::collections::HashMap` within it. + */ +class BeforeIteratingOverMapJsonCustomization(private val codegenContext: ServerCodegenContext) : JsonSerializerCustomization() { + override fun section(section: JsonSerializerSection): Writable = when (section) { + is JsonSerializerSection.BeforeIteratingOverMap -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + // Note that this particular implementation just so happens to work because when the customization + // is invoked in the JSON serializer, the value expression is guaranteed to be a variable binding name. + // If the expression in the future were to be more complex, we wouldn't be able to write the left-hand + // side of this assignment. + rust("""let ${section.valueExpression.name} = &${section.valueExpression.name}.0;""") + } + } + else -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt new file mode 100644 index 000000000..677350dd6 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +/** + * [ConstrainedMapGenerator] generates a wrapper tuple newtype holding a constrained `std::collections::HashMap`. + * This type can be built from unconstrained values, yielding a `ConstraintViolation` when the input does not satisfy + * the constraints. + * + * The [`length` trait] is the only constraint trait applicable to map shapes. + * + * If [unconstrainedSymbol] is provided, the `MaybeConstrained` trait is implemented for the constrained type, using the + * [unconstrainedSymbol]'s associated type as the associated type for the trait. + * + * [`length` trait]: https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait + */ +class ConstrainedMapGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: MapShape, + private val unconstrainedSymbol: Symbol? = null, +) { + private val model = codegenContext.model + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val symbolProvider = codegenContext.symbolProvider + + fun render() { + // The `length` trait is the only constraint trait applicable to map shapes. + val lengthTrait = shape.expectTrait() + + val name = constrainedShapeSymbolProvider.toSymbol(shape).name + val inner = "std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>" + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + + val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { + "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" + } else if (lengthTrait.min.isPresent) { + "${lengthTrait.min.get()} <= length" + } else { + "length <= ${lengthTrait.max.get()}" + } + + val constrainedTypeVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + val constrainedTypeMetadata = RustMetadata( + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq)), + visibility = constrainedTypeVisibility, + ) + + val codegenScope = arrayOf( + "KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)), + "ValueSymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.value.target)), + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolation, + ) + + writer.documentShape(shape, model, note = rustDocsNote(name)) + constrainedTypeMetadata.render(writer) + writer.rustTemplate("struct $name(pub(crate) $inner);", *codegenScope) + if (constrainedTypeVisibility == Visibility.PUBCRATE) { + Attribute.AllowUnused.render(writer) + } + writer.rustTemplate( + """ + impl $name { + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + + /// ${rustDocsIntoInnerMethod(inner)} + pub fn into_inner(self) -> $inner { + self.0 + } + } + + impl #{TryFrom}<$inner> for $name { + type Error = #{ConstraintViolation}; + + /// ${rustDocsTryFromMethod(name, inner)} + fn try_from(value: $inner) -> Result { + let length = value.len(); + if $condition { + Ok(Self(value)) + } else { + Err(#{ConstraintViolation}::Length(length)) + } + } + } + + impl #{From}<$name> for $inner { + fn from(value: $name) -> Self { + value.into_inner() + } + } + """, + *codegenScope, + ) + + if (!publicConstrainedTypes && isValueConstrained(shape, model, symbolProvider)) { + writer.rustTemplate( + """ + impl #{From}<$name> for #{FullyUnconstrainedSymbol} { + fn from(value: $name) -> Self { + value + .into_inner() + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect() + } + } + """, + *codegenScope, + "FullyUnconstrainedSymbol" to symbolProvider.toSymbol(shape), + ) + } + + if (unconstrainedSymbol != null) { + writer.rustTemplate( + """ + impl #{ConstrainedTrait} for $name { + type Unconstrained = #{UnconstrainedSymbol}; + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "UnconstrainedSymbol" to unconstrainedSymbol, + ) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt new file mode 100644 index 000000000..fb5ce1dae --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained + +/** + * Common helper functions used in [UnconstrainedMapGenerator] and [MapConstraintViolationGenerator]. + */ + +fun isKeyConstrained(shape: StringShape, symbolProvider: SymbolProvider) = shape.isDirectlyConstrained(symbolProvider) + +fun isValueConstrained(shape: Shape, model: Model, symbolProvider: SymbolProvider): Boolean = + shape.canReachConstrainedShape(model, symbolProvider) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt new file mode 100644 index 000000000..d0d447cda --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt @@ -0,0 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +/** + * Functions shared amongst the constrained shape generators, to keep them DRY and consistent. + */ + +fun rustDocsNote(typeName: String) = + "this is a constrained type because its corresponding modeled Smithy shape has one or more " + + "[constraint traits]. Use [`parse`] or [`$typeName::TryFrom`] to construct values of this type." + + "[constraint traits]: https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html" + +fun rustDocsTryFromMethod(typeName: String, inner: String) = + "Constructs a `$typeName` from an [`$inner`], failing when the provided value does not satisfy the modeled constraints." + +fun rustDocsInnerMethod(inner: String) = + "Returns an immutable reference to the underlying [`$inner`]." + +fun rustDocsIntoInnerMethod(inner: String) = + "Consumes the value, returning the underlying [`$inner`]." diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt new file mode 100644 index 000000000..a63801cce --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -0,0 +1,183 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +/** + * [ConstrainedStringGenerator] generates a wrapper tuple newtype holding a constrained `String`. + * This type can be built from unconstrained values, yielding a `ConstraintViolation` when the input does not satisfy + * the constraints. + */ +class ConstrainedStringGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: StringShape, +) { + val model = codegenContext.model + val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + val lengthTrait = shape.expectTrait() + + val symbol = constrainedShapeSymbolProvider.toSymbol(shape) + val name = symbol.name + val inner = RustType.String.render() + val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) + + val condition = if (lengthTrait.min.isPresent && lengthTrait.max.isPresent) { + "(${lengthTrait.min.get()}..=${lengthTrait.max.get()}).contains(&length)" + } else if (lengthTrait.min.isPresent) { + "${lengthTrait.min.get()} <= length" + } else { + "length <= ${lengthTrait.max.get()}" + } + + val constrainedTypeVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + val constrainedTypeMetadata = RustMetadata( + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.Clone, RuntimeType.PartialEq, RuntimeType.Eq, RuntimeType.Hash)), + visibility = constrainedTypeVisibility, + ) + + // Note that we're using the linear time check `chars().count()` instead of `len()` on the input value, since the + // Smithy specification says the `length` trait counts the number of Unicode code points when applied to string shapes. + // https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait + writer.documentShape(shape, model, note = rustDocsNote(name)) + constrainedTypeMetadata.render(writer) + writer.rust("struct $name(pub(crate) $inner);") + if (constrainedTypeVisibility == Visibility.PUBCRATE) { + Attribute.AllowUnused.render(writer) + } + writer.rustTemplate( + """ + impl $name { + /// Extracts a string slice containing the entire underlying `String`. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// ${rustDocsInnerMethod(inner)} + pub fn inner(&self) -> &$inner { + &self.0 + } + + /// ${rustDocsIntoInnerMethod(inner)} + pub fn into_inner(self) -> $inner { + self.0 + } + } + + impl #{ConstrainedTrait} for $name { + type Unconstrained = $inner; + } + + impl #{From}<$inner> for #{MaybeConstrained} { + fn from(value: $inner) -> Self { + Self::Unconstrained(value) + } + } + + impl #{Display} for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ${shape.redactIfNecessary(model, "self.0")}.fmt(f) + } + } + + impl #{TryFrom}<$inner> for $name { + type Error = #{ConstraintViolation}; + + /// ${rustDocsTryFromMethod(name, inner)} + fn try_from(value: $inner) -> Result { + let length = value.chars().count(); + if $condition { + Ok(Self(value)) + } else { + Err(#{ConstraintViolation}::Length(length)) + } + } + } + + impl #{From}<$name> for $inner { + fn from(value: $name) -> Self { + value.into_inner() + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "ConstraintViolation" to constraintViolation, + "MaybeConstrained" to symbol.makeMaybeConstrained(), + "Display" to RuntimeType.Display, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + ) + + val constraintViolationModuleName = constraintViolation.namespace.split(constraintViolation.namespaceDelimiter).last() + writer.withModule(RustModule(constraintViolationModuleName, RustMetadata(visibility = constrainedTypeVisibility))) { + rust( + """ + ##[derive(Debug, PartialEq)] + pub enum ${constraintViolation.name} { + Length(usize), + } + """, + ) + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl ${constraintViolation.name}") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${lengthTrait.validationErrorMessage()}", length, &path), + path, + }, + """, + ) + } + } + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt new file mode 100644 index 000000000..288065d75 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedTraitForEnumGenerator.kt @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.util.expectTrait + +/** + * [ConstrainedTraitForEnumGenerator] generates code that implements the [RuntimeType.ConstrainedTrait] trait on an + * enum shape. + */ +class ConstrainedTraitForEnumGenerator( + val model: Model, + val symbolProvider: RustSymbolProvider, + val writer: RustWriter, + val shape: StringShape, +) { + fun render() { + shape.expectTrait() + + val symbol = symbolProvider.toSymbol(shape) + val name = symbol.name + val unconstrainedType = "String" + + writer.rustTemplate( + """ + impl #{ConstrainedTrait} for $name { + type Unconstrained = $unconstrainedType; + } + + impl From<$unconstrainedType> for #{MaybeConstrained} { + fn from(value: $unconstrainedType) -> Self { + Self::Unconstrained(value) + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "MaybeConstrained" to symbol.makeMaybeConstrained(), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt new file mode 100644 index 000000000..684d83322 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt @@ -0,0 +1,121 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +class MapConstraintViolationGenerator( + codegenContext: ServerCodegenContext, + private val modelsModuleWriter: RustWriter, + val shape: MapShape, +) { + private val model = codegenContext.model + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val symbolProvider = codegenContext.symbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + val valueShape = model.expectShape(shape.value.target) + val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + val constraintViolationName = constraintViolationSymbol.name + + val constraintViolationCodegenScopeMutableList: MutableList> = mutableListOf() + if (isKeyConstrained(keyShape, symbolProvider)) { + constraintViolationCodegenScopeMutableList.add("KeyConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(keyShape)) + } + if (isValueConstrained(valueShape, model, symbolProvider)) { + constraintViolationCodegenScopeMutableList.add("ValueConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(valueShape)) + constraintViolationCodegenScopeMutableList.add("KeySymbol" to constrainedShapeSymbolProvider.toSymbol(keyShape)) + } + val constraintViolationCodegenScope = constraintViolationCodegenScopeMutableList.toTypedArray() + + val constraintViolationVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + modelsModuleWriter.withModule( + RustModule( + constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), + RustMetadata(visibility = constraintViolationVisibility), + ), + ) { + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) We should really have two `ConstraintViolation` + // types here. One will just have variants for each constraint trait on the map shape, for use by the user. + // The other one will have variants if the shape's key or value is directly or transitively constrained, + // and is for use by the framework. + rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationName { + ${if (shape.hasTrait()) "Length(usize)," else ""} + ${if (isKeyConstrained(keyShape, symbolProvider)) "##[doc(hidden)] Key(#{KeyConstraintViolationSymbol})," else ""} + ${if (isValueConstrained(valueShape, model, symbolProvider)) "##[doc(hidden)] Value(#{KeySymbol}, #{ValueConstraintViolationSymbol})," else ""} + } + """, + *constraintViolationCodegenScope, + ) + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl $constraintViolationName") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + shape.getTrait()?.also { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.validationErrorMessage()}", length, &path), + path, + }, + """, + ) + } + if (isKeyConstrained(keyShape, symbolProvider)) { + // Note how we _do not_ append the key's member name to the path. This is intentional, as + // per the `RestJsonMalformedLengthMapKey` test. Note keys are always strings. + // https://github.com/awslabs/smithy/blob/ee0b4ff90daaaa5101f32da936c25af8c91cc6e9/smithy-aws-protocol-tests/model/restJson1/validation/malformed-length.smithy#L296-L295 + rust("""Self::Key(key_constraint_violation) => key_constraint_violation.as_validation_exception_field(path),""") + } + if (isValueConstrained(valueShape, model, symbolProvider)) { + // `as_str()` works with regular `String`s and constrained string shapes. + rust("""Self::Value(key, value_constraint_violation) => value_constraint_violation.as_validation_exception_field(path + "/" + key.as_str()),""") + } + } + } + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt new file mode 100644 index 000000000..b789c2166 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.isTransitivelyButNotDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPublicType + +/** + * A generator for a wrapper tuple newtype over a collection shape's symbol + * type. + * + * This newtype is for a collection shape that is _transitively_ constrained, + * but not directly. That is, the collection shape does not have a constraint + * trait attached, but the members it holds reach a constrained shape. The + * generated newtype is therefore `pub(crate)`, as the class name indicates, + * and is not available to end users. After deserialization, upon constraint + * traits' enforcement, this type is converted into the regular `Vec` the user + * sees via the generated converters. + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1401) If the collection + * shape is _directly_ constrained, use [ConstrainedCollectionGenerator] + * instead. + */ +class PubCrateConstrainedCollectionGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: CollectionShape, +) { + private val model = codegenContext.model + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val symbolProvider = codegenContext.symbolProvider + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val symbol = symbolProvider.toSymbol(shape) + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + + val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + val moduleName = constrainedSymbol.namespace.split(constrainedSymbol.namespaceDelimiter).last() + val name = constrainedSymbol.name + val innerShape = model.expectShape(shape.member.target) + val innerConstrainedSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape) + } else { + constrainedShapeSymbolProvider.toSymbol(innerShape) + } + + val codegenScope = arrayOf( + "InnerConstrainedSymbol" to innerConstrainedSymbol, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "UnconstrainedSymbol" to unconstrainedSymbol, + "Symbol" to symbol, + "From" to RuntimeType.From, + ) + + writer.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerConstrainedSymbol}>); + + impl #{ConstrainedTrait} for $name { + type Unconstrained = #{UnconstrainedSymbol}; + } + """, + *codegenScope, + ) + + if (publicConstrainedTypes) { + // If the target member shape is itself _not_ directly constrained, and is an aggregate non-Structure shape, + // then its corresponding constrained type is the `pub(crate)` wrapper tuple type, which needs converting into + // the public type the user is exposed to. The two types are isomorphic, and we can convert between them using + // `From`. So we track this particular case here in order to iterate over the list's members and convert + // each of them. + // + // Note that we could add the iteration code unconditionally and it would still be correct, but the `into()` calls + // would be useless. Clippy flags this as [`useless_conversion`]. We could deactivate the lint, but it's probably + // best that we just don't emit a useless iteration, lest the compiler not optimize it away (see [Godbolt]), + // and to make the generated code a little bit simpler. + // + // [`useless_conversion`]: https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion. + // [Godbolt]: https://godbolt.org/z/eheWebWMa + val innerNeedsConstraining = + !innerShape.isDirectlyConstrained(symbolProvider) && (innerShape is CollectionShape || innerShape is MapShape) + + rustTemplate( + """ + impl #{From}<#{Symbol}> for $name { + fn from(v: #{Symbol}) -> Self { + ${ if (innerNeedsConstraining) { + "Self(v.into_iter().map(|item| item.into()).collect())" + } else { + "Self(v)" + } } + } + } + + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (innerNeedsConstraining) { + "v.0.into_iter().map(|item| item.into()).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } else { + val innerNeedsConversion = innerShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + + rustTemplate( + """ + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (innerNeedsConversion) { + "v.0.into_iter().map(|item| item.into()).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt new file mode 100644 index 000000000..591b11b7e --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt @@ -0,0 +1,142 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.isTransitivelyButNotDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPublicType + +/** + * A generator for a wrapper tuple newtype over a map shape's symbol type. + * + * This newtype is for a map shape that is _transitively_ constrained, but not + * directly. That is, the map shape does not have a constraint trait attached, + * but the keys and/or values it holds reach a constrained shape. The generated + * newtype is therefore `pub(crate)`, as the class name indicates, and is not + * available to end users. After deserialization, upon constraint traits' + * enforcement, this type is converted into the regular `HashMap` the user sees + * via the generated converters. + * + * If the map shape is _directly_ constrained, use [ConstrainedMapGenerator] + * instead. + */ +class PubCrateConstrainedMapGenerator( + val codegenContext: ServerCodegenContext, + val writer: RustWriter, + val shape: MapShape, +) { + private val model = codegenContext.model + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val symbolProvider = codegenContext.symbolProvider + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val symbol = symbolProvider.toSymbol(shape) + val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + val moduleName = constrainedSymbol.namespace.split(constrainedSymbol.namespaceDelimiter).last() + val name = constrainedSymbol.name + val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + val valueShape = model.expectShape(shape.value.target) + val keySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape) + val valueSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape) + } else { + constrainedShapeSymbolProvider.toSymbol(valueShape) + } + + val codegenScope = arrayOf( + "KeySymbol" to keySymbol, + "ValueSymbol" to valueSymbol, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "UnconstrainedSymbol" to unconstrainedSymbol, + "Symbol" to symbol, + "From" to RuntimeType.From, + ) + + writer.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>); + + impl #{ConstrainedTrait} for $name { + type Unconstrained = #{UnconstrainedSymbol}; + } + """, + *codegenScope, + ) + + if (publicConstrainedTypes) { + // Unless the map holds an aggregate shape as its value shape whose symbol's type is _not_ `pub(crate)`, the + // `.into()` calls are useless. + // See the comment in [ConstrainedCollectionShape] for a more detailed explanation. + val innerNeedsConstraining = + !valueShape.isDirectlyConstrained(symbolProvider) && (valueShape is CollectionShape || valueShape is MapShape) + + rustTemplate( + """ + impl #{From}<#{Symbol}> for $name { + fn from(v: #{Symbol}) -> Self { + ${ if (innerNeedsConstraining) { + "Self(v.into_iter().map(|(k, v)| (k, v.into())).collect())" + } else { + "Self(v)" + } } + } + } + + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (innerNeedsConstraining) { + "v.0.into_iter().map(|(k, v)| (k, v.into())).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } else { + val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + val valueNeedsConversion = valueShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + + rustTemplate( + """ + impl #{From}<$name> for #{Symbol} { + fn from(v: $name) -> Self { + ${ if (keyNeedsConversion || valueNeedsConversion) { + val keyConversion = if (keyNeedsConversion) { ".into()" } else { "" } + val valueConversion = if (valueNeedsConversion) { ".into()" } else { "" } + "v.0.into_iter().map(|(k, v)| (k$keyConversion, v$valueConversion)).collect()" + } else { + "v.0" + } } + } + } + """, + *codegenScope, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt new file mode 100644 index 000000000..f2c572eed --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt @@ -0,0 +1,218 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape + +/** + * Renders constraint violation types that arise when building a structure shape builder. + * + * Used by [ServerBuilderGenerator] and [ServerBuilderGeneratorWithoutPublicConstrainedTypes]. + */ +class ServerBuilderConstraintViolations( + codegenContext: ServerCodegenContext, + private val shape: StructureShape, + private val builderTakesInUnconstrainedTypes: Boolean, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val members: List = shape.allMembers.values.toList() + val all = members.flatMap { member -> + listOfNotNull( + forMember(member), + builderConstraintViolationForMember(member), + ) + } + + fun render( + writer: RustWriter, + visibility: Visibility, + nonExhaustive: Boolean, + shouldRenderAsValidationExceptionFieldList: Boolean, + ) { + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.PartialEq)).render(writer) + writer.docs("Holds one variant for each of the ways the builder can fail.") + if (nonExhaustive) Attribute.NonExhaustive.render(writer) + val constraintViolationSymbolName = constraintViolationSymbolProvider.toSymbol(shape).name + writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationSymbolName") { + renderConstraintViolations(writer) + } + renderImplDisplayConstraintViolation(writer) + writer.rust("impl #T for ConstraintViolation { }", RuntimeType.StdError) + + if (shouldRenderAsValidationExceptionFieldList) { + renderAsValidationExceptionFieldList(writer) + } + } + + /** + * Returns the builder failure associated with the `member` field if its target is constrained. + */ + fun builderConstraintViolationForMember(member: MemberShape) = + if (builderTakesInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)) { + ConstraintViolation(member, ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE) + } else { + null + } + + /** + * Returns the builder failure associated with the [member] field if it is `required`. + */ + fun forMember(member: MemberShape): ConstraintViolation? { + check(members.contains(member)) + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. + return if (symbolProvider.toSymbol(member).isOptional()) { + null + } else { + ConstraintViolation(member, ConstraintViolationKind.MISSING_MEMBER) + } + } + + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) This impl does not take into account the `sensitive` trait. + // When constraint violation error messages are adjusted to match protocol tests, we should ensure it's honored. + private fun renderImplDisplayConstraintViolation(writer: RustWriter) { + writer.rustBlock("impl #T for ConstraintViolation", RuntimeType.Display) { + rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { + rustBlock("match self") { + all.forEach { + val arm = if (it.hasInner()) { + "ConstraintViolation::${it.name()}(_)" + } else { + "ConstraintViolation::${it.name()}" + } + rust("""$arm => write!(f, "${it.message(symbolProvider, model)}"),""") + } + } + } + } + } + + private fun renderConstraintViolations(writer: RustWriter) { + for (constraintViolation in all) { + when (constraintViolation.kind) { + ConstraintViolationKind.MISSING_MEMBER -> { + writer.docs("${constraintViolation.message(symbolProvider, model).replaceFirstChar { it.uppercaseChar() }}.") + writer.rust("${constraintViolation.name()},") + } + + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> { + val targetShape = model.expectShape(constraintViolation.forMember.target) + + val constraintViolationSymbol = + constraintViolationSymbolProvider.toSymbol(targetShape) + // If the corresponding structure's member is boxed, box this constraint violation symbol too. + .letIf(constraintViolation.forMember.hasTrait()) { + it.makeRustBoxed() + } + + // Note we cannot express the inner constraint violation as `>::Error`, because `T` might + // be `pub(crate)` and that would leak `T` in a public interface. + writer.docs("${constraintViolation.message(symbolProvider, model)}.".replaceFirstChar { it.uppercaseChar() }) + Attribute.DocHidden.render(writer) + writer.rust("${constraintViolation.name()}(#T),", constraintViolationSymbol) + } + } + } + } + + private fun renderAsValidationExceptionFieldList(writer: RustWriter) { + val validationExceptionFieldWritable = writable { + rustBlock("match self") { + all.forEach { + if (it.hasInner()) { + rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") + } else { + rust( + """ + ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { + message: format!("Value null at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), + path: path + "/${it.forMember.memberName}", + }, + """, + ) + } + } + } + } + + writer.rustTemplate( + """ + impl ConstraintViolation { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + #{ValidationExceptionFieldWritable:W} + } + } + """, + "ValidationExceptionFieldWritable" to validationExceptionFieldWritable, + "String" to RuntimeType.String, + ) + } +} + +/** + * The kinds of constraint violations that can occur when building the builder. + */ +enum class ConstraintViolationKind { + // A field is required but was not provided. + MISSING_MEMBER, + + // An unconstrained type was provided for a field targeting a constrained shape, but it failed to convert into the constrained type. + CONSTRAINED_SHAPE_FAILURE, +} + +data class ConstraintViolation(val forMember: MemberShape, val kind: ConstraintViolationKind) { + fun name() = when (kind) { + ConstraintViolationKind.MISSING_MEMBER -> "Missing${forMember.memberName.toPascalCase()}" + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> forMember.memberName.toPascalCase() + } + + /** + * Whether the constraint violation is a Rust tuple struct with one element. + */ + fun hasInner() = kind == ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE + + /** + * A message for a `ConstraintViolation` variant. This is used in both Rust documentation and the `Display` trait implementation. + */ + fun message(symbolProvider: SymbolProvider, model: Model): String { + val memberName = symbolProvider.toMemberName(forMember) + val structureSymbol = symbolProvider.toSymbol(model.expectShape(forMember.container)) + return when (kind) { + ConstraintViolationKind.MISSING_MEMBER -> "`$memberName` was not provided but it is required when building `${structureSymbol.name}`" + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Nest errors. Adjust message following protocol tests. + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> "constraint violation occurred building member `$memberName` when building `${structureSymbol.name}`" + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt new file mode 100644 index 000000000..6f319f671 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -0,0 +1,542 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.implInto +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled + +/** + * Generates a builder for the Rust type associated with the [StructureShape]. + * + * This generator is meant for use by the server project. Clients use the [BuilderGenerator] from the `codegen-client` + * Gradle subproject instead. + * + * This builder is different in that it enforces [constraint traits] upon calling `.build()`. If any constraint + * violations occur, the `build` method returns them. + * + * These are the main differences with the builders generated by the client's [BuilderGenerator]: + * + * - The design of this builder is simpler and closely follows what you get when using the [derive_builder] crate: + * * The builder has one method per struct member named _exactly_ like the struct member and whose input type + * matches _exactly_ the struct's member type. This method is generated by [renderBuilderMemberFn]. + * * The builder has one _setter_ method (i.e. prefixed with `set_`) per struct member whose input type is the + * corresponding _unconstrained type_ for the member. This method is always `pub(crate)` and meant for use for + * server deserializers only. + * * There are no convenience methods to add items to vector and hash map struct members. + * - The builder is not `PartialEq`. This is because the builder's members may or may not have been constrained (their + * types hold `MaybeConstrained`), and so it doesn't make sense to compare e.g. two builders holding the same data + * values, but one builder holds the member in the constrained variant while the other one holds it in the unconstrained + * variant. + * - The builder always implements `TryFrom for Structure` or `From for Structure`, depending on whether + * the structure is constrained (and hence enforcing the constraints might yield an error) or not, respectively. + * + * The builder is `pub(crate)` when `publicConstrainedTypes` is `false`, since in this case the user is never exposed + * to constrained types, and only the server's deserializers need to enforce constraint traits upon receiving a request. + * The user is exposed to [ServerBuilderGeneratorWithoutPublicConstrainedTypes] in this case instead, which intentionally + * _does not_ enforce constraints. + * + * [constraint traits]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html + * [derive_builder]: https://docs.rs/derive_builder/latest/derive_builder/index.html + */ +class ServerBuilderGenerator( + codegenContext: ServerCodegenContext, + private val shape: StructureShape, +) { + companion object { + /** + * Returns whether a structure shape, whose builder has been generated with [ServerBuilderGenerator], requires a + * fallible builder to be constructed. + */ + fun hasFallibleBuilder( + structureShape: StructureShape, + model: Model, + symbolProvider: SymbolProvider, + takeInUnconstrainedTypes: Boolean, + ): Boolean = + if (takeInUnconstrainedTypes) { + structureShape.canReachConstrainedShape(model, symbolProvider) + } else { + structureShape + .members() + .map { symbolProvider.toSymbol(it) } + .any { !it.isOptional() } + } + } + + private val takeInUnconstrainedTypes = shape.isReachableFromOperationInput() + private val model = codegenContext.model + private val runtimeConfig = codegenContext.runtimeConfig + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val visibility = if (publicConstrainedTypes) Visibility.PUBLIC else Visibility.PUBCRATE + private val symbolProvider = codegenContext.symbolProvider + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val members: List = shape.allMembers.values.toList() + private val structureSymbol = symbolProvider.toSymbol(shape) + private val builderSymbol = shape.serverBuilderSymbol(codegenContext) + private val moduleName = builderSymbol.namespace.split(builderSymbol.namespaceDelimiter).last() + private val isBuilderFallible = hasFallibleBuilder(shape, model, symbolProvider, takeInUnconstrainedTypes) + private val serverBuilderConstraintViolations = + ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes) + + private val codegenScope = arrayOf( + "RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig), + "Structure" to structureSymbol, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "MaybeConstrained" to RuntimeType.MaybeConstrained(), + ) + + fun render(writer: RustWriter) { + writer.docs("See #D.", structureSymbol) + writer.withModule(RustModule(moduleName, RustMetadata(visibility = visibility))) { + renderBuilder(this) + } + } + + private fun renderBuilder(writer: RustWriter) { + if (isBuilderFallible) { + serverBuilderConstraintViolations.render( + writer, + visibility, + nonExhaustive = true, + shouldRenderAsValidationExceptionFieldList = shape.isReachableFromOperationInput(), + ) + + // Only generate converter from `ConstraintViolation` into `RequestRejection` if the structure shape is + // an operation input shape. + if (shape.hasTrait()) { + renderImplFromConstraintViolationForRequestRejection(writer) + } + + if (takeInUnconstrainedTypes) { + renderImplFromBuilderForMaybeConstrained(writer) + } + + renderTryFromBuilderImpl(writer) + } else { + renderFromBuilderImpl(writer) + } + + writer.docs("A builder for #D.", structureSymbol) + // Matching derives to the main structure, - `PartialEq` (see class documentation for why), + `Default` + // since we are a builder and everything is optional. + val baseDerives = structureSymbol.expectRustMetadata().derives + val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.Clone)) + RuntimeType.Default + baseDerives.copy(derives = derives).render(writer) + writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate)" else "" } struct Builder") { + members.forEach { renderBuilderMember(this, it) } + } + + writer.rustBlock("impl Builder") { + for (member in members) { + if (publicConstrainedTypes) { + renderBuilderMemberFn(this, member) + } + + if (takeInUnconstrainedTypes) { + renderBuilderMemberSetterFn(this, member) + } + } + renderBuildFn(this) + } + } + + private fun renderImplFromConstraintViolationForRequestRejection(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{RequestRejection} { + fn from(constraint_violation: ConstraintViolation) -> Self { + let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); + let validation_exception = crate::error::ValidationException { + message: format!("1 validation error detected. {}", &first_validation_exception_field.message), + field_list: Some(vec![first_validation_exception_field]), + }; + Self::ConstraintViolation( + crate::operation_ser::serialize_structure_crate_error_validation_exception(&validation_exception) + .expect("impossible") + ) + } + } + """, + *codegenScope, + ) + } + + private fun renderImplFromBuilderForMaybeConstrained(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{StructureMaybeConstrained} { + fn from(builder: Builder) -> Self { + Self::Unconstrained(builder) + } + } + """, + *codegenScope, + "StructureMaybeConstrained" to structureSymbol.makeMaybeConstrained(), + ) + } + + private fun renderBuildFn(implBlockWriter: RustWriter) { + implBlockWriter.docs("""Consumes the builder and constructs a #D.""", structureSymbol) + if (isBuilderFallible) { + implBlockWriter.docs( + """ + The builder fails to construct a #D if a [`ConstraintViolation`] occurs. + """, + structureSymbol, + ) + + if (serverBuilderConstraintViolations.all.size > 1) { + implBlockWriter.docs("If the builder fails, it will return the _first_ encountered [`ConstraintViolation`].") + } + } + implBlockWriter.rustTemplate( + """ + pub fn build(self) -> #{ReturnType:W} { + self.build_enforcing_all_constraints() + } + """, + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) + renderBuildEnforcingAllConstraintsFn(implBlockWriter) + } + + private fun renderBuildEnforcingAllConstraintsFn(implBlockWriter: RustWriter) { + implBlockWriter.rustBlockTemplate( + "fn build_enforcing_all_constraints(self) -> #{ReturnType:W}", + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) { + conditionalBlock("Ok(", ")", conditional = isBuilderFallible) { + coreBuilder(this) + } + } + } + + fun renderConvenienceMethod(implBlock: RustWriter) { + implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol) + implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { + write("#T::default()", builderSymbol) + } + } + + private fun renderBuilderMember(writer: RustWriter, member: MemberShape) { + val memberSymbol = builderMemberSymbol(member) + val memberName = constrainedShapeSymbolProvider.toMemberName(member) + // Builder members are crate-public to enable using them directly in serializers/deserializers. + // During XML deserialization, `builder..take` is used to append to lists and maps. + writer.write("pub(crate) $memberName: #T,", memberSymbol) + } + + /** + * Render a `foo` method to set shape member `foo`. The caller must provide a value with the exact same type + * as the shape member's type. + * + * This method is meant for use by the user; it is not used by the generated crate's (de)serializers. + * + * This method is only generated when `publicConstrainedTypes` is `true`. Otherwise, the user has at their disposal + * the method from [ServerBuilderGeneratorWithoutPublicConstrainedTypes]. + */ + private fun renderBuilderMemberFn( + writer: RustWriter, + member: MemberShape, + ) { + check(publicConstrainedTypes) + val symbol = symbolProvider.toSymbol(member) + val memberName = symbolProvider.toMemberName(member) + + val hasBox = symbol.mapRustType { it.stripOuter() }.isRustBoxed() + val wrapInMaybeConstrained = takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider) + + writer.documentShape(member, model) + writer.deprecatedShape(member) + + if (hasBox && wrapInMaybeConstrained) { + // In the case of recursive shapes, the member might be boxed. If so, and the member is also constrained, the + // implementation of this function needs to immediately unbox the value to wrap it in `MaybeConstrained`, + // and then re-box. Clippy warns us that we could have just taken in an unboxed value to avoid this round-trip + // to the heap. However, that will make the builder take in a value whose type does not exactly match the + // shape member's type. + // We don't want to introduce API asymmetry just for this particular case, so we disable the lint. + Attribute.Custom("allow(clippy::boxed_local)").render(writer) + } + writer.rustBlock("pub fn $memberName(mut self, input: ${symbol.rustType().render()}) -> Self") { + withBlock("self.$memberName = ", "; self") { + conditionalBlock("Some(", ")", conditional = !symbol.isOptional()) { + val maybeConstrainedVariant = + "${symbol.makeMaybeConstrained().rustType().namespace}::MaybeConstrained::Constrained" + + var varExpr = if (symbol.isOptional()) "v" else "input" + if (hasBox) varExpr = "*$varExpr" + if (!constrainedTypeHoldsFinalType(member)) varExpr = "($varExpr).into()" + + if (wrapInMaybeConstrained) { + conditionalBlock("input.map(##[allow(clippy::redundant_closure)] |v| ", ")", conditional = symbol.isOptional()) { + conditionalBlock("Box::new(", ")", conditional = hasBox) { + rust("$maybeConstrainedVariant($varExpr)") + } + } + } else { + write("input") + } + } + } + } + } + + /** + * Returns whether the constrained builder member type (the type on which the `Constrained` trait is implemented) + * is the final type the user sees when receiving the built struct. This is true when the corresponding constrained + * type is public and not `pub(crate)`, which happens when the target is a structure shape, a union shape, or is + * directly constrained. + * + * An example where this returns false is when the member shape targets a list whose members are lists of structures + * having at least one `required` member. In this case the member shape is transitively but not directly constrained, + * so the generated constrained type is `pub(crate)` and needs converting into the final type the user will be + * exposed to. + * + * See [PubCrateConstrainedShapeSymbolProvider] too. + */ + private fun constrainedTypeHoldsFinalType(member: MemberShape): Boolean { + val targetShape = model.expectShape(member.target) + return targetShape is StructureShape || + targetShape is UnionShape || + member.hasConstraintTraitOrTargetHasConstraintTrait(model, symbolProvider) + } + + /** + * Render a `set_foo` method. + * This method is able to take in unconstrained types for constrained shapes, like builders of structs in the case + * of structure shapes. + * + * This method is only used by deserializers at the moment and is therefore `pub(crate)`. + */ + private fun renderBuilderMemberSetterFn( + writer: RustWriter, + member: MemberShape, + ) { + val builderMemberSymbol = builderMemberSymbol(member) + val inputType = builderMemberSymbol.rustType().stripOuter().implInto() + .letIf( + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): + // The only reason why this condition can't simply be `member.isOptional` + // is because non-`required` blob streaming members are interpreted as + // `required`, so we can't use `member.isOptional` here. + symbolProvider.toSymbol(member).isOptional(), + ) { "Option<$it>" } + val memberName = symbolProvider.toMemberName(member) + + writer.documentShape(member, model) + // Setter names will never hit a reserved word and therefore never need escaping. + writer.rustBlock("pub(crate) fn set_${member.memberName.toSnakeCase()}(mut self, input: $inputType) -> Self") { + rust( + """ + self.$memberName = ${ + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. + if (symbolProvider.toSymbol(member).isOptional()) { + "input.map(|v| v.into())" + } else { + "Some(input.into())" + } + }; + self + """, + ) + } + } + + private fun renderTryFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{TryFrom} for #{Structure} { + type Error = ConstraintViolation; + + fn try_from(builder: Builder) -> Result { + builder.build() + } + } + """, + *codegenScope, + ) + } + + private fun renderFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{Structure} { + fn from(builder: Builder) -> Self { + builder.build() + } + } + """, + *codegenScope, + ) + } + + /** + * Returns the symbol for a builder's member. + * All builder members are optional, but only some are `Option`s where `T` needs to be constrained. + */ + private fun builderMemberSymbol(member: MemberShape): Symbol = + if (takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)) { + val strippedOption = if (member.hasConstraintTraitOrTargetHasConstraintTrait(model, symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(member) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(member) + } + // Strip the `Option` in case the member is not `required`. + .mapRustType { it.stripOuter() } + + val hadBox = strippedOption.isRustBoxed() + strippedOption + // Strip the `Box` in case the member can reach itself recursively. + .mapRustType { it.stripOuter() } + // Wrap it in the Cow-like `constrained::MaybeConstrained` type, since we know the target member shape can + // reach a constrained shape. + .makeMaybeConstrained() + // Box it in case the member can reach itself recursively. + .letIf(hadBox) { it.makeRustBoxed() } + // Ensure we always end up with an `Option`. + .makeOptional() + } else { + constrainedShapeSymbolProvider.toSymbol(member).makeOptional() + } + + /** + * Writes the code to instantiate the struct the builder builds. + * + * Builder member types are either: + * 1. `Option>`; or + * 2. `Option`. + * + * Where `U` is a constrained type. + * + * The structs they build have member types: + * a) `Option`; or + * b) `T`. + * + * `U` is equal to `T` when: + * - the shape for `U` has a constraint trait and `publicConstrainedTypes` is `true`; or + * - the member shape is a structure or union shape. + * Otherwise, `U` is always a `pub(crate)` tuple newtype holding `T`. + * + * For each member, this function first safely unwraps case 1. into 2., then converts `U` into `T` if necessary, + * and then converts into b) if necessary. + */ + private fun coreBuilder(writer: RustWriter) { + writer.rustBlock("#T", structureSymbol) { + for (member in members) { + val memberName = symbolProvider.toMemberName(member) + + withBlock("$memberName: self.$memberName", ",") { + // Write the modifier(s). + serverBuilderConstraintViolations.builderConstraintViolationForMember(member)?.also { constraintViolation -> + val hasBox = builderMemberSymbol(member) + .mapRustType { it.stripOuter() } + .isRustBoxed() + if (hasBox) { + rustTemplate( + """ + .map(|v| match *v { + #{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)), + #{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)), + }) + .map(|res| + res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" } + .map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err))) + ) + .transpose()? + """, + *codegenScope, + ) + } else { + rustTemplate( + """ + .map(|v| match v { + #{MaybeConstrained}::Constrained(x) => Ok(x), + #{MaybeConstrained}::Unconstrained(x) => x.try_into(), + }) + .map(|res| + res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} + .map_err(ConstraintViolation::${constraintViolation.name()}) + ) + .transpose()? + """, + *codegenScope, + ) + + // Constrained types are not public and this is a member shape that would have generated a + // public constrained type, were the setting to be enabled. + // We've just checked the constraints hold by going through the non-public + // constrained type, but the user wants to work with the unconstrained type, so we have to + // unwrap it. + if (!publicConstrainedTypes && member.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model)) { + rust( + ".map(|v: #T| v.into())", + constrainedShapeSymbolProvider.toSymbol(model.expectShape(member.target)), + ) + } + } + } + serverBuilderConstraintViolations.forMember(member)?.also { + rust(".ok_or(ConstraintViolation::${it.name()})?") + } + } + } + } + } +} + +fun buildFnReturnType(isBuilderFallible: Boolean, structureSymbol: Symbol) = writable { + if (isBuilderFallible) { + rust("Result<#T, ConstraintViolation>", structureSymbol) + } else { + rust("#T", structureSymbol) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt new file mode 100644 index 000000000..897bdc116 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt @@ -0,0 +1,238 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.makeOptional +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType + +/** + * Generates a builder for the Rust type associated with the [StructureShape]. + * + * This builder is similar in design to [ServerBuilderGenerator], so consult its documentation in that regard. However, + * this builder has a few differences. + * + * Unlike [ServerBuilderGenerator], this builder only enforces constraints that are baked into the type system _when + * `publicConstrainedTypes` is false_. So in terms of honoring the Smithy spec, this builder only enforces enums + * and the `required` trait. + * + * Unlike [ServerBuilderGenerator], this builder is always public. It is the only builder type the user is exposed to + * when `publicConstrainedTypes` is false. + */ +class ServerBuilderGeneratorWithoutPublicConstrainedTypes( + codegenContext: ServerCodegenContext, + shape: StructureShape, +) { + companion object { + /** + * Returns whether a structure shape, whose builder has been generated with + * [ServerBuilderGeneratorWithoutPublicConstrainedTypes], requires a fallible builder to be constructed. + * + * This builder only enforces the `required` trait. + */ + fun hasFallibleBuilder( + structureShape: StructureShape, + symbolProvider: SymbolProvider, + ): Boolean = + structureShape + .members() + .map { symbolProvider.toSymbol(it) } + .any { !it.isOptional() } + } + + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val members: List = shape.allMembers.values.toList() + private val structureSymbol = symbolProvider.toSymbol(shape) + + private val builderSymbol = shape.serverBuilderSymbol(symbolProvider, false) + private val moduleName = builderSymbol.namespace.split("::").last() + private val isBuilderFallible = hasFallibleBuilder(shape, symbolProvider) + private val serverBuilderConstraintViolations = + ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false) + + private val codegenScope = arrayOf( + "RequestRejection" to ServerRuntimeType.RequestRejection(codegenContext.runtimeConfig), + "Structure" to structureSymbol, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "MaybeConstrained" to RuntimeType.MaybeConstrained(), + ) + + fun render(writer: RustWriter) { + writer.docs("See #D.", structureSymbol) + writer.withModule(RustModule.public(moduleName)) { + renderBuilder(this) + } + } + + private fun renderBuilder(writer: RustWriter) { + if (isBuilderFallible) { + serverBuilderConstraintViolations.render( + writer, + Visibility.PUBLIC, + nonExhaustive = false, + shouldRenderAsValidationExceptionFieldList = false, + ) + + renderTryFromBuilderImpl(writer) + } else { + renderFromBuilderImpl(writer) + } + + writer.docs("A builder for #D.", structureSymbol) + // Matching derives to the main structure, - `PartialEq` (to be consistent with [ServerBuilderGenerator]), + `Default` + // since we are a builder and everything is optional. + val baseDerives = structureSymbol.expectRustMetadata().derives + val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.Clone)) + RuntimeType.Default + baseDerives.copy(derives = derives).render(writer) + writer.rustBlock("pub struct Builder") { + members.forEach { renderBuilderMember(this, it) } + } + + writer.rustBlock("impl Builder") { + for (member in members) { + renderBuilderMemberFn(this, member) + } + renderBuildFn(this) + } + } + + private fun renderBuildFn(implBlockWriter: RustWriter) { + implBlockWriter.docs("""Consumes the builder and constructs a #D.""", structureSymbol) + if (isBuilderFallible) { + implBlockWriter.docs( + """ + The builder fails to construct a #D if you do not provide a value for all non-`Option`al members. + """, + structureSymbol, + ) + } + implBlockWriter.rustTemplate( + """ + pub fn build(self) -> #{ReturnType:W} { + self.build_enforcing_required_and_enum_traits() + } + """, + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) + renderBuildEnforcingRequiredAndEnumTraitsFn(implBlockWriter) + } + + private fun renderBuildEnforcingRequiredAndEnumTraitsFn(implBlockWriter: RustWriter) { + implBlockWriter.rustBlockTemplate( + "fn build_enforcing_required_and_enum_traits(self) -> #{ReturnType:W}", + "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol), + ) { + conditionalBlock("Ok(", ")", conditional = isBuilderFallible) { + coreBuilder(this) + } + } + } + + private fun coreBuilder(writer: RustWriter) { + writer.rustBlock("#T", structureSymbol) { + for (member in members) { + val memberName = symbolProvider.toMemberName(member) + + withBlock("$memberName: self.$memberName", ",") { + serverBuilderConstraintViolations.forMember(member)?.also { + rust(".ok_or(ConstraintViolation::${it.name()})?") + } + } + } + } + } + + fun renderConvenienceMethod(implBlock: RustWriter) { + implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol) + implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { + write("#T::default()", builderSymbol) + } + } + + private fun renderBuilderMember(writer: RustWriter, member: MemberShape) { + val memberSymbol = builderMemberSymbol(member) + val memberName = symbolProvider.toMemberName(member) + // Builder members are crate-public to enable using them directly in serializers/deserializers. + // During XML deserialization, `builder..take` is used to append to lists and maps. + writer.write("pub(crate) $memberName: #T,", memberSymbol) + } + + /** + * Render a `foo` method to set shape member `foo`. The caller must provide a value with the exact same type + * as the shape member's type. + * + * This method is meant for use by the user; it is not used by the generated crate's (de)serializers. + */ + private fun renderBuilderMemberFn(writer: RustWriter, member: MemberShape) { + val memberSymbol = symbolProvider.toSymbol(member) + val memberName = symbolProvider.toMemberName(member) + + writer.documentShape(member, model) + writer.deprecatedShape(member) + + writer.rustBlock("pub fn $memberName(mut self, input: #T) -> Self", memberSymbol) { + withBlock("self.$memberName = ", "; self") { + conditionalBlock("Some(", ")", conditional = !memberSymbol.isOptional()) { + rust("input") + } + } + } + } + + private fun renderTryFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{TryFrom} for #{Structure} { + type Error = ConstraintViolation; + + fn try_from(builder: Builder) -> Result { + builder.build() + } + } + """, + *codegenScope, + ) + } + + private fun renderFromBuilderImpl(writer: RustWriter) { + writer.rustTemplate( + """ + impl #{From} for #{Structure} { + fn from(builder: Builder) -> Self { + builder.build() + } + } + """, + *codegenScope, + ) + } + + /** + * Returns the symbol for a builder's member. + */ + private fun builderMemberSymbol(member: MemberShape): Symbol = symbolProvider.toSymbol(member).makeOptional() +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt new file mode 100644 index 000000000..a8ee7fd8f --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt @@ -0,0 +1,35 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext + +fun StructureShape.serverBuilderSymbol(codegenContext: ServerCodegenContext): Symbol = + this.serverBuilderSymbol(codegenContext.symbolProvider, !codegenContext.settings.codegenConfig.publicConstrainedTypes) + +fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: Boolean): Symbol { + val structureSymbol = symbolProvider.toSymbol(this) + val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) + + if (pubCrate) { + "_internal" + } else { + "" + } + val rustType = RustType.Opaque("Builder", "${structureSymbol.namespace}::$builderNamespace") + return Symbol.builder() + .rustType(rustType) + .name(rustType.name) + .namespace(rustType.namespace, "::") + .definitionFile(structureSymbol.definitionFile) + .build() +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 323a40ebf..1514750cd 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -4,81 +4,111 @@ */ package software.amazon.smithy.rust.codegen.server.smithy.generators -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput open class ServerEnumGenerator( - model: Model, - symbolProvider: RustSymbolProvider, + val codegenContext: ServerCodegenContext, private val writer: RustWriter, shape: StringShape, - enumTrait: EnumTrait, - private val runtimeConfig: RuntimeConfig, -) : EnumGenerator(model, symbolProvider, writer, shape, enumTrait) { +) : EnumGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, shape.expectTrait()) { override var target: CodegenTarget = CodegenTarget.SERVER - private val errorStruct = "${enumName}UnknownVariantError" + + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + private val constraintViolationName = constraintViolationSymbol.name + private val codegenScope = arrayOf( + "String" to RuntimeType.String, + ) override fun renderFromForStr() { - writer.rust( - """ - ##[derive(Debug, PartialEq, Eq, Hash)] - pub struct $errorStruct(String); - """, - ) + writer.withModule( + RustModule.public(constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last()), + ) { + rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub struct $constraintViolationName(pub(crate) #{String}); + """, + *codegenScope, + ) + + if (shape.isReachableFromOperationInput()) { + val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") + val message = "Value {} at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" + + rustTemplate( + """ + impl $constraintViolationName { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + crate::model::ValidationExceptionField { + message: format!(r##"$message"##, &self.0, &path), + path, + } + } + } + """, + *codegenScope, + ) + } + } writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.TryFrom) { - write("type Error = $errorStruct;") - writer.rustBlock("fn try_from(s: &str) -> Result>::Error>", RuntimeType.TryFrom) { - writer.rustBlock("match s") { + rust("type Error = #T;", constraintViolationSymbol) + rustBlock("fn try_from(s: &str) -> Result>::Error>", RuntimeType.TryFrom) { + rustBlock("match s") { sortedMembers.forEach { member -> - write("${member.value.dq()} => Ok($enumName::${member.derivedName()}),") + rust("${member.value.dq()} => Ok($enumName::${member.derivedName()}),") } - write("_ => Err($errorStruct(s.to_owned()))") + rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } } writer.rustTemplate( """ - impl #{From}<$errorStruct> for #{RequestRejection} { - fn from(e: $errorStruct) -> Self { - Self::EnumVariantNotFound(Box::new(e)) - } - } - impl #{StdError} for $errorStruct { } - impl #{Display} for $errorStruct { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + impl #{TryFrom}<#{String}> for $enumName { + type Error = #{UnknownVariantSymbol}; + fn try_from(s: #{String}) -> std::result::Result>::Error> { + s.as_str().try_into() } } """, - "Display" to RuntimeType.Display, - "From" to RuntimeType.From, - "StdError" to RuntimeType.StdError, - "RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig), + "String" to RuntimeType.String, + "TryFrom" to RuntimeType.TryFrom, + "UnknownVariantSymbol" to constraintViolationSymbol, ) } override fun renderFromStr() { - writer.rust( + writer.rustTemplate( """ impl std::str::FromStr for $enumName { - type Err = $errorStruct; - fn from_str(s: &str) -> std::result::Result { - $enumName::try_from(s) + type Err = #{UnknownVariantSymbol}; + fn from_str(s: &str) -> std::result::Result::Err> { + Self::try_from(s) } } """, + "UnknownVariantSymbol" to constraintViolationSymbol, ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 9189380dd..13901e521 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -6,11 +6,15 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput /** * Server enums do not have an `Unknown` variant like client enums do, so constructing an enum from @@ -24,11 +28,30 @@ private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writa ) } +class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { + override fun hasFallibleBuilder(shape: StructureShape): Boolean { + // Only operation input builders take in unconstrained types. + val takesInUnconstrainedTypes = shape.isReachableFromOperationInput() + return ServerBuilderGenerator.hasFallibleBuilder( + shape, + codegenContext.model, + codegenContext.symbolProvider, + takesInUnconstrainedTypes, + ) + } + + override fun setterName(memberShape: MemberShape): String = codegenContext.symbolProvider.toMemberName(memberShape) + + override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = + codegenContext.symbolProvider.toSymbol(memberShape).isOptional() +} + fun serverInstantiator(codegenContext: CodegenContext) = Instantiator( codegenContext.symbolProvider, codegenContext.model, codegenContext.runtimeConfig, + ServerBuilderKindBehavior(codegenContext), ::enumFromStringFn, defaultsForRequiredFields = true, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt index b61b1baa4..b34685f7b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt @@ -38,7 +38,8 @@ class ServerOperationGenerator( if (operation.errors.isEmpty()) { rust("std::convert::Infallible") } else { - rust("crate::error::${operationName}Error") + // Name comes from [ServerCombinedErrorGenerator]. + rust("crate::error::${symbolProvider.toSymbol(operation).name}Error") } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt new file mode 100644 index 000000000..2812c59d7 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerStructureConstrainedTraitImpl.kt @@ -0,0 +1,32 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider + +class ServerStructureConstrainedTraitImpl( + private val symbolProvider: RustSymbolProvider, + private val publicConstrainedTypes: Boolean, + private val shape: StructureShape, + private val writer: RustWriter, +) { + fun render() { + writer.rustTemplate( + """ + impl #{ConstrainedTrait} for #{Structure} { + type Unconstrained = #{Builder}; + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "Structure" to symbolProvider.toSymbol(shape), + "Builder" to shape.serverBuilderSymbol(symbolProvider, !publicConstrainedTypes), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt new file mode 100644 index 000000000..602cbb7aa --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt @@ -0,0 +1,139 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput + +/** + * Generates a Rust type for a constrained collection shape that is able to hold values for the corresponding + * _unconstrained_ shape. This type is a [RustType.Opaque] wrapper tuple newtype holding a `Vec`. Upon request parsing, + * server deserializers use this type to store the incoming values without enforcing the modeled constraints. Only after + * the full request has been parsed are constraints enforced, via the `impl TryFrom for + * ConstrainedSymbol`. + * + * This type is never exposed to the user; it is always `pub(crate)`. Only the deserializers use it. + * + * Consult [UnconstrainedShapeSymbolProvider] for more details and for an example. + */ +class UnconstrainedCollectionGenerator( + val codegenContext: ServerCodegenContext, + private val unconstrainedModuleWriter: RustWriter, + private val modelsModuleWriter: RustWriter, + val shape: CollectionShape, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + val module = symbol.namespace.split(symbol.namespaceDelimiter).last() + val name = symbol.name + val innerShape = model.expectShape(shape.member.target) + val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape) + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + val constraintViolationName = constraintViolationSymbol.name + val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape) + + unconstrainedModuleWriter.withModule(RustModule(module, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::vec::Vec<#{InnerUnconstrainedSymbol}>); + + impl From<$name> for #{MaybeConstrained} { + fn from(value: $name) -> Self { + Self::Unconstrained(value) + } + } + + impl #{TryFrom}<$name> for #{ConstrainedSymbol} { + type Error = #{ConstraintViolationSymbol}; + + fn try_from(value: $name) -> Result { + let res: Result<_, (usize, #{InnerConstraintViolationSymbol})> = value + .0 + .into_iter() + .enumerate() + .map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation))) + .collect(); + res.map(Self) + .map_err(|(idx, inner_violation)| #{ConstraintViolationSymbol}(idx, inner_violation)) + } + } + """, + "InnerUnconstrainedSymbol" to innerUnconstrainedSymbol, + "InnerConstraintViolationSymbol" to innerConstraintViolationSymbol, + "ConstrainedSymbol" to constrainedSymbol, + "ConstraintViolationSymbol" to constraintViolationSymbol, + "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(), + "TryFrom" to RuntimeType.TryFrom, + ) + } + + val constraintViolationVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + modelsModuleWriter.withModule( + RustModule( + constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), + RustMetadata(visibility = constraintViolationVisibility), + ), + ) { + // The first component of the tuple struct is the index in the collection where the first constraint + // violation was found. + rustTemplate( + """ + ##[derive(Debug, PartialEq)] + pub struct $constraintViolationName( + pub(crate) usize, + pub(crate) #{InnerConstraintViolationSymbol} + ); + """, + "InnerConstraintViolationSymbol" to innerConstraintViolationSymbol, + ) + + if (shape.isReachableFromOperationInput()) { + rustTemplate( + """ + impl $constraintViolationName { + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + self.1.as_validation_exception_field(format!("{}/{}", path, &self.0)) + } + } + """, + "String" to RuntimeType.String, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt new file mode 100644 index 000000000..4d47eb622 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt @@ -0,0 +1,207 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained + +/** + * Generates a Rust type for a constrained map shape that is able to hold values for the corresponding + * _unconstrained_ shape. This type is a [RustType.Opaque] wrapper tuple newtype holding a `HashMap`. Upon request parsing, + * server deserializers use this type to store the incoming values without enforcing the modeled constraints. Only after + * the full request has been parsed are constraints enforced, via the `impl TryFrom for + * ConstrainedSymbol`. + * + * This type is never exposed to the user; it is always `pub(crate)`. Only the deserializers use it. + * + * Consult [UnconstrainedShapeSymbolProvider] for more details and for an example. + */ +class UnconstrainedMapGenerator( + val codegenContext: ServerCodegenContext, + private val unconstrainedModuleWriter: RustWriter, + val shape: MapShape, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + private val name = symbol.name + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider + private val constrainedSymbol = if (shape.isDirectlyConstrained(symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(shape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + } + private val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + private val valueShape = model.expectShape(shape.value.target) + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val module = symbol.namespace.split(symbol.namespaceDelimiter).last() + val keySymbol = unconstrainedShapeSymbolProvider.toSymbol(keyShape) + val valueSymbol = unconstrainedShapeSymbolProvider.toSymbol(valueShape) + + unconstrainedModuleWriter.withModule(RustModule(module, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $name(pub(crate) std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}>); + + impl From<$name> for #{MaybeConstrained} { + fn from(value: $name) -> Self { + Self::Unconstrained(value) + } + } + + """, + "KeySymbol" to keySymbol, + "ValueSymbol" to valueSymbol, + "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(), + ) + + renderTryFromUnconstrainedForConstrained(this) + } + } + + private fun renderTryFromUnconstrainedForConstrained(writer: RustWriter) { + writer.rustBlock("impl std::convert::TryFrom<$name> for #{T}", constrainedSymbol) { + rust("type Error = #T;", constraintViolationSymbol) + + rustBlock("fn try_from(value: $name) -> Result") { + if (isKeyConstrained(keyShape, symbolProvider) || isValueConstrained(valueShape, model, symbolProvider)) { + val resolveToNonPublicConstrainedValueType = + isValueConstrained(valueShape, model, symbolProvider) && + !valueShape.isDirectlyConstrained(symbolProvider) && + !valueShape.isStructureShape + val constrainedValueSymbol = if (resolveToNonPublicConstrainedValueType) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape) + } else { + constrainedShapeSymbolProvider.toSymbol(valueShape) + } + + val constrainedKeySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape) + val constrainKeyWritable = writable { + rustTemplate( + "let k: #{ConstrainedKeySymbol} = k.try_into().map_err(Self::Error::Key)?;", + "ConstrainedKeySymbol" to constrainedKeySymbol, + ) + } + val constrainValueWritable = writable { + rustTemplate( + """ + match #{ConstrainedValueSymbol}::try_from(v) { + Ok(v) => Ok((k, v)), + Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), + } + """, + "ConstrainedValueSymbol" to constrainedValueSymbol, + ) + } + val epilogueWritable = writable { rust("Ok((k, v))") } + + val constrainKVWritable = if ( + isKeyConstrained(keyShape, symbolProvider) && + isValueConstrained(valueShape, model, symbolProvider) + ) { + listOf(constrainKeyWritable, constrainValueWritable).join("\n") + } else if (isKeyConstrained(keyShape, symbolProvider)) { + listOf(constrainKeyWritable, epilogueWritable).join("\n") + } else if (isValueConstrained(valueShape, model, symbolProvider)) { + constrainValueWritable + } else { + epilogueWritable + } + + rustTemplate( + """ + let res: Result, Self::Error> = value.0 + .into_iter() + .map(|(k, v)| { + #{ConstrainKVWritable:W} + }) + .collect(); + let hm = res?; + """, + "ConstrainedKeySymbol" to constrainedKeySymbol, + "ConstrainedValueSymbol" to constrainedValueSymbol, + "ConstrainKVWritable" to constrainKVWritable, + ) + + val constrainedValueTypeIsNotFinalType = + resolveToNonPublicConstrainedValueType && shape.isDirectlyConstrained(symbolProvider) + if (constrainedValueTypeIsNotFinalType) { + // The map is constrained. Its value shape reaches a constrained shape, but the value shape itself + // is not directly constrained. The value shape must be an aggregate shape. But it is not a + // structure shape. So it must be a collection or map shape. In this case the type for the value + // shape that implements the `Constrained` trait _does not_ coincide with the regular type the user + // is exposed to. The former will be the `pub(crate)` wrapper tuple type created by a + // `Constrained*Generator`, whereas the latter will be an stdlib container type. Both types are + // isomorphic though, and we can convert between them using `From`, so that's what we do here. + // + // As a concrete example of this particular case, consider the model: + // + // ```smithy + // @length(min: 1) + // map Map { + // key: String, + // value: List, + // } + // + // list List { + // member: NiceString + // } + // + // @length(min: 1, max: 69) + // string NiceString + // ``` + rustTemplate( + """ + let hm: std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}> = + hm.into_iter().map(|(k, v)| (k, v.into())).collect(); + """, + "KeySymbol" to symbolProvider.toSymbol(keyShape), + "ValueSymbol" to symbolProvider.toSymbol(valueShape), + ) + } + } else { + rust("let hm = value.0;") + } + + if (shape.isDirectlyConstrained(symbolProvider)) { + rust("Self::try_from(hm)") + } else { + rust("Ok(Self(hm))") + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt new file mode 100644 index 000000000..dd470daf9 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt @@ -0,0 +1,248 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput + +/** + * Generates a Rust type for a constrained union shape that is able to hold values for the corresponding _unconstrained_ + * shape. This type is a [RustType.Opaque] enum newtype, with each variant holding the corresponding unconstrained type. + * Upon request parsing, server deserializers use this type to store the incoming values without enforcing the modeled + * constraints. Only after the full request has been parsed are constraints enforced, via the `impl + * TryFrom for ConstrainedSymbol`. + * + * This type is never exposed to the user; it is always `pub(crate)`. Only the deserializers use it. + * + * Consult [UnconstrainedShapeSymbolProvider] for more details and for an example. + */ +class UnconstrainedUnionGenerator( + val codegenContext: ServerCodegenContext, + private val unconstrainedModuleWriter: RustWriter, + private val modelsModuleWriter: RustWriter, + val shape: UnionShape, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider + private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes + private val constraintViolationSymbolProvider = + with(codegenContext.constraintViolationSymbolProvider) { + if (publicConstrainedTypes) { + this + } else { + PubCrateConstraintViolationSymbolProvider(this) + } + } + private val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape) + private val sortedMembers: List = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) } + + fun render() { + check(shape.canReachConstrainedShape(model, symbolProvider)) + + val moduleName = symbol.namespace.split(symbol.namespaceDelimiter).last() + val name = symbol.name + val constrainedSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) + val constraintViolationName = constraintViolationSymbol.name + + unconstrainedModuleWriter.withModule(RustModule(moduleName, RustMetadata(visibility = Visibility.PUBCRATE))) { + rustBlock( + """ + ##[allow(clippy::enum_variant_names)] + ##[derive(Debug, Clone)] + pub(crate) enum $name + """, + ) { + sortedMembers.forEach { member -> + rust( + "${unconstrainedShapeSymbolProvider.toMemberName(member)}(#T),", + unconstrainedShapeSymbolProvider.toSymbol(member), + ) + } + } + + rustTemplate( + """ + impl #{TryFrom}<$name> for #{ConstrainedSymbol} { + type Error = #{ConstraintViolationSymbol}; + + fn try_from(value: $name) -> Result { + #{body:W} + } + } + """, + "TryFrom" to RuntimeType.TryFrom, + "ConstrainedSymbol" to constrainedSymbol, + "ConstraintViolationSymbol" to constraintViolationSymbol, + "body" to generateTryFromUnconstrainedUnionImpl(), + ) + } + + modelsModuleWriter.rustTemplate( + """ + impl #{ConstrainedTrait} for #{ConstrainedSymbol} { + type Unconstrained = #{UnconstrainedSymbol}; + } + + impl From<#{UnconstrainedSymbol}> for #{MaybeConstrained} { + fn from(value: #{UnconstrainedSymbol}) -> Self { + Self::Unconstrained(value) + } + } + """, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait(), + "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(), + "ConstrainedSymbol" to constrainedSymbol, + "UnconstrainedSymbol" to symbol, + ) + + val constraintViolationVisibility = if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } + modelsModuleWriter.withModule( + RustModule( + constraintViolationSymbol.namespace.split(constraintViolationSymbol.namespaceDelimiter).last(), + RustMetadata(visibility = constraintViolationVisibility), + ), + ) { + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.PartialEq)).render(this) + rustBlock("pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate)" else "" } enum $constraintViolationName") { + constraintViolations().forEach { renderConstraintViolation(this, it) } + } + + if (shape.isReachableFromOperationInput()) { + rustBlock("impl $constraintViolationName") { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + withBlock("match self {", "}") { + for (constraintViolation in constraintViolations()) { + rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""") + } + } + } + } + } + } + } + + data class ConstraintViolation(val forMember: MemberShape) { + fun name() = forMember.memberName.toPascalCase() + } + + private fun constraintViolations() = + sortedMembers + .filter { it.targetCanReachConstrainedShape(model, symbolProvider) } + .map { ConstraintViolation(it) } + + private fun renderConstraintViolation(writer: RustWriter, constraintViolation: ConstraintViolation) { + val targetShape = model.expectShape(constraintViolation.forMember.target) + + val constraintViolationSymbol = + constraintViolationSymbolProvider.toSymbol(targetShape) + // If the corresponding union's member is boxed, box this constraint violation symbol too. + .letIf(constraintViolation.forMember.hasTrait()) { + it.makeRustBoxed() + } + + writer.rust( + "${constraintViolation.name()}(#T),", + constraintViolationSymbol, + ) + } + + private fun generateTryFromUnconstrainedUnionImpl() = writable { + withBlock("Ok(", ")") { + withBlock("match value {", "}") { + sortedMembers.forEach { member -> + val memberName = unconstrainedShapeSymbolProvider.toMemberName(member) + withBlockTemplate( + "#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(", + "),", + "UnconstrainedUnion" to symbol, + ) { + if (!member.canReachConstrainedShape(model, symbolProvider)) { + rust("unconstrained") + } else { + val targetShape = model.expectShape(member.target) + val resolveToNonPublicConstrainedType = + targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait() && + (!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider)) + + val (unconstrainedVar, boxIt) = if (member.hasTrait()) { + "(*unconstrained)" to ".map(Box::new).map_err(Box::new)" + } else { + "unconstrained" to "" + } + + if (resolveToNonPublicConstrainedType) { + val constrainedSymbol = + if (!publicConstrainedTypes && targetShape.isDirectlyConstrained(symbolProvider)) { + codegenContext.constrainedShapeSymbolProvider.toSymbol(targetShape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(targetShape) + } + rustTemplate( + """ + { + let constrained: #{ConstrainedSymbol} = $unconstrainedVar + .try_into() + $boxIt + .map_err(Self::Error::${ConstraintViolation(member).name()})?; + constrained.into() + } + """, + "ConstrainedSymbol" to constrainedSymbol, + ) + } else { + rust( + """ + $unconstrainedVar + .try_into() + $boxIt + .map_err(Self::Error::${ConstraintViolation(member).name()})? + """, + ) + } + } + } + } + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index 55ff0fbe7..f866d83e3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -5,21 +5,49 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape class ServerRequestBindingGenerator( protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: ServerCodegenContext, operationShape: OperationShape, ) { - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun serverBuilderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol( + codegenContext.symbolProvider, + !codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + private val httpBindingGenerator = + HttpBindingGenerator( + protocol, + codegenContext, + codegenContext.unconstrainedShapeSymbolProvider, + operationShape, + ::serverBuilderSymbol, + listOf( + ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization( + codegenContext, + ), + ), + ) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType = httpBindingGenerator.generateDeserializeHeaderFn(binding) @@ -39,3 +67,22 @@ class ServerRequestBindingGenerator( binding: HttpBindingDescriptor, ): RuntimeType = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) } + +/** + * A customization to, just after we've deserialized HTTP request headers bound to a map shape via `@httpPrefixHeaders`, + * wrap the `std::collections::HashMap` in an unconstrained type wrapper newtype. + */ +class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> emptySection + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> writable { + if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { + rust( + "let out = out.map(#T);", + codegenContext.unconstrainedShapeSymbolProvider.toSymbol(section.memberShape).mapRustType { it.stripOuter() }, + ) + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 1967e4304..020f558ce 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -5,21 +5,67 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.http +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType class ServerResponseBindingGenerator( protocol: Protocol, - codegenContext: CodegenContext, + private val codegenContext: ServerCodegenContext, operationShape: OperationShape, ) { - private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape) + private fun builderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol(codegenContext) + + private val httpBindingGenerator = + HttpBindingGenerator( + protocol, + codegenContext, + codegenContext.symbolProvider, + operationShape, + ::builderSymbol, + listOf( + ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization( + codegenContext, + ), + ), + ) fun generateAddHeadersFn(shape: Shape): RuntimeType? = httpBindingGenerator.generateAddHeadersFn(shape, HttpMessageType.RESPONSE) } + +/** + * A customization to, just before we iterate over a _constrained_ map shape that is bound to HTTP response headers via + * `@httpPrefixHeaders`, unwrap the wrapper newtype and take a shared reference to the actual `std::collections::HashMap` + * within it. + */ +class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : + HttpBindingCustomization() { + override fun section(section: HttpBindingSection): Writable = when (section) { + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + rust("let ${section.variableName} = &${section.variableName}.0;") + } + } + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> emptySection + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 2dea2153b..86b33c617 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -5,8 +5,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -21,10 +24,22 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserCustomization +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.ReturnSymbolToParse +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator +import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape private fun allOperations(codegenContext: CodegenContext): List { val index = TopDownIndex.of(codegenContext.model) @@ -79,9 +94,9 @@ interface ServerProtocol : Protocol { } class ServerAwsJsonProtocol( - codegenContext: CodegenContext, + private val serverCodegenContext: ServerCodegenContext, awsJsonVersion: AwsJsonVersion, -) : AwsJson(codegenContext, awsJsonVersion), ServerProtocol { +) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol { private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = arrayOf( "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), @@ -89,11 +104,33 @@ class ServerAwsJsonProtocol( private val symbolProvider = codegenContext.symbolProvider private val service = codegenContext.serviceShape + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.serverBuilderSymbol(serverCodegenContext) + fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = + if (shape.canReachConstrainedShape(codegenContext.model, symbolProvider)) { + ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) + } else { + ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) + } + return JsonParserGenerator( + codegenContext, + httpBindingResolver, + ::awsJsonFieldName, + ::builderSymbol, + ::returnSymbolToParse, + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(serverCodegenContext), + ), + ) + } + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - ServerAwsJsonSerializerGenerator(codegenContext, httpBindingResolver, awsJsonVersion) + ServerAwsJsonSerializerGenerator(serverCodegenContext, httpBindingResolver, awsJsonVersion) companion object { - fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = ServerAwsJsonProtocol(awsJson.codegenContext, awsJson.version) + fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = + ServerAwsJsonProtocol(awsJson.codegenContext as ServerCodegenContext, awsJson.version) } override fun markerStruct(): RuntimeType { @@ -203,12 +240,38 @@ private fun restRouterConstruction( } class ServerRestJsonProtocol( - codegenContext: CodegenContext, -) : RestJson(codegenContext), ServerProtocol { + private val serverCodegenContext: ServerCodegenContext, +) : RestJson(serverCodegenContext), ServerProtocol { val runtimeConfig = codegenContext.runtimeConfig + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { + fun builderSymbol(shape: StructureShape): Symbol = + shape.serverBuilderSymbol(serverCodegenContext) + fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = + if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) { + ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) + } else { + ReturnSymbolToParse(serverCodegenContext.symbolProvider.toSymbol(shape), false) + } + return JsonParserGenerator( + codegenContext, + httpBindingResolver, + ::restJsonFieldName, + ::builderSymbol, + ::returnSymbolToParse, + listOf( + ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization( + serverCodegenContext, + ), + ), + ) + } + + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver) + companion object { - fun fromCoreProtocol(restJson: RestJson): ServerRestJsonProtocol = ServerRestJsonProtocol(restJson.codegenContext) + fun fromCoreProtocol(restJson: RestJson): ServerRestJsonProtocol = ServerRestJsonProtocol(restJson.codegenContext as ServerCodegenContext) } override fun markerStruct() = ServerRuntimeType.Protocol("RestJson1", "rest_json_1", runtimeConfig) @@ -257,3 +320,22 @@ class ServerRestXmlProtocol( override fun serverContentTypeCheckNoModeledInput() = true } + +/** + * A customization to, just before we box a recursive member that we've deserialized into `Option`, convert it into + * `MaybeConstrained` if the target shape can reach a constrained shape. + */ +class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : + JsonParserCustomization() { + override fun section(section: JsonParserSection): Writable = when (section) { + is JsonParserSection.BeforeBoxingDeserializedMember -> writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust(".map(|x| x.into())") + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index ea8792e40..3fbdc2ebc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -935,18 +935,8 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonMalformedUnionNoFieldsSet", TestType.MalformedRequest), - // Tests involving constraint traits, which are not yet implemented. - // See https://github.com/awslabs/smithy-rs/pull/1342. - FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case1", TestType.MalformedRequest), + // Tests involving constraint traits, which are not yet fully implemented. + // See https://github.com/awslabs/smithy-rs/issues/1401. FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlobOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlobOverride_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthListOverride_case0", TestType.MalformedRequest), @@ -960,17 +950,8 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlob_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthList_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthListValue_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthListValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMap_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMap_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapKey_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapKey_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapValue_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthString_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthString_case2", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternListOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternListOverride_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapKeyOverride_case0", TestType.MalformedRequest), @@ -1010,9 +991,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMaxStringOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedLengthMinStringOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthQueryStringNoValue", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMaxString", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMinString", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxByteOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloatOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinByteOverride", TestType.MalformedRequest), @@ -1021,10 +999,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinByte", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRequiredBodyExplicitNull", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRequiredBodyUnset", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRequiredHeaderUnset", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRecursiveStructures", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedPatternSensitiveString", TestType.MalformedRequest), // Some tests for the S3 service (restXml). diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 8215d485b..815608fb2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -10,18 +10,18 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.core.smithy.protocols.awsJsonFieldName -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonCustomization -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSection +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerAwsJsonProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol @@ -56,9 +56,9 @@ class ServerAwsJsonFactory(private val version: AwsJsonVersion) : * AwsJson requires errors to be serialized in server responses with an additional `__type` field. This * customization writes the right field depending on the version of the AwsJson protocol. */ -class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCustomization() { - override fun section(section: JsonSection): Writable = when (section) { - is JsonSection.ServerError -> writable { +class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonSerializerCustomization() { + override fun section(section: JsonSerializerSection): Writable = when (section) { + is JsonSerializerSection.ServerError -> writable { if (section.structureShape.hasTrait()) { val typeId = when (awsJsonVersion) { // AwsJson 1.0 wants the whole shape ID (namespace#Shape). @@ -82,7 +82,7 @@ class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCusto * https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization */ class ServerAwsJsonSerializerGenerator( - private val codegenContext: CodegenContext, + private val codegenContext: ServerCodegenContext, private val httpBindingResolver: HttpBindingResolver, private val awsJsonVersion: AwsJsonVersion, private val jsonSerializerGenerator: JsonSerializerGenerator = @@ -90,6 +90,6 @@ class ServerAwsJsonSerializerGenerator( codegenContext, httpBindingResolver, ::awsJsonFieldName, - customizations = listOf(ServerAwsJsonError(awsJsonVersion)), + customizations = listOf(ServerAwsJsonError(awsJsonVersion), BeforeIteratingOverMapJsonCustomization(codegenContext)), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 2cf86be5e..4cc16278c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape @@ -32,32 +33,29 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asType import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock -import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization -import software.amazon.smithy.rust.codegen.core.smithy.extractSymbolFromOption -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.MakeOperationGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.toOptional import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.smithy.wrapOptional @@ -74,16 +72,19 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerRequestBindingGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerResponseBindingGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import java.util.logging.Logger /** * Implement operations' input parsing and output serialization. Protocols can plug their own implementations * and overrides by creating a protocol factory inheriting from this class and feeding it to the [ServerProtocolLoader]. - * See `ServerRestJsonFactory.kt` for more info. + * See `ServerRestJson.kt` for more info. */ class ServerHttpBoundProtocolGenerator( codegenContext: ServerCodegenContext, @@ -117,6 +118,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) : ProtocolTraitImplGenerator { private val logger = Logger.getLogger(javaClass.name) private val symbolProvider = codegenContext.symbolProvider + private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider private val model = codegenContext.model private val runtimeConfig = codegenContext.runtimeConfig private val httpBindingResolver = protocol.httpBindingResolver @@ -592,7 +594,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( * case it will generate response headers for the given error shape. * * It sets three groups of headers in order. Headers from one group take precedence over headers in a later group. - * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. + * 1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits. = null * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ @@ -712,7 +714,10 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val structuredDataParser = protocol.structuredDataParser(operationShape) Attribute.AllowUnusedMut.render(this) - rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) + rust( + "let mut input = #T::default();", + inputShape.serverBuilderSymbol(codegenContext), + ) val parser = structuredDataParser.serverInputParser(operationShape) val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { @@ -732,9 +737,21 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val member = binding.member val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) if (parsedValue != null) { - withBlock("input = input.${member.setterName()}(", ");") { - parsedValue(this) - } + rust("if let Some(value) = ") + parsedValue(this) + rust( + """ + { + input = input.${member.setterName()}(${ + if (symbolProvider.toSymbol(binding.member).isOptional()) { + "Some(value)" + } else { + "value" + } + }); + } + """, + ) } } serverRenderUriPathParser(this, operationShape) @@ -750,7 +767,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } - val err = if (StructureGenerator.hasFallibleBuilder(inputShape, symbolProvider)) { + val err = if (ServerBuilderGenerator.hasFallibleBuilder( + inputShape, + model, + symbolProvider, + takeInUnconstrainedTypes = true, + ) + ) { "?" } else "" rustTemplate("input.build()$err", *codegenScope) @@ -884,13 +907,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( .forEachIndexed { index, segment -> val binding = pathBindings.find { it.memberName == segment.content } if (binding != null && segment.isLabel) { - val deserializer = generateParseFn(binding, true) + val deserializer = generateParseStrFn(binding, true) rustTemplate( """ input = input.${binding.member.setterName()}( - ${symbolProvider.toOptional(binding.member, "#{deserializer}(m$index)?")} + #{deserializer}(m$index)? ); - """.trimIndent(), + """, *codegenScope, "deserializer" to deserializer, ) @@ -905,13 +928,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // * a map of set of string. enum class QueryParamsTargetMapValueType { STRING, LIST, SET; - - fun asRustType(): RustType = - when (this) { - STRING -> RustType.String - LIST -> RustType.Vec(RustType.String) - SET -> RustType.HashSet(RustType.String) - } } private fun queryParamsTargetMapValueType(targetMapValue: Shape): QueryParamsTargetMapValueType = @@ -924,8 +940,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } else { throw ExpectationNotMetException( """ - @httpQueryParams trait applied to non-supported target - $targetMapValue of type ${targetMapValue.type} + @httpQueryParams trait applied to non-supported target $targetMapValue of type ${targetMapValue.type} """.trimIndent(), targetMapValue.sourceLocation, ) @@ -947,9 +962,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType { check(this.location == HttpLocation.QUERY_PARAMS) - val queryParamsTarget = model.expectShape(this.member.target) - val mapTarget = queryParamsTarget.asMapShape().get() - return queryParamsTargetMapValueType(model.expectShape(mapTarget.value.target)) + val queryParamsTarget = model.expectShape(this.member.target, MapShape::class.java) + return queryParamsTargetMapValueType(model.expectShape(queryParamsTarget.value.target)) } with(writer) { @@ -962,11 +976,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) if (queryParamsBinding != null) { - rustTemplate( - "let mut query_params: #{HashMap} = #{HashMap}::new();", - "HashMap" to software.amazon.smithy.rust.codegen.core.rustlang.RustType.HashMap.RuntimeType, - ) + val target = model.expectShape(queryParamsBinding.member.target, MapShape::class.java) + val hasConstrainedTarget = target.canReachConstrainedShape(model, symbolProvider) + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Here we only check the target shape; + // constraint traits on member shapes are not implemented yet. + val targetSymbol = unconstrainedShapeSymbolProvider.toSymbol(target) + withBlock("let mut query_params: #T = ", ";", targetSymbol) { + conditionalBlock("#T(", ")", conditional = hasConstrainedTarget, targetSymbol) { + rust("#T::new()", RustType.HashMap.RuntimeType) + } + } } val (queryBindingsTargettingCollection, queryBindingsTargettingSimple) = queryBindings.partition { model.expectShape(it.member.target) is CollectionShape } @@ -979,13 +998,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustBlock("for (k, v) in pairs") { queryBindingsTargettingSimple.forEach { - val deserializer = generateParseFn(it, false) + val deserializer = generateParseStrFn(it, false) val memberName = symbolProvider.toMemberName(it.member) rustTemplate( """ if !seen_$memberName && k == "${it.locationName}" { input = input.${it.member.setterName()}( - ${symbolProvider.toOptional(it.member, "#{deserializer}(&v)?")} + #{deserializer}(&v)? ); seen_$memberName = true; } @@ -993,22 +1012,20 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "deserializer" to deserializer, ) } - queryBindingsTargettingCollection.forEach { - rustBlock("if k == ${it.locationName.dq()}") { + queryBindingsTargettingCollection.forEachIndexed { idx, it -> + rustBlock("${if (idx > 0) "else " else ""}if k == ${it.locationName.dq()}") { val targetCollectionShape = model.expectShape(it.member.target, CollectionShape::class.java) val memberShape = model.expectShape(targetCollectionShape.member.target) when { memberShape.isStringShape -> { - // NOTE: This path is traversed with or without @enum applied. The `try_from` is used - // as a common conversion. - rustTemplate( - """ - let v = <#{memberShape}>::try_from(v.as_ref())?; - """, - *codegenScope, - "memberShape" to symbolProvider.toSymbol(memberShape), - ) + if (queryParamsBinding != null) { + // If there's an `@httpQueryParams` binding, it will want to consume the parsed data + // too further down, so we need to clone it. + rust("let v = v.clone().into_owned();") + } else { + rust("let v = v.into_owned();") + } } memberShape.isTimestampShape -> { val index = HttpBindingIndex.of(model) @@ -1042,47 +1059,79 @@ private class ServerHttpBoundProtocolTraitImplGenerator( } if (queryParamsBinding != null) { + val target = model.expectShape(queryParamsBinding.member.target, MapShape::class.java) + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Here we only check the target shape; + // constraint traits on member shapes are not implemented yet. + val hasConstrainedTarget = target.canReachConstrainedShape(model, symbolProvider) when (queryParamsBinding.queryParamsBindingTargetMapValueType()) { QueryParamsTargetMapValueType.STRING -> { - rust("query_params.entry(String::from(k)).or_insert_with(|| String::from(v));") - } else -> { - rustTemplate( - """ - let entry = query_params.entry(String::from(k)).or_default(); - entry.push(String::from(v)); - """.trimIndent(), - ) + rust("query_params.${if (hasConstrainedTarget) "0." else ""}entry(String::from(k)).or_insert_with(|| String::from(v));") + } + QueryParamsTargetMapValueType.LIST, QueryParamsTargetMapValueType.SET -> { + if (hasConstrainedTarget) { + val collectionShape = model.expectShape(target.value.target, CollectionShape::class.java) + val collectionSymbol = unconstrainedShapeSymbolProvider.toSymbol(collectionShape) + rust( + // `or_insert_with` instead of `or_insert` to avoid the allocation when the entry is + // not empty. + """ + let entry = query_params.0.entry(String::from(k)).or_insert_with(|| #T(std::vec::Vec::new())); + entry.0.push(String::from(v)); + """, + collectionSymbol, + ) + } else { + rust( + """ + let entry = query_params.entry(String::from(k)).or_default(); + entry.push(String::from(v)); + """, + ) + } } } } } if (queryParamsBinding != null) { - rust("input = input.${queryParamsBinding.member.setterName()}(Some(query_params));") + val isOptional = unconstrainedShapeSymbolProvider.toSymbol(queryParamsBinding.member).isOptional() + withBlock("input = input.${queryParamsBinding.member.setterName()}(", ");") { + conditionalBlock("Some(", ")", conditional = isOptional) { + write("query_params") + } + } } - queryBindingsTargettingCollection.forEach { - val memberName = symbolProvider.toMemberName(it.member) - rustTemplate( - """ - input = input.${it.member.setterName()}( - if $memberName.is_empty() { - None - } else { - Some($memberName) + queryBindingsTargettingCollection.forEach { binding -> + // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not + // implemented yet. + val hasConstrainedTarget = + model.expectShape(binding.member.target, CollectionShape::class.java).canReachConstrainedShape(model, symbolProvider) + val memberName = unconstrainedShapeSymbolProvider.toMemberName(binding.member) + val isOptional = unconstrainedShapeSymbolProvider.toSymbol(binding.member).isOptional() + rustBlock("if !$memberName.is_empty()") { + withBlock( + "input = input.${ + binding.member.setterName() + }(", + ");", + ) { + conditionalBlock("Some(", ")", conditional = isOptional) { + conditionalBlock( + "#T(", + ")", + conditional = hasConstrainedTarget, + unconstrainedShapeSymbolProvider.toSymbol(binding.member).mapRustType { it.stripOuter() }, + ) { + write(memberName) + } } - ); - """.trimIndent(), - ) + } + } } } } private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { - val httpBindingGenerator = - ServerRequestBindingGenerator( - protocol, - codegenContext, - operationShape, - ) + val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ @@ -1096,12 +1145,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( private fun serverRenderPrefixHeadersParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { check(binding.location == HttpLocation.PREFIX_HEADERS) - val httpBindingGenerator = - ServerRequestBindingGenerator( - protocol, - codegenContext, - operationShape, - ) + val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding) writer.rustTemplate( """ @@ -1112,10 +1156,9 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun generateParseFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType { - val output = symbolProvider.toSymbol(binding.member) + private fun generateParseStrFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType { + val output = unconstrainedShapeSymbolProvider.toSymbol(binding.member) val fnName = generateParseStrFnName(binding) - val symbol = output.extractSymbolFromOption() return RuntimeType.forInlineFun(fnName, operationDeserModule) { rustBlockTemplate( "pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>", @@ -1126,24 +1169,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator( when { target.isStringShape -> { - // NOTE: This path is traversed with or without @enum applied. The `try_from` is used as a - // common conversion. if (percentDecoding) { rustTemplate( """ - let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?; - let value = #{T}::try_from(value.as_ref())?; + let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?.into_owned(); """, *codegenScope, - "T" to symbol, ) } else { - rustTemplate( - """ - let value = #{T}::try_from(value)?; - """, - "T" to symbol, - ) + rust("let value = value.to_owned();") } } target.isTimestampShape -> { @@ -1187,7 +1221,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } - rust( """ Ok(${symbolProvider.wrapOptional(binding.member, "value")}) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt similarity index 63% rename from codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt rename to codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt index d3d0fea63..a913b806d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt @@ -6,9 +6,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory +import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeIteratingOverMapJsonCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol /** @@ -36,3 +41,15 @@ class ServerRestJsonFactory : ProtocolGeneratorFactory { + shape.hasTrait() + } else -> PANIC("this method does not support shape type ${shape.type}") +} + +fun StringShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun StructureShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun CollectionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun UnionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) +fun MapShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt new file mode 100644 index 000000000..257bce1f0 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.SetShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTrait + +/** + * Attach the `smithy.framework#ValidationException` error to operations whose inputs are constrained, if they belong + * to a service in an allowlist. + * + * Some of the models we generate in CI have constrained operation inputs, but the operations don't have + * `smithy.framework#ValidationException` in their list of errors. This is a codegen error, unless + * `disableDefaultValidation` is set to `true`, a code generation mode we don't support yet. See [1] for more details. + * Until we implement said mode, we manually attach the error to build these models, since we don't own them (they're + * either actual AWS service model excerpts, or they come from the `awslabs/smithy` library. + * + * [1]: https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1401): This transformer will go away once we add support for + * `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. + */ +object AttachValidationExceptionToConstrainedOperationInputsInAllowList { + private val sherviceShapeIdAllowList = + setOf( + // These we currently generate server SDKs for. + ShapeId.from("aws.protocoltests.restjson#RestJson"), + ShapeId.from("aws.protocoltests.json10#JsonRpc10"), + ShapeId.from("aws.protocoltests.json#JsonProtocol"), + ShapeId.from("com.amazonaws.s3#AmazonS3"), + ShapeId.from("com.amazonaws.ebs#Ebs"), + + // These are only loaded in the classpath and need this model transformer, but we don't generate server + // SDKs for them. Here they are for reference. + // ShapeId.from("aws.protocoltests.restxml#RestXml"), + // ShapeId.from("com.amazonaws.glacier#Glacier"), + // ShapeId.from("aws.protocoltests.ec2#AwsEc2"), + // ShapeId.from("aws.protocoltests.query#AwsQuery"), + // ShapeId.from("com.amazonaws.machinelearning#AmazonML_20141212"), + ) + + fun transform(model: Model): Model { + val walker = Walker(model) + + val operationsWithConstrainedInputWithoutValidationException = model.serviceShapes + .filter { sherviceShapeIdAllowList.contains(it.toShapeId()) } + .flatMap { it.operations } + .map { model.expectShape(it, OperationShape::class.java) } + .filter { operationShape -> + // Walk the shapes reachable via this operation input. + walker.walkShapes(operationShape.inputShape(model)) + .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } + } + .filter { !it.errors.contains(ShapeId.from("smithy.framework#ValidationException")) } + + return ModelTransformer.create().mapShapes(model) { shape -> + if (shape is OperationShape && operationsWithConstrainedInputWithoutValidationException.contains(shape)) { + shape.toBuilder().addError("smithy.framework#ValidationException").build() + } else { + shape + } + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt new file mode 100644 index 000000000..d4c6feaed --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RemoveEbsModelValidationException.kt @@ -0,0 +1,38 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.util.orNull + +/** + * The Amazon Elastic Block Store (Amazon EBS) model is one model that we generate in CI. + * Unfortunately, it defines its own `ValidationException` shape, that conflicts with + * `smithy.framework#ValidationException` [0]. + * + * So this is a model that a service owner would generate when "disabling default validation": in such a code generation + * mode, the service owner is responsible for mapping an operation input-level constraint violation into a modeled + * operation error. This mode, as well as what the end goal for validation exception responses looks like, is described + * in more detail in [1]. We don't support this mode yet. + * + * So this transformer simply removes the EBB model's `ValidationException`. A subsequent model transformer, + * [AttachValidationExceptionToConstrainedOperationInputsInAllowList], ensures that it is replaced by + * `smithy.framework#ValidationException`. + * + * [0]: https://github.com/awslabs/smithy-rs/blob/274adf155042cde49251a0e6b8842d6f56cd5b6d/codegen-core/common-test-models/ebs.json#L1270-L1288 + * [1]: https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1401): This transformer will go away once we implement + * `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. + */ +object RemoveEbsModelValidationException { + fun transform(model: Model): Model { + val shapeToRemove = model.getShape(ShapeId.from("com.amazonaws.ebs#ValidationException")).orNull() + return ModelTransformer.create().removeShapes(model, listOfNotNull(shapeToRemove)) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt new file mode 100644 index 000000000..cf58f3f9d --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.neighbor.Walker +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.server.smithy.traits.ShapeReachableFromOperationInputTagTrait + +/** + * Tag shapes reachable from operation input with the + * [ShapeReachableFromOperationInputTagTrait] tag. + * + * This is useful to determine whether we need to generate code to + * enforce constraints upon request deserialization in the server. + * + * This needs to be a model transformer; it cannot be lazily calculated + * when needed. This is because other model transformers may transform + * the model such that shapes that were reachable from operation + * input are no longer so. For example, [EventStreamNormalizer] pulls + * event stream error variants out of the union shape where they are defined. + * As such, [ShapesReachableFromOperationInputTagger] needs to run + * before these model transformers. + * + * WARNING: This transformer tags _all_ [aggregate shapes], and _some_ [simple shapes], + * but not all of them. Read the implementation to find out what shape types it + * currently tags. + * + * [simple shapes]: https://awslabs.github.io/smithy/2.0/spec/simple-types.html + * [aggregate shapes]: https://awslabs.github.io/smithy/2.0/spec/aggregate-types.html#aggregate-types + */ +object ShapesReachableFromOperationInputTagger { + fun transform(model: Model): Model { + val inputShapes = model.operationShapes.map { + model.expectShape(it.inputShape, StructureShape::class.java) + } + val walker = Walker(model) + val shapesReachableFromOperationInputs = inputShapes + .flatMap { walker.walkShapes(it) } + .toSet() + + return ModelTransformer.create().mapShapes(model) { shape -> + when (shape) { + is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape -> { + if (shapesReachableFromOperationInputs.contains(shape)) { + val builder = when (shape) { + is StructureShape -> shape.toBuilder() + is UnionShape -> shape.toBuilder() + is ListShape -> shape.toBuilder() + is MapShape -> shape.toBuilder() + is StringShape -> shape.toBuilder() + else -> UNREACHABLE("the `when` is exhaustive") + } + builder.addTrait(ShapeReachableFromOperationInputTagTrait()).build() + } else { + shape + } + } + else -> shape + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt new file mode 100644 index 000000000..bcf7fe34c --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider + +const val baseModelString = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + + structure TestInputOutput { + constrainedString: ConstrainedString, + constrainedMap: ConstrainedMap, + unconstrainedMap: TransitivelyConstrainedMap + } + + @length(min: 1, max: 69) + string ConstrainedString + + string UnconstrainedString + + @length(min: 1, max: 69) + map ConstrainedMap { + key: String, + value: String + } + + map TransitivelyConstrainedMap { + key: String, + value: ConstrainedMap + } + + @length(min: 1, max: 69) + list ConstrainedCollection { + member: String + } + """ + +class ConstrainedShapeSymbolProviderTest { + private val model = baseModelString.asSmithyModel() + private val serviceShape = model.lookup("test#TestService") + private val symbolProvider = serverTestSymbolProvider(model, serviceShape) + private val constrainedShapeSymbolProvider = ConstrainedShapeSymbolProvider(symbolProvider, model, serviceShape) + + private val constrainedMapShape = model.lookup("test#ConstrainedMap") + private val constrainedMapType = constrainedShapeSymbolProvider.toSymbol(constrainedMapShape).rustType() + + @Test + fun `it should return a constrained string type for a constrained string shape`() { + val constrainedStringShape = model.lookup("test#ConstrainedString") + val constrainedStringType = constrainedShapeSymbolProvider.toSymbol(constrainedStringShape).rustType() + + constrainedStringType shouldBe RustType.Opaque("ConstrainedString", "crate::model") + } + + @Test + fun `it should return a constrained map type for a constrained map shape`() { + constrainedMapType shouldBe RustType.Opaque("ConstrainedMap", "crate::model") + } + + @Test + fun `it should not blindly delegate to the base symbol provider when the shape is an aggregate shape and is not directly constrained`() { + val unconstrainedMapShape = model.lookup("test#TransitivelyConstrainedMap") + val unconstrainedMapType = constrainedShapeSymbolProvider.toSymbol(unconstrainedMapShape).rustType() + + unconstrainedMapType shouldBe RustType.HashMap(RustType.String, constrainedMapType) + } + + @Test + fun `it should delegate to the base symbol provider for unconstrained simple shapes`() { + val unconstrainedStringShape = model.lookup("test#UnconstrainedString") + val unconstrainedStringSymbol = constrainedShapeSymbolProvider.toSymbol(unconstrainedStringShape) + + unconstrainedStringSymbol shouldBe symbolProvider.toSymbol(unconstrainedStringShape) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt new file mode 100644 index 000000000..80e2d93da --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -0,0 +1,135 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.inspectors.forAll +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider + +class ConstraintsTest { + private val model = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + + structure TestInputOutput { + map: MapA, + + recursive: RecursiveShape + } + + structure RecursiveShape { + shape: RecursiveShape, + mapB: MapB + } + + @length(min: 1, max: 69) + map MapA { + key: String, + value: MapB + } + + map MapB { + key: String, + value: StructureA + } + + @uniqueItems + list ListA { + member: MyString + } + + @pattern("\\w+") + string MyString + + @length(min: 1, max: 69) + string LengthString + + structure StructureA { + @range(min: 1, max: 69) + int: Integer, + + @required + string: String + } + + // This shape is not in the service closure. + structure StructureB { + @pattern("\\w+") + patternString: String, + + @required + requiredString: String, + + mapA: MapA, + + @length(min: 1, max: 5) + mapAPrecedence: MapA + } + """.asSmithyModel() + private val symbolProvider = serverTestSymbolProvider(model) + + private val testInputOutput = model.lookup("test#TestInputOutput") + private val recursiveShape = model.lookup("test#RecursiveShape") + private val mapA = model.lookup("test#MapA") + private val mapB = model.lookup("test#MapB") + private val listA = model.lookup("test#ListA") + private val myString = model.lookup("test#MyString") + private val lengthString = model.lookup("test#LengthString") + private val structA = model.lookup("test#StructureA") + private val structAInt = model.lookup("test#StructureA\$int") + private val structAString = model.lookup("test#StructureA\$string") + + @Test + fun `it should not recognize uniqueItems as a constraint trait because it's deprecated`() { + listA.isDirectlyConstrained(symbolProvider) shouldBe false + } + + @Test + fun `it should detect supported constrained traits as constrained`() { + listOf(mapA, structA, lengthString).forAll { + it.isDirectlyConstrained(symbolProvider) shouldBe true + } + } + + @Test + fun `it should not detect unsupported constrained traits as constrained`() { + listOf(structAInt, structAString, myString).forAll { + it.isDirectlyConstrained(symbolProvider) shouldBe false + } + } + + @Test + fun `it should evaluate reachability of constrained shapes`() { + mapA.canReachConstrainedShape(model, symbolProvider) shouldBe true + structAInt.canReachConstrainedShape(model, symbolProvider) shouldBe false + + // This should be true when we start supporting the `pattern` trait on string shapes. + listA.canReachConstrainedShape(model, symbolProvider) shouldBe false + + // All of these eventually reach `StructureA`, which is constrained because one of its members is `required`. + testInputOutput.canReachConstrainedShape(model, symbolProvider) shouldBe true + mapB.canReachConstrainedShape(model, symbolProvider) shouldBe true + recursiveShape.canReachConstrainedShape(model, symbolProvider) shouldBe true + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt new file mode 100644 index 000000000..21baefe74 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProviders + +class PubCrateConstrainedShapeSymbolProviderTest { + private val model = """ + $baseModelString + + list TransitivelyConstrainedCollection { + member: Structure + } + + structure Structure { + @required + requiredMember: String + } + + structure StructureWithMemberTargetingAggregateShape { + member: TransitivelyConstrainedCollection + } + + union Union { + structure: Structure + } + """.asSmithyModel() + + private val serverTestSymbolProviders = serverTestSymbolProviders(model) + private val symbolProvider = serverTestSymbolProviders.symbolProvider + private val pubCrateConstrainedShapeSymbolProvider = serverTestSymbolProviders.pubCrateConstrainedShapeSymbolProvider + + @Test + fun `it should crash when provided with a shape that is directly constrained`() { + val constrainedStringShape = model.lookup("test#ConstrainedString") + shouldThrow { pubCrateConstrainedShapeSymbolProvider.toSymbol(constrainedStringShape) } + } + + @Test + fun `it should crash when provided with a shape that is unconstrained`() { + val unconstrainedStringShape = model.lookup("test#UnconstrainedString") + shouldThrow { pubCrateConstrainedShapeSymbolProvider.toSymbol(unconstrainedStringShape) } + } + + @Test + fun `it should return an opaque type for transitively constrained collection shapes`() { + val transitivelyConstrainedCollectionShape = model.lookup("test#TransitivelyConstrainedCollection") + val transitivelyConstrainedCollectionType = + pubCrateConstrainedShapeSymbolProvider.toSymbol(transitivelyConstrainedCollectionShape).rustType() + + transitivelyConstrainedCollectionType shouldBe RustType.Opaque( + "TransitivelyConstrainedCollectionConstrained", + "crate::constrained::transitively_constrained_collection_constrained", + ) + } + + @Test + fun `it should return an opaque type for transitively constrained map shapes`() { + val transitivelyConstrainedMapShape = model.lookup("test#TransitivelyConstrainedMap") + val transitivelyConstrainedMapType = + pubCrateConstrainedShapeSymbolProvider.toSymbol(transitivelyConstrainedMapShape).rustType() + + transitivelyConstrainedMapType shouldBe RustType.Opaque( + "TransitivelyConstrainedMapConstrained", + "crate::constrained::transitively_constrained_map_constrained", + ) + } + + @Test + fun `it should not blindly delegate to the base symbol provider when provided with a transitively constrained structure member shape targeting an aggregate shape`() { + val memberShape = model.lookup("test#StructureWithMemberTargetingAggregateShape\$member") + val memberType = pubCrateConstrainedShapeSymbolProvider.toSymbol(memberShape).rustType() + + memberType shouldBe RustType.Option( + RustType.Opaque( + "TransitivelyConstrainedCollectionConstrained", + "crate::constrained::transitively_constrained_collection_constrained", + ), + ) + } + + @Test + fun `it should delegate to the base symbol provider when provided with a structure shape`() { + val structureShape = model.lookup("test#TestInputOutput") + val structureSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(structureShape) + + structureSymbol shouldBe symbolProvider.toSymbol(structureShape) + } + + @Test + fun `it should delegate to the base symbol provider when provided with a union shape`() { + val unionShape = model.lookup("test#Union") + val unionSymbol = pubCrateConstrainedShapeSymbolProvider.toSymbol(unionShape) + + unionSymbol shouldBe symbolProvider.toSymbol(unionShape) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt new file mode 100644 index 000000000..7c8efe9c1 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProviderTest.kt @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProviders + +/** + * While [UnconstrainedShapeSymbolProvider] _must_ be in the `codegen` subproject, these tests need to be in the + * `codegen-server` subproject, because they use [serverTestSymbolProvider]. + */ +class UnconstrainedShapeSymbolProviderTest { + private val baseModelString = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + + structure TestInputOutput { + list: ListA + } + """ + + @Test + fun `it should adjust types for unconstrained shapes`() { + val model = + """ + $baseModelString + + list ListA { + member: ListB + } + + list ListB { + member: StructureC + } + + structure StructureC { + @required + string: String + } + """.asSmithyModel() + + val unconstrainedShapeSymbolProvider = serverTestSymbolProviders(model).unconstrainedShapeSymbolProvider + + val listAShape = model.lookup("test#ListA") + val listAType = unconstrainedShapeSymbolProvider.toSymbol(listAShape).rustType() + + val listBShape = model.lookup("test#ListB") + val listBType = unconstrainedShapeSymbolProvider.toSymbol(listBShape).rustType() + + val structureCShape = model.lookup("test#StructureC") + val structureCType = unconstrainedShapeSymbolProvider.toSymbol(structureCShape).rustType() + + listAType shouldBe RustType.Opaque("ListAUnconstrained", "crate::unconstrained::list_a_unconstrained") + listBType shouldBe RustType.Opaque("ListBUnconstrained", "crate::unconstrained::list_b_unconstrained") + structureCType shouldBe RustType.Opaque("Builder", "crate::model::structure_c") + } + + @Test + fun `it should delegate to the base symbol provider if called with a shape that cannot reach a constrained shape`() { + val model = + """ + $baseModelString + + list ListA { + member: StructureB + } + + structure StructureB { + string: String + } + """.asSmithyModel() + + val unconstrainedShapeSymbolProvider = serverTestSymbolProviders(model).unconstrainedShapeSymbolProvider + + val listAShape = model.lookup("test#ListA") + val structureBShape = model.lookup("test#StructureB") + + unconstrainedShapeSymbolProvider.toSymbol(structureBShape).rustType().render() shouldBe "crate::model::StructureB" + unconstrainedShapeSymbolProvider.toSymbol(listAShape).rustType().render() shouldBe "std::vec::Vec" + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt new file mode 100644 index 000000000..0624358ba --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt @@ -0,0 +1,254 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.inspectors.forSome +import io.kotest.inspectors.shouldForAll +import io.kotest.matchers.collections.shouldHaveAtLeastSize +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import java.util.logging.Level + +internal class ValidateUnsupportedConstraintsAreNotUsedTest { + private val baseModel = + """ + namespace test + + service TestService { + version: "123", + operations: [TestOperation] + } + + operation TestOperation { + input: TestInputOutput, + output: TestInputOutput, + } + """ + + private fun validateModel(model: Model, serverCodegenConfig: ServerCodegenConfig = ServerCodegenConfig()): ValidationResult { + val service = model.lookup("test#TestService") + return validateUnsupportedConstraints(model, service, serverCodegenConfig) + } + + @Test + fun `it should detect when an operation with constrained input but that does not have ValidationException attached in errors`() { + val model = + """ + $baseModel + + structure TestInputOutput { + @required + requiredString: String + } + """.asSmithyModel() + val service = model.lookup("test#TestService") + val validationResult = validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model, service) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "Operation test#TestOperation takes in input that is constrained" + } + + @Test + fun `it should detect when unsupported constraint traits on member shapes are used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + @length(min: 1, max: 69) + lengthString: String + } + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "The member shape `test#TestInputOutput\$lengthString` has the constraint trait `smithy.api#length` attached" + } + + @Test + fun `it should not detect when the required trait on a member shape is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + @required + string: String + } + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 0 + } + + private val constraintTraitOnStreamingBlobShapeModel = + """ + $baseModel + + structure TestInputOutput { + @required + streamingBlob: StreamingBlob + } + + @streaming + @length(min: 69) + blob StreamingBlob + """.asSmithyModel() + + @Test + fun `it should detect when constraint traits on streaming blob shapes are used`() { + val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel) + + validationResult.messages shouldHaveSize 2 + validationResult.messages.forSome { + it.message shouldContain + """ + The blob shape `test#StreamingBlob` has both the `smithy.api#length` and `smithy.api#streaming` constraint traits attached. + It is unclear what the semantics for streaming blob shapes are. + """.trimIndent().replace("\n", " ") + } + } + + @Test + fun `it should detect when constraint traits in event streams are used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + eventStream: EventStream + } + + @streaming + union EventStream { + message: Message + } + + structure Message { + lengthString: LengthString + } + + @length(min: 1) + string LengthString + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain + """ + The string shape `test#LengthString` has the constraint trait `smithy.api#length` attached. + This shape is also part of an event stream; it is unclear what the semantics for constrained shapes in event streams are. + """.trimIndent().replace("\n", " ") + } + + @Test + fun `it should detect when the length trait on collection shapes or on blob shapes is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + collection: LengthCollection, + blob: LengthBlob + } + + @length(min: 1) + list LengthCollection { + member: String + } + + @length(min: 1) + blob LengthBlob + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 2 + validationResult.messages.forSome { it.message shouldContain "The list shape `test#LengthCollection` has the constraint trait `smithy.api#length` attached" } + validationResult.messages.forSome { it.message shouldContain "The blob shape `test#LengthBlob` has the constraint trait `smithy.api#length` attached" } + } + + @Test + fun `it should detect when the pattern trait on string shapes is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + patternString: PatternString + } + + @pattern("^[A-Za-z]+$") + string PatternString + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "The string shape `test#PatternString` has the constraint trait `smithy.api#pattern` attached" + } + + @Test + fun `it should detect when the range trait is used`() { + val model = + """ + $baseModel + + structure TestInputOutput { + rangeInteger: RangeInteger + } + + @range(min: 1) + integer RangeInteger + """.asSmithyModel() + val validationResult = validateModel(model) + + validationResult.messages shouldHaveSize 1 + validationResult.messages[0].message shouldContain "The integer shape `test#RangeInteger` has the constraint trait `smithy.api#range` attached" + } + + @Test + fun `it should abort when ignoreUnsupportedConstraints is false and unsupported constraints are used`() { + val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel, ServerCodegenConfig()) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.shouldAbort shouldBe true + } + + @Test + fun `it should not abort when ignoreUnsupportedConstraints is true and unsupported constraints are used`() { + val validationResult = validateModel( + constraintTraitOnStreamingBlobShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.shouldAbort shouldBe false + } + + @Test + fun `it should set log level to error when ignoreUnsupportedConstraints is false and unsupported constraints are used`() { + val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel, ServerCodegenConfig()) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.messages.shouldForAll { it.level shouldBe Level.SEVERE } + } + + @Test + fun `it should set log level to warn when ignoreUnsupportedConstraints is true and unsupported constraints are used`() { + val validationResult = validateModel( + constraintTraitOnStreamingBlobShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) + + validationResult.messages shouldHaveAtLeastSize 1 + validationResult.messages.shouldForAll { it.level shouldBe Level.WARNING } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt new file mode 100644 index 000000000..3fa955994 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger +import java.util.stream.Stream + +class ConstrainedMapGeneratorTest { + + data class TestCase(val model: Model, val validMap: ObjectNode, val invalidMap: ObjectNode) + + class ConstrainedMapGeneratorTestProvider : ArgumentsProvider { + private val testCases = listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", 11, 13), + // Min equal to max. + Triple("@length(min: 11, max: 11)", 11, 12), + // Only min. + Triple("@length(min: 11)", 15, 10), + // Only max. + Triple("@length(max: 11)", 11, 12), + ).map { + val validStringMap = List(it.second) { index -> index.toString() to "value" }.toMap() + val inValidStringMap = List(it.third) { index -> index.toString() to "value" }.toMap() + Triple(it.first, ObjectNode.fromStringMap(validStringMap), ObjectNode.fromStringMap(inValidStringMap)) + }.map { (trait, validMap, invalidMap) -> + TestCase( + """ + namespace test + + $trait + map ConstrainedMap { + key: String, + value: String + } + """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform), + validMap, + invalidMap, + ) + } + + override fun provideArguments(context: ExtensionContext?): Stream = + testCases.map { Arguments.of(it) }.stream() + } + + @ParameterizedTest + @ArgumentsSource(ConstrainedMapGeneratorTestProvider::class) + fun `it should generate constrained map types`(testCase: TestCase) { + val constrainedMapShape = testCase.model.lookup("test#ConstrainedMap") + + val codegenContext = serverTestCodegenContext(testCase.model) + val symbolProvider = codegenContext.symbolProvider + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(ModelsModule) { + render(codegenContext, this, constrainedMapShape) + + val instantiator = serverInstantiator(codegenContext) + rustBlock("##[cfg(test)] fn build_valid_map() -> std::collections::HashMap") { + instantiator.render(this, constrainedMapShape, testCase.validMap) + } + rustBlock("##[cfg(test)] fn build_invalid_map() -> std::collections::HashMap") { + instantiator.render(this, constrainedMapShape, testCase.invalidMap) + } + + unitTest( + name = "try_from_success", + test = """ + let map = build_valid_map(); + let _constrained: ConstrainedMap = map.try_into().unwrap(); + """, + ) + unitTest( + name = "try_from_fail", + test = """ + let map = build_invalid_map(); + let constrained_res: Result = map.try_into(); + constrained_res.unwrap_err(); + """, + ) + unitTest( + name = "inner", + test = """ + let map = build_valid_map(); + let constrained = ConstrainedMap::try_from(map.clone()).unwrap(); + + assert_eq!(constrained.inner(), &map); + """, + ) + unitTest( + name = "into_inner", + test = """ + let map = build_valid_map(); + let constrained = ConstrainedMap::try_from(map.clone()).unwrap(); + + assert_eq!(constrained.into_inner(), map); + """, + ) + } + + project.compileAndTest() + } + + @Test + fun `type should not be constructible without using a constructor`() { + val model = """ + namespace test + + @length(min: 1, max: 69) + map ConstrainedMap { + key: String, + value: String + } + """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) + val constrainedMapShape = model.lookup("test#ConstrainedMap") + + val writer = RustWriter.forModule(ModelsModule.name) + + val codegenContext = serverTestCodegenContext(model) + render(codegenContext, writer, constrainedMapShape) + + // Check that the wrapped type is `pub(crate)`. + writer.toString() shouldContain "pub struct ConstrainedMap(pub(crate) std::collections::HashMap);" + } + + private fun render( + codegenContext: ServerCodegenContext, + writer: RustWriter, + constrainedMapShape: MapShape, + ) { + ConstrainedMapGenerator(codegenContext, writer, constrainedMapShape).render() + MapConstraintViolationGenerator(codegenContext, writer, constrainedMapShape).render() + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt new file mode 100644 index 000000000..75db6303f --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt @@ -0,0 +1,179 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import java.util.stream.Stream + +class ConstrainedStringGeneratorTest { + + data class TestCase(val model: Model, val validString: String, val invalidString: String) + + class ConstrainedStringGeneratorTestProvider : ArgumentsProvider { + private val testCases = listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", "validString", "invalidString"), + // Min equal to max. + Triple("@length(min: 11, max: 11)", "validString", "invalidString"), + // Only min. + Triple("@length(min: 11)", "validString", ""), + // Only max. + Triple("@length(max: 11)", "", "invalidString"), + // Count Unicode scalar values, not `.len()`. + Triple( + "@length(min: 3, max: 3)", + "👍👍👍", // These three emojis are three Unicode scalar values. + "👍👍👍👍", + ), + ).map { + TestCase( + """ + namespace test + + ${it.first} + string ConstrainedString + """.asSmithyModel(), + it.second, + it.third, + ) + } + + override fun provideArguments(context: ExtensionContext?): Stream = + testCases.map { Arguments.of(it) }.stream() + } + + @ParameterizedTest + @ArgumentsSource(ConstrainedStringGeneratorTestProvider::class) + fun `it should generate constrained string types`(testCase: TestCase) { + val constrainedStringShape = testCase.model.lookup("test#ConstrainedString") + + val codegenContext = serverTestCodegenContext(testCase.model) + val symbolProvider = codegenContext.symbolProvider + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(ModelsModule) { + ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() + + unitTest( + name = "try_from_success", + test = """ + let string = "${testCase.validString}".to_owned(); + let _constrained: ConstrainedString = string.try_into().unwrap(); + """, + ) + unitTest( + name = "try_from_fail", + test = """ + let string = "${testCase.invalidString}".to_owned(); + let constrained_res: Result = string.try_into(); + constrained_res.unwrap_err(); + """, + ) + unitTest( + name = "inner", + test = """ + let string = "${testCase.validString}".to_owned(); + let constrained = ConstrainedString::try_from(string).unwrap(); + + assert_eq!(constrained.inner(), "${testCase.validString}"); + """, + ) + unitTest( + name = "into_inner", + test = """ + let string = "${testCase.validString}".to_owned(); + let constrained = ConstrainedString::try_from(string.clone()).unwrap(); + + assert_eq!(constrained.into_inner(), string); + """, + ) + } + + project.compileAndTest() + } + + @Test + fun `type should not be constructible without using a constructor`() { + val model = """ + namespace test + + @length(min: 1, max: 69) + string ConstrainedString + """.asSmithyModel() + val constrainedStringShape = model.lookup("test#ConstrainedString") + + val codegenContext = serverTestCodegenContext(model) + + val writer = RustWriter.forModule(ModelsModule.name) + + ConstrainedStringGenerator(codegenContext, writer, constrainedStringShape).render() + + // Check that the wrapped type is `pub(crate)`. + writer.toString() shouldContain "pub struct ConstrainedString(pub(crate) std::string::String);" + } + + @Test + fun `Display implementation`() { + val model = """ + namespace test + + @length(min: 1, max: 69) + string ConstrainedString + + @sensitive + @length(min: 1, max: 78) + string SensitiveConstrainedString + """.asSmithyModel() + val constrainedStringShape = model.lookup("test#ConstrainedString") + val sensitiveConstrainedStringShape = model.lookup("test#SensitiveConstrainedString") + + val codegenContext = serverTestCodegenContext(model) + + val project = TestWorkspace.testProject(codegenContext.symbolProvider) + + project.withModule(ModelsModule) { + ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() + ConstrainedStringGenerator(codegenContext, this, sensitiveConstrainedStringShape).render() + + unitTest( + name = "non_sensitive_string_display_implementation", + test = """ + let string = "a non-sensitive string".to_owned(); + let constrained = ConstrainedString::try_from(string).unwrap(); + assert_eq!(format!("{}", constrained), "a non-sensitive string") + """, + ) + + unitTest( + name = "sensitive_string_display_implementation", + test = """ + let string = "That is how heavy a secret can become. It can make blood flow easier than ink.".to_owned(); + let constrained = SensitiveConstrainedString::try_from(string).unwrap(); + assert_eq!(format!("{}", constrained), "*** Sensitive Data Redacted ***") + """, + ) + } + + project.compileAndTest() + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt index da3ea23d8..d414aa63f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGeneratorTest.kt @@ -8,15 +8,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider class ServerCombinedErrorGeneratorTest { @@ -55,16 +54,20 @@ class ServerCombinedErrorGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) project.withModule(RustModule.public("error")) { listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach { - model.lookup("error#$it").renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + model.lookup("error#$it").serverRenderWithModelBuilder(model, symbolProvider, this) } val errors = listOf("FooException", "ComplexError", "InvalidGreeting").map { model.lookup("error#$it") } - val generator = ServerCombinedErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(model.lookup("error#Greeting")), errors) - generator.render(this) + ServerCombinedErrorGenerator( + model, + symbolProvider, + symbolProvider.toSymbol(model.lookup("error#Greeting")), + errors, + ).render(this) unitTest( name = "generates_combined_error_enums", test = """ - let variant = InvalidGreeting::builder().message("an error").build(); + let variant = InvalidGreeting { message: String::from("an error") }; assert_eq!(format!("{}", variant), "InvalidGreeting: an error"); assert_eq!(variant.message(), "an error"); assert_eq!( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt index a07447a71..0e813cebb 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt @@ -9,12 +9,10 @@ import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerEnumGeneratorTest { private val model = """ @@ -36,30 +34,26 @@ class ServerEnumGeneratorTest { string InstanceType """.asSmithyModel() + private val codegenContext = serverTestCodegenContext(model) + private val writer = RustWriter.forModule("model") + private val shape = model.lookup("test#InstanceType") + @Test fun `it generates TryFrom, FromStr and errors for enums`() { - val provider = serverTestSymbolProvider(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#InstanceType") - val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig) - generator.render() + ServerEnumGenerator(codegenContext, writer, shape).render() writer.compileAndTest( """ use std::str::FromStr; assert_eq!(InstanceType::try_from("t2.nano").unwrap(), InstanceType::T2Nano); assert_eq!(InstanceType::from_str("t2.nano").unwrap(), InstanceType::T2Nano); - assert_eq!(InstanceType::try_from("unknown").unwrap_err(), InstanceTypeUnknownVariantError("unknown".to_string())); + assert_eq!(InstanceType::try_from("unknown").unwrap_err(), crate::model::instance_type::ConstraintViolation(String::from("unknown"))); """, ) } @Test fun `it generates enums without the unknown variant`() { - val provider = serverTestSymbolProvider(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#InstanceType") - val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig) - generator.render() + ServerEnumGenerator(codegenContext, writer, shape).render() writer.compileAndTest( """ // check no unknown @@ -74,11 +68,7 @@ class ServerEnumGeneratorTest { @Test fun `it generates enums without non_exhaustive`() { - val provider = serverTestSymbolProvider(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#InstanceType") - val generator = ServerEnumGenerator(model, provider, writer, shape, shape.expectTrait(), TestRuntimeConfig) - generator.render() + ServerEnumGenerator(codegenContext, writer, shape).render() writer.toString() shouldNotContain "#[non_exhaustive]" } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index 945dac55d..1bfbd26e2 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -13,18 +13,17 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerInstantiatorTest { @@ -140,9 +139,9 @@ class ServerInstantiatorTest { val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { - structure.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) - inner.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) - nestedStruct.renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER) + structure.serverRenderWithModelBuilder(model, symbolProvider, this) + inner.serverRenderWithModelBuilder(model, symbolProvider, this) + nestedStruct.serverRenderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, union).render() unitTest("server_instantiator_test") { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt new file mode 100644 index 000000000..42774274d --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class UnconstrainedCollectionGeneratorTest { + @Test + fun `it should generate unconstrained lists`() { + val model = + """ + namespace test + + list ListA { + member: ListB + } + + list ListB { + member: StructureC + } + + structure StructureC { + @required + int: Integer, + + @required + string: String + } + """.asSmithyModel() + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + + val listA = model.lookup("test#ListA") + val listB = model.lookup("test#ListB") + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) + } + + project.withModule(RustModule.private("constrained")) { + listOf(listA, listB).forEach { + PubCrateConstrainedCollectionGenerator(codegenContext, this, it).render() + } + } + project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(ModelsModule) modelsModuleWriter@{ + listOf(listA, listB).forEach { + UnconstrainedCollectionGenerator( + codegenContext, + this@unconstrainedModuleWriter, + this@modelsModuleWriter, + it, + ).render() + } + + this@unconstrainedModuleWriter.unitTest( + name = "list_a_unconstrained_fail_to_constrain_with_first_error", + test = """ + let c_builder1 = crate::model::StructureC::builder().int(69); + let c_builder2 = crate::model::StructureC::builder().string("david".to_owned()); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder1, c_builder2]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let expected_err = + crate::model::list_a::ConstraintViolation(0, crate::model::list_b::ConstraintViolation( + 0, crate::model::structure_c::ConstraintViolation::MissingString, + )); + + assert_eq!( + expected_err, + crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap_err() + ); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "list_a_unconstrained_succeed_to_constrain", + test = """ + let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let expected: Vec> = vec![vec![crate::model::StructureC { + string: "david".to_owned(), + int: 69 + }]]; + let actual: Vec> = + crate::constrained::list_a_constrained::ListAConstrained::try_from(list_a_unconstrained).unwrap().into(); + + assert_eq!(expected, actual); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "list_a_unconstrained_converts_into_constrained", + test = """ + let c_builder = crate::model::StructureC::builder(); + let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]); + let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]); + + let _list_a: crate::constrained::MaybeConstrained = list_a_unconstrained.into(); + """, + ) + project.compileAndTest() + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt new file mode 100644 index 000000000..a5877b7c0 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt @@ -0,0 +1,164 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class UnconstrainedMapGeneratorTest { + @Test + fun `it should generate unconstrained maps`() { + val model = + """ + namespace test + + map MapA { + key: String, + value: MapB + } + + map MapB { + key: String, + value: StructureC + } + + structure StructureC { + @required + int: Integer, + + @required + string: String + } + """.asSmithyModel() + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + + val mapA = model.lookup("test#MapA") + val mapB = model.lookup("test#MapB") + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) + } + + project.withModule(RustModule.private("constrained")) { + listOf(mapA, mapB).forEach { + PubCrateConstrainedMapGenerator(codegenContext, this, it).render() + } + } + project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(ModelsModule) modelsModuleWriter@{ + listOf(mapA, mapB).forEach { + UnconstrainedMapGenerator(codegenContext, this@unconstrainedModuleWriter, it).render() + + MapConstraintViolationGenerator(codegenContext, this@modelsModuleWriter, it).render() + } + + this@unconstrainedModuleWriter.unitTest( + name = "map_a_unconstrained_fail_to_constrain_with_some_error", + test = """ + let c_builder1 = crate::model::StructureC::builder().int(69); + let c_builder2 = crate::model::StructureC::builder().string(String::from("david")); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB1"), c_builder1), + (String::from("KeyB2"), c_builder2), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + // Any of these two errors could be returned; it depends on the order in which the maps are visited. + let missing_string_expected_err = crate::model::map_a::ConstraintViolation::Value( + "KeyA".to_owned(), + crate::model::map_b::ConstraintViolation::Value( + "KeyB1".to_owned(), + crate::model::structure_c::ConstraintViolation::MissingString, + ) + ); + let missing_int_expected_err = crate::model::map_a::ConstraintViolation::Value( + "KeyA".to_owned(), + crate::model::map_b::ConstraintViolation::Value( + "KeyB2".to_owned(), + crate::model::structure_c::ConstraintViolation::MissingInt, + ) + ); + + let actual_err = crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap_err(); + + assert!(actual_err == missing_string_expected_err || actual_err == missing_int_expected_err); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "map_a_unconstrained_succeed_to_constrain", + test = """ + let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david")); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB"), c_builder), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + let expected = std::collections::HashMap::from([ + (String::from("KeyA"), std::collections::HashMap::from([ + (String::from("KeyB"), crate::model::StructureC { + int: 69, + string: String::from("david") + }), + ])) + ]); + + assert_eq!( + expected, + crate::constrained::map_a_constrained::MapAConstrained::try_from(map_a_unconstrained).unwrap().into() + ); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "map_a_unconstrained_converts_into_constrained", + test = """ + let c_builder = crate::model::StructureC::builder(); + let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyB"), c_builder), + ]) + ); + let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained( + std::collections::HashMap::from([ + (String::from("KeyA"), map_b_unconstrained), + ]) + ); + + let _map_a: crate::constrained::MaybeConstrained = map_a_unconstrained.into(); + """, + ) + + project.compileAndTest() + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt new file mode 100644 index 000000000..f31285de9 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +class UnconstrainedUnionGeneratorTest { + @Test + fun `it should generate unconstrained unions`() { + val model = + """ + namespace test + + union Union { + structure: Structure + } + + structure Structure { + @required + requiredMember: String + } + """.asSmithyModel() + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + + val unionShape = model.lookup("test#Union") + + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + model.lookup("test#Structure").serverRenderWithModelBuilder(model, symbolProvider, this) + } + + project.withModule(ModelsModule) { + UnionGenerator(model, symbolProvider, this, unionShape, renderUnknownVariant = false).render() + } + project.withModule(RustModule.private("unconstrained")) unconstrainedModuleWriter@{ + project.withModule(ModelsModule) modelsModuleWriter@{ + UnconstrainedUnionGenerator(codegenContext, this@unconstrainedModuleWriter, this@modelsModuleWriter, unionShape).render() + + this@unconstrainedModuleWriter.unitTest( + name = "unconstrained_union_fail_to_constrain", + test = """ + let builder = crate::model::Structure::builder(); + let union_unconstrained = union_unconstrained::UnionUnconstrained::Structure(builder); + + let expected_err = crate::model::union::ConstraintViolation::Structure( + crate::model::structure::ConstraintViolation::MissingRequiredMember, + ); + + assert_eq!( + expected_err, + crate::model::Union::try_from(union_unconstrained).unwrap_err() + ); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "unconstrained_union_succeed_to_constrain", + test = """ + let builder = crate::model::Structure::builder().required_member(String::from("david")); + let union_unconstrained = union_unconstrained::UnionUnconstrained::Structure(builder); + + let expected: crate::model::Union = crate::model::Union::Structure(crate::model::Structure { + required_member: String::from("david"), + }); + let actual: crate::model::Union = crate::model::Union::try_from(union_unconstrained).unwrap(); + + assert_eq!(expected, actual); + """, + ) + + this@unconstrainedModuleWriter.unitTest( + name = "unconstrained_union_converts_into_constrained", + test = """ + let builder = crate::model::Structure::builder(); + let union_unconstrained = union_unconstrained::UnionUnconstrained::Structure(builder); + + let _union: crate::constrained::MaybeConstrained = + union_unconstrained.into(); + """, + ) + project.compileAndTest() + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt index 3a917af0d..52d312ffa 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/EventStreamTestTools.kt @@ -313,7 +313,11 @@ object EventStreamTestModels { """.trimIndent(), ) { Ec2QueryProtocol(it) }, - ).flatMap { listOf(it, it.copy(target = CodegenTarget.SERVER)) } + ) + // TODO(https://github.com/awslabs/smithy-rs/issues/1442) Server tests + // should be run from the server subproject using the + // `serverTestSymbolProvider()`. + // .flatMap { listOf(it, it.copy(target = CodegenTarget.SERVER)) } class UnmarshallTestCasesProvider : ArgumentsProvider { override fun provideArguments(context: ExtensionContext?): Stream = diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt index 0b219a569..5094426d7 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt @@ -7,12 +7,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.parse import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings import software.amazon.smithy.rust.codegen.core.testutil.unitTest @@ -34,14 +36,13 @@ class EventStreamUnmarshallerGeneratorTest { target = testCase.target, ) val protocol = testCase.protocolBuilder(codegenContext) + fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) val generator = EventStreamUnmarshallerGenerator( protocol, - test.model, - TestRuntimeConfig, - test.symbolProvider, + codegenContext, test.operationShape, test.streamShape, - target = testCase.target, + ::builderSymbol, ) test.project.lib { diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 10888a09e..46d5c5223 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -34,7 +34,8 @@ once_cell = "1.13" regex = "1.5.5" serde_urlencoded = "0.7" strum_macros = "0.24" -thiserror = "1" +# TODO Investigate. +thiserror = "<=1.0.36" tracing = "0.1.35" tokio = { version = "1.8.4", features = ["full"] } tower = { version = "0.4.11", features = ["util", "make"], default-features = false } diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs index 5804a3fb7..318dac909 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs @@ -214,7 +214,7 @@ pub async fn capture_pokemon( ) -> Result { if input.region != "Kanto" { return Err(error::CapturePokemonError::UnsupportedRegionError( - error::UnsupportedRegionError::builder().build(), + error::UnsupportedRegionError { region: input.region }, )); } let output_stream = stream! { @@ -230,7 +230,9 @@ pub async fn capture_pokemon( if ! matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { yield Err( crate::error::CapturePokemonEventsError::InvalidPokeballError( - crate::error::InvalidPokeballError::builder().pokeball(pokeball).build() + crate::error::InvalidPokeballError { + pokeball: pokeball.to_owned() + } ) ); } else { @@ -253,11 +255,12 @@ pub async fn capture_pokemon( .to_string(); let pokedex: Vec = (0..255).collect(); yield Ok(crate::model::CapturePokemonEvents::Event( - crate::model::CaptureEvent::builder() - .name(pokemon) - .shiny(shiny) - .pokedex_update(Blob::new(pokedex)) - .build(), + crate::model::CaptureEvent { + name: Some(pokemon), + shiny: Some(shiny), + pokedex_update: Some(Blob::new(pokedex)), + captured: Some(true), + } )); } } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 0d6e656db..65d4497ff 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -179,15 +179,12 @@ pub enum RequestRejection { FloatParse(crate::Error), BoolParse(crate::Error), - // TODO(https://github.com/awslabs/smithy-rs/issues/1243): In theory, we could get rid of this - // error, but it would be a lot of effort for comparatively low benefit. - /// Used when consuming the input struct builder. - Build(crate::Error), - - /// Used by the server when the enum variant sent by a client is not known. - /// Unlike the rejections above, the inner type is code generated, - /// with each enum having its own generated error type. - EnumVariantNotFound(Box), + /// Used when consuming the input struct builder, and constraint violations occur. + // Unlike the rejections above, this does not take in `crate::Error`, since it is constructed + // directly in the code-generated SDK instead of in this crate. + // TODO(https://github.com/awslabs/smithy-rs/issues/1703): this will hold a type that can be + // rendered into a protocol-specific response later on. + ConstraintViolation(String), } #[derive(Debug, Display)] @@ -237,7 +234,6 @@ impl From for RequestRejection { convert_to_request_rejection!(aws_smithy_json::deserialize::Error, JsonDeserialize); convert_to_request_rejection!(aws_smithy_xml::decode::XmlError, XmlDeserialize); -convert_to_request_rejection!(aws_smithy_http::operation::BuildError, Build); convert_to_request_rejection!(aws_smithy_http::header::ParseError, HeaderParse); convert_to_request_rejection!(aws_smithy_types::date_time::DateTimeParseError, DateTimeParse); convert_to_request_rejection!(aws_smithy_types::primitive::PrimitiveParseError, PrimitiveParse); diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index e389240f8..7503af192 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -21,14 +21,13 @@ //! and converts into the corresponding `RuntimeError`, and then it uses the its //! [`RuntimeError::into_response`] method to render and send a response. -use http::StatusCode; - use crate::extension::RuntimeErrorExtension; use crate::proto::aws_json_10::AwsJson1_0; use crate::proto::aws_json_11::AwsJson1_1; use crate::proto::rest_json_1::RestJson1; use crate::proto::rest_xml::RestXml; use crate::response::IntoResponse; +use http::StatusCode; #[derive(Debug)] pub enum RuntimeError { @@ -40,6 +39,10 @@ pub enum RuntimeError { // TODO(https://github.com/awslabs/smithy-rs/issues/1663) NotAcceptable, UnsupportedMediaType, + + // TODO(https://github.com/awslabs/smithy-rs/issues/1703): this will hold a type that can be + // rendered into a protocol-specific response later on. + Validation(String), } /// String representation of the runtime error type. @@ -52,6 +55,7 @@ impl RuntimeError { Self::InternalFailure(_) => "InternalFailureException", Self::NotAcceptable => "NotAcceptableException", Self::UnsupportedMediaType => "UnsupportedMediaTypeException", + Self::Validation(_) => "ValidationException", } } @@ -61,6 +65,7 @@ impl RuntimeError { Self::InternalFailure(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::NotAcceptable => StatusCode::NOT_ACCEPTABLE, Self::UnsupportedMediaType => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::Validation(_) => StatusCode::BAD_REQUEST, } } } @@ -93,48 +98,78 @@ impl IntoResponse for InternalFailureException { impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/json") .header("X-Amzn-Errortype", self.name()) - .extension(RuntimeErrorExtension::new(self.name().to_string())) - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization - .body(crate::body::to_boxed("{}")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + _ => crate::body::to_boxed("{}"), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/xml") - .extension(RuntimeErrorExtension::new(self.name().to_string())) - .body(crate::body::to_boxed("")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + // TODO(https://github.com/awslabs/smithy/issues/1446) The Smithy spec does not yet + // define constraint violation HTTP body responses for RestXml. + RuntimeError::Validation(_reason) => todo!("https://github.com/awslabs/smithy/issues/1446"), + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + _ => crate::body::to_boxed("{}"), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/x-amz-json-1.0") - .extension(RuntimeErrorExtension::new(self.name().to_string())) - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization - .body(crate::body::to_boxed("")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + // See https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html#empty-body-serialization + _ => crate::body::to_boxed("{}"), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } impl IntoResponse for RuntimeError { fn into_response(self) -> http::Response { - http::Response::builder() + let res = http::Response::builder() .status(self.status_code()) .header("Content-Type", "application/x-amz-json-1.1") - .extension(RuntimeErrorExtension::new(self.name().to_string())) - // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization - .body(crate::body::to_boxed("")) + .extension(RuntimeErrorExtension::new(self.name().to_string())); + + let body = match self { + RuntimeError::Validation(reason) => crate::body::to_boxed(reason), + // https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_1-protocol.html#empty-body-serialization + _ => crate::body::to_boxed(""), + }; + + res + .body(body) .expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } @@ -155,6 +190,7 @@ impl From for RuntimeError { fn from(err: crate::rejection::RequestRejection) -> Self { match err { crate::rejection::RequestRejection::MissingContentType(_reason) => Self::UnsupportedMediaType, + crate::rejection::RequestRejection::ConstraintViolation(reason) => Self::Validation(reason), _ => Self::Serialization(crate::Error::new(err)), } } diff --git a/rust-runtime/inlineable/src/constrained.rs b/rust-runtime/inlineable/src/constrained.rs new file mode 100644 index 000000000..1276eccbc --- /dev/null +++ b/rust-runtime/inlineable/src/constrained.rs @@ -0,0 +1,15 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub(crate) trait Constrained { + type Unconstrained; +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub(crate) enum MaybeConstrained { + Constrained(T), + Unconstrained(T::Unconstrained), +} diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index 3cbcd5e5f..2c2634110 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#[allow(unused)] +mod constrained; #[allow(dead_code)] mod ec2_query_errors; #[allow(dead_code)] -- GitLab