diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index aed98437ca264cdd2d03958e78c10af8f50f3790..a0e8445d4b90a143c3818818a348b5d71538624a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -114,6 +114,7 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") val HttpRequestBuilder = Http("request::Builder") + val HttpResponseBuilder = Http("response::Builder") val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde") val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt index 90300319c41f7f13caaf843797b39b4fd8864a69..0e089cf2d07d2ab4584ce6a6e4c90abe974f8b4d 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt @@ -1,23 +1,35 @@ package software.amazon.smithy.rust.codegen.smithy.generators +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait +import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase +import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.lang.Custom import software.amazon.smithy.rust.codegen.lang.RustMetadata import software.amazon.smithy.rust.codegen.lang.RustWriter -import software.amazon.smithy.rust.codegen.lang.docs import software.amazon.smithy.rust.codegen.lang.escape +import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.orNull +import software.amazon.smithy.rust.codegen.util.outputShape import java.util.logging.Logger data class ProtocolSupport( - val requestBodySerialization: Boolean + val requestBodySerialization: Boolean, + val responseDeserialization: Boolean, + val errorDeserialization: Boolean ) + /** * Generate protocol tests for an operation */ @@ -28,6 +40,7 @@ class HttpProtocolTestGenerator( private val writer: RustWriter ) { private val logger = Logger.getLogger(javaClass.name) + // TODO: remove these once Smithy publishes fixes. // These tests are not even attempted to be compiled val DisableTests = setOf( @@ -55,15 +68,35 @@ class HttpProtocolTestGenerator( "RestJsonHttpPrefixHeadersArePresent" // https://github.com/awslabs/smithy-rs/issues/35 ) private val inputShape = operationShape.inputShape(protocolConfig.model) - fun render() { - operationShape.getTrait(HttpRequestTestsTrait::class.java).map { - renderHttpRequestTests(it) - } + private val outputShape = operationShape.outputShape(protocolConfig.model) + private val operationSymbol = protocolConfig.symbolProvider.toSymbol(operationShape) + private val operationIndex = OperationIndex.of(protocolConfig.model) + + private val instantiator = with(protocolConfig) { + Instantiator(symbolProvider, model, runtimeConfig) } - private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) { - with(protocolConfig) { - val operationName = symbolProvider.toSymbol(operationShape).name + sealed class TestCase { + abstract val testCase: HttpMessageTestCase + + data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase() + data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) : + TestCase() + } + + fun render() { + val requestTests = operationShape.getTrait(HttpRequestTestsTrait::class.java) + .orNull()?.testCases.orEmpty().map { TestCase.RequestTest(it) } + val responseTests = operationShape.getTrait(HttpResponseTestsTrait::class.java) + .orNull()?.testCases.orEmpty().map { TestCase.ResponseTest(it, outputShape) } + + val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> + val testCases = error.getTrait(HttpResponseTestsTrait::class.java).orNull()?.testCases.orEmpty() + testCases.map { TestCase.ResponseTest(it, error) } + } + val allTests: List = (requestTests + responseTests + errorTests).filterMatching() + if (allTests.isNotEmpty()) { + val operationName = operationSymbol.name val testModuleName = "${operationName.toSnakeCase()}_request_test" val moduleMeta = RustMetadata( public = false, @@ -73,68 +106,138 @@ class HttpProtocolTestGenerator( ) ) writer.withModule(testModuleName, moduleMeta) { - httpRequestTestsTrait.testCases.filter { it.protocol == protocol } - .filter { !DisableTests.contains(it.id) }.forEach { testCase -> - try { - renderHttpRequestTestCase(testCase, this) - } catch (ex: Exception) { - println("failed to generate ${testCase.id}") - ex.printStackTrace() - } - } + renderAllTestCases(allTests) } } } - private val instantiator = with(protocolConfig) { - Instantiator(symbolProvider, model, runtimeConfig) + private fun RustWriter.renderAllTestCases(allTests: List) { + allTests.forEach { + renderTestCaseBlock(it.testCase, this) { + when (it) { + is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase) + is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape) + } + } + } } - private fun renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase, testModuleWriter: RustWriter) { - httpRequestTestCase.documentation.map { - testModuleWriter.docs(testModuleWriter.escape(it)) + /** + * Filter out test cases that are disabled or don't match the service protocol + */ + private fun List.filterMatching(): List = this.filter { testCase -> + testCase.testCase.protocol == protocolConfig.protocol && + !DisableTests.contains(testCase.testCase.id) + } + + private fun renderTestCaseBlock( + testCase: HttpMessageTestCase, + testModuleWriter: RustWriter, + block: RustWriter.() -> Unit + ) { + testModuleWriter.setNewlinePrefix("/// ") + testCase.documentation.map { + testModuleWriter.writeWithNoFormatting(it) } - testModuleWriter.docs("Test ID: ${httpRequestTestCase.id}") + testModuleWriter.write("Test ID: ${testCase.id}") + testModuleWriter.setNewlinePrefix("") testModuleWriter.writeWithNoFormatting("#[test]") - if (ExpectFail.contains(httpRequestTestCase.id)) { + if (ExpectFail.contains(testCase.id)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } - testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { - writeInline("let input =") - instantiator.render(this, inputShape, httpRequestTestCase.params) - write(";") - if (protocolSupport.requestBodySerialization) { - write("let http_request = ${protocolConfig.symbolProvider.toSymbol(inputShape).name}::assemble(input.request_builder_base(), input.build_body());") - } else { - write("let http_request = ${protocolConfig.symbolProvider.toSymbol(inputShape).name}::assemble(input.request_builder_base(), vec![]);") - } - with(httpRequestTestCase) { - write( - """ + val fnName = when (testCase) { + is HttpResponseTestCase -> "_response" + is HttpRequestTestCase -> "_request" + else -> throw CodegenException("unknown test case type") + } + testModuleWriter.rustBlock("fn test_${testCase.id.toSnakeCase()}$fnName()") { + block(this) + } + } + + private fun RustWriter.renderHttpRequestTestCase( + httpRequestTestCase: HttpRequestTestCase + ) { + writeInline("let input =") + instantiator.render(this, inputShape, httpRequestTestCase.params) + write(";") + if (protocolSupport.requestBodySerialization) { + write("let http_request = ${protocolConfig.symbolProvider.toSymbol(inputShape).name}::assemble(input.request_builder_base(), input.build_body());") + } else { + write("let http_request = ${protocolConfig.symbolProvider.toSymbol(inputShape).name}::assemble(input.request_builder_base(), vec![]);") + } + with(httpRequestTestCase) { + write( + """ assert_eq!(http_request.method(), ${method.dq()}); assert_eq!(http_request.uri().path(), ${uri.dq()}); """ - ) + ) + } + checkQueryParams(this, httpRequestTestCase.queryParams) + checkForbidQueryParams(this, httpRequestTestCase.forbidQueryParams) + checkRequiredQueryParams(this, httpRequestTestCase.requireQueryParams) + checkHeaders(this, httpRequestTestCase.headers) + checkForbidHeaders(this, httpRequestTestCase.forbidHeaders) + checkRequiredHeaders(this, httpRequestTestCase.requireHeaders) + if (protocolSupport.requestBodySerialization) { + checkBody(this, httpRequestTestCase.body.orElse(""), httpRequestTestCase.bodyMediaType.orElse(null)) + } + + // Explicitly warn if the test case defined parameters that we aren't doing anything with + with(httpRequestTestCase) { + if (authScheme.isPresent) { + logger.warning("Test case provided authScheme but this was ignored") } - checkQueryParams(this, httpRequestTestCase.queryParams) - checkForbidQueryParams(this, httpRequestTestCase.forbidQueryParams) - checkRequiredQueryParams(this, httpRequestTestCase.requireQueryParams) - checkHeaders(this, httpRequestTestCase.headers) - checkForbidHeaders(this, httpRequestTestCase.forbidHeaders) - checkRequiredHeaders(this, httpRequestTestCase.requireHeaders) - if (protocolSupport.requestBodySerialization) { - checkBody(this, httpRequestTestCase.body.orElse(""), httpRequestTestCase.bodyMediaType.orElse(null)) + if (!httpRequestTestCase.vendorParams.isEmpty) { + logger.warning("Test case provided vendorParams but these were ignored") } + } + } - // Explicitly warn if the test case defined parameters that we aren't doing anything with - with(httpRequestTestCase) { - if (authScheme.isPresent) { - logger.warning("Test case provided authScheme but this was ignored") - } - if (!httpRequestTestCase.vendorParams.isEmpty) { - logger.warning("Test case provided vendorParams but these were ignored") - } + private fun RustWriter.renderHttpResponseTestCase( + httpResponseTestCase: HttpResponseTestCase, + expectedShape: StructureShape + ) { + if (!protocolSupport.responseDeserialization || ( + !protocolSupport.errorDeserialization && expectedShape.hasTrait( + ErrorTrait::class.java + ) + ) + ) { + rust("/* test case disabled for this protocol (not yet supported) */") + if (ExpectFail.contains(httpResponseTestCase.id)) { + // this test needs to fail, minor hack. Caused by overlap between ids of request & response tests + write("todo!()") } + return + } + writeInline("let expected_output =") + instantiator.render(this, expectedShape, httpResponseTestCase.params) + write(";") + write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder) + httpResponseTestCase.headers.forEach { (key, value) -> + writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})") + } + rust( + """ + .status(${httpResponseTestCase.code}) + .body(${httpResponseTestCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"}) + .unwrap(); + """ + ) + write("let parsed = #T::from_response(http_response);", operationSymbol) + if (expectedShape.hasTrait(ErrorTrait::class.java)) { + val errorSymbol = operationShape.errorSymbol(protocolConfig.symbolProvider) + val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name + rustBlock("if let Err(#T::$errorVariant(actual_error)) = parsed", errorSymbol) { + write("assert_eq!(expected_output, actual_error);") + } + rustBlock("else") { + write("panic!(\"wrong variant: {:?}\", parsed);") + } + } else { + write("assert_eq!(parsed.unwrap(), expected_output);") } } @@ -154,7 +257,9 @@ class HttpProtocolTestGenerator( // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion assertOk(rustWriter) { rustWriter.write( - "#T(input.build_body(), ${rustWriter.escape(body).dq()}, #T::from(${(mediaType ?: "unknown").dq()}))", + "#T(input.build_body(), ${ + rustWriter.escape(body).dq() + }, #T::from(${(mediaType ?: "unknown").dq()}))", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_body"), RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "MediaType") ) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt index 29b6ebbc050b84f6f68058deb04b686f92c64e14..76fbbbe9ee1d589bf1b7be58bcfb2a22300a0e28 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt @@ -73,7 +73,7 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat ) } - override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true) + override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true, responseDeserialization = false, errorDeserialization = false) } /** diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt index 5ce14b8b0bd30fb3bf943f1f1ceec95cceda093f..7e7e39748957f66615e0dbfd5d0e8d629f32155c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt @@ -36,7 +36,7 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory { override fun support(): ProtocolSupport { // TODO: Support body for RestJson - return ProtocolSupport(requestBodySerialization = false) + return ProtocolSupport(requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt index 134ff9add0883c2862616a2f938f3a2ac17533fa..090b06cc70c2a052a5d72452626b5cb82dadbbc2 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt @@ -14,3 +14,8 @@ fun OperationShape.inputShape(model: Model): StructureShape { // The Rust Smithy generator adds an input to all shapes automatically return model.expectShape(this.input.get(), StructureShape::class.java) } + +fun OperationShape.outputShape(model: Model): StructureShape { + // The Rust Smithy generator adds an output to all shapes automatically + return model.expectShape(this.output.get(), StructureShape::class.java) +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt index 1ac13ede555ef80033253c602ce49ddb61b93486..5fc32b6ae296846bd394102a547c0e0700ce0334 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt @@ -1,17 +1,19 @@ package software.amazon.smithy.rust.codegen.smithy.generators import io.kotest.matchers.string.shouldContain -import org.intellij.lang.annotations.Language import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.testutil.TestRuntimeConfig import software.amazon.smithy.rust.testutil.asSmithyModel import software.amazon.smithy.rust.testutil.compileAndTest @@ -23,6 +25,7 @@ class HttpProtocolTestGeneratorTest { use aws.protocols#restJson1 use smithy.test#httpRequestTests + use smithy.test#httpResponseTests @restJson1 service HelloService { @@ -56,8 +59,29 @@ class HttpProtocolTestGeneratorTest { bodyMediaType: "application/json" } ]) + @httpResponseTests([{ + id: "basic_response_test", + protocol: restJson1, + documentation: "Parses operations with empty JSON bodies", + body: "{\"value\": \"hey there!\"}", + params: {"value": "hey there!"}, + bodyMediaType: "application/json", + headers: {"Content-Type": "application/x-amz-json-1.1"}, + code: 200, + }]) operation SayHello { - input: SayHelloInput + input: SayHelloInput, + output: SayHelloOutput, + errors: [BadRequest] + } + + structure SayHelloOutput { + value: String + } + + @error("client") + structure BadRequest { + message: String } structure SayHelloInput { @@ -84,14 +108,31 @@ class HttpProtocolTestGeneratorTest { private fun writeHttpImpl( writer: RustWriter, httpRequestBuilder: String, - @Language( - "Rust", - prefix = "fn foo() -> String {", - suffix = "}" - ) body: String = "${correctBody.dq()}.to_string()" + body: String = "${correctBody.dq()}.to_string()", + correctResponse: String = """Ok(SayHelloOutput::builder().value("hey there!").build())""" ) { + val operation = model.lookup("com.example#SayHello") + writer.withModule("error") { + StructureGenerator(model, symbolProvider, this, model.lookup("com.example#BadRequest")).render() + CombinedErrorGenerator(model, symbolProvider, operation).render(this) + } writer.withModule("operation") { StructureGenerator(model, symbolProvider, this, model.lookup("com.example#SayHelloInput")).render() + StructureGenerator(model, symbolProvider, this, model.lookup("com.example#SayHelloOutput")).render() + rustBlock("pub struct SayHello") { + rust("_input: SayHelloInput") + } + rustBlock("impl SayHello") { + rustBlock( + "pub fn from_response(_response: #T>) -> Result<#T, #T>", + + RuntimeType.Http("response::Response"), + symbolProvider.toSymbol(operation.outputShape(model)), + operation.errorSymbol(symbolProvider) + ) { + writeWithNoFormatting(correctResponse) + } + } rustBlock("impl SayHelloInput") { rustBlock("pub fn request_builder_base(&self) -> #T", RuntimeType.HttpRequestBuilder) { write("#T::new()", RuntimeType.HttpRequestBuilder) @@ -100,7 +141,11 @@ class HttpProtocolTestGeneratorTest { rustBlock("pub fn build_body(&self) -> String") { write(body) } - rustBlock("pub fn assemble>>(builder: #T, body: T) -> #T>", RuntimeType.HttpRequestBuilder, RuntimeType.Http("request::Request")) { + rustBlock( + "pub fn assemble>>(builder: #T, body: T) -> #T>", + RuntimeType.HttpRequestBuilder, + RuntimeType.Http("request::Request") + ) { write("let body = body.into();") write("builder.header(#T, body.len()).body(body)", RuntimeType.Http("header::CONTENT_LENGTH")) write(""".expect("http request should be valid")""") @@ -115,7 +160,11 @@ class HttpProtocolTestGeneratorTest { ) HttpProtocolTestGenerator( protocolConfig, - ProtocolSupport(requestBodySerialization = true), + ProtocolSupport( + requestBodySerialization = true, + responseDeserialization = true, + errorDeserialization = false + ), model.lookup("com.example#SayHello"), this ).render() @@ -136,7 +185,26 @@ class HttpProtocolTestGeneratorTest { val testOutput = writer.compileAndTest() // Verify the test actually ran - testOutput shouldContain "test_say_hello ... ok" + testOutput shouldContain "test_say_hello_request ... ok" + } + + @Test + fun `test incorrect response parsing`() { + val writer = RustWriter.root() + writeHttpImpl( + writer, + """ + .uri("/?Hi=Hello%20there&required") + .header("X-Greeting", "Hi") + .method("POST") + """, + correctResponse = "Ok(SayHelloOutput::builder().build())" + ) + val err = assertThrows { + writer.compileAndTest(expectFailure = true) + } + + err.message shouldContain "test_basic_response_test_response ... FAILED" } @Test @@ -158,7 +226,7 @@ class HttpProtocolTestGeneratorTest { writer.compileAndTest(expectFailure = true) } - err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "test_say_hello_request ... FAILED" err.message shouldContain "body did not match" } @@ -180,7 +248,7 @@ class HttpProtocolTestGeneratorTest { writer.compileAndTest(expectFailure = true) } // Verify the test actually ran - err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "test_say_hello_request ... FAILED" err.message shouldContain "missing query param" } @@ -202,7 +270,7 @@ class HttpProtocolTestGeneratorTest { writer.compileAndTest(expectFailure = true) } // Verify the test actually ran - err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "test_say_hello_request ... FAILED" err.message shouldContain "forbidden query param" } @@ -224,7 +292,7 @@ class HttpProtocolTestGeneratorTest { writer.compileAndTest(expectFailure = true) } // Verify the test actually ran - err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "test_say_hello_request ... FAILED" err.message shouldContain "required query param missing" } @@ -244,7 +312,7 @@ class HttpProtocolTestGeneratorTest { val err = assertThrows { writer.compileAndTest(expectFailure = true) } - err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "test_say_hello_request ... FAILED" err.message shouldContain "invalid header value" } }