Unverified Commit 7fabbb74 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Add server operation registry and router (#850)

This commit adds two things:

1. A runtime router implementing `tower`'s
   [`Service`](https://docs.rs/tower-service/0.3.1/tower_service/trait.Service.html)
   that adheres to [Smithy's `http` trait
   specification](https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#http-trait),
   that is linear in the number of registered routes.
2. A code-generated "operation registry" that allows service
   implementers to provide Rust functions and declare them as the
   handlers for their service's operations.

The framework will receive HTTP requests from the server and route them
to the corresponding operation handler.
parent 476e988d
Loading
Loading
Loading
Loading
+129 −0
Original line number Original line Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.*
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.RequestSpecModule
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
 * OperationRegistryGenerator
 */
class OperationRegistryGenerator(
    codegenContext: CodegenContext,
    private val operations: List<OperationShape>,
) {
    private val serverCrate = "aws_smithy_http_server"
    private val service = codegenContext.serviceShape
    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
    private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenScope = arrayOf(
        "Router" to RuntimeType.Router(runtimeConfig),
    )
    private val httpBindingResolver: HttpBindingResolver =
        HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes.consistent("application/json"))

    fun render(writer: RustWriter) {
        Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.DeriveBuilder)).render(writer)
        Attribute.Custom("builder(pattern = \"owned\")").render(writer)
        // Generic arguments of the `OperationRegistryBuilder<Fun0, Fut0, ..., FunN, FutN>`.
        val operationsGenericArguments = operations.mapIndexed { i, _ -> "Fun$i, Fut$i"}.joinToString()
        val operationRegistryName = "${service.getContextualName(service)}OperationRegistry<${operationsGenericArguments}>"
        writer.rustBlock("""
            pub struct $operationRegistryName
            where
                ${operationsTraitBounds()}
            """.trimIndent()) {
            val members = operationNames
                .mapIndexed { i, operationName -> "$operationName: Fun$i" }
                .joinToString(separator = ",\n")
            rust(members)
        }

        writer.rustBlockTemplate("""
            impl<${operationsGenericArguments}> From<$operationRegistryName> for #{Router}
            where
                ${operationsTraitBounds()}
            """.trimIndent(), *codegenScope) {
            rustBlock("fn from(registry: ${operationRegistryName}) -> Self") {
                val operationInOutWrappers = operations.map {
                    val operationName = symbolProvider.toSymbol(it).name
                    Pair("crate::operation::$operationName${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}",
                    "crate::operation::$operationName${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}")
                }
                val requestSpecsVarNames = operationNames.map { "${it}_request_spec" }
                val routes = requestSpecsVarNames.zip(operationNames).zip(operationInOutWrappers) { (requestSpecVarName, operationName), (inputWrapper, outputWrapper) ->
                    ".route($requestSpecVarName, $serverCrate::routing::operation_handler::operation::<_, _, $inputWrapper, _, $outputWrapper>(registry.$operationName))"
                }.joinToString(separator = "\n")

                val requestSpecs = requestSpecsVarNames.zip(operations) { requestSpecVarName, operation ->
                    "let $requestSpecVarName = ${operation.requestSpec()};"
                }.joinToString(separator = "\n")
                rustTemplate("""
                    $requestSpecs
                    #{Router}::new()
                        $routes
                    """.trimIndent(), *codegenScope)
            }
        }
    }

    private fun operationsTraitBounds(): String = operations
        .mapIndexed { i, operation ->
            val outputType = if (operation.errors.isNotEmpty()) {
                "Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>"
            } else {
                symbolProvider.toSymbol(operation.outputShape(model)).fullName
            }
            """
            Fun$i: FnOnce(${symbolProvider.toSymbol(operation.inputShape(model))}) -> Fut$i + Clone + Send + Sync + 'static,
            Fut$i: std::future::Future<Output = $outputType> + Send
            """.trimIndent()
        }.joinToString(separator = ",\n")

    private fun OperationShape.requestSpec(): String {
        val httpTrait = httpBindingResolver.httpTrait(this)
        val namespace = RequestSpecModule(runtimeConfig).fullyQualifiedName()

        // TODO Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait

        val pathSegments = httpTrait.uri.segments.map {
            "$namespace::PathSegment::" +
                if (it.isGreedyLabel) "Greedy"
                else if (it.isLabel) "Label"
                else "Literal(String::from(\"${it.content}\"))"
        }
        val querySegments = httpTrait.uri.queryLiterals.map {
            "$namespace::QuerySegment::" +
                if (it.value == "") "Key(String::from(\"${it.key}\"))"
                else "KeyValue(String::from(\"${it.key}\"), String::from(\"${it.value}\"))"
        }

        return """
            $namespace::RequestSpec::new(
                http::Method::${httpTrait.method},
                $namespace::UriSpec {
                    host_prefix: None,
                    path_and_query: $namespace::PathAndQuerySpec {
                        path_segments: $namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
                        query_segments: $namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
                    }
                }
            )""".trimIndent()
    }
}
 No newline at end of file
+5 −1
Original line number Original line Diff line number Diff line
@@ -36,7 +36,7 @@ class ServerServiceGenerator(
     */
     */
    fun render() {
    fun render() {
        val operations = index.getContainedOperations(context.serviceShape).sortedBy { it.id }
        val operations = index.getContainedOperations(context.serviceShape).sortedBy { it.id }
        operations.map { operation ->
        for (operation in operations) {
            rustCrate.useShapeWriter(operation) { operationWriter ->
            rustCrate.useShapeWriter(operation) { operationWriter ->
                protocolGenerator.serverRenderOperation(
                protocolGenerator.serverRenderOperation(
                    operationWriter,
                    operationWriter,
@@ -54,5 +54,9 @@ class ServerServiceGenerator(
                }
                }
            }
            }
        }
        }
        rustCrate.withModule(RustModule.public("operation_registry", "A registry of your service's operations.")) { writer ->
            OperationRegistryGenerator(context, operations)
                .render(writer)
        }
    }
    }
}
}
+71 −65
Original line number Original line Diff line number Diff line
@@ -38,13 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescripto
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.*
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import java.util.logging.Logger
import java.util.logging.Logger


/*
/*
@@ -77,7 +71,7 @@ class ServerHttpProtocolGenerator(


/*
/*
 * Generate all operation input parsers and output serializers for streaming and
 * Generate all operation input parsers and output serializers for streaming and
 * non-straming types.
 * non-streaming types.
 */
 */
private class ServerHttpProtocolImplGenerator(
private class ServerHttpProtocolImplGenerator(
    private val codegenContext: CodegenContext,
    private val codegenContext: CodegenContext,
@@ -109,37 +103,48 @@ private class ServerHttpProtocolImplGenerator(
        val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model))
        val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model))
        val operationName = symbolProvider.toSymbol(operationShape).name
        val operationName = symbolProvider.toSymbol(operationShape).name


        // For streaming response bodies, we need to generate a different implementation of the parse traits.
        operationWriter.renderTraits(operationName, inputSymbol, outputSymbol, operationShape)
        // These will first offer the streaming input to the parser & potentially read the body into memory
        // if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
        if (operationShape.outputShape(model).hasStreamingMember(model)) {
            with(operationWriter) {
                renderStreamingTraits(operationName, outputSymbol, operationShape)
            }
        } else {
            with(operationWriter) {
                renderNonStreamingTraits(operationName, inputSymbol, outputSymbol, operationShape)
            }
        }
    }
    }


    /*
    /*
     * Generation of non-streaming traits. A non-streaming trait requires the HTTP body to be fully read in
     * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request
     * memory before parsing or deserialization. From a server perspective we need a way to parse an HTTP
     * and response bodies, that is, models without streaming traits
     * request from `Bytes` and serialize a HTTP response to `Bytes`. These traits are the public entrypoint
     * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html).
     * of the ser/de logic of the smithy-rs server.
     * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or
     * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
     * an HTTP response to `Bytes`.
     * TODO Add support for streaming.
     * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server.
     */
     */
    private fun RustWriter.renderNonStreamingTraits(
    private fun RustWriter.renderTraits(
        operationName: String?,
        operationName: String?,
        inputSymbol: Symbol,
        inputSymbol: Symbol,
        outputSymbol: Symbol,
        outputSymbol: Symbol,
        operationShape: OperationShape
        operationShape: OperationShape
    ) {
    ) {
        // Implement Axum `FromRequest` trait for non streaming input types.
        // Implement Axum `FromRequest` trait for input types.
        val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
        val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

        val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) {
            // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait.
            // It will first offer the streaming input to the parser and potentially read the body into memory
            // if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
            """
            async fn from_request(_req: &mut #{Axum}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                todo!("Streaming support for input shapes is not yet supported in `smithy-rs`")
            }
            """.trimIndent()
        } else {
            """
            async fn from_request(req: &mut #{Axum}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                #{SmithyHttpServer}::protocols::check_json_content_type(req)?;
                Ok($inputName(#{parse_request}(req).await?))
            }
            """.trimIndent()
        }
        rustTemplate(
        rustTemplate(
            """
            """
            struct $inputName(#{I});
            pub(crate) struct $inputName(#{I});
            ##[#{Axum}::async_trait]
            ##[#{Axum}::async_trait]
            impl<B> #{Axum}::extract::FromRequest<B> for $inputName
            impl<B> #{Axum}::extract::FromRequest<B> for $inputName
            where
            where
@@ -149,30 +154,46 @@ private class ServerHttpProtocolImplGenerator(
                #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
                #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
            {
            {
                type Rejection = #{SmithyRejection};
                type Rejection = #{SmithyRejection};
                async fn from_request(req: &mut #{Axum}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                $fromRequest
                    #{SmithyHttpServer}::protocols::check_json_content_type(req)?;
                    Ok($inputName(#{parse_request}(req).await?))
                }
            }""".trimIndent(),
            }""".trimIndent(),
            *codegenScope,
            *codegenScope,
            "I" to inputSymbol,
            "I" to inputSymbol,
            "parse_request" to serverParseRequest(operationShape)
            "parse_request" to serverParseRequest(operationShape)
        )
        )


        // Implement Axum `IntoResponse` for non streaming output types.
        // Implement Axum `IntoResponse` for output types.
        val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
        val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
        val errorSymbol = operationShape.errorSymbol(symbolProvider)
        val errorSymbol = operationShape.errorSymbol(symbolProvider)


        val handleSerializeOutput = """
        // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait.
        // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler.
        val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")"
        if (operationShape.errors.isNotEmpty()) {
            val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) {
                intoResponseStreaming
            } else {
                """
                match self {
                    Self::Output(o) => {
                        match #{serialize_response}(&o) {
                            Ok(response) => response,
                            Ok(response) => response,
                            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
                            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
                        }
                    },
                    Self::Error(err) => {
                        match #{serialize_error}(&err) {
                            Ok(response) => response,
                            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
                        }
                    }
                }
                """.trimIndent()
                """.trimIndent()
        if (operationShape.errors.isNotEmpty()) {
            }
            // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
            // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
            // that can in turn be converted into a response.
            // that can in turn be converted into a response.
            rustTemplate(
            rustTemplate(
                """
                """
                enum $outputName {
                pub(crate) enum $outputName {
                    Output(#{O}),
                    Output(#{O}),
                    Error(#{E})
                    Error(#{E})
                }
                }
@@ -182,19 +203,7 @@ private class ServerHttpProtocolImplGenerator(
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;


                    fn into_response(self) -> #{http}::Response<Self::Body> {
                    fn into_response(self) -> #{http}::Response<Self::Body> {
                        match self {
                        $intoResponseImpl
                            Self::Output(o) => {
                                match #{serialize_response}(&o) {
                                    $handleSerializeOutput
                                }
                            },
                            Self::Error(err) => {
                                match #{serialize_error}(&err) {
                                    Ok(response) => response,
                                    Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
                                }
                            }
                        }
                    }
                    }
                }""".trimIndent(),
                }""".trimIndent(),
                *codegenScope,
                *codegenScope,
@@ -204,21 +213,29 @@ private class ServerHttpProtocolImplGenerator(
                "serialize_error" to serverSerializeError(operationShape)
                "serialize_error" to serverSerializeError(operationShape)
            )
            )
        } else {
        } else {
            val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) {
                intoResponseStreaming
            } else {
                """
                match #{serialize_response}(&self.0) {
                    Ok(response) => response,
                    Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
                }
                """.trimIndent()
            }
            // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
            // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
            // we control that can in turn be converted into a response.
            // we control that can in turn be converted into a response.
            rustTemplate(
            rustTemplate(
                """
                """
                struct $outputName(#{O});
                pub(crate) struct $outputName(#{O});
                ##[#{Axum}::async_trait]
                ##[#{Axum}::async_trait]
                impl #{Axum}::response::IntoResponse for $outputName {
                impl #{Axum}::response::IntoResponse for $outputName {
                    type Body = #{SmithyHttpServer}::Body;
                    type Body = #{SmithyHttpServer}::Body;
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;


                    fn into_response(self) -> #{http}::Response<Self::Body> {
                    fn into_response(self) -> #{http}::Response<Self::Body> {
                        match #{serialize_response}(&self.0) {
                        $handleSerializeOutput
                        $handleSerializeOutput
                    }
                    }
                    }
                }""".trimIndent(),
                }""".trimIndent(),
                *codegenScope,
                *codegenScope,
                "O" to outputSymbol,
                "O" to outputSymbol,
@@ -268,17 +285,6 @@ private class ServerHttpProtocolImplGenerator(
        )
        )
    }
    }


    /*
     * TODO: implement streaming traits
     */
    private fun RustWriter.renderStreamingTraits(
        operationName: String,
        outputSymbol: Symbol,
        operationShape: OperationShape
    ) {
        logger.warning("[rust-server-codegen] $operationName: streaming trait is not yet implemented")
    }

    private fun serverParseRequest(operationShape: OperationShape): RuntimeType {
    private fun serverParseRequest(operationShape: OperationShape): RuntimeType {
        val fnName = "parse_${operationShape.id.name.toSnakeCase()}_request"
        val fnName = "parse_${operationShape.id.name.toSnakeCase()}_request"
        val inputShape = operationShape.inputShape(model)
        val inputShape = operationShape.inputShape(model)
+1 −1
Original line number Original line Diff line number Diff line
@@ -192,6 +192,7 @@ data class CargoDependency(
        val Axum: CargoDependency = CargoDependency("axum", CratesIo("0.3"))
        val Axum: CargoDependency = CargoDependency("axum", CratesIo("0.3"))
        val Bytes: CargoDependency = CargoDependency("bytes", CratesIo("1"))
        val Bytes: CargoDependency = CargoDependency("bytes", CratesIo("1"))
        val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.1"))
        val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.1"))
        val DeriveBuilder = CargoDependency("derive_builder", CratesIo("0.10"))
        val FastRand: CargoDependency = CargoDependency("fastrand", CratesIo("1"))
        val FastRand: CargoDependency = CargoDependency("fastrand", CratesIo("1"))
        val Hex: CargoDependency = CargoDependency("hex", CratesIo("0.4.3"))
        val Hex: CargoDependency = CargoDependency("hex", CratesIo("0.4.3"))
        val HttpBody: CargoDependency = CargoDependency("http-body", CratesIo("0.4"))
        val HttpBody: CargoDependency = CargoDependency("http-body", CratesIo("0.4"))
@@ -213,7 +214,6 @@ data class CargoDependency(
        fun SmithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("eventstream")
        fun SmithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("eventstream")
        fun SmithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http")
        fun SmithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http")
        fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
        fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
        fun SmithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-tower")
        fun SmithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) =
        fun SmithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) =
            runtimeConfig.runtimeCrate("protocol-test").copy(scope = DependencyScope.Dev)
            runtimeConfig.runtimeCrate("protocol-test").copy(scope = DependencyScope.Dev)
        fun smithyJson(runtimeConfig: RuntimeConfig): CargoDependency = runtimeConfig.runtimeCrate("json")
        fun smithyJson(runtimeConfig: RuntimeConfig): CargoDependency = runtimeConfig.runtimeCrate("json")
+0 −1
Original line number Original line Diff line number Diff line
@@ -317,7 +317,6 @@ sealed class Attribute {
        val container: Boolean = false
        val container: Boolean = false
    ) : Attribute() {
    ) : Attribute() {
        override fun render(writer: RustWriter) {
        override fun render(writer: RustWriter) {

            val bang = if (container) "!" else ""
            val bang = if (container) "!" else ""
            writer.raw("#$bang[$annotation]")
            writer.raw("#$bang[$annotation]")
            symbols.forEach {
            symbols.forEach {
Loading