Unverified Commit a89a0832 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Rewrite `ServerOperationRegistryGenerator` (#1491)

This code was written a long time ago, and made heavy use of string
interpolation instead of the more robust Rust templating primitives from
`RustWriter.kt`.
parent b88d72e2
Loading
Loading
Loading
Loading
+173 −139
Original line number Diff line number Diff line
@@ -11,27 +11,42 @@ import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
 * ServerOperationRegistryGenerator
 * [ServerOperationRegistryGenerator] renders the `OperationRegistry` struct, a place where users can register their
 * service's operation implementations.
 *
 * Users can construct the operation registry using a builder. They can subsequently convert the operation registry into
 * the [`aws_smithy_http_server::Router`], a [`tower::Service`] that will route incoming requests to their operation
 * handlers, invoking them and returning the response.
 *
 * [`aws_smithy_http_server::Router`]: https://docs.rs/aws-smithy-http-server/latest/aws_smithy_http_server/struct.Router.html
 * [`tower::Service`]: https://docs.rs/tower/latest/tower/trait.Service.html
 */
class ServerOperationRegistryGenerator(
    coreCodegenContext: CoreCodegenContext,
    private val httpBindingResolver: HttpBindingResolver,
    private val operations: List<OperationShape>,
) {
    private val model = coreCodegenContext.model
    private val protocol = coreCodegenContext.protocol
    private val symbolProvider = coreCodegenContext.symbolProvider
    private val serviceName = coreCodegenContext.serviceShape.toShapeId().name
@@ -43,10 +58,13 @@ class ServerOperationRegistryGenerator(
        "ServerOperationHandler" to ServerRuntimeType.OperationHandler(runtimeConfig),
        "Tower" to ServerCargoDependency.Tower.asType(),
        "Phantom" to ServerRuntimeType.Phantom,
        "StdError" to RuntimeType.StdError
        "StdError" to RuntimeType.StdError,
        "Display" to RuntimeType.Display,
        "From" to RuntimeType.From,
    )
    private val operationRegistryName = "OperationRegistry"
    private val operationRegistryBuilderName = "OperationRegistryBuilder"
    private val operationRegistryBuilderName = "${operationRegistryName}Builder"
    private val operationRegistryErrorName = "${operationRegistryBuilderName}Error"
    private val genericArguments = "B, " + operations.mapIndexed { i, _ -> "Op$i, In$i" }.joinToString()
    private val operationRegistryNameWithArguments = "$operationRegistryName<$genericArguments>"
    private val operationRegistryBuilderNameWithArguments = "$operationRegistryBuilderName<$genericArguments>"
@@ -60,17 +78,8 @@ class ServerOperationRegistryGenerator(
        renderRouterImplementationFromOperationRegistryBuilder(writer)
    }

    /*
     * Renders the `OperationRegistry` structure, holding all the operations and their generic inputs.
     */
    private fun renderOperationRegistryStruct(writer: RustWriter) {
        // A lot of things can become pretty complex in this type as it will hold 2 generics per operation.
        Attribute.Custom("allow(clippy::all)").render(writer)
        writer.rustBlock(
            """
            pub struct $operationRegistryNameWithArguments
            """.trimIndent()
        ) {
        writer.rustBlock("pub struct $operationRegistryNameWithArguments") {
            val members = operationNames
                .mapIndexed { i, operationName -> "$operationName: Op$i" }
                .joinToString(separator = ",\n")
@@ -84,17 +93,11 @@ class ServerOperationRegistryGenerator(
        }
    }

    /*
     * Renders the `OperationRegistryBuilder` structure, used to build the `OperationRegistry`, which can then be converted into a Smithy router.
    /**
     * Renders the `OperationRegistryBuilder` structure, used to build the `OperationRegistry`.
     */
    private fun renderOperationRegistryBuilderStruct(writer: RustWriter) {
        // A lot of things can become pretty complex in this type as it will hold 2 generics per operation.
        Attribute.Custom("allow(clippy::all)").render(writer)
        writer.rustBlock(
            """
            pub struct $operationRegistryBuilderNameWithArguments
            """.trimIndent()
        ) {
        writer.rustBlock("pub struct $operationRegistryBuilderNameWithArguments") {
            val members = operationNames
                .mapIndexed { i, operationName -> "$operationName: Option<Op$i>" }
                .joinToString(separator = ",\n")
@@ -108,48 +111,46 @@ class ServerOperationRegistryGenerator(
        }
    }

    /*
    /**
     * Renders the `OperationRegistryBuilderError` type, used to error out in case there are uninitialized fields.
     * This enum implement `Debug`, `Display` and `std::error::Error`.
     * This is an enum deriving `Debug` and implementing `Display` and `std::error::Error`.
     */
    private fun renderOperationRegistryBuilderError(writer: RustWriter) {
        Attribute.Derives(setOf(RuntimeType.Debug)).render(writer)
        writer.rustTemplate(
            """
            pub enum ${operationRegistryBuilderName}Error {
            pub enum ${operationRegistryErrorName}{
                UninitializedField(&'static str)
            }
            impl #{Display} for ${operationRegistryBuilderName}Error {
            impl #{Display} for ${operationRegistryErrorName}{
                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                    match self {
                        Self::UninitializedField(v) => write!(f, "{}", v),
                    }
                }
            }
            impl #{StdError} for ${operationRegistryBuilderName}Error {}
            """.trimIndent(),
            impl #{StdError} for $operationRegistryErrorName {}
            """,
            *codegenScope,
            "Display" to RuntimeType.Display,
        )
    }

    /*
    /**
     * Renders the `OperationRegistryBuilder` `Default` implementation, used to create a new builder that can be
     * populated with the service's operations.
     * populated with the service's operation implementations.
     */
    private fun renderOperationRegistryBuilderDefault(writer: RustWriter) {
        writer.rustBlockTemplate(
            """
            impl<$genericArguments> Default for $operationRegistryBuilderNameWithArguments
            """.trimIndent()
        ) {
            val defaultOperations = operationNames.map { operationName ->
        writer.rustBlockTemplate("impl<$genericArguments> std::default::Default for $operationRegistryBuilderNameWithArguments") {
            val defaultOperations = operationNames.joinToString(separator = "\n,") { operationName ->
                "$operationName: Default::default()"
            }.joinToString(separator = "\n,")
            }
            rustTemplate(
                """
                fn default() -> Self {
                    Self { $defaultOperations, _phantom: #{Phantom} }
                    Self {
                        $defaultOperations,
                        _phantom: #{Phantom}
                    }
                }
                """,
                *codegenScope
@@ -157,97 +158,99 @@ class ServerOperationRegistryGenerator(
        }
    }

    /*
     * Renders the `OperationRegistryBuilder` implementation, where operations are stored.
    /**
     * Renders the `OperationRegistryBuilder`'s impl block, where operations are stored.
     * The `build()` method converts the builder into an `OperationRegistry` instance.
     */
    private fun renderOperationRegistryBuilderImplementation(writer: RustWriter) {
        // A lot of things can become pretty complex in this type as it will hold 2 generics per operation.
        Attribute.Custom("allow(clippy::all)").render(writer)
        writer.rustBlockTemplate(
        writer.rustBlock("impl<$genericArguments> $operationRegistryBuilderNameWithArguments") {
            operationNames.forEachIndexed { i, operationName ->
                rust(
                    """
            impl<$genericArguments> $operationRegistryBuilderNameWithArguments
            """.trimIndent(),
            *codegenScope
        ) {
            val registerOperations = operationNames.mapIndexed { i, operationName ->
                """pub fn $operationName(self, value: Op$i) -> Self {
                    pub fn $operationName(self, value: Op$i) -> Self {
                        let mut new = self;
                        new.$operationName = Some(value);
                        new
                }"""
            }.joinToString(separator = "\n")
            val registerOperationsBuilder = operationNames.map { operationName ->
                    }
                    """
                )
            }

            rustBlock("pub fn build(self) -> Result<$operationRegistryNameWithArguments, ${operationRegistryErrorName}>") {
                withBlock("Ok( $operationRegistryName {", "})") {
                    for (operationName in operationNames) {
                        rust(
                            """
                            $operationName: match self.$operationName {
                                Some(v) => v,
                    None => return Err(${operationRegistryBuilderName}Error::UninitializedField(${operationName.dq()})),
                }
                                None => return Err(${operationRegistryErrorName}::UninitializedField("$operationName")),
                            },
                            """
            }.joinToString(separator = "\n,")
                        )
                    }
                    rustTemplate("_phantom: #{Phantom}", *codegenScope)
                }
            }
        }
    }

    /**
     * Renders the converter between the `OperationRegistry` and the `Router` via the `std::convert::From` trait.
     */
    private fun renderRouterImplementationFromOperationRegistryBuilder(writer: RustWriter) {
        val operationTraitBounds = writable {
            operations.forEachIndexed { i, operation ->
                rustTemplate(
                    """
                $registerOperations
                pub fn build(self) -> Result<$operationRegistryNameWithArguments, ${operationRegistryBuilderName}Error> {
                    Ok($operationRegistryName { $registerOperationsBuilder, _phantom: #{Phantom} })
                }
                    Op$i: #{ServerOperationHandler}::Handler<B, In$i, #{OperationInput}>,
                    In$i: 'static + Send,
                    """,
                *codegenScope
                    *codegenScope,
                    "OperationInput" to symbolProvider.toSymbol(operation.inputShape(model))
                )
            }
        }

    /*
     * Renders the conversion code between the `OperationRegistry` and the `Router` via the `std::convert::From` trait.
     */
    private fun renderRouterImplementationFromOperationRegistryBuilder(writer: RustWriter) {
        // A lot of things can become pretty complex in this type as it will hold 2 generics per operation.
        val operationsTraitBounds = operations
            .mapIndexed { i, operation ->
                val operationName = symbolProvider.toSymbol(operation).name
                """Op$i: #{ServerOperationHandler}::Handler<B, In$i, crate::input::${operationName}Input>,
                In$i: 'static + Send"""
            }.joinToString(separator = ",\n")
        Attribute.Custom("allow(clippy::all)").render(writer)
        writer.rustBlockTemplate(
            // The bound `B: Send` is required because of [`tower::util::BoxCloneService`].
            // [`tower::util::BoxCloneService`]: https://docs.rs/tower/latest/tower/util/struct.BoxCloneService.html#method.new
            """
            impl<$genericArguments> From<$operationRegistryNameWithArguments> for #{Router}<B>
            impl<$genericArguments> #{From}<$operationRegistryNameWithArguments> for #{Router}<B>
            where
                B: Send + 'static,
                $operationsTraitBounds
            """.trimIndent(),
            *codegenScope
                #{operationTraitBounds:W}
            """,
            *codegenScope,
            "operationTraitBounds" to operationTraitBounds
        ) {
            rustBlock("fn from(registry: $operationRegistryNameWithArguments) -> Self") {
                val requestSpecsVarNames = operationNames.map { "${it}_request_spec" }
                val requestSpecs = requestSpecsVarNames.zip(operations) { requestSpecVarName, operation ->
                    "let $requestSpecVarName = ${operation.requestSpec()};"
                }.joinToString(separator = "\n")
                val towerServices = requestSpecsVarNames.zip(operationNames) { requestSpecVarName, operationName ->
                    "(#{Tower}::util::BoxCloneService::new(#{ServerOperationHandler}::operation(registry.$operationName)), $requestSpecVarName)"
                }.joinToString(prefix = "vec![", separator = ",\n", postfix = "]")

                requestSpecsVarNames.zip(operations).forEach { (requestSpecVarName, operation) ->
                    rustTemplate(
                    """
                    $requestSpecs
                    #{Router}::${runtimeRouterConstructor()}($towerServices)
                    """.trimIndent(),
                        "let $requestSpecVarName = #{RequestSpec:W};",
                        "RequestSpec" to operation.requestSpec()
                    )
                }

                withBlockTemplate("#{Router}::${runtimeRouterConstructor()}(vec![", "])", *codegenScope) {
                    requestSpecsVarNames.zip(operationNames).forEach { (requestSpecVarName, operationName) ->
                        rustTemplate(
                            "(#{Tower}::util::BoxCloneService::new(#{ServerOperationHandler}::operation(registry.$operationName)), $requestSpecVarName),",
                            *codegenScope
                        )
                    }
                }
            }
        }
    }

    /*
     * Renders the `PhantomData` generic members.
    /**
     * Returns the `PhantomData` generic members in a comma-separated list.
     */
    private fun phantomMembers(): String {
        return operationNames
            .mapIndexed { i, _ -> "In$i" }
            .joinToString(separator = ",\n")
    }
    private fun phantomMembers() = operationNames.mapIndexed { i, _ -> "In$i" }.joinToString(separator = ",\n")

    /*
    /**
     * Finds the runtime function to construct a new `Router` based on the Protocol.
     */
    private fun runtimeRouterConstructor(): String =
@@ -259,54 +262,85 @@ class ServerOperationRegistryGenerator(
            else -> TODO("Protocol $protocol not supported yet")
        }

    /*
     * Returns the `RequestSpec`s for an operation based on its HTTP-bound route.
    /**
     * Returns a writable for the `RequestSpec` for an operation based on the service's protocol.
     */
    private fun OperationShape.requestSpec(): String =
    private fun OperationShape.requestSpec(): Writable =
        when (protocol) {
            RestJson1Trait.ID, RestXmlTrait.ID -> restRequestSpec()
            AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> awsJsonOperationName()
            else -> TODO("Protocol $protocol not supported yet")
        }

    /*
     * Returns an AwsJson specific runtime `RequestSpec`.
    /**
     * Returns the operation name as required by the awsJson1.x protocols.
     */
    private fun OperationShape.awsJsonOperationName(): String {
    private fun OperationShape.awsJsonOperationName(): Writable {
        val operationName = symbolProvider.toSymbol(this).name
        // TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait
        return """String::from("$serviceName.$operationName")"""
        return writable {
            rust("""String::from("$serviceName.$operationName")""")
        }
    }

    /*
     * Generates a REST (RestJson1, RestXml) specific runtime `RequestSpec`.
    /**
     * Generates a restJson1 or restXml specific `RequestSpec`.
     */
    private fun OperationShape.restRequestSpec(): String {
    private fun OperationShape.restRequestSpec(): Writable {
        val httpTrait = httpBindingResolver.httpTrait(this)
        val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName()
        val extraCodegenScope =
            arrayOf("RequestSpec", "UriSpec", "PathAndQuerySpec", "PathSpec", "QuerySpec", "PathSegment", "QuerySegment").map {
                it to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType().member("routing::request_spec::$it")
            }.toTypedArray()

        // TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait.
        val pathSegments = httpTrait.uri.segments.map {
            "$namespace::PathSegment::" +
                if (it.isGreedyLabel) "Greedy"
                else if (it.isLabel) "Label"
                else "Literal(String::from(\"${it.content}\"))"
        val pathSegmentsVec = writable {
            withBlock("vec![", "]") {
                for (segment in httpTrait.uri.segments) {
                    val variant = when {
                        segment.isGreedyLabel -> "Greedy"
                        segment.isLabel -> "Label"
                        else -> """Literal(String::from("${segment.content}"))"""
                    }
                    rustTemplate(
                        "#{PathSegment}::$variant,",
                        *extraCodegenScope
                    )
                }
            }
        val querySegments = httpTrait.uri.queryLiterals.map {
            "$namespace::QuerySegment::" +
                if (it.value == "") "Key(String::from(\"${it.key}\"))"
                else "KeyValue(String::from(\"${it.key}\"), String::from(\"${it.value}\"))"
        }

        return """
            $namespace::RequestSpec::new(
                http::Method::${httpTrait.method},
                $namespace::UriSpec::new(
                    $namespace::PathAndQuerySpec::new(
                        $namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]),
                        $namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}])
        val querySegmentsVec = writable {
            withBlock("vec![", "]") {
                for (queryLiteral in httpTrait.uri.queryLiterals) {
                    val variant = if (queryLiteral.value == "") {
                        """Key(String::from("${queryLiteral.key}"))"""
                    } else {
                        """KeyValue(String::from("${queryLiteral.key}"), String::from("${queryLiteral.value}"))"""
                    }
                    rustTemplate("#{QuerySegment}::$variant,", *extraCodegenScope)
                }
            }
        }

        return writable {
            rustTemplate(
                """
                #{RequestSpec}::new(
                    #{Method}::${httpTrait.method},
                    #{UriSpec}::new(
                        #{PathAndQuerySpec}::new(
                            #{PathSpec}::from_vector_unchecked(#{PathSegmentsVec:W}),
                            #{QuerySpec}::from_vector_unchecked(#{QuerySegmentsVec:W})
                        )
                    ),
                )
        """.trimIndent()
                """,
                *codegenScope,
                *extraCodegenScope,
                "PathSegmentsVec" to pathSegmentsVec,
                "QuerySegmentsVec" to querySegmentsVec,
                "Method" to CargoDependency.Http.asType().member("Method"),
            )
        }
    }
}