Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +1 −0 Original line number Diff line number Diff line Loading @@ -114,6 +114,7 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") val HttpRequestBuilder = Http("request::Builder") val HttpResponseBuilder = Http("response::Builder") val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde") val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde") Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +160 −55 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.generators import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.lang.Custom import software.amazon.smithy.rust.codegen.lang.RustMetadata import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.docs import software.amazon.smithy.rust.codegen.lang.escape import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.outputShape import java.util.logging.Logger data class ProtocolSupport( val requestBodySerialization: Boolean val requestBodySerialization: Boolean, val responseDeserialization: Boolean, val errorDeserialization: Boolean ) /** * Generate protocol tests for an operation */ Loading @@ -28,6 +40,7 @@ class HttpProtocolTestGenerator( private val writer: RustWriter ) { private val logger = Logger.getLogger(javaClass.name) // TODO: remove these once Smithy publishes fixes. // These tests are not even attempted to be compiled val DisableTests = setOf( Loading Loading @@ -55,15 +68,35 @@ class HttpProtocolTestGenerator( "RestJsonHttpPrefixHeadersArePresent" // https://github.com/awslabs/smithy-rs/issues/35 ) private val inputShape = operationShape.inputShape(protocolConfig.model) fun render() { operationShape.getTrait(HttpRequestTestsTrait::class.java).map { renderHttpRequestTests(it) private val outputShape = operationShape.outputShape(protocolConfig.model) private val operationSymbol = protocolConfig.symbolProvider.toSymbol(operationShape) private val operationIndex = OperationIndex.of(protocolConfig.model) private val instantiator = with(protocolConfig) { Instantiator(symbolProvider, model, runtimeConfig) } sealed class TestCase { abstract val testCase: HttpMessageTestCase data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase() data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() } private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) { with(protocolConfig) { val operationName = symbolProvider.toSymbol(operationShape).name fun render() { val requestTests = operationShape.getTrait(HttpRequestTestsTrait::class.java) .orNull()?.testCases.orEmpty().map { TestCase.RequestTest(it) } val responseTests = operationShape.getTrait(HttpResponseTestsTrait::class.java) .orNull()?.testCases.orEmpty().map { TestCase.ResponseTest(it, outputShape) } val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> val testCases = error.getTrait(HttpResponseTestsTrait::class.java).orNull()?.testCases.orEmpty() testCases.map { TestCase.ResponseTest(it, error) } } val allTests: List<TestCase> = (requestTests + responseTests + errorTests).filterMatching() if (allTests.isNotEmpty()) { val operationName = operationSymbol.name val testModuleName = "${operationName.toSnakeCase()}_request_test" val moduleMeta = RustMetadata( public = false, Loading @@ -73,33 +106,58 @@ class HttpProtocolTestGenerator( ) ) writer.withModule(testModuleName, moduleMeta) { httpRequestTestsTrait.testCases.filter { it.protocol == protocol } .filter { !DisableTests.contains(it.id) }.forEach { testCase -> try { renderHttpRequestTestCase(testCase, this) } catch (ex: Exception) { println("failed to generate ${testCase.id}") ex.printStackTrace() renderAllTestCases(allTests) } } } private fun RustWriter.renderAllTestCases(allTests: List<TestCase>) { allTests.forEach { renderTestCaseBlock(it.testCase, this) { when (it) { is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) } } } } private val instantiator = with(protocolConfig) { Instantiator(symbolProvider, model, runtimeConfig) /** * Filter out test cases that are disabled or don't match the service protocol */ private fun List<TestCase>.filterMatching(): List<TestCase> = this.filter { testCase -> testCase.testCase.protocol == protocolConfig.protocol && !DisableTests.contains(testCase.testCase.id) } private fun renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase, testModuleWriter: RustWriter) { httpRequestTestCase.documentation.map { testModuleWriter.docs(testModuleWriter.escape(it)) private fun renderTestCaseBlock( testCase: HttpMessageTestCase, testModuleWriter: RustWriter, block: RustWriter.() -> Unit ) { testModuleWriter.setNewlinePrefix("/// ") testCase.documentation.map { testModuleWriter.writeWithNoFormatting(it) } testModuleWriter.docs("Test ID: ${httpRequestTestCase.id}") testModuleWriter.write("Test ID: ${testCase.id}") testModuleWriter.setNewlinePrefix("") testModuleWriter.writeWithNoFormatting("#[test]") if (ExpectFail.contains(httpRequestTestCase.id)) { if (ExpectFail.contains(testCase.id)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { val fnName = when (testCase) { is HttpResponseTestCase -> "_response" is HttpRequestTestCase -> "_request" else -> throw CodegenException("unknown test case type") } testModuleWriter.rustBlock("fn test_${testCase.id.toSnakeCase()}$fnName()") { block(this) } } private fun RustWriter.renderHttpRequestTestCase( httpRequestTestCase: HttpRequestTestCase ) { writeInline("let input =") instantiator.render(this, inputShape, httpRequestTestCase.params) write(";") Loading Loading @@ -136,6 +194,51 @@ class HttpProtocolTestGenerator( } } } private fun RustWriter.renderHttpResponseTestCase( httpResponseTestCase: HttpResponseTestCase, expectedShape: StructureShape ) { if (!protocolSupport.responseDeserialization || ( !protocolSupport.errorDeserialization && expectedShape.hasTrait( ErrorTrait::class.java ) ) ) { rust("/* test case disabled for this protocol (not yet supported) */") if (ExpectFail.contains(httpResponseTestCase.id)) { // this test needs to fail, minor hack. Caused by overlap between ids of request & response tests write("todo!()") } return } writeInline("let expected_output =") instantiator.render(this, expectedShape, httpResponseTestCase.params) write(";") write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder) httpResponseTestCase.headers.forEach { (key, value) -> writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})") } rust( """ .status(${httpResponseTestCase.code}) .body(${httpResponseTestCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"}) .unwrap(); """ ) write("let parsed = #T::from_response(http_response);", operationSymbol) if (expectedShape.hasTrait(ErrorTrait::class.java)) { val errorSymbol = operationShape.errorSymbol(protocolConfig.symbolProvider) val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name rustBlock("if let Err(#T::$errorVariant(actual_error)) = parsed", errorSymbol) { write("assert_eq!(expected_output, actual_error);") } rustBlock("else") { write("panic!(\"wrong variant: {:?}\", parsed);") } } else { write("assert_eq!(parsed.unwrap(), expected_output);") } } private fun checkRequiredHeaders(rustWriter: RustWriter, requireHeaders: List<String>) { Loading @@ -154,7 +257,9 @@ class HttpProtocolTestGenerator( // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion assertOk(rustWriter) { rustWriter.write( "#T(input.build_body(), ${rustWriter.escape(body).dq()}, #T::from(${(mediaType ?: "unknown").dq()}))", "#T(input.build_body(), ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_body"), RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "MediaType") ) Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +1 −1 Original line number Diff line number Diff line Loading @@ -73,7 +73,7 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat ) } override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true) override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true, responseDeserialization = false, errorDeserialization = false) } /** Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +1 −1 Original line number Diff line number Diff line Loading @@ -36,7 +36,7 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> { override fun support(): ProtocolSupport { // TODO: Support body for RestJson return ProtocolSupport(requestBodySerialization = false) return ProtocolSupport(requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false) } } Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt +5 −0 Original line number Diff line number Diff line Loading @@ -14,3 +14,8 @@ fun OperationShape.inputShape(model: Model): StructureShape { // The Rust Smithy generator adds an input to all shapes automatically return model.expectShape(this.input.get(), StructureShape::class.java) } fun OperationShape.outputShape(model: Model): StructureShape { // The Rust Smithy generator adds an output to all shapes automatically return model.expectShape(this.output.get(), StructureShape::class.java) } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +1 −0 Original line number Diff line number Diff line Loading @@ -114,6 +114,7 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") val HttpRequestBuilder = Http("request::Builder") val HttpResponseBuilder = Http("response::Builder") val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde") val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde") Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +160 −55 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.generators import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.lang.Custom import software.amazon.smithy.rust.codegen.lang.RustMetadata import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.docs import software.amazon.smithy.rust.codegen.lang.escape import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.outputShape import java.util.logging.Logger data class ProtocolSupport( val requestBodySerialization: Boolean val requestBodySerialization: Boolean, val responseDeserialization: Boolean, val errorDeserialization: Boolean ) /** * Generate protocol tests for an operation */ Loading @@ -28,6 +40,7 @@ class HttpProtocolTestGenerator( private val writer: RustWriter ) { private val logger = Logger.getLogger(javaClass.name) // TODO: remove these once Smithy publishes fixes. // These tests are not even attempted to be compiled val DisableTests = setOf( Loading Loading @@ -55,15 +68,35 @@ class HttpProtocolTestGenerator( "RestJsonHttpPrefixHeadersArePresent" // https://github.com/awslabs/smithy-rs/issues/35 ) private val inputShape = operationShape.inputShape(protocolConfig.model) fun render() { operationShape.getTrait(HttpRequestTestsTrait::class.java).map { renderHttpRequestTests(it) private val outputShape = operationShape.outputShape(protocolConfig.model) private val operationSymbol = protocolConfig.symbolProvider.toSymbol(operationShape) private val operationIndex = OperationIndex.of(protocolConfig.model) private val instantiator = with(protocolConfig) { Instantiator(symbolProvider, model, runtimeConfig) } sealed class TestCase { abstract val testCase: HttpMessageTestCase data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase() data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() } private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) { with(protocolConfig) { val operationName = symbolProvider.toSymbol(operationShape).name fun render() { val requestTests = operationShape.getTrait(HttpRequestTestsTrait::class.java) .orNull()?.testCases.orEmpty().map { TestCase.RequestTest(it) } val responseTests = operationShape.getTrait(HttpResponseTestsTrait::class.java) .orNull()?.testCases.orEmpty().map { TestCase.ResponseTest(it, outputShape) } val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> val testCases = error.getTrait(HttpResponseTestsTrait::class.java).orNull()?.testCases.orEmpty() testCases.map { TestCase.ResponseTest(it, error) } } val allTests: List<TestCase> = (requestTests + responseTests + errorTests).filterMatching() if (allTests.isNotEmpty()) { val operationName = operationSymbol.name val testModuleName = "${operationName.toSnakeCase()}_request_test" val moduleMeta = RustMetadata( public = false, Loading @@ -73,33 +106,58 @@ class HttpProtocolTestGenerator( ) ) writer.withModule(testModuleName, moduleMeta) { httpRequestTestsTrait.testCases.filter { it.protocol == protocol } .filter { !DisableTests.contains(it.id) }.forEach { testCase -> try { renderHttpRequestTestCase(testCase, this) } catch (ex: Exception) { println("failed to generate ${testCase.id}") ex.printStackTrace() renderAllTestCases(allTests) } } } private fun RustWriter.renderAllTestCases(allTests: List<TestCase>) { allTests.forEach { renderTestCaseBlock(it.testCase, this) { when (it) { is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) } } } } private val instantiator = with(protocolConfig) { Instantiator(symbolProvider, model, runtimeConfig) /** * Filter out test cases that are disabled or don't match the service protocol */ private fun List<TestCase>.filterMatching(): List<TestCase> = this.filter { testCase -> testCase.testCase.protocol == protocolConfig.protocol && !DisableTests.contains(testCase.testCase.id) } private fun renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase, testModuleWriter: RustWriter) { httpRequestTestCase.documentation.map { testModuleWriter.docs(testModuleWriter.escape(it)) private fun renderTestCaseBlock( testCase: HttpMessageTestCase, testModuleWriter: RustWriter, block: RustWriter.() -> Unit ) { testModuleWriter.setNewlinePrefix("/// ") testCase.documentation.map { testModuleWriter.writeWithNoFormatting(it) } testModuleWriter.docs("Test ID: ${httpRequestTestCase.id}") testModuleWriter.write("Test ID: ${testCase.id}") testModuleWriter.setNewlinePrefix("") testModuleWriter.writeWithNoFormatting("#[test]") if (ExpectFail.contains(httpRequestTestCase.id)) { if (ExpectFail.contains(testCase.id)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { val fnName = when (testCase) { is HttpResponseTestCase -> "_response" is HttpRequestTestCase -> "_request" else -> throw CodegenException("unknown test case type") } testModuleWriter.rustBlock("fn test_${testCase.id.toSnakeCase()}$fnName()") { block(this) } } private fun RustWriter.renderHttpRequestTestCase( httpRequestTestCase: HttpRequestTestCase ) { writeInline("let input =") instantiator.render(this, inputShape, httpRequestTestCase.params) write(";") Loading Loading @@ -136,6 +194,51 @@ class HttpProtocolTestGenerator( } } } private fun RustWriter.renderHttpResponseTestCase( httpResponseTestCase: HttpResponseTestCase, expectedShape: StructureShape ) { if (!protocolSupport.responseDeserialization || ( !protocolSupport.errorDeserialization && expectedShape.hasTrait( ErrorTrait::class.java ) ) ) { rust("/* test case disabled for this protocol (not yet supported) */") if (ExpectFail.contains(httpResponseTestCase.id)) { // this test needs to fail, minor hack. Caused by overlap between ids of request & response tests write("todo!()") } return } writeInline("let expected_output =") instantiator.render(this, expectedShape, httpResponseTestCase.params) write(";") write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder) httpResponseTestCase.headers.forEach { (key, value) -> writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})") } rust( """ .status(${httpResponseTestCase.code}) .body(${httpResponseTestCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"}) .unwrap(); """ ) write("let parsed = #T::from_response(http_response);", operationSymbol) if (expectedShape.hasTrait(ErrorTrait::class.java)) { val errorSymbol = operationShape.errorSymbol(protocolConfig.symbolProvider) val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name rustBlock("if let Err(#T::$errorVariant(actual_error)) = parsed", errorSymbol) { write("assert_eq!(expected_output, actual_error);") } rustBlock("else") { write("panic!(\"wrong variant: {:?}\", parsed);") } } else { write("assert_eq!(parsed.unwrap(), expected_output);") } } private fun checkRequiredHeaders(rustWriter: RustWriter, requireHeaders: List<String>) { Loading @@ -154,7 +257,9 @@ class HttpProtocolTestGenerator( // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion assertOk(rustWriter) { rustWriter.write( "#T(input.build_body(), ${rustWriter.escape(body).dq()}, #T::from(${(mediaType ?: "unknown").dq()}))", "#T(input.build_body(), ${ rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_body"), RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "MediaType") ) Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +1 −1 Original line number Diff line number Diff line Loading @@ -73,7 +73,7 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat ) } override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true) override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true, responseDeserialization = false, errorDeserialization = false) } /** Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +1 −1 Original line number Diff line number Diff line Loading @@ -36,7 +36,7 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> { override fun support(): ProtocolSupport { // TODO: Support body for RestJson return ProtocolSupport(requestBodySerialization = false) return ProtocolSupport(requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false) } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt +5 −0 Original line number Diff line number Diff line Loading @@ -14,3 +14,8 @@ fun OperationShape.inputShape(model: Model): StructureShape { // The Rust Smithy generator adds an input to all shapes automatically return model.expectShape(this.input.get(), StructureShape::class.java) } fun OperationShape.outputShape(model: Model): StructureShape { // The Rust Smithy generator adds an output to all shapes automatically return model.expectShape(this.output.get(), StructureShape::class.java) }