Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CodegenVisitor.kt +8 −2 Original line number Diff line number Diff line Loading @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.server.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpDeserializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.CodegenContext Loading Loading @@ -273,7 +274,12 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC ServiceGenerator( rustCrate, protocolGenerator, protocolGeneratorFactory.support(), ProtocolSupport( requestDeserialization = true, requestBodyDeserialization = true, responseSerialization = true, errorSerialization = true ), codegenContext, codegenDecorator ) Loading @@ -296,7 +302,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC impl Error { ##[allow(dead_code)] fn generic(msg: &'static str) -> Self { pub fn generic(msg: &'static str) -> Self { Self::Generic(msg.into()) } } Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenPlugin.kt +24 −0 Original line number Diff line number Diff line Loading @@ -22,6 +22,11 @@ import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecor import java.util.logging.Level import java.util.logging.Logger /** Rust Codegen Plugin * This is the entrypoint for code generation, triggered by the smithy-build plugin. * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. */ class RustCodegenPlugin : SmithyBuildPlugin { private val logger = Logger.getLogger(javaClass.name) Loading @@ -30,23 +35,42 @@ class RustCodegenPlugin : SmithyBuildPlugin { override fun execute(context: PluginContext) { // Suppress extremely noisy logs about reserved words Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF // Discover `RustCodegenDecorators` on the classpath. `RustCodegenDectorator` return different types of // customization. A customization is a function of: // - location (eg. the mutate section of an operation) // - context (eg. the of the operation) // - writer: The active RustWriter at the given location val codegenDecorator = CombinedCodegenDecorator.fromClasspath(context) // CodegenVistor is the main driver of code generation that traverses the model and generates code CodegenVisitor(context, codegenDecorator).execute() } companion object { /** SymbolProvider * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider * * The Symbol provider is composed of a base `SymbolVisitor` which handles the core funcitonality, then is layered * with other symbol providers, documented inline, to handle the full scope of Smithy types. */ fun baseSymbolProvider( model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig ) = SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) // Generate different types for EventStream shapes (eg. transcribe streaming) .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) } // Generate `ByteStream` instead of `Blob` for streaming binary shapes (eg. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it) } // Streaming shapes need different derives (eg. they cannot derive Eq) .let { StreamingShapeMetadataProvider(it, model) } // Rename shapes that clash with Rust reserved words & and other SDK specific features eg. `send()` cannot // be the name of an operation input .let { RustReservedWordSymbolProvider(it) } } } codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceGenerator.kt +6 −1 Original line number Diff line number Diff line Loading @@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator Loading @@ -14,7 +16,6 @@ import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfi import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.util.inputShape class ServiceGenerator( Loading @@ -37,6 +38,10 @@ class ServiceGenerator( operation, decorator.operationCustomizations(context, operation, listOf()) ) // render protocol tests into `operation.rs` (note operationWriter vs. // inputWriter) ProtocolTestGenerator(context, protocolSupport, operation, operationWriter) .render() } } rustCrate.withModule(RustModule.Error) { writer -> Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ProtocolTestGenerator.kt 0 → 100644 +447 −0 File added.Preview size limit exceeded, changes collapsed. Show changes codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RestJson1.kt +2 −162 Original line number Diff line number Diff line Loading @@ -6,22 +6,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait import software.amazon.smithy.model.traits.IdempotencyTokenTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.protocoltests.traits.AppliesTo import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.CratesIo import software.amazon.smithy.rust.codegen.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable Loading @@ -47,13 +39,10 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember import software.amazon.smithy.rust.codegen.testutil.TokioTest import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.findMemberWithTrait import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase Loading Loading @@ -90,7 +79,6 @@ class RestJson1HttpSerializerGenerator( fun render(writer: RustWriter, operationShape: OperationShape) { renderResponseSerializer(writer, operationShape) renderErrorSerializer(writer, operationShape) renderTests(writer, operationShape) } private fun renderResponseSerializer(writer: RustWriter, operationShape: OperationShape) { Loading Loading @@ -121,7 +109,7 @@ class RestJson1HttpSerializerGenerator( HttpLocation.RESPONSE_CODE -> { rustTemplate( """ let status = output.${it.memberName} let status = output.${it.memberName.toLowerCase()} .ok_or(#{JsonSerdeError}::generic(${(it.member.memberName + " missing or empty").dq()}))?; let http_status: u16 = #{Convert}::TryFrom::<i32>::try_from(status) .map_err(|_| #{JsonSerdeError}::generic(${("invalid status code").dq()}))?; Loading Loading @@ -221,100 +209,6 @@ class RestJson1HttpSerializerGenerator( ) } } private fun renderTests(writer: RustWriter, operationShape: OperationShape) { val operationIndex = OperationIndex.of(model) val outputShape = operationShape.outputShape(model) val responseTests = operationShape .getTrait<HttpResponseTestsTrait>() ?.getTestCasesFor(AppliesTo.SERVER) .orEmpty() .map { it to outputShape } val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> error.getTrait<HttpResponseTestsTrait>()?.testCases.orEmpty().map { it to error } } if (responseTests.isEmpty() && errorTests.isEmpty()) { return } val testModuleName = "serialize_${operationShape.id.name.toSnakeCase()}_test" val moduleMeta = RustMetadata( public = false, additionalAttributes = listOf( Attribute.Cfg("test"), Attribute.Custom( "allow(unreachable_code, unused_variables)" ) ) ) writer.write("") writer.withModule(testModuleName, moduleMeta) { responseTests.forEach { renderSerializeResponseTestCase(operationShape, it.first, it.second) } errorTests.forEach { renderSerializeResponseTestCase(operationShape, it.first, it.second) } } } private fun RustWriter.renderSerializeResponseTestCase( operationShape: OperationShape, testCase: HttpResponseTestCase, shape: StructureShape ) { val isError = shape.hasTrait<ErrorTrait>() val fnName = if (isError) "serialize_${operationShape.id.name.toSnakeCase()}_error" else "serialize_${operationShape.id.name.toSnakeCase()}_response" val variantName = if (isError) "${format(operationShape.errorSymbol(symbolProvider))}Kind::${symbolProvider.toSymbol(shape).name}" else "" Attribute.Custom("test").render(this) rustBlock("fn ${testCase.id.toSnakeCase()}()") { rust("let config = #T::Config::builder().build();", RuntimeType.Config) writeInline("let expected = ") instantiator.render(this, shape, testCase.params) write(";") if (isError) { rust("let expected = $variantName(expected);") } rust( """let response = #T(&expected).expect("failed to serialize response");""", operation.member(fnName) ) rust("assert_eq!(response.status(), ${testCase.code});") rustTemplate( """ use #{ParseHttpResponse}; let parser = #{op}::new(); let actual = parser.parse_loaded(&response); """.trimIndent(), *codegenScope, "op" to symbolProvider.toSymbol(operationShape), ) if (isError) { rust("""let actual = actual.expect_err("failed to parse error");""") rust( """ match (&expected, &actual.kind) { ($variantName(expected), $variantName(actual)) => assert_eq!(expected, actual), _ => panic!("incorrect error type"), }; """.trimIndent() ) } else { rust("""let actual = actual.expect("failed to parse error");""") rust("assert_eq!(expected, actual);") } } } } class RestJson1HttpDeserializerGenerator( Loading Loading @@ -348,7 +242,7 @@ class RestJson1HttpDeserializerGenerator( fun render(writer: RustWriter, operationShape: OperationShape) { renderRequestDeserializer(writer, operationShape) renderRequestDeserializerTests(writer, operationShape) // renderRequestDeserializerTests(writer, operationShape) } private fun renderRequestDeserializer(writer: RustWriter, operationShape: OperationShape) { Loading Loading @@ -562,58 +456,4 @@ class RestJson1HttpDeserializerGenerator( val memberName = binding.memberName.toSnakeCase() return "deser_label_${containerName}_$memberName" } private fun renderRequestDeserializerTests(writer: RustWriter, operationShape: OperationShape) { val testCases = operationShape.getTrait<HttpRequestTestsTrait>()?.getTestCasesFor(AppliesTo.SERVER) ?: return val testModuleName = "deser_${operationShape.id.name.toSnakeCase()}_test" val moduleMeta = RustMetadata( public = false, additionalAttributes = listOf( Attribute.Cfg("test"), Attribute.Custom( "allow(unreachable_code, unused_variables)" ) ) ) writer.write("") writer.withModule(testModuleName, moduleMeta) { testCases.forEach { renderRequestDeserializerTestCase(it, operationShape) } } } private fun RustWriter.renderRequestDeserializerTestCase(testCase: HttpRequestTestCase, operationShape: OperationShape) { TokioTest.render(this) rustBlock("async fn ${testCase.id.toSnakeCase()}()") { val inputShape = operationShape.inputShape(model) val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request" val customToken = inputShape.findMemberWithTrait<IdempotencyTokenTrait>(model)?.let { """.make_token("00000000-0000-4000-8000-000000000000")""" } ?: "" rust("let config = #T::Config::builder()$customToken.build();", RuntimeType.Config) writeInline("let expected = ") instantiator.render(this, inputShape, testCase.params) write(";") rust( """let op = expected.make_operation(&config).await.expect("failed to build operation");""" ) rust("let (request, parts) = op.into_request_response().0.into_parts();") rustTemplate( "let request = request.map(|body| #{Bytes}::from(body.bytes().unwrap().to_vec()));", *codegenScope ) rust("assert_eq!(request.method().as_str(), ${testCase.method.dq()});") rust("assert_eq!(request.uri().to_owned(), ${testCase.uri.dq()});") rust( """let actual = #T(&request).expect("failed to parse request");""", operation.member(deserFnName) ) rust("assert_eq!(expected, actual);") } } } Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CodegenVisitor.kt +8 −2 Original line number Diff line number Diff line Loading @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.server.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpDeserializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.CodegenContext Loading Loading @@ -273,7 +274,12 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC ServiceGenerator( rustCrate, protocolGenerator, protocolGeneratorFactory.support(), ProtocolSupport( requestDeserialization = true, requestBodyDeserialization = true, responseSerialization = true, errorSerialization = true ), codegenContext, codegenDecorator ) Loading @@ -296,7 +302,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC impl Error { ##[allow(dead_code)] fn generic(msg: &'static str) -> Self { pub fn generic(msg: &'static str) -> Self { Self::Generic(msg.into()) } } Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenPlugin.kt +24 −0 Original line number Diff line number Diff line Loading @@ -22,6 +22,11 @@ import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecor import java.util.logging.Level import java.util.logging.Logger /** Rust Codegen Plugin * This is the entrypoint for code generation, triggered by the smithy-build plugin. * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. */ class RustCodegenPlugin : SmithyBuildPlugin { private val logger = Logger.getLogger(javaClass.name) Loading @@ -30,23 +35,42 @@ class RustCodegenPlugin : SmithyBuildPlugin { override fun execute(context: PluginContext) { // Suppress extremely noisy logs about reserved words Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF // Discover `RustCodegenDecorators` on the classpath. `RustCodegenDectorator` return different types of // customization. A customization is a function of: // - location (eg. the mutate section of an operation) // - context (eg. the of the operation) // - writer: The active RustWriter at the given location val codegenDecorator = CombinedCodegenDecorator.fromClasspath(context) // CodegenVistor is the main driver of code generation that traverses the model and generates code CodegenVisitor(context, codegenDecorator).execute() } companion object { /** SymbolProvider * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider * * The Symbol provider is composed of a base `SymbolVisitor` which handles the core funcitonality, then is layered * with other symbol providers, documented inline, to handle the full scope of Smithy types. */ fun baseSymbolProvider( model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig ) = SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) // Generate different types for EventStream shapes (eg. transcribe streaming) .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) } // Generate `ByteStream` instead of `Blob` for streaming binary shapes (eg. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes .let { BaseSymbolMetadataProvider(it) } // Streaming shapes need different derives (eg. they cannot derive Eq) .let { StreamingShapeMetadataProvider(it, model) } // Rename shapes that clash with Rust reserved words & and other SDK specific features eg. `send()` cannot // be the name of an operation input .let { RustReservedWordSymbolProvider(it) } } }
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceGenerator.kt +6 −1 Original line number Diff line number Diff line Loading @@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RustCrate import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator Loading @@ -14,7 +16,6 @@ import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfi import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.util.inputShape class ServiceGenerator( Loading @@ -37,6 +38,10 @@ class ServiceGenerator( operation, decorator.operationCustomizations(context, operation, listOf()) ) // render protocol tests into `operation.rs` (note operationWriter vs. // inputWriter) ProtocolTestGenerator(context, protocolSupport, operation, operationWriter) .render() } } rustCrate.withModule(RustModule.Error) { writer -> Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ProtocolTestGenerator.kt 0 → 100644 +447 −0 File added.Preview size limit exceeded, changes collapsed. Show changes
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/RestJson1.kt +2 −162 Original line number Diff line number Diff line Loading @@ -6,22 +6,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait import software.amazon.smithy.model.traits.IdempotencyTokenTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.protocoltests.traits.AppliesTo import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.CratesIo import software.amazon.smithy.rust.codegen.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable Loading @@ -47,13 +39,10 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember import software.amazon.smithy.rust.codegen.testutil.TokioTest import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.findMemberWithTrait import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase Loading Loading @@ -90,7 +79,6 @@ class RestJson1HttpSerializerGenerator( fun render(writer: RustWriter, operationShape: OperationShape) { renderResponseSerializer(writer, operationShape) renderErrorSerializer(writer, operationShape) renderTests(writer, operationShape) } private fun renderResponseSerializer(writer: RustWriter, operationShape: OperationShape) { Loading Loading @@ -121,7 +109,7 @@ class RestJson1HttpSerializerGenerator( HttpLocation.RESPONSE_CODE -> { rustTemplate( """ let status = output.${it.memberName} let status = output.${it.memberName.toLowerCase()} .ok_or(#{JsonSerdeError}::generic(${(it.member.memberName + " missing or empty").dq()}))?; let http_status: u16 = #{Convert}::TryFrom::<i32>::try_from(status) .map_err(|_| #{JsonSerdeError}::generic(${("invalid status code").dq()}))?; Loading Loading @@ -221,100 +209,6 @@ class RestJson1HttpSerializerGenerator( ) } } private fun renderTests(writer: RustWriter, operationShape: OperationShape) { val operationIndex = OperationIndex.of(model) val outputShape = operationShape.outputShape(model) val responseTests = operationShape .getTrait<HttpResponseTestsTrait>() ?.getTestCasesFor(AppliesTo.SERVER) .orEmpty() .map { it to outputShape } val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> error.getTrait<HttpResponseTestsTrait>()?.testCases.orEmpty().map { it to error } } if (responseTests.isEmpty() && errorTests.isEmpty()) { return } val testModuleName = "serialize_${operationShape.id.name.toSnakeCase()}_test" val moduleMeta = RustMetadata( public = false, additionalAttributes = listOf( Attribute.Cfg("test"), Attribute.Custom( "allow(unreachable_code, unused_variables)" ) ) ) writer.write("") writer.withModule(testModuleName, moduleMeta) { responseTests.forEach { renderSerializeResponseTestCase(operationShape, it.first, it.second) } errorTests.forEach { renderSerializeResponseTestCase(operationShape, it.first, it.second) } } } private fun RustWriter.renderSerializeResponseTestCase( operationShape: OperationShape, testCase: HttpResponseTestCase, shape: StructureShape ) { val isError = shape.hasTrait<ErrorTrait>() val fnName = if (isError) "serialize_${operationShape.id.name.toSnakeCase()}_error" else "serialize_${operationShape.id.name.toSnakeCase()}_response" val variantName = if (isError) "${format(operationShape.errorSymbol(symbolProvider))}Kind::${symbolProvider.toSymbol(shape).name}" else "" Attribute.Custom("test").render(this) rustBlock("fn ${testCase.id.toSnakeCase()}()") { rust("let config = #T::Config::builder().build();", RuntimeType.Config) writeInline("let expected = ") instantiator.render(this, shape, testCase.params) write(";") if (isError) { rust("let expected = $variantName(expected);") } rust( """let response = #T(&expected).expect("failed to serialize response");""", operation.member(fnName) ) rust("assert_eq!(response.status(), ${testCase.code});") rustTemplate( """ use #{ParseHttpResponse}; let parser = #{op}::new(); let actual = parser.parse_loaded(&response); """.trimIndent(), *codegenScope, "op" to symbolProvider.toSymbol(operationShape), ) if (isError) { rust("""let actual = actual.expect_err("failed to parse error");""") rust( """ match (&expected, &actual.kind) { ($variantName(expected), $variantName(actual)) => assert_eq!(expected, actual), _ => panic!("incorrect error type"), }; """.trimIndent() ) } else { rust("""let actual = actual.expect("failed to parse error");""") rust("assert_eq!(expected, actual);") } } } } class RestJson1HttpDeserializerGenerator( Loading Loading @@ -348,7 +242,7 @@ class RestJson1HttpDeserializerGenerator( fun render(writer: RustWriter, operationShape: OperationShape) { renderRequestDeserializer(writer, operationShape) renderRequestDeserializerTests(writer, operationShape) // renderRequestDeserializerTests(writer, operationShape) } private fun renderRequestDeserializer(writer: RustWriter, operationShape: OperationShape) { Loading Loading @@ -562,58 +456,4 @@ class RestJson1HttpDeserializerGenerator( val memberName = binding.memberName.toSnakeCase() return "deser_label_${containerName}_$memberName" } private fun renderRequestDeserializerTests(writer: RustWriter, operationShape: OperationShape) { val testCases = operationShape.getTrait<HttpRequestTestsTrait>()?.getTestCasesFor(AppliesTo.SERVER) ?: return val testModuleName = "deser_${operationShape.id.name.toSnakeCase()}_test" val moduleMeta = RustMetadata( public = false, additionalAttributes = listOf( Attribute.Cfg("test"), Attribute.Custom( "allow(unreachable_code, unused_variables)" ) ) ) writer.write("") writer.withModule(testModuleName, moduleMeta) { testCases.forEach { renderRequestDeserializerTestCase(it, operationShape) } } } private fun RustWriter.renderRequestDeserializerTestCase(testCase: HttpRequestTestCase, operationShape: OperationShape) { TokioTest.render(this) rustBlock("async fn ${testCase.id.toSnakeCase()}()") { val inputShape = operationShape.inputShape(model) val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request" val customToken = inputShape.findMemberWithTrait<IdempotencyTokenTrait>(model)?.let { """.make_token("00000000-0000-4000-8000-000000000000")""" } ?: "" rust("let config = #T::Config::builder()$customToken.build();", RuntimeType.Config) writeInline("let expected = ") instantiator.render(this, inputShape, testCase.params) write(";") rust( """let op = expected.make_operation(&config).await.expect("failed to build operation");""" ) rust("let (request, parts) = op.into_request_response().0.into_parts();") rustTemplate( "let request = request.map(|body| #{Bytes}::from(body.bytes().unwrap().to_vec()));", *codegenScope ) rust("assert_eq!(request.method().as_str(), ${testCase.method.dq()});") rust("assert_eq!(request.uri().to_owned(), ${testCase.uri.dq()});") rust( """let actual = #T(&request).expect("failed to parse request");""", operation.member(deserFnName) ) rust("assert_eq!(expected, actual);") } } }