Unverified Commit eb1cb074 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Generate Rust docs for the server's `OperationRegistry` (#1498)

Including a quickstart example for the service being generated that the
user can copy and paste.
parent a94b4b4e
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.server.smithy
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.CratesIo
import software.amazon.smithy.rust.codegen.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig

@@ -21,6 +22,7 @@ object ServerCargoDependency {
    val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
    val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
    val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
    val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.0"), scope = DependencyScope.Dev)

    fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
}
+122 −0
Original line number Diff line number Diff line
@@ -10,8 +10,10 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
@@ -25,9 +27,15 @@ import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.Errors
import software.amazon.smithy.rust.codegen.smithy.Inputs
import software.amazon.smithy.rust.codegen.smithy.Outputs
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
@@ -46,6 +54,7 @@ class ServerOperationRegistryGenerator(
    private val httpBindingResolver: HttpBindingResolver,
    private val operations: List<OperationShape>,
) {
    private val crateName = coreCodegenContext.settings.moduleName
    private val model = coreCodegenContext.model
    private val protocol = coreCodegenContext.protocol
    private val symbolProvider = coreCodegenContext.symbolProvider
@@ -70,6 +79,7 @@ class ServerOperationRegistryGenerator(
    private val operationRegistryBuilderNameWithArguments = "$operationRegistryBuilderName<$genericArguments>"

    fun render(writer: RustWriter) {
        renderOperationRegistryRustDocs(writer)
        renderOperationRegistryStruct(writer)
        renderOperationRegistryBuilderStruct(writer)
        renderOperationRegistryBuilderError(writer)
@@ -78,6 +88,83 @@ class ServerOperationRegistryGenerator(
        renderRouterImplementationFromOperationRegistryBuilder(writer)
    }

    private fun renderOperationRegistryRustDocs(writer: RustWriter) {
        writer.rustTemplate(
            """
##[allow(clippy::tabs_in_doc_comments)]
/// The `${operationRegistryName}` is the place where you can register
/// your service's operation implementations.
/// 
/// Use [`${operationRegistryBuilderName}`] to construct the
/// `${operationRegistryName}`. For each of the [operations] modeled in
/// your Smithy service, you need to provide an implementation in the
/// form of a Rust async function or closure that takes in the
/// operation's input as their first parameter, and returns the
/// operation's output. If your operation is fallible (i.e. it
/// contains the `errors` member in your Smithy model), the function
/// implementing the operation has to be fallible (i.e. return a
/// [`Result`]). **You must register an implementation for all
/// operations with the correct signature**, or your application
/// will fail to compile.
///
/// The operation registry can be converted into an [`#{Router}`] for
/// your service. This router will take care of routing HTTP
/// requests to the matching operation implementation, adhering to
/// your service's protocol and the [HTTP binding traits] that you
/// used in your Smithy model. This router can be converted into a
/// type implementing [`tower::make::MakeService`], a _service
/// factory_. You can feed this value to a [Hyper server], and the
/// server will instantiate and [`serve`] your service.
/// 
/// Here's a full example to get you started:
/// 
/// ```rust
/// use std::net::SocketAddr;
${ if (operations.any { it.errors.isNotEmpty() }) {
"/// use ${crateName}::{${Inputs.namespace}, ${Outputs.namespace}, ${Errors.namespace}};"
} else {
"/// use ${crateName}::{${Inputs.namespace}, ${Outputs.namespace}};"
} }
/// use ${crateName}::operation_registry::${operationRegistryBuilderName};
/// use #{Router};
///
/// ##[#{Tokio}::main]
/// pub async fn main() {
///    let app: Router = ${operationRegistryBuilderName}::default()
${operationNames.map { ".$it($it)" }.joinToString("\n") { it.prependIndent("///        ") }}
///        .build()
///        .expect("unable to build operation registry")
///        .into();
///
///    let bind: SocketAddr = format!("{}:{}", "127.0.0.1", "6969")
///        .parse()
///        .expect("unable to parse the server bind address and port");
///
///    let server = #{Hyper}::Server::bind(&bind).serve(app.into_make_service());
///
///    // Run your service!
///    // if let Err(err) = server.await {
///    //   eprintln!("server error: {}", err);
///    // }
/// }
///
${operationImplementationStubs(operations)}
/// ```
///
/// [`serve`]: https://docs.rs/hyper/0.14.16/hyper/server/struct.Builder.html##method.serve
/// [`tower::make::MakeService`]: https://docs.rs/tower/latest/tower/make/trait.MakeService.html
/// [HTTP binding traits]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html
/// [operations]: https://awslabs.github.io/smithy/1.0/spec/core/model.html##operation
/// [Hyper server]: https://docs.rs/hyper/latest/hyper/server/index.html
""",
            "Router" to ServerRuntimeType.Router(runtimeConfig),
            // These should be dev-dependencies. Not all sSDKs depend on `Hyper` (only those that convert the body
            // `to_bytes`), and none depend on `tokio`.
            "Tokio" to ServerCargoDependency.TokioDev.asType(),
            "Hyper" to CargoDependency.Hyper.copy(scope = DependencyScope.Dev).asType()
        )
    }

    private fun renderOperationRegistryStruct(writer: RustWriter) {
        writer.rustBlock("pub struct $operationRegistryNameWithArguments") {
            val members = operationNames
@@ -343,4 +430,39 @@ class ServerOperationRegistryGenerator(
            )
        }
    }

    private fun operationImplementationStubs(operations: List<OperationShape>): String =
        operations.joinToString("\n///\n") {
            val operationDocumentation = it.getTrait<DocumentationTrait>()?.value
            val ret = if (!operationDocumentation.isNullOrBlank()) {
                operationDocumentation.replace("#", "##").prependIndent("/// /// ") + "\n"
            } else ""
            ret +
                    """
                    /// ${it.signature()} {
                    ///     todo!()
                    /// }
                    """.trimIndent()
        }

    /**
     * Returns the function signature for an operation handler implementation. Used in the documentation.
     */
    private fun OperationShape.signature(): String {
        val inputSymbol = symbolProvider.toSymbol(inputShape(model))
        val outputSymbol = symbolProvider.toSymbol(outputShape(model))
        val errorSymbol = errorSymbol(symbolProvider)

        val inputT = "${Inputs.namespace}::${inputSymbol.name}"
        val t = "${Outputs.namespace}::${outputSymbol.name}"
        val outputT = if (errors.isEmpty()) {
            t
        } else {
            val e = "${Errors.namespace}::${errorSymbol.name}"
            "Result<$t, $e>"
        }

        val operationName = symbolProvider.toSymbol(this).name.toSnakeCase()
        return "async fn $operationName(input: $inputT) -> $outputT"
    }
}
+9 −1
Original line number Diff line number Diff line
@@ -55,7 +55,15 @@ open class ServerServiceGenerator(
        rustCrate.withModule(RustModule.public("operation_handler", "Operation handlers definition and implementation.")) { writer ->
            renderOperationHandler(writer, operations)
        }
        rustCrate.withModule(RustModule.public("operation_registry", "A registry of your service's operations.")) { writer ->
        rustCrate.withModule(
            RustModule.public(
                "operation_registry",
                """
                Contains the [`operation_registry::OperationRegistry`], a place where
                you can register your service's operation implementations.
                """
            )
        ) { writer ->
            renderOperationRegistry(writer, operations)
        }
        renderExtras(operations)
+48 −0
Original line number Diff line number Diff line
@@ -6,11 +6,18 @@
package software.amazon.smithy.rust.codegen.server.smithy.testutil

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.server.smithy.RustCodegenServerPlugin
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.ServerCodegenConfig
import software.amazon.smithy.rust.codegen.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider

// These are the settings we default to if the user does not override them in their `smithy-build.json`.
val ServerTestSymbolVisitorConfig = SymbolVisitorConfig(
@@ -29,3 +36,44 @@ fun serverTestSymbolProvider(
        serviceShape ?: ServiceShape.builder().version("test").id("test#Service").build(),
        ServerTestSymbolVisitorConfig
    )

fun serverTestRustSettings(
    service: ShapeId = ShapeId.from("notrelevant#notrelevant"),
    moduleName: String = "test-module",
    moduleVersion: String = "0.0.1",
    moduleAuthors: List<String> = listOf("notrelevant"),
    moduleDescription: String = "not relevant",
    moduleRepository: String? = null,
    runtimeConfig: RuntimeConfig = TestRuntimeConfig,
    codegenConfig: ServerCodegenConfig = ServerCodegenConfig(),
    license: String? = null,
    examplesUri: String? = null,
    customizationConfig: ObjectNode? = null
) = ServerRustSettings(
    service,
    moduleName,
    moduleVersion,
    moduleAuthors,
    moduleDescription,
    moduleRepository,
    runtimeConfig,
    codegenConfig,
    license,
    examplesUri,
    customizationConfig
)

fun serverTestCodegenContext(
    model: Model,
    serviceShape: ServiceShape? = null,
    settings: ServerRustSettings = serverTestRustSettings(),
    protocolShapeId: ShapeId? = null
): ServerCodegenContext = ServerCodegenContext(
    model,
    testSymbolProvider(model),
    serviceShape
        ?: model.serviceShapes.firstOrNull()
        ?: ServiceShape.builder().version("test").id("test#Service").build(),
    protocolShapeId ?: ShapeId.from("test#Protocol"),
    settings,
)
+120 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.server.smithy.generators

import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.util.lookup

class ServerOperationRegistryGeneratorTest {
    private val model = """
        namespace test
        
        use aws.protocols#restJson1

        @restJson1
        service Service {
            operations: [
                Frobnify,
                SayHello,
            ],
        }

        /// Only the Frobnify operation is documented,
        /// over multiple lines.
        /// And here are #hash #tags!
        @http(method: "GET", uri: "/frobnify")
        operation Frobnify {
            input: FrobnifyInputOutput,
            output: FrobnifyInputOutput,
            errors: [FrobnifyFailure]
        }

        @http(method: "GET", uri: "/hello")
        operation SayHello {
            input: SayHelloInputOutput,
            output: SayHelloInputOutput,
        }

        structure FrobnifyInputOutput {}
        structure SayHelloInputOutput {}

        @error("server")
        structure FrobnifyFailure {}
    """.asSmithyModel()

    @Test
    fun `it generates quickstart example`() {
        val serviceShape = model.lookup<ServiceShape>("test#Service")
        val (protocolShapeId, protocolGeneratorFactory) = ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols).protocolFor(
            model,
            serviceShape
        )
        val serverCodegenContext = serverTestCodegenContext(
            model,
            serviceShape,
            settings = serverTestRustSettings(moduleName = "service"),
            protocolShapeId = protocolShapeId
        )

        val index = TopDownIndex.of(serverCodegenContext.model)
        val operations = index.getContainedOperations(serverCodegenContext.serviceShape).sortedBy { it.id }
        val httpBindingResolver = protocolGeneratorFactory.protocol(serverCodegenContext).httpBindingResolver

        val generator = ServerOperationRegistryGenerator(serverCodegenContext, httpBindingResolver, operations)
        val writer = RustWriter.forModule("operation_registry")
        generator.render(writer)

        writer.toString() shouldContain
                """
                /// ```rust
                /// use std::net::SocketAddr;
                /// use service::{input, output, error};
                /// use service::operation_registry::OperationRegistryBuilder;
                /// use aws_smithy_http_server::routing::Router;
                ///
                /// #[tokio::main]
                /// pub async fn main() {
                ///    let app: Router = OperationRegistryBuilder::default()
                ///        .frobnify(frobnify)
                ///        .say_hello(say_hello)
                ///        .build()
                ///        .expect("unable to build operation registry")
                ///        .into();
                ///
                ///    let bind: SocketAddr = format!("{}:{}", "127.0.0.1", "6969")
                ///        .parse()
                ///        .expect("unable to parse the server bind address and port");
                ///
                ///    let server = hyper::Server::bind(&bind).serve(app.into_make_service());
                ///
                ///    // Run your service!
                ///    // if let Err(err) = server.await {
                ///    //   eprintln!("server error: {}", err);
                ///    // }
                /// }
                ///
                /// /// Only the Frobnify operation is documented,
                /// /// over multiple lines.
                /// /// And here are #hash #tags!
                /// async fn frobnify(input: input::FrobnifyInputOutput) -> Result<output::FrobnifyInputOutput, error::FrobnifyError> {
                ///     todo!()
                /// }
                ///
                /// async fn say_hello(input: input::SayHelloInputOutput) -> output::SayHelloInputOutput {
                ///     todo!()
                /// }
                /// ```
                ///""".trimIndent()
    }
}
Loading