diff --git a/CHANGELOG.md b/CHANGELOG.md index eefe506129abf33bb413b39a33f61d35fc209860..1dc8ecb340b6002b9a092bfe273d037fe6803aef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ vNext (Month Day, Year) ======================= **Breaking changes** - :warning: MSRV increased from 1.52.1 to 1.53.0 per our 3-behind MSRV policy. +- :warning: Several classes in the codegen module were renamed and/or refactored (smithy-rs#735): + - `ProtocolConfig` became `CodegenContext` and moved to `software.amazon.smithy.rust.codegen.smithy` + - `HttpProtocolGenerator` became `ProtocolGenerator` and was refactored + to rely on composition instead of inheritance + - `HttpProtocolTestGenerator` became `ProtocolTestGenerator` + - `Protocol` moved into `software.amazon.smithy.rust.codegen.smithy.protocols` **New this week** - :bug: Fix an issue where `smithy-xml` may have generated invalid XML (smithy-rs#719) diff --git a/aws/SDK_CHANGELOG.md b/aws/SDK_CHANGELOG.md index 651cbef869a7f6d3f5d342718d41cabfdf8b9ac9..e44d96685607bca7676abdaed9cabaa8c9419186 100644 --- a/aws/SDK_CHANGELOG.md +++ b/aws/SDK_CHANGELOG.md @@ -1,5 +1,6 @@ vNext (Month Day, Year) ======================= + **Breaking changes** - :warning: MSRV increased from 1.52.1 to 1.53.0 per our 3-behind MSRV policy. @@ -8,7 +9,8 @@ vNext (Month Day, Year) **New This Week** -- :tada: Add presigned request support and examples for S3 GetObject and PutObject (smithy-rs#731) +- :tada: Add presigned request support and examples for S3 GetObject and PutObject (smithy-rs#731, aws-sdk-rust#139) +- :tada: Add presigned request support and example for Polly SynthesizeSpeech (smithy-rs#735, aws-sdk-rust#139) - :bug: Fix error when receiving `Cont` event from S3 SelectObjectContent (smithy-rs#736) - :bug: Fix bug in event stream receiver that could cause the last events in the response stream to be lost when using S3 SelectObjectContent (smithy-rs#736) - Updated Transcribe code example to take an audio file as a command-line option and added readme. diff --git a/aws/rust-runtime/aws-sig-auth/src/signer.rs b/aws/rust-runtime/aws-sig-auth/src/signer.rs index 1ea3bc7b19c84bec33f9529c5d488d65841e8118..651bcf2a7079d369e4e6e576fba2cbbc49cf5f50 100644 --- a/aws/rust-runtime/aws-sig-auth/src/signer.rs +++ b/aws/rust-runtime/aws-sig-auth/src/signer.rs @@ -188,25 +188,20 @@ impl SigV4Signer { let signing_params = Self::signing_params(settings, credentials, request_config); let (signing_instructions, signature) = { - // A body that is already in memory can be signed directly. A body that is not in memory + // A body that is already in memory can be signed directly. A body that is not in memory // (any sort of streaming body or presigned request) will be signed via UNSIGNED-PAYLOAD. - let signable_body = - if operation_config.signature_type == HttpSignatureType::HttpRequestQueryParams { - SignableBody::UnsignedPayload - } else { - request_config - .payload_override - // the payload_override is a cheap clone because it contains either a - // reference or a short checksum (we're not cloning the entire body) - .cloned() - .unwrap_or_else(|| { - request - .body() - .bytes() - .map(SignableBody::Bytes) - .unwrap_or(SignableBody::UnsignedPayload) - }) - }; + let signable_body = request_config + .payload_override + // the payload_override is a cheap clone because it contains either a + // reference or a short checksum (we're not cloning the entire body) + .cloned() + .unwrap_or_else(|| { + request + .body() + .bytes() + .map(SignableBody::Bytes) + .unwrap_or(SignableBody::UnsignedPayload) + }); let signable_request = SignableRequest::new( request.method(), diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt index 55f91616c0f55c9fb37888ecaa38f83e773c1915..a5eeac83f4e77560152caa800f1352002c5cbc29 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization @@ -29,7 +30,6 @@ import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.util.dq @@ -46,31 +46,31 @@ class AwsEndpointDecorator : RustCodegenDecorator { } override fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + EndpointConfigCustomization(protocolConfig, endpoints) + return baseCustomizations + EndpointConfigCustomization(codegenContext, endpoints) } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations + EndpointResolverFeature(protocolConfig.runtimeConfig, operation) + return baseCustomizations + EndpointResolverFeature(codegenContext.runtimeConfig, operation) } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + PubUseEndpoint(protocolConfig.runtimeConfig) + return baseCustomizations + PubUseEndpoint(codegenContext.runtimeConfig) } } -class EndpointConfigCustomization(private val protocolConfig: ProtocolConfig, private val endpointData: ObjectNode) : +class EndpointConfigCustomization(private val codegenContext: CodegenContext, private val endpointData: ObjectNode) : ConfigCustomization() { - private val runtimeConfig = protocolConfig.runtimeConfig + private val runtimeConfig = codegenContext.runtimeConfig private val resolveAwsEndpoint = runtimeConfig.awsEndpointDependency().asType().copy(name = "ResolveAwsEndpoint") override fun section(section: ServiceConfig): Writable = writable { when (section) { @@ -92,7 +92,7 @@ class EndpointConfigCustomization(private val protocolConfig: ProtocolConfig, pr resolveAwsEndpoint ) ServiceConfig.BuilderBuild -> { - val resolverGenerator = EndpointResolverGenerator(protocolConfig, endpointData) + val resolverGenerator = EndpointResolverGenerator(codegenContext, endpointData) rust( """endpoint_resolver: self.endpoint_resolver.unwrap_or_else(|| ::std::sync::Arc::new( @@ -140,9 +140,9 @@ class PubUseEndpoint(private val runtimeConfig: RuntimeConfig) : LibRsCustomizat } } -class EndpointResolverGenerator(protocolConfig: ProtocolConfig, private val endpointData: ObjectNode) { - private val runtimeConfig = protocolConfig.runtimeConfig - private val endpointPrefix = protocolConfig.serviceShape.expectTrait().endpointPrefix +class EndpointResolverGenerator(codegenContext: CodegenContext, private val endpointData: ObjectNode) { + private val runtimeConfig = codegenContext.runtimeConfig + private val endpointPrefix = codegenContext.serviceShape.expectTrait().endpointPrefix private val awsEndpoint = runtimeConfig.awsEndpointDependency().asType() private val codegenScope = arrayOf( diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index 3c7983c52f684d28065970371ac469b9757148de..d979bf7cc75d9303a2145a6e802783f8d37196a8 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustCrate @@ -24,7 +25,6 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ClientGenerics import software.amazon.smithy.rust.codegen.smithy.generators.FluentClientGenerator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig private class Types(runtimeConfig: RuntimeConfig) { private val smithyClientDep = CargoDependency.SmithyClient(runtimeConfig).copy(optional = true) @@ -44,12 +44,12 @@ class AwsFluentClientDecorator : RustCodegenDecorator { // Must run after the AwsPresigningDecorator so that the presignable trait is correctly added to operations override val order: Byte = (AwsPresigningDecorator.ORDER + 1).toByte() - override fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) { - val types = Types(protocolConfig.runtimeConfig) + override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { + val types = Types(codegenContext.runtimeConfig) val module = RustMetadata(additionalAttributes = listOf(Attribute.Cfg.feature("client")), public = true) rustCrate.withModule(RustModule("client", module)) { writer -> FluentClientGenerator( - protocolConfig, + codegenContext, includeSmithyGenericClientDocs = false, generics = ClientGenerics( connectorDefault = "#{AwsFluentClient_DynConnector}", @@ -61,7 +61,7 @@ class AwsFluentClientDecorator : RustCodegenDecorator { "AwsFluentClient_retry" to types.smithyClientRetry, ) ), - customizations = listOf(AwsPresignedFluentBuilderMethod(protocolConfig.runtimeConfig)) + customizations = listOf(AwsPresignedFluentBuilderMethod(codegenContext.runtimeConfig)) ).render(writer) AwsFluentClientExtensions(types).render(writer) } @@ -72,7 +72,7 @@ class AwsFluentClientDecorator : RustCodegenDecorator { } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { return baseCustomizations + object : LibRsCustomization() { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt index 59f494f98ff04d9cc57c6d25e8cfe966e22bcdfd..e06fc6ff1361360c975561db9b22749fcd8e5e17 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt @@ -6,37 +6,82 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.HttpBinding +import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.HttpQueryTrait +import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.Feature import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RustCrate -import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.FluentClientCustomization import software.amazon.smithy.rust.codegen.smithy.generators.FluentClientSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.MakeOperationGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator +import software.amazon.smithy.rust.codegen.util.cloneOperation +import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rustsdk.traits.PresignableTrait +import kotlin.streams.toList -private val PRESIGNABLE_OPERATIONS = listOf( - ShapeId.from("com.amazonaws.s3#GetObject"), - ShapeId.from("com.amazonaws.s3#PutObject"), -) +internal enum class PayloadSigningType { + EMPTY, + UNSIGNED_PAYLOAD, +} + +private fun syntheticShapeId(shape: ToShapeId): ShapeId = + shape.toShapeId().let { id -> ShapeId.fromParts(id.namespace + ".synthetic.aws.presigned", id.name) } + +internal class PresignableOperation( + val payloadSigningType: PayloadSigningType, + val modelTransforms: List = emptyList() +) { + fun hasModelTransforms(): Boolean = modelTransforms.isNotEmpty() +} + +private val SYNTHESIZE_SPEECH_OP = ShapeId.from("com.amazonaws.polly#SynthesizeSpeech") +internal val PRESIGNABLE_OPERATIONS by lazy { + mapOf( + // S3 + ShapeId.from("com.amazonaws.s3#GetObject") to PresignableOperation(PayloadSigningType.UNSIGNED_PAYLOAD), + ShapeId.from("com.amazonaws.s3#PutObject") to PresignableOperation(PayloadSigningType.UNSIGNED_PAYLOAD), + + // Polly + SYNTHESIZE_SPEECH_OP to PresignableOperation( + PayloadSigningType.EMPTY, + // Polly's SynthesizeSpeech operation has the HTTP method overridden to GET, + // and the document members changed to query param members. + modelTransforms = listOf( + OverrideHttpMethodTransform(mapOf(SYNTHESIZE_SPEECH_OP to "GET")), + MoveDocumentMembersToQueryParamsTransform(listOf(SYNTHESIZE_SPEECH_OP)), + ) + ), + ) +} -class AwsPresigningDecorator : RustCodegenDecorator { +class AwsPresigningDecorator internal constructor( + private val presignableOperations: Map = PRESIGNABLE_OPERATIONS +) : RustCodegenDecorator { companion object { const val ORDER: Byte = 0 } @@ -44,9 +89,9 @@ class AwsPresigningDecorator : RustCodegenDecorator { override val name: String = "AwsPresigning" override val order: Byte = ORDER - override fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) { - val hasPresignedOps = protocolConfig.model.shapes().anyMatch { shape -> - shape is OperationShape && PRESIGNABLE_OPERATIONS.contains(shape.id) + override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { + val hasPresignedOps = codegenContext.model.shapes().anyMatch { shape -> + shape is OperationShape && presignableOperations.containsKey(shape.id) } if (hasPresignedOps) { rustCrate.mergeFeature(Feature("client", default = true, listOf("tower"))) @@ -54,30 +99,48 @@ class AwsPresigningDecorator : RustCodegenDecorator { } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List - ): List = baseCustomizations + listOf( - AwsInputPresignedMethod(protocolConfig.runtimeConfig, protocolConfig.symbolProvider, operation) - ) + ): List = baseCustomizations + listOf(AwsInputPresignedMethod(codegenContext, operation)) - /** Adds presignable trait to known presignable operations */ + /** + * Adds presignable trait to known presignable operations and creates synthetic presignable shapes for codegen + */ override fun transformModel(service: ServiceShape, model: Model): Model { - return ModelTransformer.create().mapShapes(model) { shape -> - if (shape is OperationShape && PRESIGNABLE_OPERATIONS.contains(shape.id)) { - shape.toBuilder().addTrait(PresignableTrait()).build() + val modelWithSynthetics = addSyntheticOperations(model) + val presignableTransforms = mutableListOf() + val intermediate = ModelTransformer.create().mapShapes(modelWithSynthetics) { shape -> + if (shape is OperationShape && presignableOperations.containsKey(shape.id)) { + presignableTransforms.addAll(presignableOperations.getValue(shape.id).modelTransforms) + shape.toBuilder().addTrait(PresignableTrait(syntheticShapeId(shape))).build() } else { shape } } + // Apply operation-specific model transformations + return presignableTransforms.fold(intermediate) { m, t -> t.transform(m) } + } + + private fun addSyntheticOperations(model: Model): Model { + val presignableOps = model.shapes() + .filter { shape -> shape is OperationShape && presignableOperations.containsKey(shape.id) } + .toList() + return model.toBuilder().also { builder -> + for (op in presignableOps) { + builder.cloneOperation(model, op, ::syntheticShapeId) + } + }.build() } } class AwsInputPresignedMethod( - runtimeConfig: RuntimeConfig, - private val symbolProvider: RustSymbolProvider, + private val codegenContext: CodegenContext, private val operationShape: OperationShape ) : OperationCustomization() { + private val runtimeConfig = codegenContext.runtimeConfig + private val symbolProvider = codegenContext.symbolProvider + private val codegenScope = arrayOf( "aws_hyper" to runtimeConfig.awsRuntimeDependency("aws-hyper").copy(optional = true).asType(), "Error" to AwsRuntimeType.Presigning.member("config::Error"), @@ -85,22 +148,42 @@ class AwsInputPresignedMethod( "PresignedRequestService" to AwsRuntimeType.Presigning.member("service::PresignedRequestService"), "PresigningConfig" to AwsRuntimeType.Presigning.member("config::PresigningConfig"), "SdkError" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("result::SdkError"), + "aws_sigv4" to runtimeConfig.awsRuntimeDependency("aws-sigv4").asType(), "sig_auth" to runtimeConfig.sigAuth().asType(), "tower" to CargoDependency.Tower.asType(), ) override fun section(section: OperationSection): Writable = writable { if (section is OperationSection.InputImpl && section.operationShape.hasTrait()) { - writeInputPresignedMethod() + writeInputPresignedMethod(section) } } - private fun RustWriter.writeInputPresignedMethod() { + private fun RustWriter.writeInputPresignedMethod(section: OperationSection.InputImpl) { val operationError = operationShape.errorSymbol(symbolProvider) + val presignableOp = PRESIGNABLE_OPERATIONS.getValue(operationShape.id) + + var makeOperationFn = "make_operation" + if (presignableOp.hasModelTransforms()) { + makeOperationFn = "_make_presigned_operation" + + val syntheticOp = + codegenContext.model.expectShape(syntheticShapeId(operationShape.id), OperationShape::class.java) + val protocol = section.protocol + MakeOperationGenerator( + codegenContext, + protocol, + HttpBoundProtocolBodyGenerator(codegenContext, protocol), + // Prefixed with underscore to avoid colliding with modeled functions + functionName = makeOperationFn, + public = false, + ).generateMakeOperation(this, syntheticOp, section.customizations) + } + rustBlockTemplate( """ /// Creates a presigned request for this operation. The credentials provider from the `config` - /// will be used to generate the request's signature, and the `presignining_config` provides additional + /// will be used to generate the request's signature, and the `presigning_config` provides additional /// presigning-specific config values, such as the amount of time the request should be valid for after /// creation. /// @@ -118,21 +201,42 @@ class AwsInputPresignedMethod( ) { rustTemplate( """ - let (mut request, _) = self.make_operation(config) + let (mut request, _) = self.$makeOperationFn(config) .map_err(|err| #{SdkError}::ConstructionFailure(err.into()))? .into_request_response(); - - // Change signature type to query params and wire up presigning config - { + """, + *codegenScope + ) + rustBlock("") { + rust( + """ + // Change signature type to query params and wire up presigning config let mut props = request.properties_mut(); props.insert(presigning_config.start_time()); - + """ + ) + withBlock("props.insert(", ");") { + rustTemplate( + "#{aws_sigv4}::http_request::SignableBody::" + + when (presignableOp.payloadSigningType) { + PayloadSigningType.EMPTY -> "Bytes(b\"\")" + PayloadSigningType.UNSIGNED_PAYLOAD -> "UnsignedPayload" + }, + *codegenScope + ) + } + rustTemplate( + """ let mut config = props.get_mut::<#{sig_auth}::signer::OperationSigningConfig>() .expect("signing config added by make_operation()"); config.signature_type = #{sig_auth}::signer::HttpSignatureType::HttpRequestQueryParams; config.expires_in = Some(presigning_config.expires()); - } - + """, + *codegenScope + ) + } + rustTemplate( + """ let middleware = #{aws_hyper}::AwsMiddleware::default(); let mut svc = #{tower}::builder::ServiceBuilder::new() .layer(&middleware) @@ -180,3 +284,73 @@ class AwsPresignedFluentBuilderMethod( } } } + +interface PresignModelTransform { + fun transform(model: Model): Model +} + +/** + * Model transform that overrides HTTP request methods for the given map of operations. + * + * Note: this doesn't work for non-REST protocols. The protocol generators will need to be refactored + * to respect HTTP traits or synthetic equivalents if this is needed for AwsQuery, Ec2Query, or AwsJson. + */ +class OverrideHttpMethodTransform( + httpMethodOverrides: Map, +) : PresignModelTransform { + private val overrides = httpMethodOverrides.mapKeys { entry -> syntheticShapeId(entry.key) } + + override fun transform(model: Model): Model { + return ModelTransformer.create().mapShapes(model) { shape -> + if (shape is OperationShape && overrides.containsKey(shape.id)) { + val newMethod = overrides.getValue(shape.id) + check(shape.hasTrait(HttpTrait.ID)) { + "OverrideHttpMethodTransform can only be used with REST protocols" + } + val originalHttpTrait = shape.expectTrait() + shape.toBuilder() + .removeTrait(HttpTrait.ID) + .addTrait(originalHttpTrait.toBuilder().method(newMethod).build()) + .build() + } else { + shape + } + } + } +} + +/** + * Model transform that moves document members into query parameters for the given list of operations. + * + * Note: this doesn't work for non-REST protocols. The protocol generators will need to be refactored + * to respect HTTP traits or synthetic equivalents if this is needed for AwsQuery, Ec2Query, or AwsJson. + */ +class MoveDocumentMembersToQueryParamsTransform( + private val presignableOperations: List, +) : PresignModelTransform { + override fun transform(model: Model): Model { + val index = HttpBindingIndex(model) + val operations = presignableOperations.map { id -> + model.expectShape(syntheticShapeId(id), OperationShape::class.java).also { shape -> + check(shape.hasTrait(HttpTrait.ID)) { + "MoveDocumentMembersToQueryParamsTransform can only be used with REST protocols" + } + } + } + + // Find document members of the presignable operations + val membersToUpdate = operations.map { operation -> + val payloadBindings = index.getRequestBindings(operation, HttpBinding.Location.DOCUMENT) + payloadBindings.map { binding -> binding.member } + }.flatten() + + // Transform found shapes for presigning + return ModelTransformer.create().mapShapes(model) { shape -> + if (shape is MemberShape && membersToUpdate.contains(shape)) { + shape.toBuilder().addTrait(HttpQueryTrait(shape.memberName)).build() + } else { + shape + } + } + } +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt index bde07bd8783c5ee6fe9b3573bdb87a08c8d535e9..191e3b60493f05e800edf9556147a91af515758f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt @@ -6,16 +6,16 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.rust.codegen.rustlang.raw +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig class CrateLicenseDecorator : RustCodegenDecorator { override val name: String = "CrateLicense" override val order: Byte = 0 - override fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) { + override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { rustCrate.withFile("LICENSE") { val license = this::class.java.getResource("/LICENSE").readText() it.raw(license) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt index 489b153afeed843f9b4dac99f3e58f3f58e46c1f..a1c58994cf9ac33ecbd2e463498ae235ec47da72 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization @@ -18,7 +19,6 @@ import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig @@ -27,25 +27,25 @@ class CredentialsProviderDecorator : RustCodegenDecorator { override val order: Byte = 0 override fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + CredentialProviderConfig(protocolConfig.runtimeConfig) + return baseCustomizations + CredentialProviderConfig(codegenContext.runtimeConfig) } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations + CredentialsProviderFeature(protocolConfig.runtimeConfig) + return baseCustomizations + CredentialsProviderFeature(codegenContext.runtimeConfig) } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + PubUseCredentials(protocolConfig.runtimeConfig) + return baseCustomizations + PubUseCredentials(codegenContext.runtimeConfig) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index 4766c9d848929d8bd6ea184fd4212521aee68545..e44c6d7fe9d130db2bef4b2e2eff29ea87beb99a 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -10,11 +10,11 @@ import software.amazon.smithy.rust.codegen.rustlang.CratesIo import software.amazon.smithy.rust.codegen.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import java.nio.file.Files import java.nio.file.Paths @@ -23,7 +23,7 @@ class IntegrationTestDecorator : RustCodegenDecorator { override val order: Byte = 0 override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { val integrationTestPath = Paths.get("aws/sdk/integration-tests") @@ -31,14 +31,14 @@ class IntegrationTestDecorator : RustCodegenDecorator { "IntegrationTestDecorator expects to be run from the smithy-rs package root" } - val moduleName = protocolConfig.moduleName.substring("aws-sdk-".length) + val moduleName = codegenContext.moduleName.substring("aws-sdk-".length) val testPackagePath = integrationTestPath.resolve(moduleName) return if (Files.exists(testPackagePath) && Files.exists(testPackagePath.resolve("Cargo.toml"))) { val hasTests = Files.exists(testPackagePath.resolve("tests")) val hasBenches = Files.exists(testPackagePath.resolve("benches")) baseCustomizations + IntegrationTestDependencies( moduleName, - protocolConfig.runtimeConfig, + codegenContext.runtimeConfig, hasTests, hasBenches ) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt index 725605974601972348f83d6c1ba9f190f36175ef..854f5019382945c76c8eb32a45e52d3c2cab8366 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization @@ -17,7 +18,6 @@ import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig @@ -52,14 +52,14 @@ class RegionDecorator : RustCodegenDecorator { override val order: Byte = 0 override fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + RegionProviderConfig(protocolConfig.runtimeConfig) + return baseCustomizations + RegionProviderConfig(codegenContext.runtimeConfig) } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { @@ -67,10 +67,10 @@ class RegionDecorator : RustCodegenDecorator { } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + PubUseRegion(protocolConfig.runtimeConfig) + return baseCustomizations + PubUseRegion(codegenContext.runtimeConfig) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryPolicyDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryPolicyDecorator.kt index e82c10ac766d64b133f0879db1fbe5f1e6da6af4..1a663b0ebd5a3af0d2c19a0b3cf231fafdcd790b 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryPolicyDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryPolicyDecorator.kt @@ -9,23 +9,23 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig class RetryPolicyDecorator : RustCodegenDecorator { override val name: String = "RetryPolicy" override val order: Byte = 0 override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations + RetryPolicyFeature(protocolConfig.runtimeConfig) + return baseCustomizations + RetryPolicyFeature(codegenContext.runtimeConfig) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SharedConfigDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SharedConfigDecorator.kt index 5f28beefc79154fe981e8948c552c07964912f99..37d6d4310554b51f9c7a5f2566105b01d2a70174 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SharedConfigDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SharedConfigDecorator.kt @@ -10,10 +10,10 @@ import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig @@ -28,15 +28,15 @@ class SharedConfigDecorator : RustCodegenDecorator { override val order: Byte = 0 override fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + NewFromShared(protocolConfig.runtimeConfig) + return baseCustomizations + NewFromShared(codegenContext.runtimeConfig) } - override fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) { + override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { val codegenScope = arrayOf( - "Config" to awsTypes(runtimeConfig = protocolConfig.runtimeConfig).asType().member("config::Config") + "Config" to awsTypes(runtimeConfig = codegenContext.runtimeConfig).asType().member("config::Config") ) rustCrate.withModule(RustModule.Config) { // TODO(sharedconfig): As more items are added to aws_types::Config, use them here to configure the config builder diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt index 72d0fe8faa513328d4162052be6b4261e4a3cf6d..fd9caa483dcd539ed7d7c823e028e6278c4ce69d 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt @@ -19,12 +19,12 @@ import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.smithy.letIf @@ -46,32 +46,32 @@ class SigV4SigningDecorator : RustCodegenDecorator { override val name: String = "SigV4Signing" override val order: Byte = 0 - private fun applies(protocolConfig: ProtocolConfig): Boolean = protocolConfig.serviceShape.hasTrait() + private fun applies(codegenContext: CodegenContext): Boolean = codegenContext.serviceShape.hasTrait() override fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations.letIf(applies(protocolConfig)) { customizations -> + return baseCustomizations.letIf(applies(codegenContext)) { customizations -> customizations + SigV4SigningConfig( - protocolConfig.runtimeConfig, - protocolConfig.serviceShape.hasEventStreamOperations(protocolConfig.model), - protocolConfig.serviceShape.expectTrait() + codegenContext.runtimeConfig, + codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model), + codegenContext.serviceShape.expectTrait() ) } } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations.letIf(applies(protocolConfig)) { + return baseCustomizations.letIf(applies(codegenContext)) { it + SigV4SigningFeature( - protocolConfig.model, + codegenContext.model, operation, - protocolConfig.runtimeConfig, - protocolConfig.serviceShape, + codegenContext.runtimeConfig, + codegenContext.serviceShape, ) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt index c8379c8e01535511940754c3c042c9fcc3449356..b9a62170cc65ab5cae71d92bb5e4234b703fe865 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt @@ -12,13 +12,13 @@ import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait @@ -30,20 +30,20 @@ class UserAgentDecorator : RustCodegenDecorator { override val order: Byte = 10 override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { // We are generating an AWS SDK, the service needs to have the AWS service trait - val serviceTrait = protocolConfig.serviceShape.expectTrait() - return baseCustomizations + ApiVersion(protocolConfig.runtimeConfig, serviceTrait) + val serviceTrait = codegenContext.serviceShape.expectTrait() + return baseCustomizations + ApiVersion(codegenContext.runtimeConfig, serviceTrait) } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations + UserAgentFeature(protocolConfig.runtimeConfig) + return baseCustomizations + UserAgentFeature(codegenContext.runtimeConfig) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt index 8ca736cd7d82465d56894fc394fc832cafbb527e..a5924a89f4898e4fbd638bf4e9df4da173f0053f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt @@ -10,24 +10,24 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.letIf class ApiGatewayDecorator : RustCodegenDecorator { override val name: String = "ApiGateway" override val order: Byte = 0 - private fun applies(protocolConfig: ProtocolConfig) = protocolConfig.serviceShape.id == ShapeId.from("com.amazonaws.apigateway#BackplaneControlService") + private fun applies(codegenContext: CodegenContext) = codegenContext.serviceShape.id == ShapeId.from("com.amazonaws.apigateway#BackplaneControlService") override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations.letIf(applies(protocolConfig)) { + return baseCustomizations.letIf(applies(codegenContext)) { it + ApiGatewayAddAcceptHeader() } } @@ -36,7 +36,7 @@ class ApiGatewayDecorator : RustCodegenDecorator { class ApiGatewayAddAcceptHeader : OperationCustomization() { override fun section(section: OperationSection): Writable = when (section) { is OperationSection.FinalizeOperation -> emptySection - OperationSection.OperationImplBlock -> emptySection + is OperationSection.OperationImplBlock -> emptySection is OperationSection.MutateRequest -> writable { rust( """${section.request} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt index 3b160f16dc5e33aac324b6d85d8b3f4d905aafb2..b39550304d0a2825754f6d2871842cb4584677c3 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt @@ -20,11 +20,11 @@ import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml @@ -52,10 +52,10 @@ class S3Decorator : RustCodegenDecorator { } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations.letIf(applies(protocolConfig.serviceShape.id)) { + return baseCustomizations.letIf(applies(codegenContext.serviceShape.id)) { it + S3PubUse() } } @@ -66,8 +66,8 @@ class S3Decorator : RustCodegenDecorator { // S3CorrectSizeIntegerType().transform(model) } -class S3(protocolConfig: ProtocolConfig) : RestXml(protocolConfig) { - private val runtimeConfig = protocolConfig.runtimeConfig +class S3(codegenContext: CodegenContext) : RestXml(codegenContext) { + private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, "Error" to RuntimeType.GenericError(runtimeConfig), diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt index 2102b15f7155e71cbd2fcf84dbbc6dc63c8a17f8..0c3b5f12f1e94168976a7e573c984c18161b56de 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt @@ -11,7 +11,7 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** Synthetic trait that indicates an operation is presignable. */ // TODO: This can be replaced once https://github.com/awslabs/smithy/pull/897 is merged. -class PresignableTrait : AnnotationTrait(ID, Node.objectNode()) { +class PresignableTrait(val syntheticOperationId: ShapeId) : AnnotationTrait(ID, Node.objectNode()) { companion object { val ID = ShapeId.from("smithy.api.aws.internal#presignable") } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt index fa83f28199c3a6881b1c5ac1ac712546ba9d8421..de297522b9babb2d2fc410d680f3c0e1c21ecb61 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt @@ -8,9 +8,14 @@ package software.amazon.smithy.rustsdk import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.HttpBinding +import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.rust.codegen.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rustsdk.traits.PresignableTrait @@ -25,11 +30,11 @@ class AwsPresigningDecoratorTest { val decorator = AwsPresigningDecorator() val model = testOperation(namespace, name) val transformed = decorator.transformModel(serviceShape(model), model) - hasPresignableTrait(transformed) shouldBe presignable + hasPresignableTrait(transformed, namespace, name) shouldBe presignable } - private fun hasPresignableTrait(model: Model): Boolean = - model.shapes().filter { shape -> shape is OperationShape } + private fun hasPresignableTrait(model: Model, namespace: String, name: String): Boolean = + model.shapes().filter { shape -> shape is OperationShape && shape.id == ShapeId.fromParts(namespace, name) } .findFirst() .orNull()!! .hasTrait(PresignableTrait.ID) @@ -56,3 +61,129 @@ class AwsPresigningDecoratorTest { } """.asSmithyModel() } + +class OverrideHttpMethodTransformTest { + @Test + fun `it should override the HTTP method for the listed operations`() { + val model = """ + namespace test + use aws.protocols#restJson1 + + @restJson1 + service TestService { + version: "2019-12-16", + operations: ["One", "Two", "Three"], + } + + structure TestInput { } + structure TestOutput { } + + @http(uri: "/one", method: "POST") + operation One { input: TestInput, output: TestOutput } + + @http(uri: "/two", method: "GET") + operation Two { input: TestInput, output: TestOutput } + + @http(uri: "/three", method: "POST") + operation Three { input: TestInput, output: TestOutput } + """.asSmithyModel() + + val serviceShape = model.expectShape(ShapeId.from("test#TestService"), ServiceShape::class.java) + val presignableOp = PresignableOperation( + PayloadSigningType.EMPTY, + listOf( + OverrideHttpMethodTransform( + mapOf( + ShapeId.from("test#One") to "GET", + ShapeId.from("test#Two") to "POST", + ) + ) + ) + ) + val transformed = AwsPresigningDecorator( + mapOf( + ShapeId.from("test#One") to presignableOp, + ShapeId.from("test#Two") to presignableOp + ) + ).transformModel(serviceShape, model) + + val synthNamespace = "test.synthetic.aws.presigned" + transformed.expectShape(ShapeId.from("$synthNamespace#One")).expectTrait().method shouldBe "GET" + transformed.expectShape(ShapeId.from("$synthNamespace#Two")).expectTrait().method shouldBe "POST" + transformed.getShape(ShapeId.from("$synthNamespace#Three")).orNull() shouldBe null + } +} + +class MoveDocumentMembersToQueryParamsTransformTest { + @Test + fun `it should move document members to query parameters for the listed operations`() { + val model = """ + namespace test + use aws.protocols#restJson1 + + @restJson1 + service TestService { + version: "2019-12-16", + operations: ["One", "Two"], + } + + structure OneInputOutput { + @httpHeader("one") + one: String, + @httpQuery("two") + two: String, + + three: String, + four: String, + } + structure TwoInputOutput { + @httpHeader("one") + one: String, + @httpQuery("two") + two: String, + + three: String, + four: String, + } + + @http(uri: "/one", method: "POST") + operation One { input: OneInputOutput, output: OneInputOutput } + + @http(uri: "/two", method: "POST") + operation Two { input: TwoInputOutput, output: TwoInputOutput } + """.asSmithyModel() + + val serviceShape = model.expectShape(ShapeId.from("test#TestService"), ServiceShape::class.java) + val presignableOp = PresignableOperation( + PayloadSigningType.EMPTY, + listOf( + MoveDocumentMembersToQueryParamsTransform( + listOf(ShapeId.from("test#One")) + ) + ) + ) + val transformed = AwsPresigningDecorator( + mapOf(ShapeId.from("test#One") to presignableOp) + ).transformModel(serviceShape, model) + + val index = HttpBindingIndex(transformed) + index.getRequestBindings(ShapeId.from("test.synthetic.aws.presigned#One")).map { (key, value) -> + key to value.location + }.toMap() shouldBe mapOf( + "one" to HttpBinding.Location.HEADER, + "two" to HttpBinding.Location.QUERY, + "three" to HttpBinding.Location.QUERY, + "four" to HttpBinding.Location.QUERY, + ) + + transformed.getShape(ShapeId.from("test.synthetic.aws.presigned#Two")).orNull() shouldBe null + index.getRequestBindings(ShapeId.from("test#Two")).map { (key, value) -> + key to value.location + }.toMap() shouldBe mapOf( + "one" to HttpBinding.Location.HEADER, + "two" to HttpBinding.Location.QUERY, + "three" to HttpBinding.Location.DOCUMENT, + "four" to HttpBinding.Location.DOCUMENT, + ) + } +} diff --git a/aws/sdk/README.md b/aws/sdk/README.md index 558ec4ebf2d10a1ae6d05c18a09d58f83d003b32..45cdad7c6dad22a8890e02332380cfcd4cda2557 100644 --- a/aws/sdk/README.md +++ b/aws/sdk/README.md @@ -1,10 +1,12 @@ -# AWS SDK Generator +AWS SDK Generator +================= This directory contains a gradle project to generate an AWS SDK. It uses the Smithy Build Plugin combined with the customizations specified in `aws/sdk-codegen` to generate an AWS SDK from Smithy models. `build.gradle.kts` will generate a `smithy-build.json` dynamically from all models in the `models` directory. -## Usage +Usage +----- Generate an SDK: `./gradlew :aws:sdk:assemble` @@ -15,7 +17,9 @@ Generate, compile, and test an SDK: Run an SDK example: `./gradlew :aws:sdk:runExample --example dynamo-helloworld` -## Controlling service generation +Controlling service generation +------------------------------ + You can use gradle properties to opt/out of generating specific services: ```bash # generate only s3,ec2,sts @@ -29,3 +33,20 @@ The generation logic is as follows: 1. If `aws.services` is specified, generate an SDK based on the inclusion/exclusion list. 2. Otherwise, if `aws.fullsdk` is specified generate an SDK based on `aws.services.fullsdk`. 3. Otherwise, generate an SDK based on `aws.services.smoketest` + +Debugging with IntelliJ +----------------------- + +The easiest way to debug codegen is to attach a remote debugger to the Gradle daemon and then run `aws:sdk:assemble`. +To do this: + +1. Temporarily modify the root `gradle.properties` file to have the following: +``` +org.gradle.jvmargs=-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=localhost:5006 +``` +2. Run `./gradlew --stop` to kill any Gradle daemons that are running without that property. +3. Configure IntelliJ to remote debug on port 5006 (or whichever port was configured above). +4. Run `./gradlew aws:sdk:assemble` (with any additional properties to limit the services generated) +5. It will hang on "Starting Daemon". This is because the Gradle daemon is waiting for a remote debugger + to start up. Attaching IntelliJ's debugger will make the build proceed, but now you can stop execution + on breakpoints and examine values. diff --git a/aws/sdk/examples/polly/src/bin/synthesize-speech-presigned.rs b/aws/sdk/examples/polly/src/bin/synthesize-speech-presigned.rs new file mode 100644 index 0000000000000000000000000000000000000000..11149511c7d3b25f8329343cdad333c825dc4816 --- /dev/null +++ b/aws/sdk/examples/polly/src/bin/synthesize-speech-presigned.rs @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +use aws_config::meta::region::RegionProviderChain; +use polly::input::SynthesizeSpeechInput; +use polly::model::{OutputFormat, VoiceId}; +use polly::presigning::config::PresigningConfig; +use polly::{Client, Config, Region, PKG_VERSION}; +use std::error::Error; +use std::fs; +use std::time::Duration; +use structopt::StructOpt; + +#[derive(Debug, StructOpt)] +struct Opt { + /// The default AWS Region. + #[structopt(short, long)] + region: Option, + + /// The file containing the text to synthesize. + #[structopt(short, long)] + filename: String, + + /// How long in seconds before the presigned request should expire. + #[structopt(short, long)] + expires_in: Option, + + /// Whether to display additional information. + #[structopt(short, long)] + verbose: bool, +} + +/// Generates a presigned request to synthesize UTF-8 input, plain text or SSML, to a stream of bytes in a file. +/// # Arguments +/// +/// * `-f FILENAME` - The name of the file containing the text to synthesize. +/// The output is saved in MP3 format in a file with the same basename, but with an __mp3__ extension. +/// * `[-r REGION]` - The Region in which the client is created. +/// If not supplied, uses the value of the **AWS_REGION** environment variable. +/// If the environment variable is not set, defaults to **us-west-2**. +/// * `[-e EXPIRES_IN]` - The amount of time the presigned request should be valid for. +/// If not given, this defaults to 15 minutes. +/// * `[-v]` - Whether to display additional information. +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let Opt { + filename, + region, + expires_in, + verbose, + } = Opt::from_args(); + let expires_in = Duration::from_secs(expires_in.unwrap_or(900)); + + let region_provider = RegionProviderChain::first_try(region.map(Region::new)) + .or_default_provider() + .or_else(Region::new("us-west-2")); + let shared_config = aws_config::from_env().region(region_provider).load().await; + let client = Client::new(&shared_config); + + println!(); + + if verbose { + println!("Polly version: {}", PKG_VERSION); + println!("Region: {:?}", shared_config.region().unwrap()); + println!("Filename: {}", &filename); + println!(); + } + + let content = fs::read_to_string(&filename).unwrap(); + + // Presigned requests can be made with the client directly + let presigned_request = client + .synthesize_speech() + .output_format(OutputFormat::Mp3) + .text(content.clone()) + .voice_id(VoiceId::Joanna) + .presigned(PresigningConfig::expires_in(expires_in)?) + .await?; + println!("From client: {:?}", presigned_request); + + // Or, they can be made directly from an operation input + let presigned_request = SynthesizeSpeechInput::builder() + .output_format(OutputFormat::Mp3) + .text(content) + .voice_id(VoiceId::Joanna) + .build()? + .presigned( + &Config::from(&shared_config), + PresigningConfig::expires_in(expires_in)?, + ) + .await?; + println!("From operation input: {:?}", presigned_request); + + Ok(()) +} diff --git a/aws/sdk/integration-tests/Cargo.toml b/aws/sdk/integration-tests/Cargo.toml index f5bce40f232e040f24e1a1c03f858e4b38943e4d..96876182d602fe8847d688a9a23bb526593a1e4a 100644 --- a/aws/sdk/integration-tests/Cargo.toml +++ b/aws/sdk/integration-tests/Cargo.toml @@ -5,6 +5,7 @@ members = [ "dynamodb", "iam", "kms", + "polly", "qldbsession", "s3", "sts", diff --git a/aws/sdk/integration-tests/polly/Cargo.toml b/aws/sdk/integration-tests/polly/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a78468f5bf5dcfc8b5a21311c06c2e67e90d970f --- /dev/null +++ b/aws/sdk/integration-tests/polly/Cargo.toml @@ -0,0 +1,21 @@ +# This Cargo.toml is unused in generated code. It exists solely to enable these tests to compile in-situ +[package] +name = "polly-tests" +version = "0.1.0" +authors = ["John DiSanti "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +aws-sdk-polly = { path = "../../build/aws-sdk/polly" } +smithy-client = { path = "../../build/aws-sdk/smithy-client", features = ["test-util"] } +smithy-http = { path = "../../build/aws-sdk/smithy-http" } +tracing-subscriber = "0.2.18" + +[dev-dependencies] +tokio = { version = "1", features = ["full"]} +http = "0.2.3" +bytes = "1" +aws-hyper = { path = "../../build/aws-sdk/aws-hyper"} +aws-http = { path = "../../build/aws-sdk/aws-http"} diff --git a/aws/sdk/integration-tests/polly/tests/presigning.rs b/aws/sdk/integration-tests/polly/tests/presigning.rs new file mode 100644 index 0000000000000000000000000000000000000000..3028e2b39258acc11dc8a434b765ec5b8973d90b --- /dev/null +++ b/aws/sdk/integration-tests/polly/tests/presigning.rs @@ -0,0 +1,67 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use aws_sdk_polly as polly; +use aws_sdk_polly::model::{OutputFormat, VoiceId}; +use polly::presigning::config::PresigningConfig; +use std::error::Error; +use std::time::{Duration, SystemTime}; + +#[tokio::test] +async fn test_presigning() -> Result<(), Box> { + let creds = polly::Credentials::from_keys( + "ANOTREAL", + "notrealrnrELgWzOk3IfjzDKtFBhDby", + Some("notarealsessiontoken".to_string()), + ); + let config = polly::Config::builder() + .credentials_provider(creds) + .region(polly::Region::new("us-east-1")) + .build(); + + let input = polly::input::SynthesizeSpeechInput::builder() + .output_format(OutputFormat::Mp3) + .text("hello, world") + .voice_id(VoiceId::Joanna) + .build()?; + + let presigned = input + .presigned( + &config, + PresigningConfig::builder() + .start_time(SystemTime::UNIX_EPOCH + Duration::from_secs(1234567891)) + .expires_in(Duration::from_secs(30)) + .build() + .unwrap(), + ) + .await?; + + let pq = presigned.uri().path_and_query().unwrap(); + let path = pq.path(); + let query = pq.query().unwrap(); + let mut query_params: Vec<&str> = query.split('&').collect(); + query_params.sort(); + + assert_eq!("GET", presigned.method().as_str()); + assert_eq!("/v1/speech", path); + assert_eq!( + &[ + "OutputFormat=mp3", + "Text=hello%2C%20world", + "VoiceId=Joanna", + "X-Amz-Algorithm=AWS4-HMAC-SHA256", + "X-Amz-Credential=ANOTREAL%2F20090213%2Fus-east-1%2Fpolly%2Faws4_request", + "X-Amz-Date=20090213T233131Z", + "X-Amz-Expires=30", + "X-Amz-Security-Token=notarealsessiontoken", + "X-Amz-Signature=79fcf765b063aa29d852fa9d7c2a9ccff94d727d743adbff84a0be6afe9a92e8", + "X-Amz-SignedHeaders=host", + ][..], + &query_params + ); + assert!(presigned.headers().is_empty()); + + Ok(()) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CodegenVisitor.kt index 5e21113221ca68d1fcc36060c1a17903627eca85..28b6a5629d261f31d56e29a68d6771ff79c8bdf3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CodegenVisitor.kt @@ -26,6 +26,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServiceGener import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpDeserializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerGenerator +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.DefaultPublicModules import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.RustSettings @@ -34,17 +35,16 @@ import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes +import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator @@ -72,9 +72,9 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC private val rustCrate: RustCrate private val fileManifest = context.fileManifest private val model: Model - private val protocolConfig: ProtocolConfig - private val protocolGenerator: ProtocolGeneratorFactory - private val httpGenerator: HttpProtocolGenerator + private val codegenContext: CodegenContext + private val protocolGeneratorFactory: ProtocolGeneratorFactory + private val protocolGenerator: ProtocolGenerator private val serializerGenerator: JsonSerializerGenerator private val deserializerGenerator: JsonParserGenerator @@ -99,43 +99,42 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC ) ) .protocolFor(context.model, service) - protocolGenerator = generator + protocolGeneratorFactory = generator model = generator.transformModel(codegenDecorator.transformModel(service, baseModel)) val baseProvider = RustCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) symbolProvider = codegenDecorator.symbolProvider(generator.symbolProvider(model, baseProvider)) - protocolConfig = - ProtocolConfig( - model, - symbolProvider, - settings.runtimeConfig, - service, - protocol, - settings.moduleName - ) + codegenContext = CodegenContext( + model, + symbolProvider, + settings.runtimeConfig, + service, + protocol, + settings.moduleName + ) rustCrate = RustCrate(context.fileManifest, symbolProvider, DefaultPublicModules) - httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig) + protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) httpBindingResolver = HttpTraitHttpBindingResolver( - protocolConfig.model, + codegenContext.model, ProtocolContentTypes.consistent("application/json"), ) - serializerGenerator = JsonSerializerGenerator(protocolConfig, httpBindingResolver) - deserializerGenerator = JsonParserGenerator(protocolConfig, httpBindingResolver) - when (protocolConfig.protocol) { + serializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver) + deserializerGenerator = JsonParserGenerator(codegenContext, httpBindingResolver) + when (codegenContext.protocol) { RestJson1Trait.ID -> { httpSerializerGenerator = - RestJson1HttpSerializerGenerator(protocolConfig, httpBindingResolver) + RestJson1HttpSerializerGenerator(codegenContext, httpBindingResolver) httpDeserializerGenerator = - RestJson1HttpDeserializerGenerator(protocolConfig, httpBindingResolver) + RestJson1HttpDeserializerGenerator(codegenContext, httpBindingResolver) } else -> { // TODO: support other protocols - throw Exception("Protocol ${protocolConfig.protocol} not support yet") + throw Exception("Protocol ${codegenContext.protocol} not support yet") } } } @@ -151,14 +150,14 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC fun execute() { val service = settings.getService(model) logger.info( - "[rust-server-codegen] Generating Rust server for service $service, protocol ${protocolConfig.protocol}..." + "[rust-server-codegen] Generating Rust server for service $service, protocol ${codegenContext.protocol}..." ) val serviceShapes = Walker(model).walkShapes(service) serviceShapes.forEach { it.accept(this) } - codegenDecorator.extras(protocolConfig, rustCrate) + codegenDecorator.extras(codegenContext, rustCrate) val module = RustMetadata(public = true) rustCrate.withModule(RustModule("error", module)) { writer -> renderSerdeError(writer) } - rustCrate.finalize(settings, codegenDecorator.libRsCustomizations(protocolConfig, listOf())) + rustCrate.finalize(settings, codegenDecorator.libRsCustomizations(codegenContext, listOf())) try { "cargo fmt".runCommand( fileManifest.baseDir, @@ -215,7 +214,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC StructureGenerator(model, symbolProvider, writer, shape).render() if (!shape.hasTrait()) { val builderGenerator = - BuilderGenerator(protocolConfig.model, protocolConfig.symbolProvider, shape) + BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) builderGenerator.render(writer) writer.implBlock(shape, symbolProvider) { builderGenerator.renderConvenienceMethod(this) @@ -241,9 +240,9 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC override fun serviceShape(shape: ServiceShape) { ServiceGenerator( rustCrate, - httpGenerator, - protocolGenerator.support(), - protocolConfig, + protocolGenerator, + protocolGeneratorFactory.support(), + codegenContext, codegenDecorator ) .render() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceGenerator.kt index 453b1601dcb7f1f861180e17017ce7766fe746cf..e9ba6f37ad783599679d929217113e02c98a209b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceGenerator.kt @@ -7,52 +7,52 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator -import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfigGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.util.inputShape class ServiceGenerator( - private val rustCrate: RustCrate, - private val protocolGenerator: HttpProtocolGenerator, - private val protocolSupport: ProtocolSupport, - private val config: ProtocolConfig, - private val decorator: RustCodegenDecorator, + private val rustCrate: RustCrate, + private val protocolGenerator: ProtocolGenerator, + private val protocolSupport: ProtocolSupport, + private val context: CodegenContext, + private val decorator: RustCodegenDecorator, ) { - private val index = TopDownIndex.of(config.model) + private val index = TopDownIndex.of(context.model) fun render() { - val operations = index.getContainedOperations(config.serviceShape).sortedBy { it.id } + val operations = index.getContainedOperations(context.serviceShape).sortedBy { it.id } operations.map { operation -> rustCrate.useShapeWriter(operation) { operationWriter -> - rustCrate.useShapeWriter(operation.inputShape(config.model)) { inputWriter -> + rustCrate.useShapeWriter(operation.inputShape(context.model)) { inputWriter -> protocolGenerator.renderOperation( - operationWriter, - inputWriter, - operation, - decorator.operationCustomizations(config, operation, listOf()) + operationWriter, + inputWriter, + operation, + decorator.operationCustomizations(context, operation, listOf()) ) } } rustCrate.withModule(RustModule.Error) { writer -> - CombinedErrorGenerator(config.model, config.symbolProvider, operation) - .render(writer) + CombinedErrorGenerator(context.model, context.symbolProvider, operation) + .render(writer) } } - TopLevelErrorGenerator(config, operations).render(rustCrate) + TopLevelErrorGenerator(context, operations).render(rustCrate) rustCrate.withModule(RustModule.Config) { writer -> ServiceConfigGenerator.withBaseBehavior( - config, - extraCustomizations = decorator.configCustomizations(config, listOf()) - ) - .render(writer) + context, + extraCustomizations = decorator.configCustomizations(context, listOf()) + ) + .render(writer) } rustCrate.lib { it.write("pub use config::Config;") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RestJson1.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RestJson1.kt index f2b433c5c47d4d77847d77d1502b9c90bfe5c5cd..5883fb0940127163be9b97f3255993a65d957c2b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RestJson1.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RestJson1.kt @@ -32,9 +32,9 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator @@ -59,29 +59,29 @@ import software.amazon.smithy.rust.codegen.util.toSnakeCase import java.util.logging.Logger abstract class ServerGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val httpBindingResolver: HttpTraitHttpBindingResolver, ) { public val logger = Logger.getLogger(javaClass.name) public val error = RuntimeType("error", null, "crate") public val operation = RuntimeType("operation", null, "crate") - public val runtimeConfig = protocolConfig.runtimeConfig - public val model = protocolConfig.model - public val symbolProvider = protocolConfig.symbolProvider + public val runtimeConfig = codegenContext.runtimeConfig + public val model = codegenContext.model + public val symbolProvider = codegenContext.symbolProvider public val instantiator = - with(protocolConfig) { Instantiator(symbolProvider, model, runtimeConfig) } + with(codegenContext) { Instantiator(symbolProvider, model, runtimeConfig) } public val smithyHttp = CargoDependency.SmithyHttp(runtimeConfig).asType() public val index = HttpBindingIndex.of(model) - public val service = protocolConfig.serviceShape + public val service = codegenContext.serviceShape public val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS abstract fun render(writer: RustWriter, operationShape: OperationShape) } class RestJson1HttpSerializerGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val httpBindingResolver: HttpTraitHttpBindingResolver, -) : ServerGenerator(protocolConfig, httpBindingResolver) { +) : ServerGenerator(codegenContext, httpBindingResolver) { private val serde = RuntimeType("json_ser", null, "crate") private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() private val codegenScope = @@ -336,12 +336,12 @@ class RestJson1HttpSerializerGenerator( } class RestJson1HttpDeserializerGenerator( - private val protocolConfig: ProtocolConfig, + private val codegenContext: CodegenContext, private val httpBindingResolver: HttpTraitHttpBindingResolver, -) : ServerGenerator(protocolConfig, httpBindingResolver) { +) : ServerGenerator(codegenContext, httpBindingResolver) { override fun render(writer: RustWriter, operationShape: OperationShape) { RestJson1HttpRequestDeserializerGenerator( - protocolConfig, + codegenContext, httpBindingResolver, operationShape ) @@ -350,15 +350,15 @@ class RestJson1HttpDeserializerGenerator( } class RestJson1HttpRequestDeserializerGenerator( - private val protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val httpBindingResolver: HttpTraitHttpBindingResolver, private val operationShape: OperationShape, -) : ServerGenerator(protocolConfig, httpBindingResolver) { +) : ServerGenerator(codegenContext, httpBindingResolver) { private val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request" private val httpBindingGenerator = ResponseBindingGenerator( - RestJson(protocolConfig), - protocolConfig, + RestJson(codegenContext), + codegenContext, operationShape, ) private val httpTrait = httpBindingResolver.httpTrait(operationShape) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustReservedWords.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustReservedWords.kt index 0339b5ee801186183e10dc6616dbfb7c8d0a7dfb..06d06d0ce5152b7b46ea2342aa55f469f064a75a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustReservedWords.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustReservedWords.kt @@ -26,6 +26,9 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : Wra "build" -> "build_value" "default" -> "default_value" "send" -> "send_value" + // To avoid conflicts with the `make_operation` and `presigned` functions on generated inputs + "make_operation" -> "make_operation_value" + "presigned" -> "presigned_value" else -> baseName } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenContext.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenContext.kt new file mode 100644 index 0000000000000000000000000000000000000000..3a1ec3c17e5b13534f69a7eb3ff882e03d8c3357 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenContext.kt @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId + +/** + * Configuration needed to generate the client for a given Service<->Protocol pair + */ +data class CodegenContext( + val model: Model, + val symbolProvider: RustSymbolProvider, + val runtimeConfig: RuntimeConfig, + val serviceShape: ServiceShape, + val protocol: ShapeId, + val moduleName: String +) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt index 3a8ea8589a4df677318ec91e4dd23bf68539c439..832bc9defbe7459d7f5bd59873fd6c8a905391ab 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt @@ -18,13 +18,12 @@ import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.transformers.AddErrorMessage @@ -48,9 +47,9 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC private val rustCrate: RustCrate private val fileManifest = context.fileManifest private val model: Model - private val protocolConfig: ProtocolConfig - private val protocolGenerator: ProtocolGeneratorFactory - private val httpGenerator: HttpProtocolGenerator + private val codegenContext: CodegenContext + private val protocolGenerator: ProtocolGeneratorFactory + private val httpGenerator: ProtocolGenerator init { val symbolVisitorConfig = @@ -65,14 +64,14 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC val baseProvider = RustCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) symbolProvider = codegenDecorator.symbolProvider(generator.symbolProvider(model, baseProvider)) - protocolConfig = - ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol, settings.moduleName) + codegenContext = + CodegenContext(model, symbolProvider, settings.runtimeConfig, service, protocol, settings.moduleName) rustCrate = RustCrate( context.fileManifest, symbolProvider, DefaultPublicModules ) - httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig) + httpGenerator = protocolGenerator.buildProtocolGenerator(codegenContext) } private fun baselineTransform(model: Model) = @@ -87,11 +86,11 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC val service = settings.getService(model) val serviceShapes = Walker(model).walkShapes(service) serviceShapes.forEach { it.accept(this) } - codegenDecorator.extras(protocolConfig, rustCrate) + codegenDecorator.extras(codegenContext, rustCrate) rustCrate.finalize( settings, codegenDecorator.libRsCustomizations( - protocolConfig, + codegenContext, listOf() ) ) @@ -112,7 +111,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC rustCrate.useShapeWriter(shape) { writer -> StructureGenerator(model, symbolProvider, writer, shape).render() if (!shape.hasTrait()) { - val builderGenerator = BuilderGenerator(protocolConfig.model, protocolConfig.symbolProvider, shape) + val builderGenerator = BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) builderGenerator.render(writer) writer.implBlock(shape, symbolProvider) { builderGenerator.renderConvenienceMethod(this) @@ -140,7 +139,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC rustCrate, httpGenerator, protocolGenerator.support(), - protocolConfig, + codegenContext, codegenDecorator ).render() } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/EndpointPrefixGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/EndpointPrefixGenerator.kt index 9991ddffee2950784b8955fb287cbe399a714e9e..0df9ff4b8bf652a4b6ac59adbcee797e6eb4e363 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/EndpointPrefixGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/EndpointPrefixGenerator.kt @@ -12,25 +12,25 @@ import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.generators.EndpointTraitBindings import software.amazon.smithy.rust.codegen.smithy.generators.OperationBuildError -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -class EndpointPrefixGenerator(private val protocolConfig: ProtocolConfig, private val shape: OperationShape) : +class EndpointPrefixGenerator(private val codegenContext: CodegenContext, private val shape: OperationShape) : OperationCustomization() { override fun section(section: OperationSection): Writable = when (section) { is OperationSection.MutateRequest -> writable { shape.getTrait(EndpointTrait::class.java).map { epTrait -> val endpointTraitBindings = EndpointTraitBindings( - protocolConfig.model, - protocolConfig.symbolProvider, - protocolConfig.runtimeConfig, + codegenContext.model, + codegenContext.symbolProvider, + codegenContext.runtimeConfig, shape, epTrait ) - val buildError = OperationBuildError(protocolConfig.runtimeConfig) + val buildError = OperationBuildError(codegenContext.runtimeConfig) withBlock("let endpoint_prefix = ", ";") { endpointTraitBindings.render(this, "self") } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/HttpChecksumRequiredGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/HttpChecksumRequiredGenerator.kt index b237278308040c0b37f86660339f60c2b92b9b97..79dd365bfaa96d88e8e50e8a519789b4c64cb292 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/HttpChecksumRequiredGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/HttpChecksumRequiredGenerator.kt @@ -13,24 +13,24 @@ import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape class HttpChecksumRequiredGenerator( - private val protocolConfig: ProtocolConfig, + private val codegenContext: CodegenContext, private val operationShape: OperationShape ) : OperationCustomization() { override fun section(section: OperationSection): Writable { if (!operationShape.hasTrait()) { return emptySection } - if (operationShape.inputShape(protocolConfig.model).hasStreamingMember(protocolConfig.model)) { + if (operationShape.inputShape(codegenContext.model).hasStreamingMember(codegenContext.model)) { throw CodegenException("HttpChecksum required cannot be applied to a streaming shape") } return when (section) { @@ -54,8 +54,8 @@ class HttpChecksumRequiredGenerator( """, "md5" to CargoDependency.Md5.asType(), "http" to CargoDependency.Http.asType(), - "base64_encode" to RuntimeType.Base64Encode(protocolConfig.runtimeConfig), - "BuildError" to protocolConfig.runtimeConfig.operationBuildError() + "base64_encode" to RuntimeType.Base64Encode(codegenContext.runtimeConfig), + "BuildError" to codegenContext.runtimeConfig.operationBuildError() ) } else -> emptySection diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/IdempotencyTokenGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/IdempotencyTokenGenerator.kt index cb766872fe03d21ce23cab03c81378b806c4d8dd..5148d70382ccb461d5e067446ffd927e301636c8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/IdempotencyTokenGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customizations/IdempotencyTokenGenerator.kt @@ -10,16 +10,16 @@ import software.amazon.smithy.model.traits.IdempotencyTokenTrait import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.util.findMemberWithTrait import software.amazon.smithy.rust.codegen.util.inputShape -class IdempotencyTokenGenerator(protocolConfig: ProtocolConfig, private val operationShape: OperationShape) : +class IdempotencyTokenGenerator(codegenContext: CodegenContext, private val operationShape: OperationShape) : OperationCustomization() { - private val model = protocolConfig.model - private val symbolProvider = protocolConfig.symbolProvider + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider private val idempotencyTokenMember = operationShape.inputShape(model).findMemberWithTrait(model) override fun section(section: OperationSection): Writable { if (idempotencyTokenMember == null) { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt index 0c80507e161734b466b8d98e5f7df3deba714cd5..74297d3c380df53ac168b7b2bb49dfa94d852f1b 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt @@ -8,18 +8,28 @@ package software.amazon.smithy.rust.codegen.smithy.customize import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol sealed class OperationSection(name: String) : Section(name) { + abstract val customizations: List + /** Write custom code into the `impl` block of this operation */ - object OperationImplBlock : OperationSection("OperationImplBlock") + data class OperationImplBlock(override val customizations: List) : + OperationSection("OperationImplBlock") /** Write additional functions inside the Input's impl block */ data class InputImpl( + override val customizations: List, val operationShape: OperationShape, - val inputShape: StructureShape + val inputShape: StructureShape, + val protocol: Protocol, ) : OperationSection("InputImpl") - data class MutateInput(val input: String, val config: String) : OperationSection("MutateInput") + data class MutateInput( + override val customizations: List, + val input: String, + val config: String + ) : OperationSection("MutateInput") /** Write custom code into the block that builds an operation * @@ -27,9 +37,17 @@ sealed class OperationSection(name: String) : Section(name) { * [config]: Name of the variable holding the service config. * * */ - data class MutateRequest(val request: String, val config: String) : OperationSection("Feature") + data class MutateRequest( + override val customizations: List, + val request: String, + val config: String + ) : OperationSection("Feature") - data class FinalizeOperation(val operation: String, val config: String) : OperationSection("Finalize") + data class FinalizeOperation( + override val customizations: List, + val operation: String, + val config: String + ) : OperationSection("Finalize") } abstract class OperationCustomization : NamedSectionGenerator() { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RequiredCustomizations.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RequiredCustomizations.kt index b60f2f697f7227ef768db94c601cfba33a13af23..3c317cf865b2ece55bff9a9fd90bb2a9ebfcf82e 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RequiredCustomizations.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RequiredCustomizations.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.smithy.customize import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.customizations.AllowLintsGenerator import software.amazon.smithy.rust.codegen.smithy.customizations.CrateVersionGenerator import software.amazon.smithy.rust.codegen.smithy.customizations.EndpointPrefixGenerator @@ -13,7 +14,6 @@ import software.amazon.smithy.rust.codegen.smithy.customizations.HttpChecksumReq import software.amazon.smithy.rust.codegen.smithy.customizations.IdempotencyTokenGenerator import software.amazon.smithy.rust.codegen.smithy.customizations.SmithyTypesPubUseGenerator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig /** A set of customizations that are included in all protocols. * @@ -24,20 +24,20 @@ class RequiredCustomizations : RustCodegenDecorator { override val order: Byte = -1 override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { - return baseCustomizations + IdempotencyTokenGenerator(protocolConfig, operation) + EndpointPrefixGenerator( - protocolConfig, + return baseCustomizations + IdempotencyTokenGenerator(codegenContext, operation) + EndpointPrefixGenerator( + codegenContext, operation - ) + HttpChecksumRequiredGenerator(protocolConfig, operation) + ) + HttpChecksumRequiredGenerator(codegenContext, operation) } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - return baseCustomizations + CrateVersionGenerator() + SmithyTypesPubUseGenerator(protocolConfig.runtimeConfig) + AllowLintsGenerator() + return baseCustomizations + CrateVersionGenerator() + SmithyTypesPubUseGenerator(codegenContext.runtimeConfig) + AllowLintsGenerator() } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt index 507ba7deac1afd230d16b7cac59a9def6b543780..74a4df6e564082e26218b54ec30c26a2c92b0cf7 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt @@ -10,11 +10,11 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.FluentClientDecorator import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap import java.util.ServiceLoader @@ -39,22 +39,22 @@ interface RustCodegenDecorator { val order: Byte fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List = baseCustomizations fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List = baseCustomizations fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List = baseCustomizations - fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) {} + fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) {} fun protocols(serviceId: ShapeId, currentProtocols: ProtocolMap): ProtocolMap = currentProtocols @@ -76,31 +76,31 @@ open class CombinedCodegenDecorator(decorators: List) : Ru get() = 0 override fun configCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { return orderedDecorators.foldRight(baseCustomizations) { decorator: RustCodegenDecorator, customizations -> - decorator.configCustomizations(protocolConfig, customizations) + decorator.configCustomizations(codegenContext, customizations) } } override fun operationCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, operation: OperationShape, baseCustomizations: List ): List { return orderedDecorators.foldRight(baseCustomizations) { decorator: RustCodegenDecorator, customizations -> - decorator.operationCustomizations(protocolConfig, operation, customizations) + decorator.operationCustomizations(codegenContext, operation, customizations) } } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { return orderedDecorators.foldRight(baseCustomizations) { decorator, customizations -> decorator.libRsCustomizations( - protocolConfig, + codegenContext, customizations ) } @@ -118,8 +118,8 @@ open class CombinedCodegenDecorator(decorators: List) : Ru } } - override fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) { - return orderedDecorators.forEach { it.extras(protocolConfig, rustCrate) } + override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { + return orderedDecorators.forEach { it.extras(codegenContext, rustCrate) } } override fun transformModel(service: ServiceShape, baseModel: Model): Model { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt index f229a3c82f6bb2e8b801930895b6aa8b18c0cf5b..1e48195b36131f4fc07f4649f51623f337177f9c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt @@ -26,6 +26,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.stripOuter import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator @@ -42,29 +43,29 @@ class FluentClientDecorator : RustCodegenDecorator { override val name: String = "FluentClient" override val order: Byte = 0 - private fun applies(protocolConfig: ProtocolConfig): Boolean = - protocolConfig.symbolProvider.config().codegenConfig.includeFluentClient + private fun applies(codegenContext: CodegenContext): Boolean = + codegenContext.symbolProvider.config().codegenConfig.includeFluentClient - override fun extras(protocolConfig: ProtocolConfig, rustCrate: RustCrate) { - if (!applies(protocolConfig)) { + override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { + if (!applies(codegenContext)) { return } val module = RustMetadata(additionalAttributes = listOf(Attribute.Cfg.feature("client")), public = true) rustCrate.withModule(RustModule("client", module)) { writer -> - FluentClientGenerator(protocolConfig, includeSmithyGenericClientDocs = true).render(writer) + FluentClientGenerator(codegenContext, includeSmithyGenericClientDocs = true).render(writer) } - val smithyClient = CargoDependency.SmithyClient(protocolConfig.runtimeConfig) + val smithyClient = CargoDependency.SmithyClient(codegenContext.runtimeConfig) rustCrate.mergeFeature(Feature("client", true, listOf(smithyClient.name))) rustCrate.mergeFeature(Feature("rustls", default = true, listOf("smithy-client/rustls"))) rustCrate.mergeFeature(Feature("native-tls", default = false, listOf("smithy-client/native-tls"))) } override fun libRsCustomizations( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, baseCustomizations: List ): List { - if (!applies(protocolConfig)) { + if (!applies(codegenContext)) { return baseCustomizations } @@ -120,21 +121,21 @@ data class ClientGenerics( } class FluentClientGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, // Whether to include Client construction details that are relevant to generic Smithy generated clients, // but not necessarily relevant to customized clients, such as the ones with the AWS SDK. private val includeSmithyGenericClientDocs: Boolean, private val generics: ClientGenerics = ClientGenerics(), private val customizations: List = emptyList(), ) { - private val serviceShape = protocolConfig.serviceShape + private val serviceShape = codegenContext.serviceShape private val operations = - TopDownIndex.of(protocolConfig.model).getContainedOperations(serviceShape).sortedBy { it.id } - private val symbolProvider = protocolConfig.symbolProvider - private val model = protocolConfig.model - private val clientDep = CargoDependency.SmithyClient(protocolConfig.runtimeConfig).copy(optional = true) - private val runtimeConfig = protocolConfig.runtimeConfig - private val moduleName = protocolConfig.moduleName + TopDownIndex.of(codegenContext.model).getContainedOperations(serviceShape).sortedBy { it.id } + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model + private val clientDep = CargoDependency.SmithyClient(codegenContext.runtimeConfig).copy(optional = true) + private val runtimeConfig = codegenContext.runtimeConfig + private val moduleName = codegenContext.moduleName private val moduleUseName = moduleName.replace("-", "_") private val humanName = serviceShape.id.name private val core = FluentClientCore(model) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt deleted file mode 100644 index 0a3adbce01406588516fe6179512607900eabe9b..0000000000000000000000000000000000000000 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -package software.amazon.smithy.rust.codegen.smithy.generators - -import software.amazon.smithy.aws.traits.ServiceTrait -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.rustlang.Attribute -import software.amazon.smithy.rust.codegen.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.rustlang.asType -import software.amazon.smithy.rust.codegen.rustlang.docs -import software.amazon.smithy.rust.codegen.rustlang.documentShape -import software.amazon.smithy.rust.codegen.rustlang.rust -import software.amazon.smithy.rust.codegen.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate -import software.amazon.smithy.rust.codegen.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.rustlang.withBlock -import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization -import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.smithy.customize.writeCustomizations -import software.amazon.smithy.rust.codegen.smithy.letIf -import software.amazon.smithy.rust.codegen.util.dq -import software.amazon.smithy.rust.codegen.util.getTrait -import software.amazon.smithy.rust.codegen.util.inputShape - -/** - * Configuration needed to generate the client for a given Service<->Protocol pair - */ -data class ProtocolConfig( - val model: Model, - val symbolProvider: RustSymbolProvider, - val runtimeConfig: RuntimeConfig, - val serviceShape: ServiceShape, - val protocol: ShapeId, - val moduleName: String -) - -interface ProtocolGeneratorFactory { - fun buildProtocolGenerator(protocolConfig: ProtocolConfig): T - fun transformModel(model: Model): Model - fun symbolProvider(model: Model, base: RustSymbolProvider): RustSymbolProvider = base - fun support(): ProtocolSupport -} - -/** - * Abstract class providing scaffolding for HTTP based protocols that must build an HTTP request (headers / URL) and - * a body. - */ -abstract class HttpProtocolGenerator(protocolConfig: ProtocolConfig) { - private val runtimeConfig = protocolConfig.runtimeConfig - private val symbolProvider = protocolConfig.symbolProvider - private val model = protocolConfig.model - - private val sdkId = - protocolConfig.serviceShape.getTrait()?.sdkId?.toLowerCase()?.replace(" ", "") - ?: protocolConfig.serviceShape.id.getName(protocolConfig.serviceShape) - - private val codegenScope = arrayOf( - "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, - "OpBuildError" to protocolConfig.runtimeConfig.operationBuildError(), - "Request" to RuntimeType.Http("request::Request"), - "RequestBuilder" to RuntimeType.HttpRequestBuilder, - "SdkBody" to RuntimeType.sdkBody(protocolConfig.runtimeConfig), - "config" to RuntimeType.Config, - "header_util" to CargoDependency.SmithyHttp(protocolConfig.runtimeConfig).asType().member("header"), - "http" to RuntimeType.http, - "operation" to RuntimeType.operationModule(runtimeConfig), - ) - - data class BodyMetadata(val takesOwnership: Boolean) - - abstract fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata - - abstract fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) - - /** Write code into the impl block for [operationShape] */ - open fun operationImplBlock(implBlockWriter: RustWriter, operationShape: OperationShape) {} - - /** - * Add necessary methods to the impl block for the input shape. - * - * Your implementation MUST call [generateRequestBuilderBase] to create the public method. - */ - abstract fun toHttpRequestImpl( - implBlockWriter: RustWriter, - operationShape: OperationShape, - inputShape: StructureShape - ) - - fun renderOperation( - operationWriter: RustWriter, - inputWriter: RustWriter, - operationShape: OperationShape, - customizations: List - ) { - val inputShape = operationShape.inputShape(model) - val builderGenerator = BuilderGenerator(model, symbolProvider, operationShape.inputShape(model)) - builderGenerator.render(inputWriter) - - // TODO: One day, it should be possible for callers to invoke - // buildOperationType* directly to get the type rather than depending - // on these aliases. - val operationTypeOutput = buildOperationTypeOutput(inputWriter, operationShape) - val operationTypeRetry = buildOperationTypeRetry(inputWriter, customizations) - val inputPrefix = symbolProvider.toSymbol(inputShape).name - inputWriter.rust( - """ - ##[doc(hidden)] pub type ${inputPrefix}OperationOutputAlias = $operationTypeOutput; - ##[doc(hidden)] pub type ${inputPrefix}OperationRetryAlias = $operationTypeRetry; - """ - ) - - // impl OperationInputShape { ... } - val operationName = symbolProvider.toSymbol(operationShape).name - inputWriter.implBlock(inputShape, symbolProvider) { - writeCustomizations(customizations, OperationSection.InputImpl(operationShape, inputShape)) - generateMakeOperation(this, operationShape, operationName, customizations) - toHttpRequestImpl(this, operationShape, inputShape) - rustBlockTemplate( - "fn assemble(mut builder: #{RequestBuilder}, body: #{SdkBody}) -> #{Request}<#{SdkBody}>", - *codegenScope - ) { - rustTemplate( - """ - if let Some(content_length) = body.content_length() { - builder = #{header_util}::set_header_if_absent( - builder, - #{http}::header::CONTENT_LENGTH, - content_length - ); - } - builder.body(body).expect("should be valid request") - """, - *codegenScope - ) - } - - // pub fn builder() -> ... { } - builderGenerator.renderConvenienceMethod(this) - } - - // pub struct Operation { ... } - operationWriter.documentShape(operationShape, model) - Attribute.Derives(setOf(RuntimeType.Clone, RuntimeType.Default, RuntimeType.Debug)).render(operationWriter) - operationWriter.rustBlock("pub struct $operationName") { - write("_private: ()") - } - operationWriter.implBlock(operationShape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) - - operationImplBlock(this, operationShape) - - rustBlock("pub fn new() -> Self") { - rust("Self { _private: () }") - } - - writeCustomizations(customizations, OperationSection.OperationImplBlock) - } - traitImplementations(operationWriter, operationShape) - } - - protected fun generateRequestBuilderBase(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) { - Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter) - implBlockWriter.rustBlockTemplate( - "fn request_builder_base(&self) -> std::result::Result<#{HttpRequestBuilder}, #{OpBuildError}>", - *codegenScope, - ) { - f(this) - } - } - - private fun buildOperationType( - writer: RustWriter, - shape: OperationShape, - customizations: List, - ): String { - val operationT = RuntimeType.operation(runtimeConfig) - val output = buildOperationTypeOutput(writer, shape) - val retry = buildOperationTypeRetry(writer, customizations) - - return with(writer) { "${format(operationT)}<$output, $retry>" } - } - - private fun buildOperationTypeOutput(writer: RustWriter, shape: OperationShape): String = - writer.format(symbolProvider.toSymbol(shape)) - - private fun buildOperationTypeRetry(writer: RustWriter, customizations: List): String = - customizations.mapNotNull { it.retryType() }.firstOrNull()?.let { writer.format(it) } ?: "()" - - private fun generateMakeOperation( - implBlockWriter: RustWriter, - shape: OperationShape, - operationName: String, - customizations: List - ) { - val baseReturnType = buildOperationType(implBlockWriter, shape, customizations) - val returnType = "std::result::Result<$baseReturnType, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" - val outputSymbol = symbolProvider.toSymbol(shape) - - val bodyMetadata = RustWriter.root().body("self", shape) - val mut = customizations.any { it.mutSelf() } - val consumes = customizations.any { it.consumesSelf() } || bodyMetadata.takesOwnership - val self = "self".letIf(mut) { "mut $it" }.letIf(!consumes) { "&$it" } - - implBlockWriter.docs("Consumes the builder and constructs an Operation<#D>", outputSymbol) - implBlockWriter.rust("##[allow(clippy::let_and_return)]") // For codegen simplicity, allow `let x = ...; x` - implBlockWriter.rustBlockTemplate( - "pub fn make_operation($self, _config: &#{config}::Config) -> $returnType", - *codegenScope - ) { - writeCustomizations(customizations, OperationSection.MutateInput("self", "_config")) - rust("let properties = smithy_http::property_bag::SharedPropertyBag::new();") - rust("let request = self.request_builder_base()?;") - withBlock("let body =", ";") { - body("self", shape) - } - rust("let request = Self::assemble(request, body);") - rustTemplate( - """ - ##[allow(unused_mut)] - let mut request = #{operation}::Request::from_parts(request.map(#{SdkBody}::from), properties); - """, - *codegenScope - ) - writeCustomizations(customizations, OperationSection.MutateRequest("request", "_config")) - rustTemplate( - """ - let op = #{operation}::Operation::new(request, #{OperationType}::new()) - .with_metadata(#{operation}::Metadata::new(${operationName.dq()}, ${sdkId.dq()})); - """, - *codegenScope, - "OperationType" to symbolProvider.toSymbol(shape) - ) - writeCustomizations(customizations, OperationSection.FinalizeOperation("op", "_config")) - rust("Ok(op)") - } - } -} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt index 65f958a33c6dca6904f4465e68fb7cf3770436d8..f7165bb3ebc08baae30c49f756812dd226806385 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt @@ -7,18 +7,22 @@ package software.amazon.smithy.rust.codegen.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfigGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.util.inputShape class ServiceGenerator( private val rustCrate: RustCrate, - private val protocolGenerator: HttpProtocolGenerator, + private val protocolGenerator: ProtocolGenerator, private val protocolSupport: ProtocolSupport, - private val config: ProtocolConfig, + private val config: CodegenContext, private val decorator: RustCodegenDecorator, ) { private val index = TopDownIndex.of(config.model) @@ -34,7 +38,7 @@ class ServiceGenerator( operation, decorator.operationCustomizations(config, operation, listOf()) ) - HttpProtocolTestGenerator(config, protocolSupport, operation, operationWriter).render() + ProtocolTestGenerator(config, protocolSupport, operation, operationWriter).render() } } rustCrate.withModule(RustModule.Error) { writer -> diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/ServiceConfigGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/ServiceConfigGenerator.kt index c59b034d878844b2b63fe5b7e00cdb0ad4050209..2cc0fdfd827dfc9613a1eb79fe5f445b0d0c6df7 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/ServiceConfigGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/config/ServiceConfigGenerator.kt @@ -14,9 +14,9 @@ import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.raw import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator import software.amazon.smithy.rust.codegen.smithy.customize.Section -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.util.hasTrait /** @@ -96,9 +96,9 @@ typealias ConfigCustomization = NamedSectionGenerator class ServiceConfigGenerator(private val customizations: List = listOf()) { companion object { - fun withBaseBehavior(protocolConfig: ProtocolConfig, extraCustomizations: List): ServiceConfigGenerator { + fun withBaseBehavior(codegenContext: CodegenContext, extraCustomizations: List): ServiceConfigGenerator { val baseFeatures = mutableListOf() - if (protocolConfig.serviceShape.needsIdempotencyToken(protocolConfig.model)) { + if (codegenContext.serviceShape.needsIdempotencyToken(codegenContext.model)) { baseFeatures.add(IdempotencyTokenProviderCustomization()) } return ServiceConfigGenerator(baseFeatures + extraCustomizations) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt index 5164fc31c1b9f20476cad6a054e8b6bb5a77f2ac..851e39bfee970013e6344718d4e24731e6e95d83 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/TopLevelErrorGenerator.kt @@ -16,9 +16,9 @@ import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustCrate -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig /** * Each service defines it's own "top-level" error combining all possible errors that a service can emit. @@ -35,15 +35,15 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig * } * ``` */ -class TopLevelErrorGenerator(protocolConfig: ProtocolConfig, private val operations: List) { - private val symbolProvider = protocolConfig.symbolProvider - private val model = protocolConfig.model +class TopLevelErrorGenerator(codegenContext: CodegenContext, private val operations: List) { + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model - private val allErrors = operations.flatMap { it.errors }.distinctBy { it.getName(protocolConfig.serviceShape) } - .map { protocolConfig.model.expectShape(it, StructureShape::class.java) } - .sortedBy { it.id.getName(protocolConfig.serviceShape) } + private val allErrors = operations.flatMap { it.errors }.distinctBy { it.getName(codegenContext.serviceShape) } + .map { codegenContext.model.expectShape(it, StructureShape::class.java) } + .sortedBy { it.id.getName(codegenContext.serviceShape) } - private val sdkError = CargoDependency.SmithyHttp(protocolConfig.runtimeConfig).asType().member("result::SdkError") + private val sdkError = CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType().member("result::SdkError") fun render(crate: RustCrate) { crate.withModule(RustModule.default("error_meta", false)) { writer -> writer.renderDefinition() diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt index da9d690baec69095e0dd0b03a467a18c49b67141..07393a81f2fe6d22fa71d800d1cd73abc5cbd449 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt @@ -26,12 +26,13 @@ import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.autoDeref import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.OperationBuildError -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.generators.redactIfNecessary import software.amazon.smithy.rust.codegen.smithy.isOptional @@ -74,19 +75,24 @@ class RequestBindingGenerator( private val httpTrait: HttpTrait, ) { private val index = HttpBindingIndex.of(model) - private val buildError = runtimeConfig.operationBuildError() private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") + private val codegenScope = arrayOf( + "BuildError" to runtimeConfig.operationBuildError(), + "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, + "Input" to symbolProvider.toSymbol(inputShape), + ) + constructor( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, defaultTimestampFormat: TimestampFormatTrait.Format, httpBindingResolver: HttpBindingResolver, shape: OperationShape, inputShape: StructureShape, ) : this( - protocolConfig.model, - protocolConfig.symbolProvider, - protocolConfig.runtimeConfig, + codegenContext.model, + codegenContext.symbolProvider, + codegenContext.runtimeConfig, defaultTimestampFormat, shape, inputShape, @@ -102,18 +108,22 @@ class RequestBindingGenerator( val hasHeaders = addHeaders(implBlockWriter) val hasQuery = uriQuery(implBlockWriter) Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter) - implBlockWriter.rustBlock( - "fn update_http_builder(&self, builder: #1T) -> std::result::Result<#1T, #2T>", - RuntimeType.HttpRequestBuilder, - buildError + implBlockWriter.rustBlockTemplate( + """ + fn update_http_builder( + input: &#{Input}, + builder: #{HttpRequestBuilder} + ) -> std::result::Result<#{HttpRequestBuilder}, #{BuildError}> + """, + *codegenScope ) { write("let mut uri = String::new();") - write("self.uri_base(&mut uri)?;") + write("uri_base(input, &mut uri)?;") if (hasQuery) { - write("self.uri_query(&mut uri);") + write("uri_query(input, &mut uri);") } if (hasHeaders) { - write("let builder = self.add_headers(builder)?;") + write("let builder = add_headers(input, builder)?;") } write("Ok(builder.method(${httpTrait.method.dq()}).uri(uri))") } @@ -131,15 +141,18 @@ class RequestBindingGenerator( shape, HttpBinding.Location.PREFIX_HEADERS ) - val buildErrorT = runtimeConfig.operationBuildError() if (headers.isEmpty() && prefixHeaders.isEmpty()) { return false } - writer.rustBlock( - "fn add_headers(&self, mut builder: #1T) -> std::result::Result<#1T, #2T>", - RuntimeType.HttpRequestBuilder, - buildErrorT + writer.rustBlockTemplate( + """ + fn add_headers( + _input: &#{Input}, + mut builder: #{HttpRequestBuilder} + ) -> std::result::Result<#{HttpRequestBuilder}, #{BuildError}> + """, + *codegenScope, ) { headers.forEach { httpBinding -> renderHeaders(httpBinding) } prefixHeaders.forEach { httpBinding -> @@ -160,7 +173,7 @@ class RequestBindingGenerator( is MapShape -> model.expectShape(memberType.value.target) else -> TODO("unexpected member for prefix headers: $memberType") } - ifSet(memberType, memberSymbol, "&self.$memberName") { field -> + ifSet(memberType, memberSymbol, "&_input.$memberName") { field -> rustTemplate( """ for (k, v) in $field { @@ -195,7 +208,7 @@ class RequestBindingGenerator( val memberType = model.expectShape(memberShape.target) val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) - ifSet(memberType, memberSymbol, "&self.$memberName") { field -> + ifSet(memberType, memberSymbol, "&_input.$memberName") { field -> listForEach(memberType, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) if (innerMemberType.isPrimitive()) { @@ -271,9 +284,9 @@ class RequestBindingGenerator( } val combinedArgs = listOf(formatString, *args.toTypedArray()) writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null) - writer.rustBlock( - "fn uri_base(&self, output: &mut String) -> Result<(), #T>", - runtimeConfig.operationBuildError() + writer.rustBlockTemplate( + "fn uri_base(_input: &#{Input}, output: &mut String) -> Result<(), #{BuildError}>", + *codegenScope ) { httpTrait.uri.labels.map { label -> val member = inputShape.expectMember(label.content) @@ -289,12 +302,12 @@ class RequestBindingGenerator( * * This function uses smithy_http::query::Query to append params to a query string: * ```rust - * fn uri_query(&self, mut output: &mut String) { + * fn uri_query(input: &Input, mut output: &mut String) { * let mut query = smithy_http::query::Query::new(&mut output); - * if let Some(inner_89) = &self.null_value { + * if let Some(inner_89) = &input.null_value { * query.push_kv("Null", &smithy_http::query::fmt_string(&inner_89)); * } - * if let Some(inner_90) = &self.empty_string { + * if let Some(inner_90) = &input.empty_string { * query.push_kv("Empty", &smithy_http::query::fmt_string(&inner_90)); * } * } @@ -309,7 +322,7 @@ class RequestBindingGenerator( return false } val preloadedParams = literalParams.keys + dynamicParams.map { it.locationName } - writer.rustBlock("fn uri_query(&self, mut output: &mut String)") { + writer.rustBlockTemplate("fn uri_query(_input: &#{Input}, mut output: &mut String)", *codegenScope) { write("let mut query = #T::new(&mut output);", RuntimeType.QueryFormat(runtimeConfig, "Writer")) literalParams.forEach { (k, v) -> // When `v` is an empty string, no value should be set. @@ -330,7 +343,7 @@ class RequestBindingGenerator( val memberName = symbolProvider.toMemberName(memberShape) val targetShape = model.expectShape(memberShape.target, MapShape::class.java) val stringFormatter = RuntimeType.QueryFormat(runtimeConfig, "fmt_string") - ifSet(model.expectShape(param.member.target), memberSymbol, "&self.$memberName") { field -> + ifSet(model.expectShape(param.member.target), memberSymbol, "&_input.$memberName") { field -> rustBlock("for (k, v) in $field") { // if v is a list, generate another level of iteration listForEach(model.expectShape(targetShape.value.target), "v") { innerField, _ -> @@ -347,7 +360,7 @@ class RequestBindingGenerator( val memberSymbol = symbolProvider.toSymbol(memberShape) val memberName = symbolProvider.toMemberName(memberShape) val outerTarget = model.expectShape(memberShape.target) - ifSet(outerTarget, memberSymbol, "&self.$memberName") { field -> + ifSet(outerTarget, memberSymbol, "&_input.$memberName") { field -> // if `param` is a list, generate another level of iteration listForEach(outerTarget, field) { innerField, targetId -> val target = model.expectShape(targetId) @@ -399,7 +412,7 @@ class RequestBindingGenerator( ) } val input = safeName("input") - rust("let $input = &self.${symbolProvider.toMemberName(member)};") + rust("let $input = &_input.${symbolProvider.toMemberName(member)};") if (symbol.isOptional()) { rust("let $input = $input.as_ref().ok_or(${buildError()})?;") } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt index 6bc6be90e9a26c10c4bdcfcf311cf385009c22e8..6478a1b3f2b89d29d24467ba4674175c74aed683 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt @@ -31,8 +31,8 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.stripOuter import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.makeOptional import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation @@ -47,13 +47,13 @@ import software.amazon.smithy.rust.codegen.util.toSnakeCase class ResponseBindingGenerator( private val protocol: Protocol, - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val operationShape: OperationShape ) { - private val runtimeConfig = protocolConfig.runtimeConfig - private val symbolProvider = protocolConfig.symbolProvider - private val model = protocolConfig.model - private val service = protocolConfig.serviceShape + private val runtimeConfig = codegenContext.runtimeConfig + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model + private val service = codegenContext.serviceShape private val index = HttpBindingIndex.of(model) private val headerUtil = CargoDependency.SmithyHttp(runtimeConfig).asType().member("header") private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/MakeOperationGenerator.kt new file mode 100644 index 0000000000000000000000000000000000000000..75cabcb077d2e63518add694ebce7282dffae53b --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/MakeOperationGenerator.kt @@ -0,0 +1,167 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.generators.protocol + +import software.amazon.smithy.aws.traits.ServiceTrait +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.rustlang.Attribute +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.docs +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.smithy.generators.http.RequestBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.smithy.letIf +import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.getTrait +import software.amazon.smithy.rust.codegen.util.inputShape + +/** Generates the `make_operation` function on input structs */ +open class MakeOperationGenerator( + protected val codegenContext: CodegenContext, + private val protocol: Protocol, + private val bodyGenerator: ProtocolBodyGenerator, + private val functionName: String = "make_operation", + private val public: Boolean = true +) { + protected val runtimeConfig = codegenContext.runtimeConfig + protected val symbolProvider = codegenContext.symbolProvider + protected val httpBindingResolver = protocol.httpBindingResolver + + private val sdkId = + codegenContext.serviceShape.getTrait()?.sdkId?.toLowerCase()?.replace(" ", "") + ?: codegenContext.serviceShape.id.getName(codegenContext.serviceShape) + + private val codegenScope = arrayOf( + "config" to RuntimeType.Config, + "header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header"), + "http" to RuntimeType.http, + "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, + "OpBuildError" to codegenContext.runtimeConfig.operationBuildError(), + "operation" to RuntimeType.operationModule(runtimeConfig), + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + ) + + fun generateMakeOperation( + implBlockWriter: RustWriter, + shape: OperationShape, + customizations: List, + ) { + val operationName = symbolProvider.toSymbol(shape).name + val baseReturnType = buildOperationType(implBlockWriter, shape, customizations) + val returnType = "std::result::Result<$baseReturnType, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" + val outputSymbol = symbolProvider.toSymbol(shape) + + val takesOwnership = bodyGenerator.bodyMetadata(shape).takesOwnership + val mut = customizations.any { it.mutSelf() } + val consumes = customizations.any { it.consumesSelf() } || takesOwnership + val self = "self".letIf(mut) { "mut $it" }.letIf(!consumes) { "&$it" } + val fnType = if (public) "pub fn" else "fn" + + implBlockWriter.docs("Consumes the builder and constructs an Operation<#D>", outputSymbol) + implBlockWriter.rust("##[allow(clippy::let_and_return)]") // For codegen simplicity, allow `let x = ...; x` + implBlockWriter.rustBlockTemplate( + "$fnType $functionName($self, _config: &#{config}::Config) -> $returnType", + *codegenScope + ) { + generateRequestBuilderBaseFn(this, shape) + writeCustomizations(customizations, OperationSection.MutateInput(customizations, "self", "_config")) + rust("let properties = smithy_http::property_bag::SharedPropertyBag::new();") + rust("let request = request_builder_base(&self)?;") + withBlock("let body =", ";") { + bodyGenerator.generateBody(this, "self", shape) + } + rust("let request = Self::assemble(request, body);") + rustTemplate( + """ + ##[allow(unused_mut)] + let mut request = #{operation}::Request::from_parts(request.map(#{SdkBody}::from), properties); + """, + *codegenScope + ) + writeCustomizations(customizations, OperationSection.MutateRequest(customizations, "request", "_config")) + rustTemplate( + """ + let op = #{operation}::Operation::new(request, #{OperationType}::new()) + .with_metadata(#{operation}::Metadata::new(${operationName.dq()}, ${sdkId.dq()})); + """, + *codegenScope, + "OperationType" to symbolProvider.toSymbol(shape) + ) + writeCustomizations(customizations, OperationSection.FinalizeOperation(customizations, "op", "_config")) + rust("Ok(op)") + } + } + + private fun buildOperationType( + writer: RustWriter, + shape: OperationShape, + customizations: List, + ): String { + val operationT = RuntimeType.operation(runtimeConfig) + val output = buildOperationTypeOutput(writer, shape) + val retry = buildOperationTypeRetry(writer, customizations) + return with(writer) { "${format(operationT)}<$output, $retry>" } + } + + private fun buildOperationTypeOutput(writer: RustWriter, shape: OperationShape): String = + writer.format(symbolProvider.toSymbol(shape)) + + private fun buildOperationTypeRetry(writer: RustWriter, customizations: List): String = + customizations.mapNotNull { it.retryType() }.firstOrNull()?.let { writer.format(it) } ?: "()" + + protected fun RustWriter.inRequestBuilderBaseFn(inputShape: StructureShape, f: RustWriter.() -> Unit) { + Attribute.Custom("allow(clippy::unnecessary_wraps)").render(this) + rustBlockTemplate( + "fn request_builder_base(input: &#{Input}) -> std::result::Result<#{HttpRequestBuilder}, #{OpBuildError}>", + *codegenScope, + "Input" to symbolProvider.toSymbol(inputShape) + ) { + f(this) + } + } + + open fun generateRequestBuilderBaseFn(writer: RustWriter, operationShape: OperationShape) { + val inputShape = operationShape.inputShape(codegenContext.model) + val httpBindingGenerator = RequestBindingGenerator( + codegenContext, + protocol.defaultTimestampFormat, + httpBindingResolver, + operationShape, + inputShape, + ) + val contentType = httpBindingResolver.requestContentType(operationShape) + httpBindingGenerator.renderUpdateHttpBuilder(writer) + writer.inRequestBuilderBaseFn(inputShape) { + writer.rust("let mut builder = update_http_builder(input, #T::new())?;", RuntimeType.HttpRequestBuilder) + val additionalHeaders = listOf("content-type" to contentType) + protocol.additionalHeaders(operationShape) + for (header in additionalHeaders) { + writer.rustTemplate( + """ + builder = #{header_util}::set_header_if_absent( + builder, + #{http}::header::HeaderName::from_static(${header.first.dq()}), + ${header.second.dq()} + ); + """, + *codegenScope + ) + } + rust("Ok(builder)") + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolGenerator.kt new file mode 100644 index 0000000000000000000000000000000000000000..9be95ffbffe7f3124be58aedba8bec7d0f9155ad --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolGenerator.kt @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.generators.protocol + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.rustlang.Attribute +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.documentShape +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.util.inputShape + +interface ProtocolBodyGenerator { + data class BodyMetadata(val takesOwnership: Boolean) + + fun bodyMetadata(operationShape: OperationShape): BodyMetadata + + fun generateBody(writer: RustWriter, self: String, operationShape: OperationShape) +} + +interface ProtocolTraitImplGenerator { + fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) +} + +/** + * Class providing scaffolding for HTTP based protocols that must build an HTTP request (headers / URL) and a body. + */ +open class ProtocolGenerator( + codegenContext: CodegenContext, + private val protocol: Protocol, + private val makeOperationGenerator: MakeOperationGenerator, + private val traitGenerator: ProtocolTraitImplGenerator, +) { + private val runtimeConfig = codegenContext.runtimeConfig + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model + + private val codegenScope = arrayOf( + "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, + "OpBuildError" to codegenContext.runtimeConfig.operationBuildError(), + "Request" to RuntimeType.Http("request::Request"), + "RequestBuilder" to RuntimeType.HttpRequestBuilder, + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + "config" to RuntimeType.Config, + "header_util" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType().member("header"), + "http" to RuntimeType.http, + "operation" to RuntimeType.operationModule(runtimeConfig), + ) + + fun renderOperation( + operationWriter: RustWriter, + inputWriter: RustWriter, + operationShape: OperationShape, + customizations: List + ) { + val inputShape = operationShape.inputShape(model) + val builderGenerator = BuilderGenerator(model, symbolProvider, operationShape.inputShape(model)) + builderGenerator.render(inputWriter) + + // TODO: One day, it should be possible for callers to invoke + // buildOperationType* directly to get the type rather than depending + // on these aliases. + val operationTypeOutput = buildOperationTypeOutput(inputWriter, operationShape) + val operationTypeRetry = buildOperationTypeRetry(inputWriter, customizations) + val inputPrefix = symbolProvider.toSymbol(inputShape).name + inputWriter.rust( + """ + ##[doc(hidden)] pub type ${inputPrefix}OperationOutputAlias = $operationTypeOutput; + ##[doc(hidden)] pub type ${inputPrefix}OperationRetryAlias = $operationTypeRetry; + """ + ) + + // impl OperationInputShape { ... } + val operationName = symbolProvider.toSymbol(operationShape).name + inputWriter.implBlock(inputShape, symbolProvider) { + writeCustomizations( + customizations, + OperationSection.InputImpl(customizations, operationShape, inputShape, protocol) + ) + makeOperationGenerator.generateMakeOperation(this, operationShape, customizations) + rustBlockTemplate( + "fn assemble(mut builder: #{RequestBuilder}, body: #{SdkBody}) -> #{Request}<#{SdkBody}>", + *codegenScope + ) { + rustTemplate( + """ + if let Some(content_length) = body.content_length() { + builder = #{header_util}::set_header_if_absent( + builder, + #{http}::header::CONTENT_LENGTH, + content_length + ); + } + builder.body(body).expect("should be valid request") + """, + *codegenScope + ) + } + + // pub fn builder() -> ... { } + builderGenerator.renderConvenienceMethod(this) + } + + // pub struct Operation { ... } + operationWriter.documentShape(operationShape, model) + Attribute.Derives(setOf(RuntimeType.Clone, RuntimeType.Default, RuntimeType.Debug)).render(operationWriter) + operationWriter.rustBlock("pub struct $operationName") { + write("_private: ()") + } + operationWriter.implBlock(operationShape, symbolProvider) { + builderGenerator.renderConvenienceMethod(this) + + rustBlock("pub fn new() -> Self") { + rust("Self { _private: () }") + } + + writeCustomizations(customizations, OperationSection.OperationImplBlock(customizations)) + } + traitGenerator.generateTraitImpls(operationWriter, operationShape) + } + + private fun buildOperationTypeOutput(writer: RustWriter, shape: OperationShape): String = + writer.format(symbolProvider.toSymbol(shape)) + + private fun buildOperationTypeRetry(writer: RustWriter, customizations: List): String = + customizations.mapNotNull { it.retryType() }.firstOrNull()?.let { writer.format(it) } ?: "()" +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolTestGenerator.kt similarity index 90% rename from codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt rename to codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolTestGenerator.kt index eac04087f7185bc3ab8d49e749e55a6782f78c21..f62d49b0cb46da074e696efc56c04be7f09a6db8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ -package software.amazon.smithy.rust.codegen.smithy.generators +package software.amazon.smithy.rust.codegen.smithy.generators.protocol import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.knowledge.OperationIndex @@ -31,7 +31,9 @@ import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.findMemberWithTrait @@ -54,20 +56,20 @@ data class ProtocolSupport( /** * Generate protocol tests for an operation */ -class HttpProtocolTestGenerator( - private val protocolConfig: ProtocolConfig, +class ProtocolTestGenerator( + private val codegenContext: CodegenContext, private val protocolSupport: ProtocolSupport, private val operationShape: OperationShape, private val writer: RustWriter ) { private val logger = Logger.getLogger(javaClass.name) - private val inputShape = operationShape.inputShape(protocolConfig.model) - private val outputShape = operationShape.outputShape(protocolConfig.model) - private val operationSymbol = protocolConfig.symbolProvider.toSymbol(operationShape) - private val operationIndex = OperationIndex.of(protocolConfig.model) + private val inputShape = operationShape.inputShape(codegenContext.model) + private val outputShape = operationShape.outputShape(codegenContext.model) + private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape) + private val operationIndex = OperationIndex.of(codegenContext.model) - private val instantiator = with(protocolConfig) { + private val instantiator = with(codegenContext) { Instantiator(symbolProvider, model, runtimeConfig) } @@ -123,7 +125,7 @@ class HttpProtocolTestGenerator( private fun List.filterMatching(): List { return if (RunOnly.isNullOrEmpty()) { this.filter { testCase -> - testCase.testCase.protocol == protocolConfig.protocol && + testCase.testCase.protocol == codegenContext.protocol && !DisableTests.contains(testCase.testCase.id) } } else { @@ -174,7 +176,7 @@ class HttpProtocolTestGenerator( rust("/* test case disabled for this protocol (not yet supported) */") return } - val customToken = if (inputShape.findMemberWithTrait(protocolConfig.model) != null) { + val customToken = if (inputShape.findMemberWithTrait(codegenContext.model) != null) { """.make_token("00000000-0000-4000-8000-000000000000")""" } else "" rust( @@ -195,7 +197,7 @@ class HttpProtocolTestGenerator( let ep = #T::endpoint::Endpoint::mutable(#T::Uri::from_static(${withScheme.dq()})); ep.set_endpoint(http_request.uri_mut(), parts.acquire().get()); """, - CargoDependency.SmithyHttp(protocolConfig.runtimeConfig).asType(), CargoDependency.Http.asType() + CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(), CargoDependency.Http.asType() ) } rust( @@ -239,7 +241,7 @@ class HttpProtocolTestGenerator( } private fun expectFail(testCase: HttpMessageTestCase): Boolean = ExpectFail.find { - it.id == testCase.id && it.action == testCase.action() && it.service == protocolConfig.serviceShape.id.toString() + it.id == testCase.id && it.action == testCase.action() && it.service == codegenContext.serviceShape.id.toString() } != null private fun RustWriter.renderHttpResponseTestCase( @@ -268,11 +270,11 @@ class HttpProtocolTestGenerator( .body(#T::from(${testCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"})) .unwrap(); """, - RuntimeType.sdkBody(runtimeConfig = protocolConfig.runtimeConfig) + RuntimeType.sdkBody(runtimeConfig = codegenContext.runtimeConfig) ) write( "let mut op_response = #T::new(http_response);", - RuntimeType.operationModule(protocolConfig.runtimeConfig).member("Response") + RuntimeType.operationModule(codegenContext.runtimeConfig).member("Response") ) rustTemplate( """ @@ -287,12 +289,12 @@ class HttpProtocolTestGenerator( """, "op" to operationSymbol, "bytes" to RuntimeType.Bytes, - "parse_http_response" to CargoDependency.SmithyHttp(protocolConfig.runtimeConfig).asType() + "parse_http_response" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType() .member("response::ParseHttpResponse"), ) if (expectedShape.hasTrait()) { - val errorSymbol = operationShape.errorSymbol(protocolConfig.symbolProvider) - val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name + val errorSymbol = operationShape.errorSymbol(codegenContext.symbolProvider) + val errorVariant = codegenContext.symbolProvider.toSymbol(expectedShape).name rust("""let parsed = parsed.expect_err("should be error response");""") rustBlock("if let #TKind::$errorVariant(actual_error) = parsed.kind", errorSymbol) { rust("assert_eq!(expected_output, actual_error);") @@ -303,8 +305,8 @@ class HttpProtocolTestGenerator( } else { rust("let parsed = parsed.unwrap();") outputShape.members().forEach { member -> - val memberName = protocolConfig.symbolProvider.toMemberName(member) - if (member.isStreaming(protocolConfig.model)) { + val memberName = codegenContext.symbolProvider.toMemberName(member) + if (member.isStreaming(codegenContext.model)) { rust( """assert_eq!( parsed.$memberName.collect().await.unwrap().into_bytes(), @@ -312,10 +314,10 @@ class HttpProtocolTestGenerator( );""" ) } else { - when (protocolConfig.model.expectShape(member.target)) { + when (codegenContext.model.expectShape(member.target)) { is DoubleShape, is FloatShape -> { addUseImports( - RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "FloatEquals").toSymbol() + RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "FloatEquals").toSymbol() ) rust( """ @@ -352,8 +354,8 @@ class HttpProtocolTestGenerator( "#T(&body, ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", - RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_body"), - RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "MediaType") + RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "validate_body"), + RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "MediaType") ) } } @@ -374,7 +376,7 @@ class HttpProtocolTestGenerator( assertOk(rustWriter) { write( "#T(&http_request, $variableName)", - RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_headers") + RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "validate_headers") ) } } @@ -409,7 +411,7 @@ class HttpProtocolTestGenerator( assertOk(rustWriter) { write( "#T(&http_request, $variableName)", - RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, checkFunction) + RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, checkFunction) ) } } @@ -419,7 +421,7 @@ class HttpProtocolTestGenerator( * for pretty prettying protocol test helper results */ private fun assertOk(rustWriter: RustWriter, inner: RustWriter.() -> Unit) { - rustWriter.write("#T(", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "assert_ok")) + rustWriter.write("#T(", RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "assert_ok")) inner(rustWriter) rustWriter.write(");") } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt index a3105bf38d0407cee67f83b29b154499b5308606..3dd32fe1ceda3460fdaa0515b51ceede14ad9c72 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt @@ -16,10 +16,9 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator @@ -40,8 +39,10 @@ sealed class AwsJsonVersion { } class AwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { - override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator { - return HttpBoundProtocolGenerator(protocolConfig, AwsJson(protocolConfig, version)) + override fun protocol(codegenContext: CodegenContext): Protocol = AwsJson(codegenContext, version) + + override fun buildProtocolGenerator(codegenContext: CodegenContext): HttpBoundProtocolGenerator { + return HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) } override fun transformModel(model: Model): Model = model @@ -92,12 +93,12 @@ class AwsJsonHttpBindingResolver( * customizes wraps [JsonSerializerGenerator] to add this functionality. */ class AwsJsonSerializerGenerator( - private val protocolConfig: ProtocolConfig, + private val codegenContext: CodegenContext, httpBindingResolver: HttpBindingResolver, private val jsonSerializerGenerator: JsonSerializerGenerator = - JsonSerializerGenerator(protocolConfig, httpBindingResolver) + JsonSerializerGenerator(codegenContext, httpBindingResolver) ) : StructuredDataSerializerGenerator by jsonSerializerGenerator { - private val runtimeConfig = protocolConfig.runtimeConfig + private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = arrayOf( "Error" to CargoDependency.SmithyTypes(runtimeConfig).asType().member("Error"), "SdkBody" to RuntimeType.sdkBody(runtimeConfig), @@ -106,12 +107,12 @@ class AwsJsonSerializerGenerator( override fun operationSerializer(operationShape: OperationShape): RuntimeType { var serializer = jsonSerializerGenerator.operationSerializer(operationShape) if (serializer == null) { - val inputShape = operationShape.inputShape(protocolConfig.model) - val fnName = protocolConfig.symbolProvider.serializeFunctionName(operationShape) + val inputShape = operationShape.inputShape(codegenContext.model) + val fnName = codegenContext.symbolProvider.serializeFunctionName(operationShape) serializer = RuntimeType.forInlineFun(fnName, RustModule.private("operation_ser")) { it.rustBlockTemplate( "pub fn $fnName(_input: &#{target}) -> Result<#{SdkBody}, #{Error}>", - *codegenScope, "target" to protocolConfig.symbolProvider.toSymbol(inputShape) + *codegenScope, "target" to codegenContext.symbolProvider.toSymbol(inputShape) ) { rustTemplate("""Ok(#{SdkBody}::from("{}"))""", *codegenScope) } @@ -122,10 +123,10 @@ class AwsJsonSerializerGenerator( } class AwsJson( - private val protocolConfig: ProtocolConfig, + private val codegenContext: CodegenContext, awsJsonVersion: AwsJsonVersion ) : Protocol { - private val runtimeConfig = protocolConfig.runtimeConfig + private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, "Error" to RuntimeType.GenericError(runtimeConfig), @@ -137,18 +138,18 @@ class AwsJson( private val jsonDeserModule = RustModule.private("json_deser") override val httpBindingResolver: HttpBindingResolver = - AwsJsonHttpBindingResolver(protocolConfig.model, awsJsonVersion) + AwsJsonHttpBindingResolver(codegenContext.model, awsJsonVersion) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS override fun additionalHeaders(operationShape: OperationShape): List> = - listOf("x-amz-target" to "${protocolConfig.serviceShape.id.name}.${operationShape.id.name}") + listOf("x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}") override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - JsonParserGenerator(protocolConfig, httpBindingResolver) + JsonParserGenerator(codegenContext, httpBindingResolver) override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - AwsJsonSerializerGenerator(protocolConfig, httpBindingResolver) + AwsJsonSerializerGenerator(codegenContext, httpBindingResolver) override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = RuntimeType.forInlineFun("parse_http_generic_error", jsonDeserModule) { writer -> diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsQuery.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsQuery.kt index 92b1e5265d9ae5ef616120974c450c090f97f6ba..4bbd217388f862ec1e7f2707a693a7d599b05d57 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsQuery.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsQuery.kt @@ -17,10 +17,9 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.AwsQueryParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.AwsQuerySerializerGenerator @@ -28,8 +27,10 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.Structured import software.amazon.smithy.rust.codegen.util.getTrait class AwsQueryFactory : ProtocolGeneratorFactory { - override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator = - HttpBoundProtocolGenerator(protocolConfig, AwsQueryProtocol(protocolConfig)) + override fun protocol(codegenContext: CodegenContext): Protocol = AwsQueryProtocol(codegenContext) + + override fun buildProtocolGenerator(codegenContext: CodegenContext): HttpBoundProtocolGenerator = + HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) override fun transformModel(model: Model): Model = model @@ -57,8 +58,8 @@ class AwsQueryBindingResolver(private val model: Model) : } } -class AwsQueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { - private val runtimeConfig = protocolConfig.runtimeConfig +class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig private val awsQueryErrors: RuntimeType = RuntimeType.wrappedXmlErrors(runtimeConfig) private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, @@ -69,15 +70,15 @@ class AwsQueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { ) private val xmlDeserModule = RustModule.private("xml_deser") - override val httpBindingResolver: HttpBindingResolver = AwsQueryBindingResolver(protocolConfig.model) + override val httpBindingResolver: HttpBindingResolver = AwsQueryBindingResolver(codegenContext.model) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - AwsQueryParserGenerator(protocolConfig, awsQueryErrors) + AwsQueryParserGenerator(codegenContext, awsQueryErrors) override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - AwsQuerySerializerGenerator(protocolConfig) + AwsQuerySerializerGenerator(codegenContext) override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = RuntimeType.forInlineFun("parse_http_generic_error", xmlDeserModule) { writer -> diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Ec2Query.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Ec2Query.kt index 1238911a7052307435af3dd67d474a943d385a18..de573752b84820ebb5f298c1d173bf444e887da8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Ec2Query.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Ec2Query.kt @@ -15,18 +15,19 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.Ec2QueryParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.Ec2QuerySerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator class Ec2QueryFactory : ProtocolGeneratorFactory { - override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator = - HttpBoundProtocolGenerator(protocolConfig, Ec2QueryProtocol(protocolConfig)) + override fun protocol(codegenContext: CodegenContext): Protocol = Ec2QueryProtocol(codegenContext) + + override fun buildProtocolGenerator(codegenContext: CodegenContext): HttpBoundProtocolGenerator = + HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) override fun transformModel(model: Model): Model = model @@ -40,8 +41,8 @@ class Ec2QueryFactory : ProtocolGeneratorFactory { } } -class Ec2QueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { - private val runtimeConfig = protocolConfig.runtimeConfig +class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig private val ec2QueryErrors: RuntimeType = RuntimeType.ec2QueryErrors(runtimeConfig) private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, @@ -53,7 +54,7 @@ class Ec2QueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { private val xmlDeserModule = RustModule.private("xml_deser") override val httpBindingResolver: HttpBindingResolver = StaticHttpBindingResolver( - protocolConfig.model, + codegenContext.model, HttpTrait.builder() .code(200) .method("POST") @@ -66,10 +67,10 @@ class Ec2QueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = - Ec2QueryParserGenerator(protocolConfig, ec2QueryErrors) + Ec2QueryParserGenerator(codegenContext, ec2QueryErrors) override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - Ec2QuerySerializerGenerator(protocolConfig) + Ec2QuerySerializerGenerator(codegenContext) override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = RuntimeType.forInlineFun("parse_http_generic_error", xmlDeserModule) { writer -> diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt index 22f15ae40bc7d2a94184d3f45aca142582301886..dd46f97723dc49b279c5b14d83cdeb3dc94a28b0 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.ErrorTrait -import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustModule @@ -29,15 +28,18 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol -import software.amazon.smithy.rust.codegen.smithy.generators.http.RequestBindingGenerator import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.MakeOperationGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolBodyGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolBodyGenerator.BodyMetadata +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.smithy.generators.setterName import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator @@ -55,44 +57,23 @@ import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase -interface Protocol { - val httpBindingResolver: HttpBindingResolver - - val defaultTimestampFormat: TimestampFormatTrait.Format - - fun additionalHeaders(operationShape: OperationShape): List> = emptyList() - - fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator - - fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator - - /** - * Generates a function signature like the following: - * ```rust - * fn parse_http_generic_error(response: &Response) -> smithy_types::error::Error - * ``` - **/ - fun parseHttpGenericError(operationShape: OperationShape): RuntimeType - - /** - * Generates a function signature like the following: - * ```rust - * fn parse_event_stream_generic_error(payload: &Bytes) -> smithy_types::error::Error - * ``` - * - * Event Stream generic errors are almost identical to HTTP generic errors, except that - * there are no response headers or statuses available to further inform the error parsing. - **/ - fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType -} - class HttpBoundProtocolGenerator( - private val protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, + protocol: Protocol, +) : ProtocolGenerator( + codegenContext, + protocol, + MakeOperationGenerator(codegenContext, protocol, HttpBoundProtocolBodyGenerator(codegenContext, protocol)), + HttpBoundProtocolTraitImplGenerator(codegenContext, protocol), +) + +private class HttpBoundProtocolTraitImplGenerator( + private val codegenContext: CodegenContext, private val protocol: Protocol, -) : HttpProtocolGenerator(protocolConfig) { - private val symbolProvider = protocolConfig.symbolProvider - private val model = protocolConfig.model - private val runtimeConfig = protocolConfig.runtimeConfig +) : ProtocolTraitImplGenerator { + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model + private val runtimeConfig = codegenContext.runtimeConfig private val httpBindingResolver = protocol.httpBindingResolver private val operationSerModule = RustModule.private("operation_ser") @@ -100,167 +81,11 @@ class HttpBoundProtocolGenerator( "ParseStrict" to RuntimeType.parseStrict(runtimeConfig), "ParseResponse" to RuntimeType.parseResponse(runtimeConfig), "http" to RuntimeType.http, - "hyper" to CargoDependency.HyperWithStream.asType(), "operation" to RuntimeType.operationModule(runtimeConfig), "Bytes" to RuntimeType.Bytes, - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "BuildError" to runtimeConfig.operationBuildError(), - "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType() ) - override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { - val serializerGenerator = protocol.structuredDataSerializer(operationShape) - val inputShape = operationShape.inputShape(model) - val payloadMemberName = - httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName - return if (payloadMemberName == null) { - serializerGenerator.operationSerializer(operationShape)?.let { serializer -> - rust( - "#T(&self).map_err(|err|#T::SerializationError(err.into()))?", - serializer, - runtimeConfig.operationBuildError() - ) - } ?: rustTemplate("#{SdkBody}::from(\"\")", *codegenScope) - BodyMetadata(takesOwnership = false) - } else { - val member = inputShape.expectMember(payloadMemberName) - if (operationShape.isInputEventStream(model)) { - serializeViaEventStream(operationShape, member, serializerGenerator) - } else { - serializeViaPayload(member, serializerGenerator) - } - } - } - - private fun RustWriter.serializeViaEventStream( - operationShape: OperationShape, - memberShape: MemberShape, - serializerGenerator: StructuredDataSerializerGenerator - ): BodyMetadata { - val memberName = symbolProvider.toMemberName(memberShape) - val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) - - val marshallerConstructorFn = EventStreamMarshallerGenerator( - model, - runtimeConfig, - symbolProvider, - unionShape, - serializerGenerator, - httpBindingResolver.requestContentType(operationShape), - ).render() - - // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the - // parameters that are not `@eventHeader` or `@eventPayload`. - rustTemplate( - """ - { - let marshaller = #{marshallerConstructorFn}(); - let signer = _config.new_event_stream_signer(properties.clone()); - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> = - self.$memberName.into_body_stream(marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); - body - } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "OperationError" to operationShape.errorSymbol(symbolProvider) - ) - return BodyMetadata(takesOwnership = true) - } - - private fun RustWriter.serializeViaPayload( - member: MemberShape, - serializerGenerator: StructuredDataSerializerGenerator - ): BodyMetadata { - val fnName = "ser_payload_${member.container.name.toSnakeCase()}" - val bodyMetadata: BodyMetadata = RustWriter.root().renderPayload(member, "payload", serializerGenerator) - val ref = when (bodyMetadata.takesOwnership) { - true -> "" - false -> "&" - } - val serializer = RuntimeType.forInlineFun(fnName, operationSerModule) { - it.rustBlockTemplate( - "pub fn $fnName(payload: $ref #{Member}) -> std::result::Result<#{SdkBody}, #{BuildError}>", - "Member" to symbolProvider.toSymbol(member), - *codegenScope - ) { - // If this targets a member & the member is None, return an empty vec - val ref = when (bodyMetadata.takesOwnership) { - false -> ".as_ref()" - true -> "" - } - - if (symbolProvider.toSymbol(member).isOptional()) { - rustTemplate( - """ - let payload = match payload$ref { - Some(t) => t, - None => return Ok(#{SdkBody}::from("")) - };""", - *codegenScope - ) - } - // When the body is a streaming blob it _literally_ is a SdkBody already - // mute this clippy warning to make the codegen a little simpler - Attribute.Custom("allow(clippy::useless_conversion)").render(this) - withBlock("Ok(#T::from(", "))", RuntimeType.sdkBody(runtimeConfig)) { - renderPayload(member, "payload", serializerGenerator) - } - } - } - rust("#T($ref self.${symbolProvider.toMemberName(member)})?", serializer) - return bodyMetadata - } - - private fun RustWriter.renderPayload( - member: MemberShape, - payloadName: String, - serializer: StructuredDataSerializerGenerator - ): BodyMetadata { - return when (val targetShape = model.expectShape(member.target)) { - // Write the raw string to the payload - is StringShape -> { - if (targetShape.hasTrait()) { - rust("$payloadName.as_str()") - } else { - rust("""$payloadName.to_string()""") - } - BodyMetadata(takesOwnership = false) - } - - // This works for streaming & non streaming blobs because they both have `into_inner()` which - // can be converted into an SDK body! - is BlobShape -> { - // Write the raw blob to the payload - rust("$payloadName.into_inner()") - BodyMetadata(takesOwnership = true) - } - is StructureShape, is UnionShape -> { - check( - !((targetShape as? UnionShape)?.isEventStream() ?: false) - ) { "Event Streams should be handled further up" } - - // JSON serialize the structure or union targeted - rust( - """#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?""", - serializer.payloadSerializer(member), runtimeConfig.operationBuildError() - ) - BodyMetadata(takesOwnership = false) - } - is DocumentShape -> { - rust( - "#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?", - serializer.documentSerializer(), - runtimeConfig.operationBuildError() - ) - BodyMetadata(takesOwnership = false) - } - else -> TODO("Unexpected payload target type") - } - } - - override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) { + override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) { val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) val operationName = symbolProvider.toSymbol(operationShape).name @@ -369,10 +194,9 @@ class HttpBoundProtocolGenerator( operationShape.errors.forEach { error -> val errorShape = model.expectShape(error, StructureShape::class.java) val variantName = symbolProvider.toSymbol(model.expectShape(error)).name + val errorCode = httpBindingResolver.errorCode(errorShape).dq() withBlock( - "${ - httpBindingResolver.errorCode(errorShape).dq() - } => #1T { meta: generic, kind: #1TKind::$variantName({", + "$errorCode => #1T { meta: generic, kind: #1TKind::$variantName({", "})},", errorSymbol ) { @@ -459,48 +283,13 @@ class HttpBoundProtocolGenerator( } } - override fun toHttpRequestImpl( - implBlockWriter: RustWriter, - operationShape: OperationShape, - inputShape: StructureShape - ) { - val httpBindingGenerator = RequestBindingGenerator( - protocolConfig, - protocol.defaultTimestampFormat, - httpBindingResolver, - operationShape, - inputShape, - ) - val contentType = httpBindingResolver.requestContentType(operationShape) - httpBindingGenerator.renderUpdateHttpBuilder(implBlockWriter) - generateRequestBuilderBase(implBlockWriter) { - rust("let mut builder = self.update_http_builder(#T::new())?;", RuntimeType.HttpRequestBuilder) - val additionalHeaders = listOf("content-type" to contentType) + protocol.additionalHeaders(operationShape) - for (header in additionalHeaders) { - rustTemplate( - """ - builder = #{header_util}::set_header_if_absent( - builder, - #{http}::header::HeaderName::from_static(${header.first.dq()}), - ${header.second.dq()} - ); - """, - "http" to RuntimeType.http, - "header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header") - - ) - } - rust("Ok(builder)") - } - } - private fun RustWriter.renderShapeParser( operationShape: OperationShape, outputShape: StructureShape, bindings: List, errorSymbol: RuntimeType, ) { - val httpBindingGenerator = ResponseBindingGenerator(protocol, protocolConfig, operationShape) + val httpBindingGenerator = ResponseBindingGenerator(protocol, codegenContext, operationShape) val structuredDataParser = protocol.structuredDataParser(operationShape) Attribute.AllowUnusedMut.render(this) rust("let mut output = #T::default();", outputShape.builderSymbol(symbolProvider)) @@ -614,3 +403,193 @@ class HttpBoundProtocolGenerator( } } } + +class HttpBoundProtocolBodyGenerator( + codegenContext: CodegenContext, + private val protocol: Protocol, +) : ProtocolBodyGenerator { + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model + private val runtimeConfig = codegenContext.runtimeConfig + private val httpBindingResolver = protocol.httpBindingResolver + + private val operationSerModule = RustModule.private("operation_ser") + + private val codegenScope = arrayOf( + "ParseStrict" to RuntimeType.parseStrict(runtimeConfig), + "ParseResponse" to RuntimeType.parseResponse(runtimeConfig), + "http" to RuntimeType.http, + "hyper" to CargoDependency.HyperWithStream.asType(), + "operation" to RuntimeType.operationModule(runtimeConfig), + "Bytes" to RuntimeType.Bytes, + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "BuildError" to runtimeConfig.operationBuildError(), + "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType() + ) + + override fun bodyMetadata(operationShape: OperationShape): BodyMetadata { + val inputShape = operationShape.inputShape(model) + val payloadMemberName = + httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + + // Only streaming operations take ownership, so that's event streams and blobs + return if (payloadMemberName == null) { + BodyMetadata(takesOwnership = false) + } else if (operationShape.isInputEventStream(model)) { + BodyMetadata(takesOwnership = true) + } else { + val member = inputShape.expectMember(payloadMemberName) + when (model.expectShape(member.target)) { + is StringShape, is DocumentShape, is StructureShape, is UnionShape -> BodyMetadata(takesOwnership = false) + is BlobShape -> BodyMetadata(takesOwnership = true) + else -> TODO("Unexpected payload target type") + } + } + } + + override fun generateBody(writer: RustWriter, self: String, operationShape: OperationShape) { + val bodyMetadata = bodyMetadata(operationShape) + val serializerGenerator = protocol.structuredDataSerializer(operationShape) + val inputShape = operationShape.inputShape(model) + val payloadMemberName = + httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + if (payloadMemberName == null) { + serializerGenerator.operationSerializer(operationShape)?.let { serializer -> + writer.rust( + "#T(&self).map_err(|err|#T::SerializationError(err.into()))?", + serializer, + runtimeConfig.operationBuildError() + ) + } ?: writer.rustTemplate("#{SdkBody}::from(\"\")", *codegenScope) + } else { + val member = inputShape.expectMember(payloadMemberName) + if (operationShape.isInputEventStream(model)) { + writer.serializeViaEventStream(operationShape, member, serializerGenerator) + } else { + writer.serializeViaPayload(bodyMetadata, member, serializerGenerator) + } + } + } + + private fun RustWriter.serializeViaEventStream( + operationShape: OperationShape, + memberShape: MemberShape, + serializerGenerator: StructuredDataSerializerGenerator + ) { + val memberName = symbolProvider.toMemberName(memberShape) + val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) + + val marshallerConstructorFn = EventStreamMarshallerGenerator( + model, + runtimeConfig, + symbolProvider, + unionShape, + serializerGenerator, + httpBindingResolver.requestContentType(operationShape), + ).render() + + // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the + // parameters that are not `@eventHeader` or `@eventPayload`. + rustTemplate( + """ + { + let marshaller = #{marshallerConstructorFn}(); + let signer = _config.new_event_stream_signer(properties.clone()); + let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> = + self.$memberName.into_body_stream(marshaller, signer); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + body + } + """, + *codegenScope, + "marshallerConstructorFn" to marshallerConstructorFn, + "OperationError" to operationShape.errorSymbol(symbolProvider) + ) + } + + private fun RustWriter.serializeViaPayload( + bodyMetadata: BodyMetadata, + member: MemberShape, + serializerGenerator: StructuredDataSerializerGenerator + ) { + val fnName = "ser_payload_${member.container.name.toSnakeCase()}" + val ref = when (bodyMetadata.takesOwnership) { + true -> "" + false -> "&" + } + val serializer = RuntimeType.forInlineFun(fnName, operationSerModule) { + it.rustBlockTemplate( + "pub fn $fnName(payload: $ref #{Member}) -> std::result::Result<#{SdkBody}, #{BuildError}>", + "Member" to symbolProvider.toSymbol(member), + *codegenScope + ) { + // If this targets a member & the member is None, return an empty vec + val ref = when (bodyMetadata.takesOwnership) { + false -> ".as_ref()" + true -> "" + } + + if (symbolProvider.toSymbol(member).isOptional()) { + rustTemplate( + """ + let payload = match payload$ref { + Some(t) => t, + None => return Ok(#{SdkBody}::from("")) + };""", + *codegenScope + ) + } + // When the body is a streaming blob it _literally_ is a SdkBody already + // mute this clippy warning to make the codegen a little simpler + Attribute.Custom("allow(clippy::useless_conversion)").render(this) + withBlock("Ok(#T::from(", "))", RuntimeType.sdkBody(runtimeConfig)) { + renderPayload(member, "payload", serializerGenerator) + } + } + } + rust("#T($ref self.${symbolProvider.toMemberName(member)})?", serializer) + } + + private fun RustWriter.renderPayload( + member: MemberShape, + payloadName: String, + serializer: StructuredDataSerializerGenerator + ) { + when (val targetShape = model.expectShape(member.target)) { + // Write the raw string to the payload + is StringShape -> { + if (targetShape.hasTrait()) { + rust("$payloadName.as_str()") + } else { + rust("""$payloadName.to_string()""") + } + } + + // This works for streaming & non streaming blobs because they both have `into_inner()` which + // can be converted into an SDK body! + is BlobShape -> { + // Write the raw blob to the payload + rust("$payloadName.into_inner()") + } + is StructureShape, is UnionShape -> { + check( + !((targetShape as? UnionShape)?.isEventStream() ?: false) + ) { "Event Streams should be handled further up" } + + // JSON serialize the structure or union targeted + rust( + """#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?""", + serializer.payloadSerializer(member), runtimeConfig.operationBuildError() + ) + } + is DocumentShape -> { + rust( + "#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?", + serializer.documentSerializer(), + runtimeConfig.operationBuildError() + ) + } + else -> TODO("Unexpected payload target type") + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Protocol.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Protocol.kt new file mode 100644 index 0000000000000000000000000000000000000000..59f18eafcd6762f4bbcca8065d8211eca02cb56b --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/Protocol.kt @@ -0,0 +1,104 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.protocols + +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait +import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.ServiceIndex +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.rust.codegen.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator + +/** + * Describes a protocol to the [HttpBoundProtocolGenerator]. + */ +interface Protocol { + /** Resolves HTTP bindings (which part of a request fields are mapped to) */ + val httpBindingResolver: HttpBindingResolver + + /** The timestamp format that should be used if no override is specified in the model */ + val defaultTimestampFormat: TimestampFormatTrait.Format + + /** Returns additional HTTP headers that should be included for the given operation for this protocol */ + fun additionalHeaders(operationShape: OperationShape): List> = emptyList() + + /** Returns a deserialization code generator for this protocol */ + fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator + + /** Returns a serialization code generator for this protocol */ + fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator + + /** + * Generates a function signature like the following: + * ```rust + * fn parse_http_generic_error(response: &Response) -> smithy_types::error::Error + * ``` + */ + fun parseHttpGenericError(operationShape: OperationShape): RuntimeType + + /** + * Generates a function signature like the following: + * ```rust + * fn parse_event_stream_generic_error(payload: &Bytes) -> smithy_types::error::Error + * ``` + * + * Event Stream generic errors are almost identical to HTTP generic errors, except that + * there are no response headers or statuses available to further inform the error parsing. + */ + fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType +} + +typealias ProtocolMap = Map> + +interface ProtocolGeneratorFactory { + fun protocol(codegenContext: CodegenContext): Protocol + fun buildProtocolGenerator(codegenContext: CodegenContext): T + fun transformModel(model: Model): Model + fun symbolProvider(model: Model, base: RustSymbolProvider): RustSymbolProvider = base + fun support(): ProtocolSupport +} + +class ProtocolLoader(private val supportedProtocols: ProtocolMap) { + fun protocolFor( + model: Model, + serviceShape: ServiceShape + ): Pair> { + val protocols: MutableMap = ServiceIndex.of(model).getProtocols(serviceShape) + val matchingProtocols = + protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } } + if (matchingProtocols.isEmpty()) { + throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}") + } + return matchingProtocols.first() + } + + companion object { + val DefaultProtocols = mapOf( + AwsJson1_0Trait.ID to AwsJsonFactory(AwsJsonVersion.Json10), + AwsJson1_1Trait.ID to AwsJsonFactory(AwsJsonVersion.Json11), + AwsQueryTrait.ID to AwsQueryFactory(), + Ec2QueryTrait.ID to Ec2QueryFactory(), + RestJson1Trait.ID to RestJsonFactory(), + RestXmlTrait.ID to RestXmlFactory(), + ) + val Default = ProtocolLoader(DefaultProtocols) + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt deleted file mode 100644 index 2cbef22d754753af7f2edf899680b6b370796f67..0000000000000000000000000000000000000000 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -package software.amazon.smithy.rust.codegen.smithy.protocols - -import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait -import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait -import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait -import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait -import software.amazon.smithy.aws.traits.protocols.RestJson1Trait -import software.amazon.smithy.aws.traits.protocols.RestXmlTrait -import software.amazon.smithy.codegen.core.CodegenException -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.knowledge.ServiceIndex -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.traits.Trait -import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory - -typealias ProtocolMap = Map> - -class ProtocolLoader(private val supportedProtocols: ProtocolMap) { - fun protocolFor( - model: Model, - serviceShape: ServiceShape - ): Pair> { - val protocols: MutableMap = ServiceIndex.of(model).getProtocols(serviceShape) - val matchingProtocols = - protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } } - if (matchingProtocols.isEmpty()) { - throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}") - } - return matchingProtocols.first() - } - - companion object { - val DefaultProtocols = mapOf( - AwsJson1_0Trait.ID to AwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to AwsJsonFactory(AwsJsonVersion.Json11), - AwsQueryTrait.ID to AwsQueryFactory(), - Ec2QueryTrait.ID to Ec2QueryFactory(), - RestJson1Trait.ID to RestJsonFactory(), - RestXmlTrait.ID to RestXmlFactory(), - ) - val Default = ProtocolLoader(DefaultProtocols) - } -} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt index eb2e34160cf50b4716d452e1076193122d45ce4d..74dd59712530cc07d63e01e996485be144880d28 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestJson.kt @@ -12,19 +12,19 @@ import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator class RestJsonFactory : ProtocolGeneratorFactory { - override fun buildProtocolGenerator( - protocolConfig: ProtocolConfig - ): HttpBoundProtocolGenerator = HttpBoundProtocolGenerator(protocolConfig, RestJson(protocolConfig)) + override fun protocol(codegenContext: CodegenContext): Protocol = RestJson(codegenContext) + + override fun buildProtocolGenerator(codegenContext: CodegenContext): HttpBoundProtocolGenerator = + HttpBoundProtocolGenerator(codegenContext, RestJson(codegenContext)) override fun transformModel(model: Model): Model = model @@ -38,8 +38,8 @@ class RestJsonFactory : ProtocolGeneratorFactory { } } -class RestJson(private val protocolConfig: ProtocolConfig) : Protocol { - private val runtimeConfig = protocolConfig.runtimeConfig +class RestJson(private val codegenContext: CodegenContext) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, "Error" to RuntimeType.GenericError(runtimeConfig), @@ -51,16 +51,16 @@ class RestJson(private val protocolConfig: ProtocolConfig) : Protocol { private val jsonDeserModule = RustModule.private("json_deser") override val httpBindingResolver: HttpBindingResolver = - HttpTraitHttpBindingResolver(protocolConfig.model, ProtocolContentTypes.consistent("application/json")) + HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes.consistent("application/json")) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return JsonParserGenerator(protocolConfig, httpBindingResolver) + return JsonParserGenerator(codegenContext, httpBindingResolver) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { - return JsonSerializerGenerator(protocolConfig, httpBindingResolver) + return JsonSerializerGenerator(codegenContext, httpBindingResolver) } override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt index f434de2eeba6a0d618ea374701d5259eff13ab3e..655f2ca51d86cd78277d7fd96de5bceff34e539d 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/RestXml.kt @@ -14,23 +14,21 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.RestXmlParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.XmlBindingTraitSerializerGenerator import software.amazon.smithy.rust.codegen.util.expectTrait -class RestXmlFactory(private val generator: (ProtocolConfig) -> Protocol = { RestXml(it) }) : +class RestXmlFactory(private val generator: (CodegenContext) -> Protocol = { RestXml(it) }) : ProtocolGeneratorFactory { - override fun buildProtocolGenerator( - protocolConfig: ProtocolConfig - ): HttpBoundProtocolGenerator { - return HttpBoundProtocolGenerator(protocolConfig, generator(protocolConfig)) - } + override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext) + + override fun buildProtocolGenerator(codegenContext: CodegenContext): HttpBoundProtocolGenerator = + HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) override fun transformModel(model: Model): Model = model @@ -44,9 +42,9 @@ class RestXmlFactory(private val generator: (ProtocolConfig) -> Protocol = { Res } } -open class RestXml(private val protocolConfig: ProtocolConfig) : Protocol { - private val restXml = protocolConfig.serviceShape.expectTrait() - private val runtimeConfig = protocolConfig.runtimeConfig +open class RestXml(private val codegenContext: CodegenContext) : Protocol { + private val restXml = codegenContext.serviceShape.expectTrait() + private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, "Error" to RuntimeType.GenericError(runtimeConfig), @@ -62,17 +60,17 @@ open class RestXml(private val protocolConfig: ProtocolConfig) : Protocol { } override val httpBindingResolver: HttpBindingResolver = - HttpTraitHttpBindingResolver(protocolConfig.model, ProtocolContentTypes.consistent("application/xml")) + HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes.consistent("application/xml")) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return RestXmlParserGenerator(protocolConfig, restXmlErrors) + return RestXmlParserGenerator(codegenContext, restXmlErrors) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { - return XmlBindingTraitSerializerGenerator(protocolConfig, httpBindingResolver) + return XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver) } override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/AwsQueryParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/AwsQueryParserGenerator.kt index e5aeadf6b56f487762f9f9395e2a086bd1b6db31..3a35ef90b74f9a09c65ec39ff66793b47d42bb80 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/AwsQueryParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/AwsQueryParserGenerator.kt @@ -6,8 +6,8 @@ package software.amazon.smithy.rust.codegen.smithy.protocols.parse import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig /** * The AWS query protocol's responses are identical to REST XML's, except that they are wrapped @@ -25,14 +25,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig * of the response parsing, but it overrides [operationParser] to add the protocol differences. */ class AwsQueryParserGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, xmlErrors: RuntimeType, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( - protocolConfig, + codegenContext, xmlErrors ) { context, inner -> - val operationName = protocolConfig.symbolProvider.toSymbol(context.shape).name + val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" val resultWrapperName = operationName + "Result" rustTemplate( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/Ec2QueryParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/Ec2QueryParserGenerator.kt index 20838b1404ba778f084d37a25bae5bcbd6200687..26885027762859103342c29f16dfaedabf01fdd4 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/Ec2QueryParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/Ec2QueryParserGenerator.kt @@ -6,8 +6,8 @@ package software.amazon.smithy.rust.codegen.smithy.protocols.parse import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig /** * The EC2 query protocol's responses are identical to REST XML's, except that they are wrapped @@ -23,14 +23,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig * of the response parsing, but it overrides [operationParser] to add the protocol differences. */ class Ec2QueryParserGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, xmlErrors: RuntimeType, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( - protocolConfig, + codegenContext, xmlErrors ) { context, inner -> - val operationName = protocolConfig.symbolProvider.toSymbol(context.shape).name + val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" rustTemplate( """ diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/JsonParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/JsonParserGenerator.kt index f6d743e0560ac45bc12f12c6d5674d75aa065f5c..0bbd07c5240c6f5cb3e835cece11ef23d94ea7d0 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/JsonParserGenerator.kt @@ -34,9 +34,9 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.canUseDefault -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.smithy.generators.setterName @@ -52,12 +52,12 @@ import software.amazon.smithy.rust.codegen.util.toPascalCase import software.amazon.smithy.utils.StringUtils class JsonParserGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, ) : StructuredDataParserGenerator { - private val model = protocolConfig.model - private val symbolProvider = protocolConfig.symbolProvider - private val runtimeConfig = protocolConfig.runtimeConfig + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() private val jsonDeserModule = RustModule.private("json_deser") private val codegenScope = arrayOf( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/RestXmlParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/RestXmlParserGenerator.kt index 755e51c3d2bc9e4f28bfd0510c092439dc052359..fd0bb9366bb4ebc62e864bbc54d751dd821eac6a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/RestXmlParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/RestXmlParserGenerator.kt @@ -6,15 +6,15 @@ package software.amazon.smithy.rust.codegen.smithy.protocols.parse import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig class RestXmlParserGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, xmlErrors: RuntimeType, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( - protocolConfig, + codegenContext, xmlErrors ) { context, inner -> val shapeName = context.outputShapeName diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 4de4aae311400a42dd55455ba7ee7482a0f51e2a..c1caf39072017cf097d6b5201fc542a7f59579a3 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -37,8 +37,8 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.smithy.generators.setterName @@ -63,7 +63,7 @@ data class OperationWrapperContext( ) class XmlBindingTraitParserGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val xmlErrors: RuntimeType, private val writeOperationWrapper: RustWriter.(OperationWrapperContext, OperationInnerWriteable) -> Unit, ) : StructuredDataParserGenerator { @@ -89,12 +89,12 @@ class XmlBindingTraitParserGenerator( */ data class Ctx(val tag: String, val accum: String?) - private val symbolProvider = protocolConfig.symbolProvider - private val smithyXml = CargoDependency.smithyXml(protocolConfig.runtimeConfig).asType() + private val symbolProvider = codegenContext.symbolProvider + private val smithyXml = CargoDependency.smithyXml(codegenContext.runtimeConfig).asType() private val xmlError = smithyXml.member("decode::XmlError") private val scopedDecoder = smithyXml.member("decode::ScopedDecoder") - private val runtimeConfig = protocolConfig.runtimeConfig + private val runtimeConfig = codegenContext.runtimeConfig // The symbols we want all the time private val codegenScope = arrayOf( @@ -106,7 +106,7 @@ class XmlBindingTraitParserGenerator( "ScopedDecoder" to scopedDecoder, "smithy_types" to CargoDependency.SmithyTypes(runtimeConfig).asType() ) - private val model = protocolConfig.model + private val model = codegenContext.model private val index = HttpBindingIndex.of(model) private val xmlIndex = XmlNameIndex.of(model) private val xmlDeserModule = RustModule.private("xml_deser") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/AwsQuerySerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/AwsQuerySerializerGenerator.kt index 99a791094666020b85c7e7c87e6a0edb44d71a51..6ce444c75b30cf34e62da823e53aabb2b9d53b6a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/AwsQuerySerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/AwsQuerySerializerGenerator.kt @@ -8,10 +8,10 @@ package software.amazon.smithy.rust.codegen.smithy.protocols.serialize import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.traits.XmlFlattenedTrait import software.amazon.smithy.model.traits.XmlNameTrait -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.util.getTrait -class AwsQuerySerializerGenerator(protocolConfig: ProtocolConfig) : QuerySerializerGenerator(protocolConfig) { +class AwsQuerySerializerGenerator(codegenContext: CodegenContext) : QuerySerializerGenerator(codegenContext) { override val protocolName: String get() = "AWS Query" override fun MemberShape.queryKeyName(prioritizedFallback: String?): String = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/Ec2QuerySerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/Ec2QuerySerializerGenerator.kt index 8256e6a76a42ba1efc306ca08df3d73e86cde66e..2e38f04e87a6b0b477b16d60357b3379303b2dd3 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/Ec2QuerySerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/Ec2QuerySerializerGenerator.kt @@ -8,11 +8,11 @@ package software.amazon.smithy.rust.codegen.smithy.protocols.serialize import software.amazon.smithy.aws.traits.protocols.Ec2QueryNameTrait import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.traits.XmlNameTrait -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.utils.StringUtils -class Ec2QuerySerializerGenerator(protocolConfig: ProtocolConfig) : QuerySerializerGenerator(protocolConfig) { +class Ec2QuerySerializerGenerator(codegenContext: CodegenContext) : QuerySerializerGenerator(codegenContext) { override val protocolName: String get() = "EC2 Query" override fun MemberShape.queryKeyName(prioritizedFallback: String?): String = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt index aab7e90f52584aeb7131d92f1d59e15f927bb020..e314769d807d57fd10993791f6706dfd1f39f9f4 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -31,9 +31,9 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation @@ -46,7 +46,7 @@ import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.toPascalCase class JsonSerializerGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, ) : StructuredDataSerializerGenerator { private data class Context( @@ -122,9 +122,9 @@ class JsonSerializerGenerator( val shape: StructureShape, ) - private val model = protocolConfig.model - private val symbolProvider = protocolConfig.symbolProvider - private val runtimeConfig = protocolConfig.runtimeConfig + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig private val smithyTypes = CargoDependency.SmithyTypes(runtimeConfig).asType() private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() private val serializerError = smithyTypes.member("Error") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/QuerySerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/QuerySerializerGenerator.kt index b1effdbc72574de1582d6c12c4e61617215551f2..db5cfc74e085c14533ae6c4632ed92240c9fb659 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/QuerySerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/QuerySerializerGenerator.kt @@ -31,9 +31,9 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.protocols.serializeFunctionName import software.amazon.smithy.rust.codegen.smithy.rustType @@ -44,7 +44,7 @@ import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.toPascalCase -abstract class QuerySerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator { +abstract class QuerySerializerGenerator(codegenContext: CodegenContext) : StructuredDataSerializerGenerator { protected data class Context( /** Expression that yields a QueryValueWriter */ val writerExpression: String, @@ -77,10 +77,10 @@ abstract class QuerySerializerGenerator(protocolConfig: ProtocolConfig) : Struct } } - protected val model = protocolConfig.model - protected val symbolProvider = protocolConfig.symbolProvider - protected val runtimeConfig = protocolConfig.runtimeConfig - private val serviceShape = protocolConfig.serviceShape + protected val model = codegenContext.model + protected val symbolProvider = codegenContext.symbolProvider + protected val runtimeConfig = codegenContext.runtimeConfig + private val serviceShape = codegenContext.serviceShape private val serializerError = RuntimeType.Infallible private val smithyTypes = CargoDependency.SmithyTypes(runtimeConfig).asType() private val smithyQuery = CargoDependency.smithyQuery(runtimeConfig).asType() diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt index 52065036c2d852334a85fae008baf59ccab616e3..a416414ebcf26bdb1dc56f41bec9d2cbb50b591d 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt @@ -36,8 +36,8 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.stripOuter import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation @@ -52,12 +52,12 @@ import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.toPascalCase class XmlBindingTraitSerializerGenerator( - protocolConfig: ProtocolConfig, + codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver ) : StructuredDataSerializerGenerator { - private val symbolProvider = protocolConfig.symbolProvider - private val runtimeConfig = protocolConfig.runtimeConfig - private val model = protocolConfig.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig + private val model = codegenContext.model private val smithyXml = CargoDependency.smithyXml(runtimeConfig).asType() private val codegenScope = arrayOf( @@ -74,7 +74,7 @@ class XmlBindingTraitSerializerGenerator( private val xmlSerModule = RustModule.private("xml_ser") private val xmlIndex = XmlNameIndex.of(model) - private val rootNamespace = protocolConfig.serviceShape.getTrait() + private val rootNamespace = codegenContext.serviceShape.getTrait() private val util = SerializerUtil(model) sealed class Ctx { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt index 451ce54ae7c1f7426d91923b2e3778c57f93dd06..342a6d9b1a71c9465e53643f0127e31e701803ba 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.util.orNull +import software.amazon.smithy.rust.codegen.util.rename import java.util.Optional import kotlin.streams.toList @@ -89,10 +90,3 @@ object OperationNormalizer { private fun empty(id: ShapeId) = StructureShape.builder().id(id) } - -private fun StructureShape.Builder.rename(newId: ShapeId): StructureShape.Builder { - val renamedMembers = this.build().members().map { - it.toBuilder().id(newId.withMember(it.memberName)).build() - } - return this.id(newId).members(renamedMembers) -} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt index 0368255476a5c865c37b956386762ace8beb024d..d192172477bdec09c38e018c8740586057224cab 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt @@ -11,13 +11,13 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.smithy.CodegenConfig +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeCrateLocation import software.amazon.smithy.rust.codegen.smithy.RustCodegenPlugin import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.smithy.letIf @@ -39,7 +39,7 @@ fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSy TestSymbolVisitorConfig ) -fun testProtocolConfig(model: Model, serviceShape: ServiceShape? = null): ProtocolConfig = ProtocolConfig( +fun testProtocolConfig(model: Model, serviceShape: ServiceShape? = null): CodegenContext = CodegenContext( model, testSymbolProvider(model), TestRuntimeConfig, diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Synthetics.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Synthetics.kt new file mode 100644 index 0000000000000000000000000000000000000000..aa8b0365e9a80a38eca6be641b03d9aefa655f0c --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Synthetics.kt @@ -0,0 +1,61 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.util + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.ToShapeId + +/** + * Clones an entire operation and its input/output shapes under a new name. + */ +fun Model.Builder.cloneOperation( + model: Model, + oldOperation: ToShapeId, + idTransform: (ShapeId) -> ShapeId +): Model.Builder { + val operationShape = model.expectShape(oldOperation.toShapeId(), OperationShape::class.java) + val inputShape = model.expectShape( + checkNotNull(operationShape.input.orNull()) { + "cloneOperation expects OperationNormalizer to be run first to add input shapes to all operations" + }, + StructureShape::class.java + ) + val outputShape = model.expectShape( + checkNotNull(operationShape.output.orNull()) { + "cloneOperation expects OperationNormalizer to be run first to add output shapes to all operations" + }, + StructureShape::class.java + ) + + val inputId = idTransform(inputShape.id) + addShape(inputShape.toBuilder().rename(inputId).build()) + val outputId = idTransform(outputShape.id) + if (outputId != inputId) { + addShape(outputShape.toBuilder().rename(outputId).build()) + } + val operationId = idTransform(operationShape.id) + addShape( + operationShape.toBuilder() + .id(operationId) + .input(inputId) + .output(outputId) + .build() + ) + return this +} + +/** + * Renames a StructureShape builder and automatically fixes all the members. + */ +fun StructureShape.Builder.rename(newId: ShapeId): StructureShape.Builder { + val renamedMembers = this.build().members().map { + it.toBuilder().id(newId.withMember(it.memberName)).build() + } + return this.id(newId).members(renamedMembers) +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt index 0a47a272fa1e3d4dc484db1d8df3a82d8aac07fb..416fa3f1a77579025176075f1c217b5abb773757 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.http.RequestBindingGenerator @@ -111,23 +112,42 @@ class RequestBindingGeneratorTest { private fun renderOperation(writer: RustWriter) { inputShape.renderWithModelBuilder(model, symbolProvider, writer) val inputShape = model.expectShape(operationShape.input.get(), StructureShape::class.java) + val bindingGen = RequestBindingGenerator( + model, + symbolProvider, + TestRuntimeConfig, + TimestampFormatTrait.Format.EPOCH_SECONDS, + operationShape, + inputShape, + httpTrait + ) writer.rustBlock("impl PutObjectInput") { - RequestBindingGenerator( - model, - symbolProvider, - TestRuntimeConfig, - TimestampFormatTrait.Format.EPOCH_SECONDS, - operationShape, - inputShape, - httpTrait - ).renderUpdateHttpBuilder(this) + // RequestBindingGenerator's functions expect to be rendered inside a function, + // but the unit test needs to call some of these functions individually. This generates + // some wrappers that can be called directly from the tests. The functions will get duplicated, + // but that's not a problem. + + rustBlock("pub fn test_uri_query(&self, mut output: &mut String)") { + bindingGen.renderUpdateHttpBuilder(this) + rust("uri_query(self, output)") + } + + rustBlock( + "pub fn test_uri_base(&self, mut output: &mut String) -> Result<(), #T>", + TestRuntimeConfig.operationBuildError() + ) { + bindingGen.renderUpdateHttpBuilder(this) + rust("uri_base(self, output)") + } + rustBlock( - "pub fn request_builder_base(&self) -> Result<#T, #T>", + "pub fn test_request_builder_base(&self) -> Result<#T, #T>", RuntimeType.HttpRequestBuilder, TestRuntimeConfig.operationBuildError() ) { - write("let builder = #T::new();", RuntimeType.HttpRequestBuilder) - write("self.update_http_builder(builder)") + bindingGen.renderUpdateHttpBuilder(this) + rust("let builder = #T::new();", RuntimeType.HttpRequestBuilder) + rust("update_http_builder(self, builder)") } } } @@ -152,10 +172,10 @@ class RequestBindingGeneratorTest { .some_value("svq!!%&") .build().expect("build should succeed"); let mut o = String::new(); - inp.uri_base(&mut o); + inp.test_uri_base(&mut o); assert_eq!(o.as_str(), "/somebucket%2Fok/1970-04-28T03%3A58%3A45Z"); o.clear(); - inp.uri_query(&mut o); + inp.test_uri_query(&mut o); assert_eq!(o.as_str(), "?paramName=svq%21%21%25%26&hello=0&hello=1&hello=2&hello=44") """ ) @@ -176,7 +196,7 @@ class RequestBindingGeneratorTest { .enabled(true) .build().expect("build should succeed"); let mut o = String::new(); - inp.uri_query(&mut o); + inp.test_uri_query(&mut o); assert_eq!(o.as_str(), "?primitive=1&enabled=true") """, clippy = true @@ -201,7 +221,7 @@ class RequestBindingGeneratorTest { .media_type("base64encodethis") .prefix("k".to_string(), "😹".to_string()) .build().unwrap(); - let http_request = inp.request_builder_base().expect("valid input").body(()).unwrap(); + let http_request = inp.test_request_builder_base().expect("valid input").body(()).unwrap(); assert_eq!(http_request.uri(), "/buk/1970-04-28T03%3A58%3A45Z?paramName=qp&hello=0&hello=1"); assert_eq!(http_request.method(), "PUT"); let mut date_header = http_request.headers().get_all("X-Dates").iter(); @@ -233,7 +253,7 @@ class RequestBindingGeneratorTest { .key(ts.clone()) .prefix("😹".to_string(), "😹".to_string()) .build().unwrap(); - let err = inp.request_builder_base().expect_err("can't make a header out of a cat emoji"); + let err = inp.test_request_builder_base().expect_err("can't make a header out of a cat emoji"); assert_eq!(format!("{}", err), "Invalid field in input: prefix (Details: `😹` cannot be used as a header name: invalid HTTP header name)"); """ ) @@ -252,7 +272,7 @@ class RequestBindingGeneratorTest { .key(ts.clone()) .prefix("valid-key".to_string(), "\n can't put a newline in a header value".to_string()) .build().unwrap(); - let err = inp.request_builder_base().expect_err("can't make a header with a newline"); + let err = inp.test_request_builder_base().expect_err("can't make a header with a newline"); assert_eq!(format!("{}", err), "Invalid field in input: prefix (Details: `\n can\'t put a newline in a header value` cannot be used as a header value: failed to parse header value)"); """ ) @@ -270,7 +290,7 @@ class RequestBindingGeneratorTest { .key(ts.clone()) .string_header("\n is not valid") .build().unwrap(); - let err = inp.request_builder_base().expect_err("can't make a header with a newline"); + let err = inp.test_request_builder_base().expect_err("can't make a header with a newline"); // make sure we obey the sensitive trait assert_eq!(format!("{}", err), "Invalid field in input: string_header (Details: `*** Sensitive Data Redacted ***` cannot be used as a header value: failed to parse header value)"); """ @@ -289,7 +309,7 @@ class RequestBindingGeneratorTest { // .bucket_name("buk") .key(ts.clone()) .build().unwrap(); - let err = inp.request_builder_base().expect_err("can't build request with bucket unset"); + let err = inp.test_request_builder_base().expect_err("can't build request with bucket unset"); assert!(matches!(err, ${writer.format(TestRuntimeConfig.operationBuildError())}::MissingField { .. })) """ ) @@ -307,7 +327,7 @@ class RequestBindingGeneratorTest { // don't set key // .key(ts.clone()) .build().unwrap(); - let err = inp.request_builder_base().expect_err("can't build request with bucket unset"); + let err = inp.test_request_builder_base().expect_err("can't build request with bucket unset"); assert!(matches!(err, ${writer.format(TestRuntimeConfig.operationBuildError())}::MissingField { .. })) """ ) @@ -324,7 +344,7 @@ class RequestBindingGeneratorTest { .bucket_name("") .key(ts.clone()) .build().unwrap(); - let err = inp.request_builder_base().expect_err("can't build request with bucket unset"); + let err = inp.test_request_builder_base().expect_err("can't build request with bucket unset"); assert!(matches!(err, ${writer.format(TestRuntimeConfig.operationBuildError())}::MissingField { .. })) """ ) diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt index 36683af9bd4becf86a3a67070b6d01ade215ee3f..13d2f9677322db1489bfe5835c094640711c54e3 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt @@ -12,7 +12,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver @@ -70,7 +70,7 @@ class ResponseBindingGeneratorTest { private val model = OperationNormalizer.transform(baseModel) private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java) private val symbolProvider = testSymbolProvider(model) - private val testProtocolConfig: ProtocolConfig = testProtocolConfig(model) + private val testCodegenContext: CodegenContext = testProtocolConfig(model) private fun RustWriter.renderOperation() { operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) @@ -80,8 +80,8 @@ class ResponseBindingGeneratorTest { .filter { it.location == HttpLocation.HEADER } bindings.forEach { binding -> val runtimeType = ResponseBindingGenerator( - RestJson(testProtocolConfig), - testProtocolConfig, + RestJson(testCodegenContext), + testCodegenContext, operationShape ).generateDeserializeHeaderFn(binding) // little hack to force these functions to be generated diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolTestGeneratorTest.kt similarity index 70% rename from codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt rename to codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index e57db2edf928fb8277ad848f0990d206a37b4c45..1f645bbabc817dd06f2fffdb5987c0616e831b06 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ -package software.amazon.smithy.rust.codegen.smithy.generators +package software.amazon.smithy.rust.codegen.smithy.generators.protocol import io.kotest.matchers.string.shouldContain import org.junit.jupiter.api.Test @@ -12,25 +12,118 @@ import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.escape import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.CodegenVisitor import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolBodyGenerator.BodyMetadata +import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap +import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.testutil.generatePluginContext import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.runCommand import java.nio.file.Path -class HttpProtocolTestGeneratorTest { +private class TestProtocolBodyGenerator(private val body: String) : ProtocolBodyGenerator { + override fun bodyMetadata(operationShape: OperationShape): BodyMetadata = + BodyMetadata(takesOwnership = false) + + override fun generateBody(writer: RustWriter, self: String, operationShape: OperationShape) { + writer.writeWithNoFormatting(body) + } +} + +private class TestProtocolTraitImplGenerator( + private val codegenContext: CodegenContext, + private val correctResponse: String +) : ProtocolTraitImplGenerator { + private val symbolProvider = codegenContext.symbolProvider + + override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) { + operationWriter.rustTemplate( + """ + impl #{parse_strict} for ${operationShape.id.name}{ + type Output = Result<#{output}, #{error}>; + fn parse(&self, response: &#{response}<#{bytes}>) -> Self::Output { + ${operationWriter.escape(correctResponse)} + } + }""", + "parse_strict" to RuntimeType.parseStrict(codegenContext.runtimeConfig), + "output" to symbolProvider.toSymbol(operationShape.outputShape(codegenContext.model)), + "error" to operationShape.errorSymbol(symbolProvider), + "response" to RuntimeType.Http("Response"), + "bytes" to RuntimeType.Bytes + ) + } +} + +private class TestProtocolMakeOperationGenerator( + codegenContext: CodegenContext, + protocol: Protocol, + body: String, + private val httpRequestBuilder: String +) : MakeOperationGenerator(codegenContext, protocol, TestProtocolBodyGenerator(body)) { + override fun generateRequestBuilderBaseFn(writer: RustWriter, operationShape: OperationShape) { + writer.inRequestBuilderBaseFn(operationShape.inputShape(codegenContext.model)) { + withBlock("Ok(#T::new()", ")", RuntimeType.HttpRequestBuilder) { + writeWithNoFormatting(httpRequestBuilder) + } + } + } +} + +// A stubbed test protocol to do enable testing intentionally broken protocols +private class TestProtocolGenerator( + codegenContext: CodegenContext, + protocol: Protocol, + httpRequestBuilder: String, + body: String, + correctResponse: String +) : ProtocolGenerator( + codegenContext, + protocol, + TestProtocolMakeOperationGenerator(codegenContext, protocol, body, httpRequestBuilder), + TestProtocolTraitImplGenerator(codegenContext, correctResponse) +) + +private class TestProtocolFactory( + private val httpRequestBuilder: String, + private val body: String, + private val correctResponse: String +) : ProtocolGeneratorFactory { + override fun protocol(codegenContext: CodegenContext): Protocol { + return RestJson(codegenContext) + } + + override fun buildProtocolGenerator(codegenContext: CodegenContext): ProtocolGenerator { + return TestProtocolGenerator( + codegenContext, + protocol(codegenContext), + httpRequestBuilder, + body, + correctResponse + ) + } + + override fun transformModel(model: Model): Model = model + + override fun support(): ProtocolSupport { + return ProtocolSupport(true, true, true, true) + } +} + +class ProtocolTestGeneratorTest { private val model = """ namespace com.example @@ -117,57 +210,6 @@ class HttpProtocolTestGeneratorTest { body: String = "${correctBody.dq()}.to_string().into()", correctResponse: String = """Ok(crate::output::SayHelloOutput::builder().value("hey there!").build())""" ): Path { - - // A stubbed test protocol to do enable testing intentionally broken protocols - class TestProtocol(private val protocolConfig: ProtocolConfig) : HttpProtocolGenerator(protocolConfig) { - private val symbolProvider = protocolConfig.symbolProvider - override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { - writeWithNoFormatting(body) - return BodyMetadata(takesOwnership = false) - } - - override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) { - operationWriter.rustTemplate( - """ - impl #{parse_strict} for ${operationShape.id.name}{ - type Output = Result<#{output}, #{error}>; - fn parse(&self, response: &#{response}<#{bytes}>) -> Self::Output { - ${operationWriter.escape(correctResponse)} - } - }""", - "parse_strict" to RuntimeType.parseStrict(protocolConfig.runtimeConfig), - "output" to symbolProvider.toSymbol(operationShape.outputShape(protocolConfig.model)), - "error" to operationShape.errorSymbol(symbolProvider), - "response" to RuntimeType.Http("Response"), - "bytes" to RuntimeType.Bytes - ) - } - - override fun toHttpRequestImpl( - implBlockWriter: RustWriter, - operationShape: OperationShape, - inputShape: StructureShape - ) { - generateRequestBuilderBase(implBlockWriter) { - withBlock("Ok(#T::new()", ")", RuntimeType.HttpRequestBuilder) { - writeWithNoFormatting(httpRequestBuilder) - } - } - } - } - - class TestProtocolFactory : ProtocolGeneratorFactory { - override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpProtocolGenerator { - return TestProtocol(protocolConfig) - } - - override fun transformModel(model: Model): Model = model - - override fun support(): ProtocolSupport { - return ProtocolSupport(true, true, true, true) - } - } - val (pluginContext, testDir) = generatePluginContext(model) val visitor = CodegenVisitor( pluginContext, @@ -176,7 +218,7 @@ class HttpProtocolTestGeneratorTest { override val order: Byte = 0 override fun protocols(serviceId: ShapeId, currentProtocols: ProtocolMap): ProtocolMap { // Intentionally replace the builtin implementation of RestJson1 with our fake protocol - return mapOf(RestJson1Trait.ID to TestProtocolFactory()) + return mapOf(RestJson1Trait.ID to TestProtocolFactory(httpRequestBuilder, body, correctResponse)) } } ) diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/EventStreamTestTools.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/EventStreamTestTools.kt index 5c92edcd79f391218528f430cd180a9858a2fbee..6e4b6293bc224b061a5961790a327e45cc048920 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/EventStreamTestTools.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/EventStreamTestTools.kt @@ -18,9 +18,9 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator @@ -123,7 +123,7 @@ object EventStreamTestModels { val validTestUnion: String, val validSomeError: String, val validUnmodeledError: String, - val protocolBuilder: (ProtocolConfig) -> Protocol, + val protocolBuilder: (CodegenContext) -> Protocol, ) { override fun toString(): String = protocolShapeId } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt index 828c5155015480f921e8b67e2884b800ca3334a0..0305abea21e4cfc692a72fbe710b46f58cad3661 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGeneratorTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.rustlang.rust -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestModels import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestTools import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig @@ -22,7 +22,7 @@ class EventStreamUnmarshallerGeneratorTest { fun test(testCase: EventStreamTestModels.TestCase) { val test = EventStreamTestTools.generateTestProject(testCase.model) - val protocolConfig = ProtocolConfig( + val codegenContext = CodegenContext( test.model, test.symbolProvider, TestRuntimeConfig, @@ -30,7 +30,7 @@ class EventStreamUnmarshallerGeneratorTest { ShapeId.from(testCase.protocolShapeId), "test" ) - val protocol = testCase.protocolBuilder(protocolConfig) + val protocol = testCase.protocolBuilder(codegenContext) val generator = EventStreamUnmarshallerGenerator( protocol, test.model, diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt index 12b8bae35a152597f544c6797f0117d0b1cc2288..1d08a5ec612a2c846f14eabe283a78fc6ae51d69 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/EventStreamMarshallerGeneratorTest.kt @@ -11,8 +11,8 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestModels import software.amazon.smithy.rust.codegen.smithy.protocols.EventStreamTestTools import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig @@ -26,7 +26,7 @@ class EventStreamMarshallerGeneratorTest { fun test(testCase: EventStreamTestModels.TestCase) { val test = EventStreamTestTools.generateTestProject(testCase.model) - val protocolConfig = ProtocolConfig( + val codegenContext = CodegenContext( test.model, test.symbolProvider, TestRuntimeConfig, @@ -34,7 +34,7 @@ class EventStreamMarshallerGeneratorTest { ShapeId.from(testCase.protocolShapeId), "test" ) - val protocol = testCase.protocolBuilder(protocolConfig) + val protocol = testCase.protocolBuilder(codegenContext) val generator = EventStreamMarshallerGenerator( test.model, TestRuntimeConfig, diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/util/SyntheticsTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/util/SyntheticsTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..1c810a7b5f601a34b794f6d5e891f5f6c2cb853b --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/util/SyntheticsTest.kt @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.util + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.util.cloneOperation +import software.amazon.smithy.rust.codegen.util.orNull +import software.amazon.smithy.rust.codegen.util.rename + +class SyntheticsTest { + @Test + fun `it should clone operations`() { + val model = """ + namespace test + + service TestService { + version: "2019-12-16", + operations: ["SomeOperation"], + } + + structure TestInput { + one: String, + two: String, + } + structure TestOutput { + three: String, + four: String, + } + + operation SomeOperation { input: TestInput, output: TestOutput } + """.asSmithyModel() + + val transformed = model.toBuilder().cloneOperation(model, ShapeId.from("test#SomeOperation")) { shapeId -> + ShapeId.fromParts(shapeId.namespace + ".cloned", shapeId.name + "Foo") + }.build() + + val newOp = transformed.expectShape(ShapeId.from("test.cloned#SomeOperationFoo"), OperationShape::class.java) + newOp.input.orNull() shouldBe ShapeId.from("test.cloned#TestInputFoo") + newOp.output.orNull() shouldBe ShapeId.from("test.cloned#TestOutputFoo") + + val newIn = transformed.expectShape(ShapeId.from("test.cloned#TestInputFoo"), StructureShape::class.java) + for (member in newIn.members()) { + member.id shouldBe ShapeId.fromParts("test.cloned", "TestInputFoo", member.memberName) + } + + val newOut = transformed.expectShape(ShapeId.from("test.cloned#TestOutputFoo"), StructureShape::class.java) + for (member in newOut.members()) { + member.id shouldBe ShapeId.fromParts("test.cloned", "TestOutputFoo", member.memberName) + } + } + + @Test + fun `it should rename structs`() { + val model = """ + namespace test + + structure SomeInput { + one: String, + two: String, + } + """.asSmithyModel() + + val original = model.expectShape(ShapeId.from("test#SomeInput"), StructureShape::class.java) + val new = original.toBuilder().rename(ShapeId.from("new#SomeOtherInput")).build() + new.id shouldBe ShapeId.from("new#SomeOtherInput") + for (member in new.members()) { + member.id shouldBe ShapeId.fromParts("new", "SomeOtherInput", member.memberName) + } + } +} diff --git a/design/src/rfcs/rfc0003_presigning_api.md b/design/src/rfcs/rfc0003_presigning_api.md index c5afb2a052204059b76d3a2bbb47a320e08939b6..476922d925cf9bb4a4d54ba1b8d06187402eef6f 100644 --- a/design/src/rfcs/rfc0003_presigning_api.md +++ b/design/src/rfcs/rfc0003_presigning_api.md @@ -1,7 +1,7 @@ RFC: API for Presigned URLs ============================ -> Status: RFC +> Status: Implemented For a summarized list of proposed changes, see the [Changes Checklist](#changes-checklist) section. @@ -198,17 +198,17 @@ to rename the `presigned` member to `presigned_value` Changes Checklist ----------------- -- [ ] Update `aws-sigv4` to support query param signing -- [ ] Create `PresignedOperationSyntheticTrait` -- [ ] Customize models for known presigned operations -- [ ] Create `PresigningConfig` and its builder -- [ ] Implement `PresignedUrlGeneratorLayer` -- [ ] Create new AWS codegen decorator to: - - [ ] Add new `presigned()` method to input code generator - - [ ] Add new `presigned()` method to fluent client generator -- [ ] Update `RustReservedWords` to reserve `presigned()` -- [ ] Add integration test to S3 -- [ ] Add examples for using presigning for: - - [ ] S3 GetObject and PutObject - - [ ] CloudFront download URLs - - [ ] Polly SynthesizeSpeech +- [x] Update `aws-sigv4` to support query param signing +- [x] Create `PresignedOperationSyntheticTrait` +- [x] Customize models for known presigned operations +- [x] Create `PresigningConfig` and its builder +- [x] Implement `PresignedUrlGeneratorLayer` +- [x] Create new AWS codegen decorator to: + - [x] Add new `presigned()` method to input code generator + - [x] Add new `presigned()` method to fluent client generator +- [x] Update `RustReservedWords` to reserve `presigned()` +- [x] Add integration test to S3 +- [x] Add integration test to Polly +- [x] Add examples for using presigning for: + - [x] S3 GetObject and PutObject + - [x] Polly SynthesizeSpeech