Unverified Commit f0929e74 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Allow server decorators to inject methods on config (#3111)

PR #3095 added a code-generated `${serviceName}Config` object on which
users can register layers and plugins. For example:

```rust
let config = PokemonServiceConfig::builder()
    .layer(layers)
    .http_plugin(authn_plugin)
    .model_plugin(authz_plugin)
    .build();
```

This PR makes it so that server decorators can inject methods on this
config builder object. These methods can apply arbitrary layers, HTTP
plugins, and/or model plugins. Moreover, the decorator can configure
whether invoking such method is required or not.

For example, a decorator can inject an `aws_auth` method that configures
some plugins using its input arguments. Missing invocation of this
method
will result in the config failing to build:

```rust
let _: SimpleServiceConfig<
    // No layers have been applied.
    tower::layer::util::Identity,
    // One HTTP plugin has been applied.
    PluginStack<IdentityPlugin, IdentityPlugin>,
    // One model plugin has been applied.
    PluginStack<IdentityPlugin, IdentityPlugin>,
> = SimpleServiceConfig::builder()
    // This method has been injected in the config builder!
    .aws_auth("a".repeat(69).to_owned(), 69)
    // The method will apply one HTTP plugin and one model plugin,
    // configuring them with the input arguments. Configuration can be
    // declared to be fallible, in which case we get a `Result` we unwrap
    // here.
    .expect("failed to configure aws_auth")
    .build()
    // Since `aws_auth` has been marked as required, if the user misses
    // invoking it, this would panic here.
    .unwrap();
```

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent cc303ab1
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -22,6 +22,7 @@ object ServerCargoDependency {
    val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
    val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
    val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13"))
    val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13"))
    val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
    val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
    val ThisError: CargoDependency = CargoDependency("thiserror", CratesIo("1.0"))
    val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
    val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
    val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev)
    val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev)
    val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5"))
    val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5"))
+7 −1
Original line number Original line Diff line number Diff line
@@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Unconstraine
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isBuilderFallible
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator
@@ -591,11 +592,15 @@ open class ServerCodegenVisitor(
        logger.info("[rust-server-codegen] Generating a service $shape")
        logger.info("[rust-server-codegen] Generating a service $shape")
        val serverProtocol = protocolGeneratorFactory.protocol(codegenContext) as ServerProtocol
        val serverProtocol = protocolGeneratorFactory.protocol(codegenContext) as ServerProtocol


        val configMethods = codegenDecorator.configMethods(codegenContext)
        val isConfigBuilderFallible = configMethods.isBuilderFallible()

        // Generate root.
        // Generate root.
        rustCrate.lib {
        rustCrate.lib {
            ServerRootGenerator(
            ServerRootGenerator(
                serverProtocol,
                serverProtocol,
                codegenContext,
                codegenContext,
                isConfigBuilderFallible,
            ).render(this)
            ).render(this)
        }
        }


@@ -612,9 +617,10 @@ open class ServerCodegenVisitor(
            ServerServiceGenerator(
            ServerServiceGenerator(
                codegenContext,
                codegenContext,
                serverProtocol,
                serverProtocol,
                isConfigBuilderFallible,
            ).render(this)
            ).render(this)


            ServiceConfigGenerator(codegenContext).render(this)
            ServiceConfigGenerator(codegenContext, configMethods).render(this)


            ScopeMacroGenerator(codegenContext).render(this)
            ScopeMacroGenerator(codegenContext).render(this)
        }
        }
+12 −4
Original line number Original line Diff line number Diff line
@@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult
import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult
import software.amazon.smithy.rust.codegen.server.smithy.generators.ConfigMethod
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import java.util.logging.Logger
import java.util.logging.Logger
@@ -41,6 +42,12 @@ interface ServerCodegenDecorator : CoreCodegenDecorator<ServerCodegenContext, Se
     * Therefore, ensure that all the structure shapes returned by this method are not in the service's closure.
     * Therefore, ensure that all the structure shapes returned by this method are not in the service's closure.
     */
     */
    fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> = emptyList()
    fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> = emptyList()

    /**
     * Configuration methods that should be injected into the `${serviceName}Config` struct to allow users to configure
     * pre-applied layers and plugins.
     */
    fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> = emptyList()
}
}


/**
/**
@@ -74,10 +81,11 @@ class CombinedServerCodegenDecorator(decorators: List<ServerCodegenDecorator>) :
            decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated)
            decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated)
        }
        }


    override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> {
    override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> =
        return orderedDecorators.map { decorator -> decorator.postprocessGenerateAdditionalStructures(operationShape) }
        orderedDecorators.flatMap { it.postprocessGenerateAdditionalStructures(operationShape) }
            .flatten()

    }
    override fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> =
        orderedDecorators.flatMap { it.configMethods(codegenContext) }


    companion object {
    companion object {
        fun fromClasspath(
        fun fromClasspath(
+32 −6
Original line number Original line Diff line number Diff line
@@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output
open class ServerRootGenerator(
open class ServerRootGenerator(
    val protocol: ServerProtocol,
    val protocol: ServerProtocol,
    private val codegenContext: ServerCodegenContext,
    private val codegenContext: ServerCodegenContext,
    private val isConfigBuilderFallible: Boolean,
) {
) {
    private val index = TopDownIndex.of(codegenContext.model)
    private val index = TopDownIndex.of(codegenContext.model)
    private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(
    private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(
@@ -57,6 +58,8 @@ open class ServerRootGenerator(
            }
            }
            .join("//!\n")
            .join("//!\n")


        val unwrapConfigBuilder = if (isConfigBuilderFallible) ".expect(\"config failed to build\")" else ""

        writer.rustTemplate(
        writer.rustTemplate(
            """
            """
            //! A fast and customizable Rust implementation of the $serviceName Smithy service.
            //! A fast and customizable Rust implementation of the $serviceName Smithy service.
@@ -75,7 +78,10 @@ open class ServerRootGenerator(
            //! ## async fn dummy() {
            //! ## async fn dummy() {
            //! use $crateName::{$serviceName, ${serviceName}Config};
            //! use $crateName::{$serviceName, ${serviceName}Config};
            //!
            //!
            //! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked();
            //! ## let app = $serviceName::builder(
            //! ##     ${serviceName}Config::builder()
            //! ##         .build()$unwrapConfigBuilder
            //! ## ).build_unchecked();
            //! let server = app.into_make_service();
            //! let server = app.into_make_service();
            //! let bind: SocketAddr = "127.0.0.1:6969".parse()
            //! let bind: SocketAddr = "127.0.0.1:6969".parse()
            //!     .expect("unable to parse the server bind address and port");
            //!     .expect("unable to parse the server bind address and port");
@@ -92,7 +98,10 @@ open class ServerRootGenerator(
            //! use $crateName::$serviceName;
            //! use $crateName::$serviceName;
            //!
            //!
            //! ## async fn dummy() {
            //! ## async fn dummy() {
            //! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked();
            //! ## let app = $serviceName::builder(
            //! ##     ${serviceName}Config::builder()
            //! ##         .build()$unwrapConfigBuilder
            //! ## ).build_unchecked();
            //! let handler = LambdaHandler::new(app);
            //! let handler = LambdaHandler::new(app);
            //! lambda_http::run(handler).await.unwrap();
            //! lambda_http::run(handler).await.unwrap();
            //! ## }
            //! ## }
@@ -118,7 +127,7 @@ open class ServerRootGenerator(
            //! let http_plugins = HttpPlugins::new()
            //! let http_plugins = HttpPlugins::new()
            //!         .push(LoggingPlugin)
            //!         .push(LoggingPlugin)
            //!         .push(MetricsPlugin);
            //!         .push(MetricsPlugin);
            //! let config = ${serviceName}Config::builder().build();
            //! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
            //! let builder: $builderName<Body, _, _, _> = $serviceName::builder(config);
            //! let builder: $builderName<Body, _, _, _> = $serviceName::builder(config);
            //! ```
            //! ```
            //!
            //!
@@ -183,13 +192,13 @@ open class ServerRootGenerator(
            //!
            //!
            //! ## Example
            //! ## Example
            //!
            //!
            //! ```rust
            //! ```rust,no_run
            //! ## use std::net::SocketAddr;
            //! ## use std::net::SocketAddr;
            //! use $crateName::{$serviceName, ${serviceName}Config};
            //! use $crateName::{$serviceName, ${serviceName}Config};
            //!
            //!
            //! ##[#{Tokio}::main]
            //! ##[#{Tokio}::main]
            //! pub async fn main() {
            //! pub async fn main() {
            //!    let config = ${serviceName}Config::builder().build();
            //!    let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
            //!    let app = $serviceName::builder(config)
            //!    let app = $serviceName::builder(config)
            ${builderFieldNames.values.joinToString("\n") { "//!        .$it($it)" }}
            ${builderFieldNames.values.joinToString("\n") { "//!        .$it($it)" }}
            //!        .build()
            //!        .build()
@@ -236,6 +245,23 @@ open class ServerRootGenerator(
    fun render(rustWriter: RustWriter) {
    fun render(rustWriter: RustWriter) {
        documentation(rustWriter)
        documentation(rustWriter)


        rustWriter.rust("pub use crate::service::{$serviceName, ${serviceName}Config, ${serviceName}ConfigBuilder, ${serviceName}Builder, MissingOperationsError};")
        // Only export config builder error if fallible.
        val configErrorReExport = if (isConfigBuilderFallible) {
            "${serviceName}ConfigError,"
        } else {
            ""
        }
        rustWriter.rust(
            """
            pub use crate::service::{
                $serviceName,
                ${serviceName}Config,
                ${serviceName}ConfigBuilder,
                $configErrorReExport
                ${serviceName}Builder,
                MissingOperationsError
            };
            """,
        )
    }
    }
}
}
+8 −2
Original line number Original line Diff line number Diff line
@@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output
class ServerServiceGenerator(
class ServerServiceGenerator(
    private val codegenContext: ServerCodegenContext,
    private val codegenContext: ServerCodegenContext,
    private val protocol: ServerProtocol,
    private val protocol: ServerProtocol,
    private val isConfigBuilderFallible: Boolean,
) {
) {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val runtimeConfig = codegenContext.runtimeConfig
    private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType()
    private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType()
@@ -107,6 +108,11 @@ class ServerServiceGenerator(
            val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///")
            val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///")
            val handler = docHandler.docSignature()
            val handler = docHandler.docSignature()
            val handlerFixed = docHandler.docFixedSignature()
            val handlerFixed = docHandler.docFixedSignature()
            val unwrapConfigBuilder = if (isConfigBuilderFallible) {
                ".expect(\"config failed to build\")"
            } else {
                ""
            }
            rustTemplate(
            rustTemplate(
                """
                """
                /// Sets the [`$structName`](crate::operation_shape::$structName) operation.
                /// Sets the [`$structName`](crate::operation_shape::$structName) operation.
@@ -123,7 +129,7 @@ class ServerServiceGenerator(
                ///
                ///
                #{Handler:W}
                #{Handler:W}
                ///
                ///
                /// let config = ${serviceName}Config::builder().build();
                /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
                /// let app = $serviceName::builder(config)
                /// let app = $serviceName::builder(config)
                ///     .$fieldName(handler)
                ///     .$fieldName(handler)
                ///     /* Set other handlers */
                ///     /* Set other handlers */
@@ -186,7 +192,7 @@ class ServerServiceGenerator(
                ///
                ///
                #{HandlerFixed:W}
                #{HandlerFixed:W}
                ///
                ///
                /// let config = ${serviceName}Config::builder().build();
                /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
                /// let svc = #{Tower}::util::service_fn(handler);
                /// let svc = #{Tower}::util::service_fn(handler);
                /// let app = $serviceName::builder(config)
                /// let app = $serviceName::builder(config)
                ///     .${fieldName}_service(svc)
                ///     .${fieldName}_service(svc)
Loading