Unverified Commit d742a699 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Implement FromRequest / ToResponse traits for inputs/outputs/errors. (#839)

* Implement FromRequest / ToResponse traits for inputs/outputs/errors
* Update server runtime to expose needed functionalities
* Remove temporary traits from server runtime
* Update CODEOWNERS
parent 155d24aa
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
*                                       @awslabs/rust-sdk-owners
/codegen-server/                        @awslabs/smithy-rs-server
/codegen-server-test/                   @awslabs/smithy-rs-server
/rust-runtime/aws-smithy-http-server/   @awslabs/smithy-rs-server
+1 −0
Original line number Diff line number Diff line
@@ -50,6 +50,7 @@ resource Service {
        id: "RegisterServiceRequestTest",
        protocol: "aws.protocols#restJson1",
        uri: "/service/1",
        bodyMediaType: "application/json",
        params: { id: "1", name: "TestService" },
        body: "{\"name\":\"TestService\"}",
        method: "PUT",
+0 −81
Original line number Diff line number Diff line
@@ -16,9 +16,6 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
@@ -132,14 +129,6 @@ class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator:
        val serviceShapes = Walker(model).walkShapes(service)
        serviceShapes.forEach { it.accept(this) }
        codegenDecorator.extras(codegenContext, rustCrate)
        val module = RustMetadata(public = true)
        rustCrate.withModule(
            RustModule(
                "error",
                module,
                documentation = "All error types that operations can respond with."
            )
        ) { writer -> renderSerdeError(writer) }
        rustCrate.finalize(
            settings,
            model,
@@ -233,74 +222,4 @@ class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator:
        )
            .render()
    }

    private fun renderSerdeError(writer: RustWriter) {
        writer.rust(
            """
            ##[derive(Debug)]
            pub enum Error {
                Generic(std::borrow::Cow<'static, str>),
                DeserializeJson(aws_smithy_json::deserialize::Error),
                DeserializeHeader(aws_smithy_http::header::ParseError),
                DeserializeLabel(std::string::String),
                BuildInput(aws_smithy_http::operation::BuildError),
                BuildResponse(http::Error),
                SmithyType(aws_smithy_types::Error),
            }

            impl Error {
                ##[allow(dead_code)]
                pub fn generic(msg: &'static str) -> Self {
                    Self::Generic(msg.into())
                }
            }

            impl From<aws_smithy_json::deserialize::Error> for Error {
                fn from(err: aws_smithy_json::deserialize::Error) -> Self {
                    Self::DeserializeJson(err)
                }
            }

            impl From<aws_smithy_http::operation::SerializationError> for Error {
                fn from(err: aws_smithy_http::operation::SerializationError) -> Self {
                    Self::BuildInput(err.into())
                }
            }

            impl From<aws_smithy_http::header::ParseError> for Error {
                fn from(err: aws_smithy_http::header::ParseError) -> Self {
                    Self::DeserializeHeader(err)
                }
            }

            impl From<aws_smithy_http::operation::BuildError> for Error {
                fn from(err: aws_smithy_http::operation::BuildError) -> Self {
                    Self::BuildInput(err)
                }
            }

            impl From<aws_smithy_types::Error> for Error {
                fn from(err: aws_smithy_types::Error) -> Self {
                    Self::SmithyType(err)
                }
            }

            impl std::fmt::Display for Error {
                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                    match *self {
                        Self::Generic(ref msg) => write!(f, "serde error: {}", msg),
                        Self::DeserializeJson(ref err) => write!(f, "json parse error: {}", err),
                        Self::DeserializeHeader(ref err) => write!(f, "header parse error: {}", err),
                        Self::DeserializeLabel(ref msg) => write!(f, "label parse error: {}", msg),
                        Self::BuildInput(ref err) => write!(f, "json payload error: {}", err),
                        Self::BuildResponse(ref err) => write!(f, "http response error: {}", err),
                        Self::SmithyType(ref err) => write!(f, "type error: {}", err),
                    }
                }
            }

            impl std::error::Error for Error {}
            """.trimIndent()
        )
    }
}
+36 −36
Original line number Diff line number Diff line
@@ -27,13 +27,14 @@ import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.server.smithy.protocols.HttpServerTraits
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape
@@ -58,15 +59,14 @@ class ServerProtocolTestGenerator(
    private val instantiator = with(codegenContext) {
        Instantiator(symbolProvider, model, runtimeConfig)
    }
    private val httpServerTraits = HttpServerTraits()

    private val codegenScope = arrayOf(
        "ParseStrictResponse" to RuntimeType.parseStrictResponse(codegenContext.runtimeConfig),
        "SerializeHttpResponse" to httpServerTraits.serializeHttpResponse(codegenContext.runtimeConfig),
        "ParseHttpRequest" to httpServerTraits.parseHttpRequest(codegenContext.runtimeConfig),
        "Bytes" to RuntimeType.Bytes,
        "SmithyHttp" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(),
        "Http" to CargoDependency.Http.asType(),
        "Hyper" to CargoDependency.Hyper.asType(),
        "Axum" to CargoDependency.Axum.asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(codegenContext.runtimeConfig).asType(),
    )

    sealed class TestCase {
@@ -177,7 +177,6 @@ class ServerProtocolTestGenerator(
        writeInline("let expected =")
        instantiator.render(this, inputShape, httpRequestTestCase.params)
        write(";")
        rustTemplate("""let op = #{op}::new();""", "op" to operationSymbol)
        with(httpRequestTestCase) {
            host.orNull()?.also { host ->
                val withScheme = "http://$host"
@@ -203,7 +202,7 @@ class ServerProtocolTestGenerator(
        if (protocolSupport.requestBodyDeserialization) {
            // "If no request body is defined, then no assertions are made about the body of the message."
            httpRequestTestCase.body.orNull()?.also { body ->
                checkBody(this, body, httpRequestTestCase.uri)
                checkBody(this, body, httpRequestTestCase)
            }
        }

@@ -233,9 +232,7 @@ class ServerProtocolTestGenerator(
        expectedShape: StructureShape
    ) {
        if (!protocolSupport.responseSerialization || (
            !protocolSupport.errorSerialization && expectedShape.hasTrait(
                    ErrorTrait::class.java
                )
            !protocolSupport.errorSerialization && expectedShape.hasTrait<ErrorTrait>()
            )
        ) {
            rust("/* test case disabled for this protocol (not yet supported) */")
@@ -244,27 +241,35 @@ class ServerProtocolTestGenerator(
        writeInline("let output =")
        instantiator.render(this, expectedShape, testCase.params)
        write(";")
        val operationName = if (expectedShape.hasTrait<ErrorTrait>()) {
            "${operationSymbol.name}${ServerHttpProtocolGenerator.OPERATION_ERROR_WRAPPER_SUFFIX}"
        } else {
            "${operationSymbol.name}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
        }
        rustTemplate(
            """
            use #{SerializeHttpResponse};
            let op = #{op}::new();
            let http_response = op.serialize(&output).expect("unable to serialize `#{op}` into HTTP response body");
            let output = super::$operationName(output);
            use #{Axum}::response::IntoResponse;
            let http_response = output.into_response();
            """,
            *codegenScope,
            "op" to operationSymbol,
        )
        rust("""
        rust(
            """
            assert_eq!(
                http::StatusCode::from_u16(${testCase.code}).expect("invalid expected HTTP status code"),
                http_response.status()
            );
        """)
            """
        )
        if (testCase.body != null) {
            rust("""
                let body = std::str::from_utf8(http_response.body())
                    .expect("serialized response body does not contain valid UTF-8");
            rustTemplate(
                """
                let body = #{Hyper}::body::to_bytes(http_response.into_body()).await.expect("unable to extract body to bytes");
                assert_eq!("${testCase.body.get().replace("\"", "\\\"")}", body);
            """)
                """,
                *codegenScope
            )
        }
    }

@@ -276,32 +281,27 @@ class ServerProtocolTestGenerator(
        basicCheck(forbidHeaders, rustWriter, "forbidden_headers", "forbid_headers")
    }

    private fun checkBody(rustWriter: RustWriter, body: String, uri: String) {
    private fun checkBody(rustWriter: RustWriter, body: String, testCase: HttpRequestTestCase) {
        val operationName = "${operationSymbol.name}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
        rustWriter.rustTemplate(
            """
            let http_request = http::Request::builder()
                .uri(${uri.dq()})
                .body(#{Bytes}::from_static(b${body.dq()}))
                .uri(${testCase.uri.dq()})
                .header("Content-Type", ${testCase.bodyMediaType.orNull()?.dq()})
                .body(#{SmithyHttpServer}::Body::from(#{Bytes}::from_static(b${body.dq()})))
                .unwrap();
            use #{Axum}::extract::FromRequest;
            let mut http_request = #{Axum}::extract::RequestParts::new(http_request);
            let input_wrapper = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request");
            let input = input_wrapper.0;
            """,
            "body" to body,
            "uri" to uri,
            *codegenScope,
        )
        rustWriter.rustTemplate(
            """
            use #{ParseHttpRequest};
            let op = #{op}::new();
            let body = op.parse_loaded(&http_request).expect("failed to parse request");
            """,
            "op" to operationSymbol,
            *codegenScope,
        )
        if (body == "") {
            rustWriter.write("// No body")
            rustWriter.write("assert_eq!(std::str::from_utf8(body).unwrap(), ${"".dq()});")
            rustWriter.write("assert_eq!(std::str::from_utf8(input).expect(\"`body` does not contain valid UTF-8\"), ${"".dq()});")
        } else {
            rustWriter.write("assert_eq!(expected, body);")
            rustWriter.write("assert_eq!(input, expected);")
        }
    }

+117 −109

File changed.

Preview size limit exceeded, changes collapsed.

Loading