Unverified Commit 7fabbb74 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Add server operation registry and router (#850)

This commit adds two things:

1. A runtime router implementing `tower`'s
   [`Service`](https://docs.rs/tower-service/0.3.1/tower_service/trait.Service.html)
   that adheres to [Smithy's `http` trait
   specification](https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#http-trait),
   that is linear in the number of registered routes.
2. A code-generated "operation registry" that allows service
   implementers to provide Rust functions and declare them as the
   handlers for their service's operations.

The framework will receive HTTP requests from the server and route them
to the corresponding operation handler.
parent 476e988d
Loading
Loading
Loading
Loading
+129 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.*
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.RequestSpecModule
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
 * OperationRegistryGenerator
 */
class OperationRegistryGenerator(
    codegenContext: CodegenContext,
    private val operations: List<OperationShape>,
) {
    private val serverCrate = "aws_smithy_http_server"
    private val service = codegenContext.serviceShape
    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
    private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenScope = arrayOf(
        "Router" to RuntimeType.Router(runtimeConfig),
    )
    private val httpBindingResolver: HttpBindingResolver =
        HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes.consistent("application/json"))

    fun render(writer: RustWriter) {
        Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.DeriveBuilder)).render(writer)
        Attribute.Custom("builder(pattern = \"owned\")").render(writer)
        // Generic arguments of the `OperationRegistryBuilder<Fun0, Fut0, ..., FunN, FutN>`.
        val operationsGenericArguments = operations.mapIndexed { i, _ -> "Fun$i, Fut$i"}.joinToString()
        val operationRegistryName = "${service.getContextualName(service)}OperationRegistry<${operationsGenericArguments}>"
        writer.rustBlock("""
            pub struct $operationRegistryName
            where
                ${operationsTraitBounds()}
            """.trimIndent()) {
            val members = operationNames
                .mapIndexed { i, operationName -> "$operationName: Fun$i" }
                .joinToString(separator = ",\n")
            rust(members)
        }

        writer.rustBlockTemplate("""
            impl<${operationsGenericArguments}> From<$operationRegistryName> for #{Router}
            where
                ${operationsTraitBounds()}
            """.trimIndent(), *codegenScope) {
            rustBlock("fn from(registry: ${operationRegistryName}) -> Self") {
                val operationInOutWrappers = operations.map {
                    val operationName = symbolProvider.toSymbol(it).name
                    Pair("crate::operation::$operationName${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}",
                    "crate::operation::$operationName${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}")
                }
                val requestSpecsVarNames = operationNames.map { "${it}_request_spec" }
                val routes = requestSpecsVarNames.zip(operationNames).zip(operationInOutWrappers) { (requestSpecVarName, operationName), (inputWrapper, outputWrapper) ->
                    ".route($requestSpecVarName, $serverCrate::routing::operation_handler::operation::<_, _, $inputWrapper, _, $outputWrapper>(registry.$operationName))"
                }.joinToString(separator = "\n")

                val requestSpecs = requestSpecsVarNames.zip(operations) { requestSpecVarName, operation ->
                    "let $requestSpecVarName = ${operation.requestSpec()};"
                }.joinToString(separator = "\n")
                rustTemplate("""
                    $requestSpecs
                    #{Router}::new()
                        $routes
                    """.trimIndent(), *codegenScope)
            }
        }
    }

    private fun operationsTraitBounds(): String = operations
        .mapIndexed { i, operation ->
            val outputType = if (operation.errors.isNotEmpty()) {
                "Result<${symbolProvider.toSymbol(operation.outputShape(model)).fullName}, ${operation.errorSymbol(symbolProvider).fullyQualifiedName()}>"
            } else {
                symbolProvider.toSymbol(operation.outputShape(model)).fullName
            }
            """
            Fun$i: FnOnce(${symbolProvider.toSymbol(operation.inputShape(model))}) -> Fut$i + Clone + Send + Sync + 'static,
            Fut$i: std::future::Future<Output = $outputType> + Send
            """.trimIndent()
        }.joinToString(separator = ",\n")

    private fun OperationShape.requestSpec(): String {
        val httpTrait = httpBindingResolver.httpTrait(this)
        val namespace = RequestSpecModule(runtimeConfig).fullyQualifiedName()

        // TODO Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait

        val pathSegments = httpTrait.uri.segments.map {
            "$namespace::PathSegment::" +
                if (it.isGreedyLabel) "Greedy"
                else if (it.isLabel) "Label"
                else "Literal(String::from(\"${it.content}\"))"
        }
        val querySegments = httpTrait.uri.queryLiterals.map {
            "$namespace::QuerySegment::" +
                if (it.value == "") "Key(String::from(\"${it.key}\"))"
                else "KeyValue(String::from(\"${it.key}\"), String::from(\"${it.value}\"))"
        }

        return """
            $namespace::RequestSpec::new(
                http::Method::${httpTrait.method},
                $namespace::UriSpec {
                    host_prefix: None,
                    path_and_query: $namespace::PathAndQuerySpec {
                        path_segments: $namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
                        query_segments: $namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
                    }
                }
            )""".trimIndent()
    }
}
 No newline at end of file
+5 −1
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ class ServerServiceGenerator(
     */
    fun render() {
        val operations = index.getContainedOperations(context.serviceShape).sortedBy { it.id }
        operations.map { operation ->
        for (operation in operations) {
            rustCrate.useShapeWriter(operation) { operationWriter ->
                protocolGenerator.serverRenderOperation(
                    operationWriter,
@@ -54,5 +54,9 @@ class ServerServiceGenerator(
                }
            }
        }
        rustCrate.withModule(RustModule.public("operation_registry", "A registry of your service's operations.")) { writer ->
            OperationRegistryGenerator(context, operations)
                .render(writer)
        }
    }
}
+71 −65
Original line number Diff line number Diff line
@@ -38,13 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescripto
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import software.amazon.smithy.rust.codegen.util.*
import java.util.logging.Logger

/*
@@ -77,7 +71,7 @@ class ServerHttpProtocolGenerator(

/*
 * Generate all operation input parsers and output serializers for streaming and
 * non-straming types.
 * non-streaming types.
 */
private class ServerHttpProtocolImplGenerator(
    private val codegenContext: CodegenContext,
@@ -109,37 +103,48 @@ private class ServerHttpProtocolImplGenerator(
        val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model))
        val operationName = symbolProvider.toSymbol(operationShape).name

        // For streaming response bodies, we need to generate a different implementation of the parse traits.
        // These will first offer the streaming input to the parser & potentially read the body into memory
        // if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
        if (operationShape.outputShape(model).hasStreamingMember(model)) {
            with(operationWriter) {
                renderStreamingTraits(operationName, outputSymbol, operationShape)
            }
        } else {
            with(operationWriter) {
                renderNonStreamingTraits(operationName, inputSymbol, outputSymbol, operationShape)
            }
        }
        operationWriter.renderTraits(operationName, inputSymbol, outputSymbol, operationShape)
    }

    /*
     * Generation of non-streaming traits. A non-streaming trait requires the HTTP body to be fully read in
     * memory before parsing or deserialization. From a server perspective we need a way to parse an HTTP
     * request from `Bytes` and serialize a HTTP response to `Bytes`. These traits are the public entrypoint
     * of the ser/de logic of the smithy-rs server.
     * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request
     * and response bodies, that is, models without streaming traits
     * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html).
     * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or
     * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
     * an HTTP response to `Bytes`.
     * TODO Add support for streaming.
     * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server.
     */
    private fun RustWriter.renderNonStreamingTraits(
    private fun RustWriter.renderTraits(
        operationName: String?,
        inputSymbol: Symbol,
        outputSymbol: Symbol,
        operationShape: OperationShape
    ) {
        // Implement Axum `FromRequest` trait for non streaming input types.
        // Implement Axum `FromRequest` trait for input types.
        val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

        val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) {
            // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait.
            // It will first offer the streaming input to the parser and potentially read the body into memory
            // if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
            """
            async fn from_request(_req: &mut #{Axum}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                todo!("Streaming support for input shapes is not yet supported in `smithy-rs`")
            }
            """.trimIndent()
        } else {
            """
            async fn from_request(req: &mut #{Axum}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                #{SmithyHttpServer}::protocols::check_json_content_type(req)?;
                Ok($inputName(#{parse_request}(req).await?))
            }
            """.trimIndent()
        }
        rustTemplate(
            """
            struct $inputName(#{I});
            pub(crate) struct $inputName(#{I});
            ##[#{Axum}::async_trait]
            impl<B> #{Axum}::extract::FromRequest<B> for $inputName
            where
@@ -149,30 +154,46 @@ private class ServerHttpProtocolImplGenerator(
                #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
            {
                type Rejection = #{SmithyRejection};
                async fn from_request(req: &mut #{Axum}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                    #{SmithyHttpServer}::protocols::check_json_content_type(req)?;
                    Ok($inputName(#{parse_request}(req).await?))
                }
                $fromRequest
            }""".trimIndent(),
            *codegenScope,
            "I" to inputSymbol,
            "parse_request" to serverParseRequest(operationShape)
        )

        // Implement Axum `IntoResponse` for non streaming output types.
        // Implement Axum `IntoResponse` for output types.
        val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
        val errorSymbol = operationShape.errorSymbol(symbolProvider)

        val handleSerializeOutput = """
        // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait.
        // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler.
        val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")"
        if (operationShape.errors.isNotEmpty()) {
            val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) {
                intoResponseStreaming
            } else {
                """
                match self {
                    Self::Output(o) => {
                        match #{serialize_response}(&o) {
                            Ok(response) => response,
                            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
                        }
                    },
                    Self::Error(err) => {
                        match #{serialize_error}(&err) {
                            Ok(response) => response,
                            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
                        }
                    }
                }
                """.trimIndent()
        if (operationShape.errors.isNotEmpty()) {
            }
            // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
            // that can in turn be converted into a response.
            rustTemplate(
                """
                enum $outputName {
                pub(crate) enum $outputName {
                    Output(#{O}),
                    Error(#{E})
                }
@@ -182,19 +203,7 @@ private class ServerHttpProtocolImplGenerator(
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;

                    fn into_response(self) -> #{http}::Response<Self::Body> {
                        match self {
                            Self::Output(o) => {
                                match #{serialize_response}(&o) {
                                    $handleSerializeOutput
                                }
                            },
                            Self::Error(err) => {
                                match #{serialize_error}(&err) {
                                    Ok(response) => response,
                                    Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
                                }
                            }
                        }
                        $intoResponseImpl
                    }
                }""".trimIndent(),
                *codegenScope,
@@ -204,21 +213,29 @@ private class ServerHttpProtocolImplGenerator(
                "serialize_error" to serverSerializeError(operationShape)
            )
        } else {
            val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) {
                intoResponseStreaming
            } else {
                """
                match #{serialize_response}(&self.0) {
                    Ok(response) => response,
                    Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
                }
                """.trimIndent()
            }
            // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
            // we control that can in turn be converted into a response.
            rustTemplate(
                """
                struct $outputName(#{O});
                pub(crate) struct $outputName(#{O});
                ##[#{Axum}::async_trait]
                impl #{Axum}::response::IntoResponse for $outputName {
                    type Body = #{SmithyHttpServer}::Body;
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;

                    fn into_response(self) -> #{http}::Response<Self::Body> {
                        match #{serialize_response}(&self.0) {
                        $handleSerializeOutput
                    }
                    }
                }""".trimIndent(),
                *codegenScope,
                "O" to outputSymbol,
@@ -268,17 +285,6 @@ private class ServerHttpProtocolImplGenerator(
        )
    }

    /*
     * TODO: implement streaming traits
     */
    private fun RustWriter.renderStreamingTraits(
        operationName: String,
        outputSymbol: Symbol,
        operationShape: OperationShape
    ) {
        logger.warning("[rust-server-codegen] $operationName: streaming trait is not yet implemented")
    }

    private fun serverParseRequest(operationShape: OperationShape): RuntimeType {
        val fnName = "parse_${operationShape.id.name.toSnakeCase()}_request"
        val inputShape = operationShape.inputShape(model)
+1 −1
Original line number Diff line number Diff line
@@ -192,6 +192,7 @@ data class CargoDependency(
        val Axum: CargoDependency = CargoDependency("axum", CratesIo("0.3"))
        val Bytes: CargoDependency = CargoDependency("bytes", CratesIo("1"))
        val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.1"))
        val DeriveBuilder = CargoDependency("derive_builder", CratesIo("0.10"))
        val FastRand: CargoDependency = CargoDependency("fastrand", CratesIo("1"))
        val Hex: CargoDependency = CargoDependency("hex", CratesIo("0.4.3"))
        val HttpBody: CargoDependency = CargoDependency("http-body", CratesIo("0.4"))
@@ -213,7 +214,6 @@ data class CargoDependency(
        fun SmithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("eventstream")
        fun SmithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http")
        fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
        fun SmithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-tower")
        fun SmithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) =
            runtimeConfig.runtimeCrate("protocol-test").copy(scope = DependencyScope.Dev)
        fun smithyJson(runtimeConfig: RuntimeConfig): CargoDependency = runtimeConfig.runtimeCrate("json")
+0 −1
Original line number Diff line number Diff line
@@ -317,7 +317,6 @@ sealed class Attribute {
        val container: Boolean = false
    ) : Attribute() {
        override fun render(writer: RustWriter) {

            val bang = if (container) "!" else ""
            writer.raw("#$bang[$annotation]")
            symbols.forEach {
Loading