Unverified Commit a1ec8b89 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Allow injecting methods with generic type parameters in the config object (#3274)

This is a follow-up to #3111. Currently, the injected methods are
limited to taking in concrete types. This PR allows for these methods to
take in generic type parameters as well.

```rust
impl<L, H, M> SimpleServiceConfigBuilder<L, H, M> {
    pub fn aws_auth<C>(config: C) {
        ...
    }
}
```

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 0d6cf722
Loading
Loading
Loading
Loading
+91 −46
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
@@ -13,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
@@ -33,7 +35,7 @@ data class ConfigMethod(
    val docs: String,
    /** The parameters of the method. **/
    val params: List<Binding>,
    /** In case the method is fallible, the error type it returns. **/
    /** In case the method is fallible, the concrete error type it returns. **/
    val errorType: RuntimeType?,
    /** The code block inside the method. **/
    val initializer: Initializer,
@@ -104,15 +106,45 @@ data class Initializer(
 * }
 *
 * has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a
 * `u64` variable.
 * `u64` variable. Both are bindings that use concrete types. Types can also be generic:
 *
 * ```rust
 * fn foo<T>(bar: T) { }
 * ```
 */
data class Binding(
sealed class Binding {
    data class Generic(
        /** The name of the variable. The name of the type parameter will be the PascalCased variable name. */
        val name: String,
        /** The type of the variable. */
        val ty: RuntimeType,
        /**
         * The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec<T>` with `T` being a
         * generic type parameter, then `genericTys` should be a singleton set containing `"T"`.
         * You can't use `L`, `H`, or `M` as the names to refer to any generic types.
         * */
        val genericTys: Set<String>,
    ) : Binding()

    data class Concrete(
        /** The name of the variable. */
        val name: String,
        /** The type of the variable. */
        val ty: RuntimeType,
)
    ) : Binding()

    fun name() =
        when (this) {
            is Concrete -> this.name
            is Generic -> this.name
        }

    fun ty() =
        when (this) {
            is Concrete -> this.ty
            is Generic -> this.ty
        }
}

class ServiceConfigGenerator(
    codegenContext: ServerCodegenContext,
@@ -327,8 +359,20 @@ class ServiceConfigGenerator(
            writable {
                val paramBindings =
                    it.params.map { binding ->
                        writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) }
                        writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) }
                    }.join("\n")
                val genericBindings = it.params.filterIsInstance<Binding.Generic>()
                val lhmBindings =
                    genericBindings.filter {
                        it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M")
                    }
                if (lhmBindings.isNotEmpty()) {
                    throw CodegenException(
                        "Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings",
                    )
                }
                val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet()
                val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray())

                // This produces a nested type like: "S<B, S<A, T>>", where
                // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack
@@ -345,7 +389,7 @@ class ServiceConfigGenerator(
                            rustTemplate(
                                "#{StackType}<#{Ty}, #{Acc:W}>",
                                "StackType" to stackType,
                                "Ty" to next.ty,
                                "Ty" to next.ty(),
                                "Acc" to acc,
                            )
                        }
@@ -391,7 +435,7 @@ class ServiceConfigGenerator(
                docs(it.docs)
                rustBlockTemplate(
                    """
                    pub fn ${it.name}(
                pub fn ${it.name}#{ParamBindingsGenericsWritable}(
                    ##[allow(unused_mut)]
                    mut self,
                    #{ParamBindings:W}
@@ -399,6 +443,7 @@ class ServiceConfigGenerator(
                """,
                    "ReturnTy" to returnTy,
                    "ParamBindings" to paramBindings,
                    "ParamBindingsGenericsWritable" to paramBindingsGenericsWritable,
                ) {
                    rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code)

@@ -412,9 +457,9 @@ class ServiceConfigGenerator(
                    conditionalBlock("Ok(", ")", conditional = it.errorType != null) {
                        val registrations =
                            (
                                it.initializer.layerBindings.map { ".layer(${it.name})" } +
                                    it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } +
                                    it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" }
                                it.initializer.layerBindings.map { ".layer(${it.name()})" } +
                                    it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } +
                                    it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" }
                            ).joinToString("")
                        rust("self$registrations")
                    }
+75 −25
Original line number Diff line number Diff line
@@ -5,7 +5,10 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
@@ -43,8 +46,9 @@ internal class ServiceConfigGeneratorTest {
                            docs = "Docs",
                            params =
                                listOf(
                                    Binding("auth_spec", RuntimeType.String),
                                    Binding("authorizer", RuntimeType.U64),
                                    Binding.Concrete("auth_spec", RuntimeType.String),
                                    Binding.Concrete("authorizer", RuntimeType.U64),
                                    Binding.Generic("generic_list", RuntimeType("::std::vec::Vec<T>"), setOf("T")),
                                ),
                            errorType = RuntimeType.std.resolve("io::Error"),
                            initializer =
@@ -57,7 +61,7 @@ internal class ServiceConfigGeneratorTest {
                                        return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1"));
                                    }
                                    
                                                if auth_spec.len() != 69 {
                                    if auth_spec.len() != 69 && generic_list.len() != 69 {
                                        return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2"));
                                    }
                                    let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin;
@@ -69,14 +73,14 @@ internal class ServiceConfigGeneratorTest {
                                    layerBindings = emptyList(),
                                    httpPluginBindings =
                                        listOf(
                                            Binding(
                                            Binding.Concrete(
                                                "authn_plugin",
                                                smithyHttpServer.resolve("plugin::IdentityPlugin"),
                                            ),
                                        ),
                                    modelPluginBindings =
                                        listOf(
                                            Binding(
                                            Binding.Concrete(
                                                "authz_plugin",
                                                smithyHttpServer.resolve("plugin::IdentityPlugin"),
                                            ),
@@ -108,7 +112,7 @@ internal class ServiceConfigGeneratorTest {
                            // One model plugin has been applied.
                            PluginStack<IdentityPlugin, IdentityPlugin>,
                        > = SimpleServiceConfig::builder()
                            .aws_auth("a".repeat(69).to_owned(), 69)
                            .aws_auth("a".repeat(69).to_owned(), 69, vec![69])
                            .expect("failed to configure aws_auth")
                            .build()
                            .unwrap();
@@ -120,7 +124,7 @@ internal class ServiceConfigGeneratorTest {
                    rust(
                        """
                        let actual_err = SimpleServiceConfig::builder()
                            .aws_auth("a".to_owned(), 69)
                            .aws_auth("a".to_owned(), 69, vec![69])
                            .unwrap_err();
                        let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string();
                        assert_eq!(actual_err.to_string(), expected);
@@ -132,7 +136,7 @@ internal class ServiceConfigGeneratorTest {
                    rust(
                        """
                        let actual_err = SimpleServiceConfig::builder()
                            .aws_auth("a".repeat(69).to_owned(), 6969)
                            .aws_auth("a".repeat(69).to_owned(), 6969, vec!["69"])
                            .unwrap_err();
                        let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string();
                        assert_eq!(actual_err.to_string(), expected);
@@ -154,7 +158,7 @@ internal class ServiceConfigGeneratorTest {
    }

    @Test
    fun `it should inject an method that applies three non-required layers`() {
    fun `it should inject a method that applies three non-required layers`() {
        val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel()

        val decorator =
@@ -191,9 +195,9 @@ internal class ServiceConfigGeneratorTest {
                                        },
                                    layerBindings =
                                        listOf(
                                            Binding("layer1", identityLayer),
                                            Binding("layer2", identityLayer),
                                            Binding("layer3", identityLayer),
                                            Binding.Concrete("layer1", identityLayer),
                                            Binding.Concrete("layer2", identityLayer),
                                            Binding.Concrete("layer3", identityLayer),
                                        ),
                                    httpPluginBindings = emptyList(),
                                    modelPluginBindings = emptyList(),
@@ -240,4 +244,50 @@ internal class ServiceConfigGeneratorTest {
            }
        }
    }

    @Test
    fun `it should throw an exception if a generic binding using L, H, or M is used`() {
        val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel()

        val decorator =
            object : ServerCodegenDecorator {
                override val name: String
                    get() = "InvalidGenericBindingsDecorator"
                override val order: Byte
                    get() = 69

                override fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> {
                    val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity")
                    return listOf(
                        ConfigMethod(
                            name = "invalid_generic_bindings",
                            docs = "Docs",
                            params =
                                listOf(
                                    Binding.Generic("param1_bad", identityLayer, setOf("L")),
                                    Binding.Generic("param2_bad", identityLayer, setOf("H")),
                                    Binding.Generic("param3_bad", identityLayer, setOf("M")),
                                    Binding.Generic("param4_ok", identityLayer, setOf("N")),
                                ),
                            errorType = null,
                            initializer =
                                Initializer(
                                    code = writable {},
                                    layerBindings = emptyList(),
                                    httpPluginBindings = emptyList(),
                                    modelPluginBindings = emptyList(),
                                ),
                            isRequired = false,
                        ),
                    )
                }
            }

        val codegenException =
            shouldThrow<CodegenException> {
                serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, _ -> }
            }

        codegenException.message.shouldContain("Injected config method `invalid_generic_bindings` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings:")
    }
}