Unverified Commit b2bdcba5 authored by Harry Barber's avatar Harry Barber Committed by GitHub
Browse files

Parameterize `Plugin` by service rather than protocol (#2772)

## Motivation and Context

Closes https://github.com/awslabs/smithy-rs/issues/1839

Currently, `Plugin` is parameterized by protocol and operation. To
improve symmetry, extensibility and uniformity we switch this to be
parameterized by service instead. The protocol can still be recovered
via the `type Protocol` associated type on `ServiceShape`.

## Description

- Add `ServiceShape` trait, encoding the properties of a Smithy service.
- Change `Plugin<Protocol, Operation, S>` to `Plugin<Service, Operation,
S>`.
- Add `FilterByOperation` and `filter_by_operation` `Plugin`s.
parent 3d0db565
Loading
Loading
Loading
Loading
+117 −9
Original line number Diff line number Diff line
@@ -182,15 +182,64 @@ meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "ysaito1001"

[[smithy-rs]]
message = """
The middleware system has been reworked as we push for a unified, simple, and consistent API. The following changes have been made in service of this goal:
message = """The middleware system has been reworked as we push for a unified, simple, and consistent API. The following changes have been made in service of this goal:

- A `ServiceShape` trait has been added.
- The `Plugin` trait has been simplified.
- The `Operation` structure has been removed.
- A `Scoped` `Plugin` has been added.

The `Plugin` trait has now been simplified and the `Operation` struct has been removed.

## Addition of `ServiceShape`

Since the [0.52 release](https://github.com/awslabs/smithy-rs/releases/tag/release-2022-12-12) the `OperationShape` has existed.

```rust
/// Models the [Smithy Operation shape].
///
/// [Smithy Operation shape]: https://awslabs.github.io/smithy/1.0/spec/core/model.html#operation
pub trait OperationShape {
    /// The ID of the operation.
    const ID: ShapeId;

    /// The operation input.
    type Input;
    /// The operation output.
    type Output;
    /// The operation error. [`Infallible`](std::convert::Infallible) in the case where no error
    /// exists.
    type Error;
}
```

This allowed `Plugin` authors to access these associated types and constants. See the [`PrintPlugin`](https://github.com/awslabs/smithy-rs/blob/main/examples/pokemon-service/src/plugin.rs) as an example.

We continue with this approach and introduce the following trait:

```rust
/// Models the [Smithy Service shape].
///
/// [Smithy Service shape]: https://smithy.io/2.0/spec/service-types.html
pub trait ServiceShape {
    /// The [`ShapeId`] of the service.
    const ID: ShapeId;

    /// The version of the service.
    const VERSION: Option<&'static str>;

    /// The [Protocol] applied to this service.
    ///
    /// [Protocol]: https://smithy.io/2.0/spec/protocol-traits.html
    type Protocol;

    /// An enumeration of all operations contained in this service.
    type Operations;
}
```

With the changes to `Plugin`, described below, middleware authors now have access to this information at compile time.

## Simplication of the `Plugin` trait

Previously,
@@ -209,14 +258,16 @@ modified an `Operation`.
Now,

```rust
trait Plugin<Protocol, Operation, S> {
    type Service;
trait Plugin<Service, Operation, T> {
    type Output;

    fn apply(&self, svc: S) -> Self::Service;
    fn apply(&self, input: T) -> Self::Output;
}
```

maps a `tower::Service` to a `tower::Service`. This is equivalent to `tower::Layer` with two extra type parameters: `Protocol` and `Operation`.
maps a `tower::Service` to a `tower::Service`. This is equivalent to `tower::Layer` with two extra type parameters: `Service` and `Operation`, which implement `ServiceShape` and `OperationShape` respectively.

Having both `Service` and `Operation` as type parameters also provides an even surface for advanced users to extend the codegenerator in a structured way. See [this issue](https://github.com/awslabs/smithy-rs/issues/2777) for more context.

The following middleware setup

@@ -286,18 +337,33 @@ where

pub struct PrintPlugin;

impl<P, Op, S, L> Plugin<P, Op, S, L> for PrintPlugin
impl<Service, Op, T> Plugin<Service, Operation, T> for PrintPlugin
where
    Op: OperationShape,
{
    type Service = PrintService<S>;
    type Output = PrintService<S>;

    fn apply(&self, svc: S) -> Self::Service {
    fn apply(&self, inner: T) -> Self::Output {
        PrintService { inner, name: Op::ID.name() }
    }
}
```

Alternatively, using the new `ServiceShape`, implemented on `Ser`:

```rust
impl<Service, Operation, T> Plugin<Service, Operation, T> for PrintPlugin
where
    Ser: ServiceShape,
{
    type Service = PrintService<S>;

    fn apply(&self, inner: T) -> Self::Service {
        PrintService { inner, name: Ser::ID.name() }
    }
}
```

A single `Plugin` can no longer apply a `tower::Layer` on HTTP requests/responses _and_ modelled structures at the same time (see middleware positions [C](https://awslabs.github.io/smithy-rs/design/server/middleware.html#c-operation-specific-http-middleware) and [D](https://awslabs.github.io/smithy-rs/design/server/middleware.html#d-operation-specific-model-middleware). Instead one `Plugin` must be specified for each and passed to the service builder constructor separately:

```rust
@@ -442,3 +508,45 @@ message = "Add a `send_with` function on `-Input` types for sending requests wit
author = "thomas-k-cameron"
references = ["smithy-rs#2652"]
meta = { "breaking" = false, "tada" = true, "bug" = false }

[[smithy-rs]]
message = """Remove `filter_by_operation_id` and `plugin_from_operation_id_fn` in favour of `filter_by_operation` and `plugin_from_operation_fn`.

Previously, we provided `filter_by_operation_id` which filtered `Plugin` application via a predicate over the Shape ID.

```rust
use aws_smithy_http_server::plugin::filter_by_operation_id;
use pokemon_service_server_sdk::operation_shape::CheckHealth;

let filtered = filter_by_operation_id(plugin, |name| name != CheckHealth::NAME);
```

This had the problem that the user is unable to exhaustively match over a `&'static str`. To remedy this we have switched to `filter_by_operation` which is a predicate over an enum containing all operations contained in the service.

```rust
use aws_smithy_http_server::plugin::filter_by_operation_id;
use pokemon_service_server_sdk::service::Operation;

let filtered = filter_by_operation(plugin, |op: Operation| op != Operation::CheckHealth);
```

Similarly, `plugin_from_operation_fn` now allows for

```rust
use aws_smithy_http_server::plugin::plugin_from_operation_fn;
use pokemon_service_server_sdk::service::Operation;

fn map<S>(op: Operation, inner: S) -> PrintService<S> {
    match op {
        Operation::CheckHealth => PrintService { name: op.shape_id().name(), inner },
        Operation::GetPokemonSpecies => PrintService { name: "hello world", inner },
        _ => todo!()
    }
}

let plugin = plugin_from_operation_fn(map);
```
"""
references = ["smithy-rs#2740", "smithy-rs#2759", "smithy-rs#2779"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "hlbarber"
+1 −1
Original line number Diff line number Diff line
@@ -108,7 +108,7 @@ class ScopeMacroGenerator(
            /// ## use #{SmithyHttpServer}::plugin::{Plugin, Scoped};
            /// ## use $crateName::scope;
            /// ## struct MockPlugin;
            /// ## impl<P, Op, S> Plugin<P, Op, S> for MockPlugin { type Service = u32; fn apply(&self, svc: S) -> u32 { 3 } }
            /// ## impl<S, Op, T> Plugin<S, Op, T> for MockPlugin { type Output = u32; fn apply(&self, input: T) -> u32 { 3 } }
            /// ## let scoped_a = Scoped::new::<ScopeA>(MockPlugin);
            /// ## let scoped_b = Scoped::new::<ScopeB>(MockPlugin);
            /// ## let a = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_a, 6);
+95 −20
Original line number Diff line number Diff line
@@ -50,7 +50,8 @@ class ServerServiceGenerator(
    private val crateName = codegenContext.moduleUseName()

    private val service = codegenContext.serviceShape
    private val serviceName = service.id.name.toPascalCase()
    private val serviceId = service.id
    private val serviceName = serviceId.name.toPascalCase()
    private val builderName = "${serviceName}Builder"
    private val builderBodyGenericTypeName = "Body"

@@ -136,30 +137,30 @@ class ServerServiceGenerator(
                    HandlerType: #{SmithyHttpServer}::operation::Handler<crate::operation_shape::$structName, HandlerExtractors>,

                    ModelPlugin: #{SmithyHttpServer}::plugin::Plugin<
                        #{Protocol},
                        $serviceName,
                        crate::operation_shape::$structName,
                        #{SmithyHttpServer}::operation::IntoService<crate::operation_shape::$structName, HandlerType>
                    >,
                    #{SmithyHttpServer}::operation::UpgradePlugin::<UpgradeExtractors>: #{SmithyHttpServer}::plugin::Plugin<
                        #{Protocol},
                        $serviceName,
                        crate::operation_shape::$structName,
                        ModelPlugin::Service
                        ModelPlugin::Output
                    >,
                    HttpPlugin: #{SmithyHttpServer}::plugin::Plugin<
                        #{Protocol},
                        $serviceName,
                        crate::operation_shape::$structName,
                        <
                            #{SmithyHttpServer}::operation::UpgradePlugin::<UpgradeExtractors>
                            as #{SmithyHttpServer}::plugin::Plugin<
                                #{Protocol},
                                $serviceName,
                                crate::operation_shape::$structName,
                                ModelPlugin::Service
                                ModelPlugin::Output
                            >
                        >::Service
                        >::Output
                    >,

                    HttpPlugin::Service: #{Tower}::Service<#{Http}::Request<Body>, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static,
                    <HttpPlugin::Service as #{Tower}::Service<#{Http}::Request<Body>>>::Future: Send + 'static,
                    HttpPlugin::Output: #{Tower}::Service<#{Http}::Request<Body>, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static,
                    <HttpPlugin::Output as #{Tower}::Service<#{Http}::Request<Body>>>::Future: Send + 'static,

                {
                    use #{SmithyHttpServer}::operation::OperationShapeExt;
@@ -199,30 +200,30 @@ class ServerServiceGenerator(
                    S: #{SmithyHttpServer}::operation::OperationService<crate::operation_shape::$structName, ServiceExtractors>,

                    ModelPlugin: #{SmithyHttpServer}::plugin::Plugin<
                        #{Protocol},
                        $serviceName,
                        crate::operation_shape::$structName,
                        #{SmithyHttpServer}::operation::Normalize<crate::operation_shape::$structName, S>
                    >,
                    #{SmithyHttpServer}::operation::UpgradePlugin::<UpgradeExtractors>: #{SmithyHttpServer}::plugin::Plugin<
                        #{Protocol},
                        $serviceName,
                        crate::operation_shape::$structName,
                        ModelPlugin::Service
                        ModelPlugin::Output
                    >,
                    HttpPlugin: #{SmithyHttpServer}::plugin::Plugin<
                        #{Protocol},
                        $serviceName,
                        crate::operation_shape::$structName,
                        <
                            #{SmithyHttpServer}::operation::UpgradePlugin::<UpgradeExtractors>
                            as #{SmithyHttpServer}::plugin::Plugin<
                                #{Protocol},
                                $serviceName,
                                crate::operation_shape::$structName,
                                ModelPlugin::Service
                                ModelPlugin::Output
                            >
                        >::Service
                        >::Output
                    >,

                    HttpPlugin::Service: #{Tower}::Service<#{Http}::Request<Body>, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static,
                    <HttpPlugin::Service as #{Tower}::Service<#{Http}::Request<Body>>>::Future: Send + 'static,
                    HttpPlugin::Output: #{Tower}::Service<#{Http}::Request<Body>, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static,
                    <HttpPlugin::Output as #{Tower}::Service<#{Http}::Request<Body>>>::Future: Send + 'static,

                {
                    use #{SmithyHttpServer}::operation::OperationShapeExt;
@@ -353,7 +354,6 @@ class ServerServiceGenerator(
            for (operationShape in operations) {
                val fieldName = builderFieldNames[operationShape]!!
                val (specBuilderFunctionName, _) = requestSpecMap.getValue(operationShape)
                val operationZstTypeName = operationStructNames[operationShape]!!
                rustTemplate(
                    """
                    (
@@ -590,6 +590,75 @@ class ServerServiceGenerator(
        )
    }

    private fun serviceShapeImpl(): Writable = writable {
        val namespace = serviceId.namespace
        val name = serviceId.name
        val absolute = serviceId.toString().replace("#", "##")
        val version = codegenContext.serviceShape.version?.let { "Some(\"$it\")" } ?: "None"
        rustTemplate(
            """
            impl #{SmithyHttpServer}::service::ServiceShape for $serviceName {
                const ID: #{SmithyHttpServer}::shape_id::ShapeId = #{SmithyHttpServer}::shape_id::ShapeId::new("$absolute", "$namespace", "$name");

                const VERSION: Option<&'static str> = $version;

                type Protocol = #{Protocol};

                type Operations = Operation;
            }
            """,
            "Protocol" to protocol.markerStruct(),
            *codegenScope,
        )
    }

    private fun operationEnum(): Writable = writable {
        val operations = operationStructNames.values.joinToString(",")
        val matchArms: Writable = operationStructNames.map {
                (shape, name) ->
            writable {
                val absolute = shape.id.toString().replace("#", "##")
                rustTemplate(
                    """
                    Operation::$name => #{SmithyHttpServer}::shape_id::ShapeId::new("$absolute", "${shape.id.namespace}", "${shape.id.name}")
                    """,
                    *codegenScope,
                )
            }
        }.join(",")
        rustTemplate(
            """
            /// An enumeration of all [operations](https://smithy.io/2.0/spec/service-types.html##operation) in $serviceName.
            ##[derive(Debug, PartialEq, Eq, Clone, Copy)]
            pub enum Operation {
                $operations
            }

            impl Operation {
                /// Returns the [operations](https://smithy.io/2.0/spec/service-types.html##operation) [`ShapeId`](#{SmithyHttpServer}::shape_id::ShapeId).
                pub fn shape_id(&self) -> #{SmithyHttpServer}::shape_id::ShapeId {
                    match self {
                        #{Arms}
                    }
                }
            }
            """,
            *codegenScope,
            "Arms" to matchArms,
        )

        for ((_, value) in operationStructNames) {
            rustTemplate(
                """
                impl #{SmithyHttpServer}::service::ContainsOperation<crate::operation_shape::$value> for $serviceName {
                    const VALUE: Operation = Operation::$value;
                }
                """,
                *codegenScope,
            )
        }
    }

    fun render(writer: RustWriter) {
        writer.rustTemplate(
            """
@@ -600,11 +669,17 @@ class ServerServiceGenerator(
            #{RequestSpecs:W}

            #{Struct:W}

            #{Operations}

            #{ServiceImpl}
            """,
            "Builder" to builder(),
            "MissingOperationsError" to missingOperationsError(),
            "RequestSpecs" to requestSpecsModule(),
            "Struct" to serviceStruct(),
            "Operations" to operationEnum(),
            "ServiceImpl" to serviceShapeImpl(),
            *codegenScope,
        )
    }
+17 −17
Original line number Diff line number Diff line
@@ -425,20 +425,20 @@ state in <<fork>>
<!-- TODO(missing_doc): Link to "Write a Plugin" documentation -->

A [`Plugin`](https://docs.rs/aws-smithy-http-server/latest/aws_smithy_http_server/plugin/trait.Plugin.html) is a
[`tower::Layer`] with two extra type parameters, `Protocol` and `Operation`. This allows the middleware to be
[`tower::Layer`] with two extra type parameters, `Service` and `Operation`, corresponding to [Smithy Service](https://awslabs.github.io/smithy/2.0/spec/service-types.html#service) and [Smithy Operation](https://awslabs.github.io/smithy/2.0/spec/service-types.html#operation). This allows the middleware to be
parameterized them and change behavior depending on the context in which it's applied.

```rust
# extern crate aws_smithy_http_server;
pub trait Plugin<Protocol, Operation, S> {
    type Service;
pub trait Plugin<Service, Operation, T> {
    type Output;

    fn apply(&self, svc: S) -> Self::Service;
    fn apply(&self, input: T) -> Self::Output;
}
# use aws_smithy_http_server::plugin::Plugin as Pl;
# impl<P, Op, S, T: Pl<P, Op, S>> Plugin<P, Op, S> for T {
#   type Service = <T as Pl<P, Op, S>>::Service;
#   fn apply(&self, svc: S) -> Self::Service { <T as Pl<P, Op, S>>::apply(self, svc) }
# impl<Ser, Op, T, U: Pl<Ser, Op, T>> Plugin<Ser, Op, T> for U {
#   type Output = <U as Pl<Ser, Op, T>>::Output;
#   fn apply(&self, input: T) -> Self::Output { <U as Pl<Ser, Op, T>>::apply(self, input) }
# }
```

@@ -538,19 +538,19 @@ The builder has two setter methods for each [Smithy Operation](https://awslabs.g
        HandlerType:Handler<GetPokemonSpecies, HandlerExtractors>,

        ModelPlugin: Plugin<
            RestJson1,
            PokemonService,
            GetPokemonSpecies,
            IntoService<GetPokemonSpecies, HandlerType>
        >,
        UpgradePlugin::<UpgradeExtractors>: Plugin<
            RestJson1,
            PokemonService,
            GetPokemonSpecies,
            ModelPlugin::Service
            ModelPlugin::Output
        >,
        HttpPlugin: Plugin<
            RestJson1,
            PokemonService,
            GetPokemonSpecies,
            UpgradePlugin::<UpgradeExtractors>::Service
            UpgradePlugin::<UpgradeExtractors>::Output
        >,
    {
        let svc = GetPokemonSpecies::from_handler(handler);
@@ -566,19 +566,19 @@ The builder has two setter methods for each [Smithy Operation](https://awslabs.g
        S: OperationService<GetPokemonSpecies, ServiceExtractors>,

        ModelPlugin: Plugin<
            RestJson1,
            PokemonService,
            GetPokemonSpecies,
            Normalize<GetPokemonSpecies, S>
        >,
        UpgradePlugin::<UpgradeExtractors>: Plugin<
            RestJson1,
            PokemonService,
            GetPokemonSpecies,
            ModelPlugin::Service
            ModelPlugin::Output
        >,
        HttpPlugin: Plugin<
            RestJson1,
            PokemonService,
            GetPokemonSpecies,
            UpgradePlugin::<UpgradeExtractors>::Service
            UpgradePlugin::<UpgradeExtractors>::Output
        >,
    {
        let svc = GetPokemonSpecies::from_service(service);
+18 −45
Original line number Diff line number Diff line
@@ -284,14 +284,17 @@ Suppose we want to apply a different `Layer` to every operation. In this case, p
Consider the following middleware:

```rust
# extern crate aws_smithy_http_server;
# extern crate tower;
use aws_smithy_http_server::shape_id::ShapeId;
use std::task::{Context, Poll};
use tower::Service;

/// A [`Service`] that adds a print log.
pub struct PrintService<S> {
    inner: S,
    name: &'static str,
    operation_id: ShapeId,
    service_id: ShapeId
}

impl<R, S> Service<R> for PrintService<S>
@@ -307,7 +310,7 @@ where
    }

    fn call(&mut self, req: R) -> Self::Future {
        println!("Hi {}", self.name);
        println!("Hi {} in {}", self.operation_id.name(), self.service_id.name());
        self.inner.call(req)
    }
}
@@ -319,57 +322,27 @@ An example of a `PrintPlugin` which prints the operation name:

```rust
# extern crate aws_smithy_http_server;
# pub struct PrintService<S> { inner: S, name: &'static str }
use aws_smithy_http_server::{plugin::Plugin, operation::OperationShape};
# use aws_smithy_http_server::shape_id::ShapeId;
# pub struct PrintService<S> { inner: S, operation_id: ShapeId, service_id: ShapeId }
use aws_smithy_http_server::{plugin::Plugin, operation::OperationShape, service::ServiceShape};

/// A [`Plugin`] for a service builder to add a [`PrintService`] over operations.
#[derive(Debug)]
pub struct PrintPlugin;

impl<P, Op, S> Plugin<P, Op, S> for PrintPlugin
impl<Ser, Op, T> Plugin<Ser, Op, T> for PrintPlugin
where
    Ser: ServiceShape,
    Op: OperationShape,
{
    type Service = PrintService<S>;

    fn apply(&self, inner: S) -> Self::Service {
        PrintService { name: Op::ID.name(), inner }
    }
}
```

An alternative example which prints the protocol name:
    type Output = PrintService<T>;

```rust
# extern crate aws_smithy_http_server;
# pub struct PrintService<S> { name: &'static str, inner: S}
use aws_smithy_http_server::{
    plugin::Plugin,
    proto::{
        aws_json_10::AwsJson1_0,
        rest_xml::RestXml,
    }
};

/// A [`Plugin`] for a service builder to add a [`PrintService`] over operations.
#[derive(Debug)]
pub struct PrintPlugin;

impl<Op, S> Plugin<AwsJson1_0, Op, S> for PrintPlugin
{
    type Service = PrintService<S>;

    fn apply(&self, inner: S) -> Self::Service {
        PrintService { name: "AWS JSON 1.0", inner }
    }
    fn apply(&self, inner: T) -> Self::Output {
        PrintService {
            inner,
            operation_id: Op::ID,
            service_id: Ser::ID,
        }

impl<Op, S> Plugin<RestXml, Op, S> for PrintPlugin
{
    type Service = PrintService<S>;

    fn apply(&self, inner: S) -> Self::Service {
        PrintService { name: "AWS REST XML", inner }
    }
}
```
@@ -403,7 +376,7 @@ This allows for:
# extern crate aws_smithy_http_server;
# use aws_smithy_http_server::plugin::{PluginStack, Plugin};
# struct PrintPlugin;
# impl<P, Op, S> Plugin<P, Op, S> for PrintPlugin { type Service = S; fn apply(&self, svc: S) -> Self::Service { svc }}
# impl<Ser, Op, T> Plugin<Ser, Op, T> for PrintPlugin { type Output = T; fn apply(&self, svc: T) -> Self::Output { svc }}
# trait PrintExt<EP> { fn print(self) -> PluginPipeline<PluginStack<PrintPlugin, EP>>; }
# impl<EP> PrintExt<EP> for PluginPipeline<EP> { fn print(self) -> PluginPipeline<PluginStack<PrintPlugin, EP>> { self.push(PrintPlugin) }}
# use pokemon_service_server_sdk::{operation_shape::GetPokemonSpecies, input::*, output::*, error::*};
Loading