Unverified Commit ee049ed2 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Ep2 tests (#2127)

* add a test of credentials

* Add tests of S3 customizations

* Add fix from separate PR

* Add FIPS and DualStack

* Fix date issue + other cleanup

* refactor out awsSdkIntegrationTest

* use extendIf in a couple more places

* Add test of string built in

* fix generation of region related code

* CR feedback
parent 3499c11b
Loading
Loading
Loading
Loading
+19 −8
Original line number Diff line number Diff line
@@ -163,18 +163,29 @@ fn smithy_to_aws(value: &SmithyEndpoint) -> Result<EndpointMetadata, Box<dyn Err
        None => return Ok((None, None)),
        _other => return Err("expected an array for authSchemes".into()),
    };
    let v4 = auth_schemes
    let auth_schemes = auth_schemes
        .iter()
        .flat_map(|doc| match doc {
            Document::Object(map)
                if map.get("name") == Some(&Document::String("sigv4".to_string())) =>
            {
                Some(map)
            }
            Document::Object(map) => Some(map),
            _ => None,
        })
        .next()
        .ok_or("could not find v4 as an acceptable auth scheme (the SDK does not support Bearer Auth at this time)")?;
        .map(|it| {
            let name = match it.get("name") {
                Some(Document::String(s)) => Some(s.as_str()),
                _ => None,
            };
            (name, it)
        });
    let (_, v4) = auth_schemes
        .clone()
        .find(|(name, _doc)| name.as_deref() == Some("sigv4"))
        .ok_or_else(|| {
            format!(
                "No auth schemes were supported. The Rust SDK only supports sigv4. \
                The authentication schemes supported by this endpoint were: {:?}",
                auth_schemes.flat_map(|(name, _)| name).collect::<Vec<_>>()
            )
        })?;

    let signing_scope = match v4.get("signingRegion") {
        Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
+76 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rulesengine.language.EndpointRuleSet
import software.amazon.smithy.rulesengine.language.syntax.parameters.Builtins
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter
import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType
import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition
import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointRulesetIndex
import software.amazon.smithy.rust.codegen.core.util.getTrait

fun EndpointRuleSet.getBuiltIn(builtIn: Parameter) = parameters.toList().find { it.builtIn == builtIn.builtIn }
fun ClientCodegenContext.getBuiltIn(builtIn: Parameter): Parameter? {
    val idx = EndpointRulesetIndex.of(model)
    val rules = idx.endpointRulesForService(serviceShape) ?: return null
    return rules.getBuiltIn(builtIn)
}

/**
 * For legacy SDKs, there are builtIn parameters that cannot be automatically used as context parameters.
 *
 * However, for the Rust SDK, these parameters can be used directly.
 */
fun Model.promoteBuiltInToContextParam(serviceId: ShapeId, builtInSrc: Parameter): Model {
    val model = this
    val idx = EndpointRulesetIndex.of(model)
    val service = model.expectShape(serviceId, ServiceShape::class.java)
    val rules = idx.endpointRulesForService(service) ?: return model
    // load the builtIn with a matching name from the ruleset allowing for any docs updates
    val builtIn = rules.getBuiltIn(builtInSrc) ?: return model

    return ModelTransformer.create().mapShapes(model) { shape ->
        if (shape !is ServiceShape || shape.id != service.id) {
            shape
        } else {
            val traitBuilder = shape.getTrait<ClientContextParamsTrait>()
                // there is a bug in the return type of the toBuilder method
                ?.let { ClientContextParamsTrait.builder().parameters(it.parameters) }
                ?: ClientContextParamsTrait.builder()
            val contextParamsTrait =
                traitBuilder.putParameter(
                    builtIn.name.asString(),
                    ClientContextParamDefinition.builder().documentation(builtIn.documentation.get()).type(
                        when (builtIn.type!!) {
                            ParameterType.STRING -> ShapeType.STRING
                            ParameterType.BOOLEAN -> ShapeType.BOOLEAN
                        },
                    ).build(),
                ).build()
            shape.toBuilder().removeTrait(ClientContextParamsTrait.ID).addTrait(contextParamsTrait).build()
        }
    }
}

class AddFIPSDualStackDecorator : ClientCodegenDecorator {
    override val name: String = "AddFipsDualStack"
    override val order: Byte = 0

    override fun transformModel(service: ServiceShape, model: Model): Model {
        return model
            .promoteBuiltInToContextParam(service.id, Builtins.FIPS)
            .promoteBuiltInToContextParam(service.id, Builtins.DUALSTACK)
    }
}
+1 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ val DECORATORS: List<ClientCodegenDecorator> = listOf(
    AwsReadmeDecorator(),
    HttpConnectorDecorator(),
    AwsEndpointsStdLib(),
    AddFIPSDualStackDecorator(),

    // Service specific decorators
    ApiGatewayDecorator(),
+127 −100
Original line number Diff line number Diff line
@@ -33,11 +33,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.util.extendIf
import software.amazon.smithy.rust.codegen.core.util.letIf

class AwsEndpointDecorator : ClientCodegenDecorator {
    override val name: String = "AwsEndpoint"
    override val order: Byte = -100
    override val order: Byte = 100

    override fun transformModel(service: ServiceShape, model: Model): Model {
        val customServices = setOf(
@@ -77,7 +78,9 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ConfigCustomization>,
    ): List<ConfigCustomization> {
        return baseCustomizations + EndpointConfigCustomization(
        return baseCustomizations.extendIf(codegenContext.getBuiltIn(Builtins.REGION) != null) {
            AwsEndpointShimCustomization(codegenContext)
        } + SdkEndpointCustomization(
            codegenContext,
        )
    }
@@ -125,7 +128,8 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
    }

    override fun endpointCustomizations(codegenContext: ClientCodegenContext): List<EndpointCustomization> {
        return listOf(object : EndpointCustomization {
        return listOf(
            object : EndpointCustomization {
                override fun builtInDefaultValue(parameter: Parameter, configRef: String): Writable? {
                    return when (parameter.builtIn) {
                        Builtins.SDK_ENDPOINT.builtIn -> writable { rust("$configRef.endpoint_url().map(|url|url.to_string())") }
@@ -135,23 +139,18 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
            },
        )
    }
}

class EndpointConfigCustomization(
    codegenContext: CodegenContext,
) :
    ConfigCustomization() {
    class AwsEndpointShimCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() {
        private val moduleUseName = codegenContext.moduleUseName()
        private val runtimeConfig = codegenContext.runtimeConfig
        private val resolveAwsEndpoint = AwsRuntimeType.awsEndpoint(runtimeConfig).resolve("ResolveAwsEndpoint")
        private val endpointShim = AwsRuntimeType.awsEndpoint(runtimeConfig).resolve("EndpointShim")
    private val moduleUseName = codegenContext.moduleUseName()
        private val codegenScope = arrayOf(
            "ResolveAwsEndpoint" to resolveAwsEndpoint,
            "EndpointShim" to endpointShim,
            "aws_types" to AwsRuntimeType.awsTypes(runtimeConfig),
        )

    override fun section(section: ServiceConfig): Writable = writable {
        override fun section(section: ServiceConfig) = writable {
            when (section) {
                ServiceConfig.BuilderImpl -> rustTemplate(
                    """
@@ -190,6 +189,32 @@ class EndpointConfigCustomization(
                        self
                    }

                    """,
                    *codegenScope,
                )

                else -> emptySection
            }
        }
    }

    class SdkEndpointCustomization(
        codegenContext: CodegenContext,
    ) :
        ConfigCustomization() {
        private val runtimeConfig = codegenContext.runtimeConfig
        private val resolveAwsEndpoint = AwsRuntimeType.awsEndpoint(runtimeConfig).resolve("ResolveAwsEndpoint")
        private val endpointShim = AwsRuntimeType.awsEndpoint(runtimeConfig).resolve("EndpointShim")
        private val codegenScope = arrayOf(
            "ResolveAwsEndpoint" to resolveAwsEndpoint,
            "EndpointShim" to endpointShim,
            "aws_types" to AwsRuntimeType.awsTypes(runtimeConfig),
        )

        override fun section(section: ServiceConfig): Writable = writable {
            when (section) {
                ServiceConfig.BuilderImpl -> rustTemplate(
                    """
                    /// Sets the endpoint url used to communicate with this service
                    ///
                    /// Note: this is used in combination with other endpoint rules, e.g. an API that applies a host-label prefix
@@ -213,13 +238,14 @@ class EndpointConfigCustomization(
                    *codegenScope,
                )

            ServiceConfig.BuilderBuild -> rust("endpoint_url: self.endpoint_url")
            ServiceConfig.BuilderStruct -> rust("endpoint_url: Option<String>")
                ServiceConfig.BuilderBuild -> rust("endpoint_url: self.endpoint_url,")
                ServiceConfig.BuilderStruct -> rust("endpoint_url: Option<String>,")
                ServiceConfig.ConfigImpl -> {
                    Attribute.AllowDeadCode.render(this)
                    rust("pub(crate) fn endpoint_url(&self) -> Option<&str> { self.endpoint_url.as_deref() }")
                }
            ServiceConfig.ConfigStruct -> rust("endpoint_url: Option<String>")

                ServiceConfig.ConfigStruct -> rust("endpoint_url: Option<String>,")
                ServiceConfig.ConfigStructAdditionalDocs -> emptySection
                ServiceConfig.Extras -> emptySection
            }
@@ -240,3 +266,4 @@ class PubUseEndpoint(private val runtimeConfig: RuntimeConfig) : LibRsCustomizat
            }
        }
    }
}
+11 −3
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSectio
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.extendIf

/* Example Generated Code */
/*
@@ -79,11 +80,15 @@ class RegionDecorator : ClientCodegenDecorator {
    override val name: String = "Region"
    override val order: Byte = 0

    private fun usesRegion(codegenContext: ClientCodegenContext) = codegenContext.getBuiltIn(Builtins.REGION) != null

    override fun configCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ConfigCustomization>,
    ): List<ConfigCustomization> {
        return baseCustomizations + RegionProviderConfig(codegenContext)
        return baseCustomizations.extendIf(usesRegion(codegenContext)) {
            RegionProviderConfig(codegenContext)
        }
    }

    override fun operationCustomizations(
@@ -91,17 +96,20 @@ class RegionDecorator : ClientCodegenDecorator {
        operation: OperationShape,
        baseCustomizations: List<OperationCustomization>,
    ): List<OperationCustomization> {
        return baseCustomizations + RegionConfigPlugin()
        return baseCustomizations.extendIf(usesRegion(codegenContext)) { RegionConfigPlugin() }
    }

    override fun libRsCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<LibRsCustomization>,
    ): List<LibRsCustomization> {
        return baseCustomizations + PubUseRegion(codegenContext.runtimeConfig)
        return baseCustomizations.extendIf(usesRegion(codegenContext)) { PubUseRegion(codegenContext.runtimeConfig) }
    }

    override fun endpointCustomizations(codegenContext: ClientCodegenContext): List<EndpointCustomization> {
        if (!usesRegion(codegenContext)) {
            return listOf()
        }
        return listOf(
            object : EndpointCustomization {
                override fun builtInDefaultValue(parameter: Parameter, configRef: String): Writable? {
Loading