diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index b1ca880ebc1ba798fbb79dc3ffdc6bcf9173f669..56a662348fdd5e3c43a9d5890cd7cf8f8add4728 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock @@ -13,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope @@ -33,7 +35,7 @@ data class ConfigMethod( val docs: String, /** The parameters of the method. **/ val params: List, - /** In case the method is fallible, the error type it returns. **/ + /** In case the method is fallible, the concrete error type it returns. **/ val errorType: RuntimeType?, /** The code block inside the method. **/ val initializer: Initializer, @@ -104,15 +106,45 @@ data class Initializer( * } * * has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a - * `u64` variable. + * `u64` variable. Both are bindings that use concrete types. Types can also be generic: + * + * ```rust + * fn foo(bar: T) { } * ``` */ -data class Binding( - /** The name of the variable. */ - val name: String, - /** The type of the variable. */ - val ty: RuntimeType, -) +sealed class Binding { + data class Generic( + /** The name of the variable. The name of the type parameter will be the PascalCased variable name. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, + /** + * The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec` with `T` being a + * generic type parameter, then `genericTys` should be a singleton set containing `"T"`. + * You can't use `L`, `H`, or `M` as the names to refer to any generic types. + * */ + val genericTys: Set, + ) : Binding() + + data class Concrete( + /** The name of the variable. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, + ) : Binding() + + fun name() = + when (this) { + is Concrete -> this.name + is Generic -> this.name + } + + fun ty() = + when (this) { + is Concrete -> this.ty + is Generic -> this.ty + } +} class ServiceConfigGenerator( codegenContext: ServerCodegenContext, @@ -271,10 +303,10 @@ class ServiceConfigGenerator( writable { rustTemplate( """ - if !self.${it.requiredBuilderFlagName()} { - return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()}); - } - """, + if !self.${it.requiredBuilderFlagName()} { + return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()}); + } + """, *codegenScope, ) } @@ -303,19 +335,19 @@ class ServiceConfigGenerator( writable { rust( """ - ##[error("service is not fully configured; invoke `${it.name}` on the config builder")] - ${it.requiredErrorVariant()}, - """, + ##[error("service is not fully configured; invoke `${it.name}` on the config builder")] + ${it.requiredErrorVariant()}, + """, ) } } rustTemplate( """ - ##[derive(Debug, #{ThisError}::Error)] - pub enum ${serviceName}ConfigError { - #{Variants:W} - } - """, + ##[derive(Debug, #{ThisError}::Error)] + pub enum ${serviceName}ConfigError { + #{Variants:W} + } + """, "ThisError" to ServerCargoDependency.ThisError.toType(), "Variants" to variants.join("\n"), ) @@ -327,8 +359,20 @@ class ServiceConfigGenerator( writable { val paramBindings = it.params.map { binding -> - writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } + writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) } }.join("\n") + val genericBindings = it.params.filterIsInstance() + val lhmBindings = + genericBindings.filter { + it.genericTys.contains("L") || it.genericTys.contains("H") || it.genericTys.contains("M") + } + if (lhmBindings.isNotEmpty()) { + throw CodegenException( + "Injected config method `${it.name}` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings: $lhmBindings", + ) + } + val paramBindingsGenericTys = genericBindings.flatMap { it.genericTys }.toSet() + val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray()) // This produces a nested type like: "S>", where // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack @@ -345,7 +389,7 @@ class ServiceConfigGenerator( rustTemplate( "#{StackType}<#{Ty}, #{Acc:W}>", "StackType" to stackType, - "Ty" to next.ty, + "Ty" to next.ty(), "Acc" to acc, ) } @@ -362,12 +406,12 @@ class ServiceConfigGenerator( writable { rustTemplate( """ - ${serviceName}ConfigBuilder< - #{LayersReturnTy:W}, - #{HttpPluginsReturnTy:W}, - #{ModelPluginsReturnTy:W}, - > - """, + ${serviceName}ConfigBuilder< + #{LayersReturnTy:W}, + #{HttpPluginsReturnTy:W}, + #{ModelPluginsReturnTy:W}, + > + """, "LayersReturnTy" to layersReturnTy, "HttpPluginsReturnTy" to httpPluginsReturnTy, "ModelPluginsReturnTy" to modelPluginsReturnTy, @@ -391,14 +435,15 @@ class ServiceConfigGenerator( docs(it.docs) rustBlockTemplate( """ - pub fn ${it.name}( - ##[allow(unused_mut)] - mut self, - #{ParamBindings:W} - ) -> #{ReturnTy:W} - """, + pub fn ${it.name}#{ParamBindingsGenericsWritable}( + ##[allow(unused_mut)] + mut self, + #{ParamBindings:W} + ) -> #{ReturnTy:W} + """, "ReturnTy" to returnTy, "ParamBindings" to paramBindings, + "ParamBindingsGenericsWritable" to paramBindingsGenericsWritable, ) { rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) @@ -412,9 +457,9 @@ class ServiceConfigGenerator( conditionalBlock("Ok(", ")", conditional = it.errorType != null) { val registrations = ( - it.initializer.layerBindings.map { ".layer(${it.name})" } + - it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + - it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" } + it.initializer.layerBindings.map { ".layer(${it.name()})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" } ).joinToString("") rust("self$registrations") } @@ -437,9 +482,9 @@ class ServiceConfigGenerator( writable { rustBlockTemplate( """ - /// Build the configuration. - pub fn build(self) -> #{BuilderBuildReturnTy:W} - """, + /// Build the configuration. + pub fn build(self) -> #{BuilderBuildReturnTy:W} + """, "BuilderBuildReturnTy" to builderBuildReturnType(), ) { rustTemplate( @@ -450,12 +495,12 @@ class ServiceConfigGenerator( conditionalBlock("Ok(", ")", isBuilderFallible) { rust( """ - super::${serviceName}Config { - layers: self.layers, - http_plugins: self.http_plugins, - model_plugins: self.model_plugins, - } - """, + super::${serviceName}Config { + layers: self.layers, + http_plugins: self.http_plugins, + model_plugins: self.model_plugins, + } + """, ) } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt index 63f1f7385d231c5544e7e2cb9d2a78c8ef63815c..a744e53791b2db0e995f6ca75624a367047a6890 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -5,7 +5,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.string.shouldContain import org.junit.jupiter.api.Test +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -43,8 +46,9 @@ internal class ServiceConfigGeneratorTest { docs = "Docs", params = listOf( - Binding("auth_spec", RuntimeType.String), - Binding("authorizer", RuntimeType.U64), + Binding.Concrete("auth_spec", RuntimeType.String), + Binding.Concrete("authorizer", RuntimeType.U64), + Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), setOf("T")), ), errorType = RuntimeType.std.resolve("io::Error"), initializer = @@ -53,30 +57,30 @@ internal class ServiceConfigGeneratorTest { writable { rustTemplate( """ - if authorizer != 69 { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); - } - - if auth_spec.len() != 69 { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); - } - let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; - let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; - """, + if authorizer != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); + } + + if auth_spec.len() != 69 && generic_list.len() != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); + } + let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + """, *codegenScope, ) }, layerBindings = emptyList(), httpPluginBindings = listOf( - Binding( + Binding.Concrete( "authn_plugin", smithyHttpServer.resolve("plugin::IdentityPlugin"), ), ), modelPluginBindings = listOf( - Binding( + Binding.Concrete( "authz_plugin", smithyHttpServer.resolve("plugin::IdentityPlugin"), ), @@ -108,7 +112,7 @@ internal class ServiceConfigGeneratorTest { // One model plugin has been applied. PluginStack, > = SimpleServiceConfig::builder() - .aws_auth("a".repeat(69).to_owned(), 69) + .aws_auth("a".repeat(69).to_owned(), 69, vec![69]) .expect("failed to configure aws_auth") .build() .unwrap(); @@ -120,7 +124,7 @@ internal class ServiceConfigGeneratorTest { rust( """ let actual_err = SimpleServiceConfig::builder() - .aws_auth("a".to_owned(), 69) + .aws_auth("a".to_owned(), 69, vec![69]) .unwrap_err(); let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string(); assert_eq!(actual_err.to_string(), expected); @@ -132,7 +136,7 @@ internal class ServiceConfigGeneratorTest { rust( """ let actual_err = SimpleServiceConfig::builder() - .aws_auth("a".repeat(69).to_owned(), 6969) + .aws_auth("a".repeat(69).to_owned(), 6969, vec!["69"]) .unwrap_err(); let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string(); assert_eq!(actual_err.to_string(), expected); @@ -154,7 +158,7 @@ internal class ServiceConfigGeneratorTest { } @Test - fun `it should inject an method that applies three non-required layers`() { + fun `it should inject a method that applies three non-required layers`() { val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() val decorator = @@ -182,18 +186,18 @@ internal class ServiceConfigGeneratorTest { writable { rustTemplate( """ - let layer1 = #{Identity}::new(); - let layer2 = #{Identity}::new(); - let layer3 = #{Identity}::new(); - """, + let layer1 = #{Identity}::new(); + let layer2 = #{Identity}::new(); + let layer3 = #{Identity}::new(); + """, *codegenScope, ) }, layerBindings = listOf( - Binding("layer1", identityLayer), - Binding("layer2", identityLayer), - Binding("layer3", identityLayer), + Binding.Concrete("layer1", identityLayer), + Binding.Concrete("layer2", identityLayer), + Binding.Concrete("layer3", identityLayer), ), httpPluginBindings = emptyList(), modelPluginBindings = emptyList(), @@ -240,4 +244,50 @@ internal class ServiceConfigGeneratorTest { } } } + + @Test + fun `it should throw an exception if a generic binding using L, H, or M is used`() { + val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() + + val decorator = + object : ServerCodegenDecorator { + override val name: String + get() = "InvalidGenericBindingsDecorator" + override val order: Byte + get() = 69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + return listOf( + ConfigMethod( + name = "invalid_generic_bindings", + docs = "Docs", + params = + listOf( + Binding.Generic("param1_bad", identityLayer, setOf("L")), + Binding.Generic("param2_bad", identityLayer, setOf("H")), + Binding.Generic("param3_bad", identityLayer, setOf("M")), + Binding.Generic("param4_ok", identityLayer, setOf("N")), + ), + errorType = null, + initializer = + Initializer( + code = writable {}, + layerBindings = emptyList(), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, + ), + ) + } + } + + val codegenException = + shouldThrow { + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, _ -> } + } + + codegenException.message.shouldContain("Injected config method `invalid_generic_bindings` has generic bindings that use `L`, `H`, or `M` to refer to the generic types. This is not allowed. Invalid generic bindings:") + } }