Unverified Commit c8ba2d5b authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Fix auth failures in codegen-client-tests in orchestrator mode (#2812)

## Motivation and Context
In orchestrator mode, most `codegen-client-test` tests were failing due
to being unable to find a matching auth scheme, or due to some of the
test models referencing the `@sigv4` trait. This PR fixes all of those
failures, and adds the `smithy.runtime.mode` flag to
`codegen-client-test` as well.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 5eb885c2
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ data class CodegenTest(
    val service: String,
    val module: String,
    val extraConfig: String? = null,
    val extraCodegenConfig: String? = null,
    val imports: List<String> = emptyList(),
)

@@ -38,6 +39,7 @@ private fun generateSmithyBuild(projectDir: String, pluginName: String, tests: L
                        "relativePath": "$projectDir/rust-runtime"
                    },
                    "codegen": {
                        ${it.extraCodegenConfig ?: ""}
                    },
                    "service": "${it.service}",
                    "module": "${it.module}",
+74 −62
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ plugins {
val smithyVersion: String by project
val defaultRustDocFlags: String by project
val properties = PropertyRetriever(rootProject, project)
fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "middleware"

val pluginName = "rust-client-codegen"
val workingDirUnderBuildDir = "smithyprojections/codegen-client-test/"
@@ -33,70 +34,81 @@ dependencies {
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
}

val allCodegenTests = "../codegen-core/common-test-models".let { commonModels ->
    listOf(
        CodegenTest("com.amazonaws.simple#SimpleService", "simple", imports = listOf("$commonModels/simple.smithy")),
        CodegenTest("com.amazonaws.dynamodb#DynamoDB_20120810", "dynamo"),
        CodegenTest("com.amazonaws.ebs#Ebs", "ebs", imports = listOf("$commonModels/ebs.json")),
        CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
        CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"),
        CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
        CodegenTest("aws.protocoltests.restjson#RestJsonExtras", "rest_json_extras", imports = listOf("$commonModels/rest-json-extras.smithy")),
        CodegenTest("aws.protocoltests.misc#MiscService", "misc", imports = listOf("$commonModels/misc.smithy")),
        CodegenTest(
            "aws.protocoltests.restxml#RestXml", "rest_xml",
            extraConfig = """, "codegen": { "addMessageToErrors": false } """,
        ),
data class ClientTest(
    val serviceShapeName: String,
    val moduleName: String,
    val dependsOn: List<String> = emptyList(),
    val addMessageToErrors: Boolean = true,
    val renameErrors: Boolean = true,
) {
    fun toCodegenTest(): CodegenTest = CodegenTest(
        serviceShapeName,
        moduleName,
        extraCodegenConfig = extraCodegenConfig(),
        imports = imports(),
    )

        CodegenTest(
            "aws.protocoltests.query#AwsQuery", "aws_query",
            extraConfig = """, "codegen": { "addMessageToErrors": false } """,
        ),
        CodegenTest(
            "aws.protocoltests.ec2#AwsEc2", "ec2_query",
            extraConfig = """, "codegen": { "addMessageToErrors": false } """,
        ),
        CodegenTest(
            "aws.protocoltests.restxml.xmlns#RestXmlWithNamespace",
            "rest_xml_namespace",
            extraConfig = """, "codegen": { "addMessageToErrors": false } """,
        ),
        CodegenTest(
            "aws.protocoltests.restxml#RestXmlExtras",
            "rest_xml_extras",
            extraConfig = """, "codegen": { "addMessageToErrors": false } """,
    private fun extraCodegenConfig(): String = StringBuilder().apply {
        append("\"addMessageToErrors\": $addMessageToErrors,\n")
        append("\"renameErrors\": $renameErrors\n,")
        append("\"enableNewSmithyRuntime\": \"${getSmithyRuntimeMode()}\"")
    }.toString()

    private fun imports(): List<String> = dependsOn.map { "../codegen-core/common-test-models/$it" }
}

val allCodegenTests = listOf(
    ClientTest("com.amazonaws.simple#SimpleService", "simple", dependsOn = listOf("simple.smithy")),
    ClientTest("com.amazonaws.dynamodb#DynamoDB_20120810", "dynamo"),
    ClientTest("com.amazonaws.ebs#Ebs", "ebs", dependsOn = listOf("ebs.json")),
    ClientTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
    ClientTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"),
    ClientTest("aws.protocoltests.restjson#RestJson", "rest_json"),
    ClientTest(
        "aws.protocoltests.restjson#RestJsonExtras",
        "rest_json_extras",
        dependsOn = listOf("rest-json-extras.smithy"),
    ),
        CodegenTest(
    ClientTest("aws.protocoltests.misc#MiscService", "misc", dependsOn = listOf("misc.smithy")),
    ClientTest("aws.protocoltests.restxml#RestXml", "rest_xml", addMessageToErrors = false),
    ClientTest("aws.protocoltests.query#AwsQuery", "aws_query", addMessageToErrors = false),
    ClientTest("aws.protocoltests.ec2#AwsEc2", "ec2_query", addMessageToErrors = false),
    ClientTest("aws.protocoltests.restxml.xmlns#RestXmlWithNamespace", "rest_xml_namespace", addMessageToErrors = false),
    ClientTest("aws.protocoltests.restxml#RestXmlExtras", "rest_xml_extras", addMessageToErrors = false),
    ClientTest(
        "aws.protocoltests.restxmlunwrapped#RestXmlExtrasUnwrappedErrors",
        "rest_xml_extras_unwrapped",
            extraConfig = """, "codegen": { "addMessageToErrors": false } """,
        addMessageToErrors = false,
    ),
        CodegenTest(
    ClientTest(
        "crate#Config",
        "naming_test_ops",
            """
            , "codegen": { "renameErrors": false }
            """.trimIndent(),
            imports = listOf("$commonModels/naming-obstacle-course-ops.smithy"),
        dependsOn = listOf("naming-obstacle-course-ops.smithy"),
        renameErrors = false,
    ),
        CodegenTest(
    ClientTest(
        "casing#ACRONYMInside_Service",
        "naming_test_casing",
            imports = listOf("$commonModels/naming-obstacle-course-casing.smithy"),
        dependsOn = listOf("naming-obstacle-course-casing.smithy"),
    ),
        CodegenTest(
    ClientTest(
        "naming_obs_structs#NamingObstacleCourseStructs",
        "naming_test_structs",
            """
            , "codegen": { "renameErrors": false }
            """.trimIndent(),
            imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"),
        dependsOn = listOf("naming-obstacle-course-structs.smithy"),
        renameErrors = false,
    ),
        CodegenTest("aws.protocoltests.json#TestService", "endpoint-rules"),
        CodegenTest("com.aws.example#PokemonService", "pokemon-service-client", imports = listOf("$commonModels/pokemon.smithy", "$commonModels/pokemon-common.smithy")),
        CodegenTest("com.aws.example#PokemonService", "pokemon-service-awsjson-client", imports = listOf("$commonModels/pokemon-awsjson.smithy", "$commonModels/pokemon-common.smithy")),
    )
}
    ClientTest("aws.protocoltests.json#TestService", "endpoint-rules"),
    ClientTest(
        "com.aws.example#PokemonService",
        "pokemon-service-client",
        dependsOn = listOf("pokemon.smithy", "pokemon-common.smithy"),
    ),
    ClientTest(
        "com.aws.example#PokemonService",
        "pokemon-service-awsjson-client",
        dependsOn = listOf("pokemon-awsjson.smithy", "pokemon-common.smithy"),
    ),
).map(ClientTest::toCodegenTest)

project.registerGenerateSmithyBuildTask(rootProject, pluginName, allCodegenTests)
project.registerGenerateCargoWorkspaceTask(rootProject, pluginName, allCodegenTests, workingDirUnderBuildDir)
+2 −37
Original line number Diff line number Diff line
@@ -7,25 +7,17 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthOption
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf

val noAuthSchemeShapeId: ShapeId = ShapeId.from("aws.smithy.rs#NoAuth")

private fun noAuthModule(codegenContext: ClientCodegenContext): RuntimeType =
    CargoDependency.smithyRuntime(codegenContext.runtimeConfig)
        .withFeature("no-auth")
        .toType()
        .resolve("client::auth::no_auth")

@@ -37,38 +29,11 @@ class NoAuthDecorator : ClientCodegenDecorator {
        codegenContext: ClientCodegenContext,
        operationShape: OperationShape,
        baseAuthOptions: List<AuthOption>,
    ): List<AuthOption> = baseAuthOptions.letIf(operationShape.hasTrait<OptionalAuthTrait>()) {
        it + AuthOption.StaticAuthOption(noAuthSchemeShapeId) {
    ): List<AuthOption> = baseAuthOptions +
        AuthOption.StaticAuthOption(noAuthSchemeShapeId) {
            rustTemplate(
                "#{NO_AUTH_SCHEME_ID},",
                "NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"),
            )
        }
}

    override fun operationCustomizations(
        codegenContext: ClientCodegenContext,
        operation: OperationShape,
        baseCustomizations: List<OperationCustomization>,
    ): List<OperationCustomization> = baseCustomizations + AnonymousAuthCustomization(codegenContext, operation)
}

class AnonymousAuthCustomization(
    private val codegenContext: ClientCodegenContext,
    private val operationShape: OperationShape,
) : OperationCustomization() {
    override fun section(section: OperationSection): Writable = writable {
        if (
            codegenContext.smithyRuntimeMode.generateOrchestrator &&
            section is OperationSection.AdditionalRuntimePlugins &&
            operationShape.hasTrait<OptionalAuthTrait>()
        ) {
            section.addOperationRuntimePlugin(this) {
                rustTemplate(
                    "#{NoAuthRuntimePlugin}::new()",
                    "NoAuthRuntimePlugin" to noAuthModule(codegenContext).resolve("NoAuthRuntimePlugin"),
                )
            }
        }
    }
}
+3 −1
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
@@ -232,7 +233,8 @@ internal class EndpointParamsGenerator(private val parameters: Parameters) {
        rustWriter.rustBlock("impl ParamsBuilder") {
            docs("Consume this builder, creating [`Params`].")
            rustBlockTemplate(
                "pub fn build(self) -> Result<#{Params}, #{ParamsError}>",
                "pub fn build(self) -> #{Result}<#{Params}, #{ParamsError}>",
                *preludeScope,
                "Params" to paramsStruct(),
                "ParamsError" to paramsError(),
            ) {
+4 −2
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
@@ -45,6 +46,7 @@ class EndpointParamsInterceptorGenerator(
        val orchestrator = runtimeApi.resolve("client::orchestrator")
        val smithyTypes = CargoDependency.smithyTypes(rc).toType()
        arrayOf(
            *preludeScope,
            "BoxError" to RuntimeType.boxError(rc),
            "ConfigBag" to RuntimeType.configBag(rc),
            "ConfigBagAccessors" to RuntimeType.smithyRuntimeApi(rc)
@@ -78,7 +80,7 @@ class EndpointParamsInterceptorGenerator(
                    &self,
                    context: &#{BeforeSerializationInterceptorContextRef}<'_, #{Input}, #{Output}, #{Error}>,
                    cfg: &mut #{ConfigBag},
                ) -> Result<(), #{BoxError}> {
                ) -> #{Result}<(), #{BoxError}> {
                    use #{ConfigBagAccessors};
                    let _input = context.input()
                        .downcast_ref::<${operationInput.name}>()
@@ -91,7 +93,7 @@ class EndpointParamsInterceptorGenerator(
                        .build()
                        .map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
                    cfg.interceptor_state().set_endpoint_resolver_params(#{EndpointResolverParams}::new(params));
                    Ok(())
                    #{Ok}(())
                }
            }
            """,
Loading