Unverified Commit 12bed905 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Fix bug where Custom Auth Schemes were not respected (#3087)

## Motivation and Context
- Fixes #3034
## Description
Because AuthSchemeOptions were being registered at the operation level,
there was no way for them to be overridden by customer-provided runtime
plugins. This moves them into a separate plugin that is added at
Client/Default priority.

## Testing
- new unit test

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 40f46621
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -468,3 +468,9 @@ message = "Clients now have a default async sleep implementation so that one doe
references = ["smithy-rs#3071"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"

[[smithy-rs]]
message = "Enable custom auth schemes to work by changing the code generated auth options to be set at the client level at `DEFAULTS` priority."
references = ["smithy-rs#3034", "smithy-rs#3087"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "rcoh"
+22 −9
Original line number Diff line number Diff line
@@ -35,24 +35,37 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isInputEventStream
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

class SigV4AuthDecorator : ClientCodegenDecorator {
    override val name: String get() = "SigV4AuthDecorator"
    override val order: Byte = 0

    private val sigv4a = "sigv4a"

    private fun sigv4(runtimeConfig: RuntimeConfig) = writable {
        val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
        rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
    }

    private fun sigv4a(runtimeConfig: RuntimeConfig) = writable {
        val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
        featureGateBlock(sigv4a) {
            rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
        }
    }

    override fun authOptions(
        codegenContext: ClientCodegenContext,
        operationShape: OperationShape,
        baseAuthSchemeOptions: List<AuthSchemeOption>,
    ): List<AuthSchemeOption> = baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption(SigV4Trait.ID) {
        val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig).resolve("auth")
        rust("#T,", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
        if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
            featureGateBlock("sigv4a") {
                rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
            }
            rust(",")
        }
    ): List<AuthSchemeOption> {
        val supportsSigV4a = codegenContext.serviceShape.supportedAuthSchemes().contains(sigv4a)
            .thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) }
        return baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption(
            SigV4Trait.ID,
            listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a,
        )
    }

    override fun serviceRuntimePluginCustomizations(
+5 −3
Original line number Diff line number Diff line
@@ -97,10 +97,12 @@ class HttpAuthDecorator : ClientCodegenDecorator {
                options.add(
                    StaticAuthSchemeOption(
                        schemeShapeId,
                        listOf(
                            writable {
                            rustTemplate("$name,", *codegenScope)
                                rustTemplate(name, *codegenScope)
                            },
                        ),
                    ),
                )
            }
            when (authScheme) {
+12 −6
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOpt
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

val noAuthSchemeShapeId: ShapeId = ShapeId.from("aws.smithy.rs#NoAuth")
@@ -30,10 +31,15 @@ class NoAuthDecorator : ClientCodegenDecorator {
        operationShape: OperationShape,
        baseAuthSchemeOptions: List<AuthSchemeOption>,
    ): List<AuthSchemeOption> = baseAuthSchemeOptions +
        AuthSchemeOption.StaticAuthSchemeOption(noAuthSchemeShapeId) {
        AuthSchemeOption.StaticAuthSchemeOption(
            noAuthSchemeShapeId,
            listOf(
                writable {
                    rustTemplate(
                "#{NO_AUTH_SCHEME_ID},",
                        "#{NO_AUTH_SCHEME_ID}",
                        "NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"),
                    )
        }
                },
            ),
        )
}
+1 −1
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ sealed interface AuthSchemeOption {
    /** Auth scheme for the `StaticAuthSchemeOptionResolver` */
    data class StaticAuthSchemeOption(
        val schemeShapeId: ShapeId,
        val constructor: Writable,
        val constructor: List<Writable>,
    ) : AuthSchemeOption

    class CustomResolver(/* unimplemented */) : AuthSchemeOption
Loading