Unverified Commit 988eb617 authored by Harry Barber's avatar Harry Barber Committed by GitHub
Browse files

Add `Scoped` `Plugin` (#2759)

## Motivation and Context

The
[`FilterByOperationName`](https://docs.rs/aws-smithy-http-server/0.55.4/aws_smithy_http_server/plugin/struct.FilterByOperationName.html

)
allows the customer to filter application of a plugin. However this is a
_runtime_ filter. A faster and type safe alternative would be a nice
option.

## Description

Add `Scoped` `Plugin` and `scope` macro.

---------

Co-authored-by: default avatardavid-perez <d@vidp.dev>
parent 312d1905
Loading
Loading
Loading
Loading
+36 −3
Original line number Diff line number Diff line
@@ -183,6 +183,12 @@ author = "ysaito1001"

[[smithy-rs]]
message = """
The middleware system has been reworked as we push for a unified, simple, and consistent API. The following changes have been made in service of this goal:

- The `Plugin` trait has been simplified.
- The `Operation` structure has been removed.
- A `Scoped` `Plugin` has been added.

The `Plugin` trait has now been simplified and the `Operation` struct has been removed.

## Simplication of the `Plugin` trait
@@ -317,7 +323,7 @@ let plugin = /* some plugin */;
let layer = LayerPlugin::new::<SomeProtocol, SomeOperation>(plugin);
```

## Remove `Operation`
## Removal of `Operation`

The `aws_smithy_http_server::operation::Operation` structure has now been removed. Previously, there existed a `{operation_name}_operation` setter on the service builder, which accepted an `Operation`. This allowed users to

@@ -356,7 +362,7 @@ let app = PokemonService::builder_without_plugins()
    .unwrap();
```

Applying a `tower::Layer` to a _single_ operation is now done through the `Plugin` API:
Applying a `tower::Layer` to a _subset_ of operations is should now be done through the `Plugin` API via `filter_by_operation_id`

```rust
use aws_smithy_http_server::plugin::{PluginLayer, filter_by_operation_name, IdentityPlugin};
@@ -371,7 +377,34 @@ let app = PokemonService::builder_with_plugins(scoped_plugin, IdentityPlugin)
    .unwrap();
```

or the new `Scoped` `Plugin` introduced below.

# Addition of `Scoped`

Currently, users can selectively apply a `Plugin` via the `filter_by_operation_id` function

```rust
use aws_smithy_http_server::plugin::filter_by_operation_id;
// Only apply `plugin` to `CheckHealth` and `GetStorage` operation
let filtered_plugin = filter_by_operation_id(plugin, |name| name == CheckHealth::ID || name == GetStorage::ID);
```

In addition to this, we now provide `Scoped`, which selectively applies a `Plugin` at _compiletime_. Users should prefer this to `filter_by_operation_id` when applicable.

```rust
use aws_smithy_http_server::plugin::Scoped;
use pokemon_service_server_sdk::scoped;

scope! {
    /// Includes only the `CheckHealth` and `GetStorage` operation.
    struct SomeScope {
        includes: [CheckHealth, GetStorage]
    }
}
let scoped_plugin = Scoped::new::<SomeScope>(plugin);
```

"""
references = ["smithy-rs#2740"]
references = ["smithy-rs#2740", "smithy-rs#2759"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "hlbarber"
+3 −0
Original line number Diff line number Diff line
@@ -63,6 +63,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedT
import software.amazon.smithy.rust.codegen.server.smithy.generators.MapConstraintViolationGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.PubCrateConstrainedCollectionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.PubCrateConstrainedMapGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ScopeMacroGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator
@@ -610,6 +611,8 @@ open class ServerCodegenVisitor(
                codegenContext,
                serverProtocol,
            ).render(this)

            ScopeMacroGenerator(codegenContext).render(this)
        }
    }

+182 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext

class ScopeMacroGenerator(
    private val codegenContext: ServerCodegenContext,
) {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenScope =
        arrayOf(
            "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(),
        )

    /** Calculate all `operationShape`s contained within the `ServiceShape`. */
    private val index = TopDownIndex.of(codegenContext.model)
    private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(compareBy { it.id })

    private fun macro(): Writable = writable {
        val firstOperationName = codegenContext.symbolProvider.toSymbol(operations.first()).name.toPascalCase()
        val operationNames = operations.joinToString(" ") {
            codegenContext.symbolProvider.toSymbol(it).name.toPascalCase()
        }

        // When writing `macro_rules!` we add whitespace between `$` and the arguments to avoid Kotlin templating.

        // To acheive the desired API we need to calculate the set theoretic complement `B \ A`.
        // The macro below, for rules prefixed with `@`, encodes a state machine which performs this.
        // The initial state is `(A) () (B)`, where `A` and `B` are lists of elements of `A` and `B`.
        // The rules, in order:
        // - Terminate on pattern `() (t0, t1, ...) (b0, b1, ...)`, the complement has been calculated as
        // `{ t0, t1, ..., b0, b1, ...}`.
        // - Send pattern `(x, a0, a1, ...) (t0, t1, ...) (x, b0, b1, ...)` to
        // `(a0, a1, ...) (t0, t1, ...) (b0, b1, ...)`, eliminating a matching `x` from `A` and `B`.
        // - Send pattern `(a0, a1, ...) (t0, t1, ...) ()` to `(a0, a1, ...) () (t0, t1, ...)`, restarting the search.
        // - Send pattern `(a0, a1, ...) (t0, t1, ...) (b0, b1, ...)` to `(a0, a1, ...) (b0, t0, t1, ...) (b1, ...)`,
        // iterating through the `B`.
        val operationBranches = operations
            .map { codegenContext.symbolProvider.toSymbol(it).name.toPascalCase() }.joinToString("") {
                """
                // $it match found, pop from both `member` and `not_member`
                (@ $ name: ident, $ contains: ident ($it $($ member: ident)*) ($($ temp: ident)*) ($it $($ not_member: ident)*)) => {
                    scope! { @ $ name, $ contains ($($ member)*) ($($ temp)*) ($($ not_member)*) }
                };
                // $it match not found, pop from `not_member` into `temp` stack
                (@ $ name: ident, $ contains: ident ($it $($ member: ident)*) ($($ temp: ident)*) ($ other: ident $($ not_member: ident)*)) => {
                    scope! { @ $ name, $ contains ($it $($ member)*) ($ other $($ temp)*) ($($ not_member)*) }
                };
                """
            }
        val crateName = codegenContext.moduleName.toSnakeCase()

        // If we have a second operation we can perform further checks
        val otherOperationName: String? = operations.toList().getOrNull(1)?.let {
            codegenContext.symbolProvider.toSymbol(it).name
        }
        val furtherTests = if (otherOperationName != null) {
            writable {
                rustTemplate(
                    """
                    /// ## let a = Plugin::<(), $otherOperationName, u64>::apply(&scoped_a, 6);
                    /// ## let b = Plugin::<(), $otherOperationName, u64>::apply(&scoped_b, 6);
                    /// ## assert_eq!(a, 6_u64);
                    /// ## assert_eq!(b, 3_u32);
                    """,
                )
            }
        } else {
            writable {}
        }

        rustTemplate(
            """
            /// A macro to help with scoping [plugins](#{SmithyHttpServer}::plugin) to a subset of all operations.
            ///
            /// In contrast to [`aws_smithy_http_server::scope`](#{SmithyHttpServer}::scope), this macro has knowledge
            /// of the service and any operations _not_ specified will be placed in the opposing group.
            ///
            /// ## Example
            ///
            /// ```rust
            /// scope! {
            ///     /// Includes [`$firstOperationName`], excluding all other operations.
            ///     struct ScopeA {
            ///         includes: [$firstOperationName]
            ///     }
            /// }
            ///
            /// scope! {
            ///     /// Excludes [`$firstOperationName`], excluding all other operations.
            ///     struct ScopeB {
            ///         excludes: [$firstOperationName]
            ///     }
            /// }
            ///
            /// ## use #{SmithyHttpServer}::plugin::{Plugin, Scoped};
            /// ## use $crateName::scope;
            /// ## struct MockPlugin;
            /// ## impl<P, Op, S> Plugin<P, Op, S> for MockPlugin { type Service = u32; fn apply(&self, svc: S) -> u32 { 3 } }
            /// ## let scoped_a = Scoped::new::<ScopeA>(MockPlugin);
            /// ## let scoped_b = Scoped::new::<ScopeB>(MockPlugin);
            /// ## let a = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_a, 6);
            /// ## let b = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_b, 6);
            /// ## assert_eq!(a, 3_u32);
            /// ## assert_eq!(b, 6_u64);
            /// ```
            ##[macro_export]
            macro_rules! scope {
                // Completed, render impls
                (@ $ name: ident, $ contains: ident () ($($ temp: ident)*) ($($ not_member: ident)*)) => {
                    $(
                        impl #{SmithyHttpServer}::plugin::scoped::Membership<$ temp> for $ name {
                            type Contains = #{SmithyHttpServer}::plugin::scoped::$ contains;
                        }
                    )*
                    $(
                        impl #{SmithyHttpServer}::plugin::scoped::Membership<$ not_member> for $ name {
                            type Contains = #{SmithyHttpServer}::plugin::scoped::$ contains;
                        }
                    )*
                };
                // All `not_member`s exhausted, move `temp` into `not_member`
                (@ $ name: ident, $ contains: ident ($($ member: ident)*) ($($ temp: ident)*) ()) => {
                    scope! { @ $ name, $ contains ($($ member)*) () ($($ temp)*) }
                };
                $operationBranches
                (
                    $(##[$ attrs:meta])*
                    $ vis:vis struct $ name:ident {
                        includes: [$($ include:ident),*]
                    }
                ) => {
                    use $ crate::operation_shape::*;
                    #{SmithyHttpServer}::scope! {
                        $(##[$ attrs])*
                        $ vis struct $ name {
                            includes: [$($ include),*],
                            excludes: []
                        }
                    }
                    scope! { @ $ name, False ($($ include)*) () ($operationNames) }
                };
                (
                    $(##[$ attrs:meta])*
                    $ vis:vis struct $ name:ident {
                        excludes: [$($ exclude:ident),*]
                    }
                ) => {
                    use $ crate::operation_shape::*;

                    #{SmithyHttpServer}::scope! {
                        $(##[$ attrs])*
                        $ vis struct $ name {
                            includes: [],
                            excludes: [$($ exclude),*]
                        }
                    }
                    scope! { @ $ name, True ($($ exclude)*) () ($operationNames) }
                };
            }
            """,
            *codegenScope,
            "FurtherTests" to furtherTests,
        )
    }

    fun render(writer: RustWriter) {
        macro()(writer)
    }
}
+20 −4
Original line number Diff line number Diff line
@@ -208,15 +208,24 @@ A "HTTP layer" can be applied to specific operations.
# extern crate aws_smithy_http_server;
# use tower::{util::service_fn, Layer};
# use std::time::Duration;
# use pokemon_service_server_sdk::{operation_shape::GetPokemonSpecies, PokemonService, input::*, output::*, error::*};
# use pokemon_service_server_sdk::{operation_shape::GetPokemonSpecies, input::*, output::*, error::*};
# use aws_smithy_http_server::{operation::OperationShapeExt, plugin::*, operation::*};
# let handler = |req: GetPokemonSpeciesInput| async { Result::<GetPokemonSpeciesOutput, GetPokemonSpeciesError>::Ok(todo!()) };
# struct LoggingLayer;
# impl LoggingLayer { pub fn new() -> Self { Self } }
# impl<S> Layer<S> for LoggingLayer { type Service = S; fn layer(&self, svc: S) -> Self::Service { svc } }
use pokemon_service_server_sdk::{PokemonService, scope};

scope! {
    /// Only log on `GetPokemonSpecies` and `GetStorage`
    struct LoggingScope {
        includes: [GetPokemonSpecies, GetStorage]
    }
}

// Construct `LoggingLayer`.
let logging_plugin = LayerPlugin(LoggingLayer::new());
let logging_plugin = filter_by_operation_id(logging_plugin, |name| name == GetPokemonSpecies::ID);
let logging_plugin = Scoped::new::<LoggingScope>(logging_plugin);
let http_plugins = PluginPipeline::new().push(logging_plugin);

let app /* : PokemonService<Route<B>> */ = PokemonService::builder_with_plugins(http_plugins, IdentityPlugin)
@@ -244,11 +253,18 @@ A "model layer" can be applied to specific operations.
# struct BufferLayer;
# impl BufferLayer { pub fn new(size: usize) -> Self { Self } }
# impl<S> Layer<S> for BufferLayer { type Service = S; fn layer(&self, svc: S) -> Self::Service { svc } }
use pokemon_service_server_sdk::PokemonService;
use pokemon_service_server_sdk::{PokemonService, scope};

scope! {
    /// Only buffer on `GetPokemonSpecies` and `GetStorage`
    struct BufferScope {
        includes: [GetPokemonSpecies, GetStorage]
    }
}

// Construct `BufferLayer`.
let buffer_plugin = LayerPlugin(BufferLayer::new(3));
let buffer_plugin = filter_by_operation_id(buffer_plugin, |name| name != GetPokemonSpecies::ID);
let buffer_plugin = Scoped::new::<BufferScope>(buffer_plugin);
let model_plugins = PluginPipeline::new().push(buffer_plugin);

let app /* : PokemonService<Route<B>> */ = PokemonService::builder_with_plugins(IdentityPlugin, model_plugins)
+13 −4
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ use std::{net::SocketAddr, sync::Arc};
use aws_smithy_http_server::{
    extension::OperationExtensionExt,
    instrumentation::InstrumentExt,
    plugin::{alb_health_check::AlbHealthCheckLayer, IdentityPlugin, PluginPipeline},
    plugin::{alb_health_check::AlbHealthCheckLayer, IdentityPlugin, PluginPipeline, Scoped},
    request::request_id::ServerRequestIdProviderLayer,
    AddExtensionLayer,
};
@@ -26,7 +26,7 @@ use pokemon_service_common::{
    capture_pokemon, check_health, get_pokemon_species, get_server_statistics, setup_tracing,
    stream_pokemon_radio, State,
};
use pokemon_service_server_sdk::PokemonService;
use pokemon_service_server_sdk::{scope, PokemonService};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@@ -44,9 +44,18 @@ pub async fn main() {
    let args = Args::parse();
    setup_tracing();

    scope! {
        /// A scope containing `GetPokemonSpecies` and `GetStorage`
        struct PrintScope {
            includes: [GetPokemonSpecies, GetStorage]
        }
    }
    // Scope the `PrintPlugin`, defined in `plugin.rs`, to `PrintScope`
    let print_plugin = Scoped::new::<PrintScope>(PluginPipeline::new().print());

    let plugins = PluginPipeline::new()
        // Apply the `PrintPlugin` defined in `plugin.rs`
        .print()
        // Apply the scoped `PrintPlugin`
        .push(print_plugin)
        // Apply the `OperationExtensionPlugin` defined in `aws_smithy_http_server::extension`. This allows other
        // plugins or tests to access a `aws_smithy_http_server::extension::OperationExtension` from
        // `Response::extensions`, or infer routing failure when it's missing.
Loading