Unverified Commit 0313a3c7 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Allow third-party devs to write custom plugins for service builders (#1736)



* Allow custom layers to every operation

Third-party developers can add custom functions to a service builder,
through a `BuilderModifier`.
These functionalities are applied to every operation of that service.

An example of a builder modifier can be found in
rust-runtime/aws-smithy-http-server/examples/pokemon-service/lib.rs

Signed-off-by: default avatarHarry Barber <hlbarber@amazon.co.uk>
Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
Co-authored-by: default avatarHarry Barber <hlbarber@amazon.co.uk>
parent c52bb03e
Loading
Loading
Loading
Loading
+34 −9
Original line number Diff line number Diff line
@@ -64,6 +64,7 @@ class ServerServiceGeneratorV2(

    /** A `Writable` block containing all the `Handler` and `Operation` setters for the builder. */
    private fun builderSetters(): Writable = writable {
        val pluginType = listOf("Pl")
        for ((index, pair) in builderFieldNames.zip(operationStructNames).withIndex()) {
            val (fieldName, structName) = pair

@@ -128,16 +129,17 @@ class ServerServiceGeneratorV2(
                /// [`$structName`](crate::operation_shape::$structName) using either
                /// [`OperationShape::from_handler`](#{SmithyHttpServer}::operation::OperationShapeExt::from_handler) or
                /// [`OperationShape::from_service`](#{SmithyHttpServer}::operation::OperationShapeExt::from_service).
                pub fn ${fieldName}_operation<NewOp, NewExts>(self, value: NewOp) -> $builderName<${(replacedOpGenerics + replacedExtGenerics).joinToString(", ")}>
                pub fn ${fieldName}_operation<NewOp, NewExts>(self, value: NewOp) -> $builderName<${(replacedOpGenerics + replacedExtGenerics + pluginType).joinToString(", ")}>
                {
                    $builderName {
                        ${switchedFields.joinToString(", ")},
                        _exts: std::marker::PhantomData
                        _exts: std::marker::PhantomData,
                        plugin: self.plugin,
                    }
                }
                """,
                "Protocol" to protocol.markerStruct(),
                "HandlerSetterGenerics" to (replacedOpServiceGenerics + (replacedExtGenerics.map { writable(it) })).join(", "),
                "HandlerSetterGenerics" to (replacedOpServiceGenerics + ((replacedExtGenerics + pluginType).map { writable(it) })).join(", "),
                *codegenScope,
            )

@@ -159,6 +161,7 @@ class ServerServiceGeneratorV2(
                    crate::operation_shape::${symbolProvider.toSymbol(operation).name.toPascalCase()},
                    $exts,
                    B,
                    Pl,
                >,
                $type::Service: Clone + Send + 'static,
                <$type::Service as #{Tower}::Service<#{Http}::Request<B>>>::Future: Send + 'static,
@@ -174,18 +177,26 @@ class ServerServiceGeneratorV2(
    /** Returns a `Writable` containing the builder struct definition and its implementations. */
    private fun builder(): Writable = writable {
        val extensionTypesDefault = extensionTypes.map { "$it = ()" }
        val structGenerics = (builderOps + extensionTypesDefault).joinToString(", ")
        val builderGenerics = (builderOps + extensionTypes).joinToString(", ")
        val pluginName = "Pl"
        val pluginTypeList = listOf(pluginName)
        val newPluginType = "New$pluginName"
        val pluginTypeDefault = listOf("$pluginName = #{SmithyHttpServer}::plugin::IdentityPlugin")
        val structGenerics = (builderOps + extensionTypesDefault + pluginTypeDefault).joinToString(", ")
        val builderGenerics = (builderOps + extensionTypes + pluginTypeList).joinToString(", ")
        val builderGenericsNoPlugin = (builderOps + extensionTypes).joinToString(", ")

        // Generate router construction block.
        val router = protocol
            .routerConstruction(
                builderFieldNames
                    .map {
                        writable { rustTemplate("self.$it.upgrade()") }
                        writable { rustTemplate("self.$it.upgrade(&self.plugin)") }
                    }
                    .asIterable(),
            )
        val setterFields = builderFieldNames.map { item ->
            "$item: self.$item"
        }.joinToString(", ")
        rustTemplate(
            """
            /// The service builder for [`$serviceName`].
@@ -194,7 +205,8 @@ class ServerServiceGeneratorV2(
            pub struct $builderName<$structGenerics> {
                ${builderFields.joinToString(", ")},
                ##[allow(unused_parens)]
                _exts: std::marker::PhantomData<(${extensionTypes.joinToString(", ")})>
                _exts: std::marker::PhantomData<(${extensionTypes.joinToString(", ")})>,
                plugin: $pluginName,
            }

            impl<$builderGenerics> $builderName<$builderGenerics> {
@@ -213,6 +225,17 @@ class ServerServiceGeneratorV2(
                    }
                }
            }

            impl<$builderGenerics, $newPluginType> #{SmithyHttpServer}::plugin::Pluggable<$newPluginType> for $builderName<$builderGenerics> {
                type Output = $builderName<$builderGenericsNoPlugin, #{SmithyHttpServer}::plugin::PluginStack<$pluginName, $newPluginType>>;
                fn apply(self, plugin: $newPluginType) -> Self::Output {
                    $builderName {
                        $setterFields,
                        _exts: self._exts,
                        plugin: #{SmithyHttpServer}::plugin::PluginStack::new(self.plugin, plugin),
                    }
                }
            }
            """,
            "Setters" to builderSetters(),
            "BuildConstraints" to buildConstraints.join(", "),
@@ -265,7 +288,8 @@ class ServerServiceGeneratorV2(
                pub fn builder() -> $builderName<#{NotSetGenerics:W}> {
                    $builderName {
                        #{NotSetFields:W},
                        _exts: std::marker::PhantomData
                        _exts: std::marker::PhantomData,
                        plugin: #{SmithyHttpServer}::plugin::IdentityPlugin
                    }
                }

@@ -276,7 +300,8 @@ class ServerServiceGeneratorV2(
                pub fn unchecked_builder() -> $builderName<#{InternalFailureGenerics:W}> {
                    $builderName {
                        #{InternalFailureFields:W},
                        _exts: std::marker::PhantomData
                        _exts: std::marker::PhantomData,
                        plugin: #{SmithyHttpServer}::plugin::IdentityPlugin
                    }
                }
            }
+3 −0
Original line number Diff line number Diff line
@@ -19,6 +19,9 @@ use pokemon_service_server_sdk::{error, input, model, model::CapturingPayload, o
use rand::Rng;
use tracing_subscriber::{prelude::*, EnvFilter};

#[doc(hidden)]
pub mod plugin;

const PIKACHU_ENGLISH_FLAVOR_TEXT: &str =
    "When several of these Pokémon gather, their electricity could build and cause lightning storms.";
const PIKACHU_SPANISH_FLAVOR_TEXT: &str =
+80 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_smithy_http_server::plugin::Plugin;

/// A [`Service`](tower::Service) that adds a print log.
#[derive(Clone, Debug)]
pub struct PrintService<S> {
    inner: S,
    name: &'static str,
}

impl<R, S> tower::Service<R> for PrintService<S>
where
    S: tower::Service<R>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = S::Future;

    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: R) -> Self::Future {
        println!("Hi {}", self.name);
        self.inner.call(req)
    }
}

/// A [`Layer`](tower::Layer) which constructs the [`PrintService`].
#[derive(Debug)]
pub struct PrintLayer {
    name: &'static str,
}
impl<S> tower::Layer<S> for PrintLayer {
    type Service = PrintService<S>;

    fn layer(&self, service: S) -> Self::Service {
        PrintService {
            inner: service,
            name: self.name,
        }
    }
}

/// A [`Plugin`]() for a service builder to add a [`PrintLayer`] over operations.
#[derive(Debug)]
pub struct PrintPlugin;
impl<P, Op, S, L> Plugin<P, Op, S, L> for PrintPlugin
where
    Op: aws_smithy_http_server::operation::OperationShape,
{
    type Service = S;
    type Layer = tower::layer::util::Stack<L, PrintLayer>;

    fn map(
        &self,
        input: aws_smithy_http_server::operation::Operation<S, L>,
    ) -> aws_smithy_http_server::operation::Operation<Self::Service, Self::Layer> {
        input.layer(PrintLayer { name: Op::NAME })
    }
}

/// An extension to service builders to add the `print()` function.
pub trait PrintExt: aws_smithy_http_server::plugin::Pluggable<PrintPlugin> {
    /// Causes all operations to print the operation name when called.
    ///
    /// This works by applying the [`PrintPlugin`].
    fn print(self) -> Self::Output
    where
        Self: Sized,
    {
        self.apply(PrintPlugin)
    }
}

impl<Builder> PrintExt for Builder where Builder: aws_smithy_http_server::plugin::Pluggable<PrintPlugin> {}
+2 −0
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@ pub mod logging;
#[doc(hidden)]
pub mod operation;
#[doc(hidden)]
pub mod plugin;
#[doc(hidden)]
pub mod protocols;
#[doc(hidden)]
pub mod rejection;
+19 −13
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ use tracing::error;

use crate::{
    body::BoxBody,
    plugin::Plugin,
    request::{FromParts, FromRequest},
    response::IntoResponse,
    runtime_error::InternalFailureException,
@@ -220,14 +221,14 @@ where

/// Provides an interface to convert a representation of an operation to a HTTP [`Service`](tower::Service) with
/// canonical associated types.
pub trait Upgradable<Protocol, Operation, Exts, B> {
pub trait Upgradable<Protocol, Operation, Exts, B, Plugin> {
    type Service: Service<http::Request<B>, Response = http::Response<BoxBody>>;

    /// Performs an upgrade from a representation of an operation to a HTTP [`Service`](tower::Service).
    fn upgrade(self) -> Self::Service;
    fn upgrade(self, plugin: &Plugin) -> Self::Service;
}

impl<P, Op, Exts, B, S, L, PollError> Upgradable<P, Op, Exts, B> for Operation<S, L>
impl<P, Op, Exts, B, Pl, S, L, PollError> Upgradable<P, Op, Exts, B, Pl> for Operation<S, L>
where
    // `Op` is used to specify the operation shape
    Op: OperationShape,
@@ -245,21 +246,26 @@ where
    // The signature of the inner service is correct
    S: Service<(Op::Input, Exts), Response = Op::Output, Error = OperationError<Op::Error, PollError>> + Clone,

    // Layer applies correctly to `Upgrade<P, Op, Exts, B, S>`
    L: Layer<Upgrade<P, Op, Exts, B, S>>,
    // The plugin takes this operation as input
    Pl: Plugin<P, Op, S, L>,

    // The modified Layer applies correctly to `Upgrade<P, Op, Exts, B, S>`
    Pl::Layer: Layer<Upgrade<P, Op, Exts, B, Pl::Service>>,

    // The signature of the output is correct
    L::Service: Service<http::Request<B>, Response = http::Response<BoxBody>>,
    <Pl::Layer as Layer<Upgrade<P, Op, Exts, B, Pl::Service>>>::Service:
        Service<http::Request<B>, Response = http::Response<BoxBody>>,
{
    type Service = L::Service;
    type Service = <Pl::Layer as Layer<Upgrade<P, Op, Exts, B, Pl::Service>>>::Service;

    /// Takes the [`Operation<S, L>`](Operation), applies [`UpgradeLayer`] to
    /// Takes the [`Operation<S, L>`](Operation), applies [`Plugin`], then applies [`UpgradeLayer`] to
    /// the modified `S`, then finally applies the modified `L`.
    ///
    /// The composition is made explicit in the method constraints and return type.
    fn upgrade(self) -> Self::Service {
        let layer = Stack::new(UpgradeLayer::new(), self.layer);
        layer.layer(self.inner)
    fn upgrade(self, plugin: &Pl) -> Self::Service {
        let mapped = plugin.map(self);
        let layer = Stack::new(UpgradeLayer::new(), mapped.layer);
        layer.layer(mapped.inner)
    }
}

@@ -273,13 +279,13 @@ pub struct MissingOperation;
/// This _does_ implement [`Upgradable`] but produces a [`Service`] which always returns an internal failure message.
pub struct FailOnMissingOperation;

impl<P, Op, Exts, B> Upgradable<P, Op, Exts, B> for FailOnMissingOperation
impl<P, Op, Exts, B, Pl> Upgradable<P, Op, Exts, B, Pl> for FailOnMissingOperation
where
    InternalFailureException: IntoResponse<P>,
{
    type Service = MissingFailure<P>;

    fn upgrade(self) -> Self::Service {
    fn upgrade(self, _plugin: &Pl) -> Self::Service {
        MissingFailure { _protocol: PhantomData }
    }
}
Loading