Unverified Commit 95dc365d authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Endpoint operation input tests (#2204)

* Add support for operationInput tests

* More unfication, fix tests, docs

* Set endpoint_url only when endpoint_url is used

* Fix test-util feature

* CR feedback

* fix missing path
parent 582ae855
Loading
Loading
Loading
Loading
+0 −105
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rulesengine.language.EndpointRuleSet
import software.amazon.smithy.rulesengine.language.syntax.parameters.Builtins
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter
import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType
import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition
import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointRulesetIndex
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.util.getTrait

fun EndpointRuleSet.getBuiltIn(builtIn: Parameter) = parameters.toList().find { it.builtIn == builtIn.builtIn }
fun ClientCodegenContext.getBuiltIn(builtIn: Parameter): Parameter? {
    val idx = EndpointRulesetIndex.of(model)
    val rules = idx.endpointRulesForService(serviceShape) ?: return null
    return rules.getBuiltIn(builtIn)
}

/**
 * For legacy SDKs, there are builtIn parameters that cannot be automatically used as context parameters.
 *
 * However, for the Rust SDK, these parameters can be used directly.
 */
fun Model.promoteBuiltInToContextParam(serviceId: ShapeId, builtInSrc: Parameter): Model {
    val model = this
    // load the builtIn with a matching name from the ruleset allowing for any docs updates
    val builtIn = this.loadBuiltIn(serviceId, builtInSrc) ?: return model

    return ModelTransformer.create().mapShapes(model) { shape ->
        if (shape !is ServiceShape || shape.id != serviceId) {
            shape
        } else {
            val traitBuilder = shape.getTrait<ClientContextParamsTrait>()
                // there is a bug in the return type of the toBuilder method
                ?.let { ClientContextParamsTrait.builder().parameters(it.parameters) }
                ?: ClientContextParamsTrait.builder()
            val contextParamsTrait =
                traitBuilder.putParameter(
                    builtIn.name.asString(),
                    ClientContextParamDefinition.builder().documentation(builtIn.documentation.get()).type(
                        when (builtIn.type!!) {
                            ParameterType.STRING -> ShapeType.STRING
                            ParameterType.BOOLEAN -> ShapeType.BOOLEAN
                        },
                    ).build(),
                ).build()
            shape.toBuilder().removeTrait(ClientContextParamsTrait.ID).addTrait(contextParamsTrait).build()
        }
    }
}

fun Model.loadBuiltIn(serviceId: ShapeId, builtInSrc: Parameter): Parameter? {
    val model = this
    val idx = EndpointRulesetIndex.of(model)
    val service = model.expectShape(serviceId, ServiceShape::class.java)
    val rules = idx.endpointRulesForService(service) ?: return null
    // load the builtIn with a matching name from the ruleset allowing for any docs updates
    return rules.getBuiltIn(builtInSrc)
}

fun Model.sdkConfigSetter(serviceId: ShapeId, builtInSrc: Parameter): Pair<AdHocSection<*>, (Section) -> Writable>? {
    val builtIn = loadBuiltIn(serviceId, builtInSrc) ?: return null
    val fieldName = builtIn.name.rustName()

    return SdkConfigSection.create { section ->
        {
            rust("${section.serviceConfigBuilder}.set_$fieldName(${section.sdkConfig}.$fieldName());")
        }
    }
}

class AddFIPSDualStackDecorator : ClientCodegenDecorator {
    override val name: String = "AddFipsDualStack"
    override val order: Byte = 0

    override fun transformModel(service: ServiceShape, model: Model): Model {
        return model
            .promoteBuiltInToContextParam(service.id, Builtins.FIPS)
            .promoteBuiltInToContextParam(service.id, Builtins.DUALSTACK)
    }

    override fun extraSections(codegenContext: ClientCodegenContext): List<Pair<AdHocSection<*>, (Section) -> Writable>> {
        return listOfNotNull(
            codegenContext.model.sdkConfigSetter(codegenContext.serviceShape.id, Builtins.FIPS),
            codegenContext.model.sdkConfigSetter(codegenContext.serviceShape.id, Builtins.DUALSTACK),
        )
    }
}
+5 −1
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@ import software.amazon.smithy.rustsdk.customize.route53.Route53Decorator
import software.amazon.smithy.rustsdk.customize.s3.S3Decorator
import software.amazon.smithy.rustsdk.customize.s3control.S3ControlDecorator
import software.amazon.smithy.rustsdk.customize.sts.STSDecorator
import software.amazon.smithy.rustsdk.endpoints.AwsEndpointDecorator
import software.amazon.smithy.rustsdk.endpoints.AwsEndpointsStdLib
import software.amazon.smithy.rustsdk.endpoints.OperationInputTestDecorator

val DECORATORS: List<ClientCodegenDecorator> = listOf(
    // General AWS Decorators
@@ -38,8 +41,9 @@ val DECORATORS: List<ClientCodegenDecorator> = listOf(
    AwsReadmeDecorator(),
    HttpConnectorDecorator(),
    AwsEndpointsStdLib(),
    AddFIPSDualStackDecorator(),
    *PromotedBuiltInsDecorators,
    GenericSmithySdkConfigSettings(),
    OperationInputTestDecorator(),

    // Service specific decorators
    ApiGatewayDecorator(),
+8 −1
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ package software.amazon.smithy.rustsdk

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation
@@ -60,10 +61,16 @@ object AwsRuntimeType {
    ).resolve("DefaultMiddleware")

    fun awsCredentialTypes(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsCredentialTypes(runtimeConfig).toType()

    fun awsCredentialTypesTestUtil(runtimeConfig: RuntimeConfig) =
        AwsCargoDependency.awsCredentialTypes(runtimeConfig).copy(scope = DependencyScope.Dev).withFeature("test-util").toType()

    fun awsEndpoint(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsEndpoint(runtimeConfig).toType()
    fun awsHttp(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsHttp(runtimeConfig).toType()
    fun awsSigAuth(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsSigAuth(runtimeConfig).toType()
    fun awsSigAuthEventStream(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsSigAuthEventStream(runtimeConfig).toType()
    fun awsSigAuthEventStream(runtimeConfig: RuntimeConfig) =
        AwsCargoDependency.awsSigAuthEventStream(runtimeConfig).toType()

    fun awsSigv4(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsSigv4(runtimeConfig).toType()
    fun awsTypes(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsTypes(runtimeConfig).toType()
}
+18 −1
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ package software.amazon.smithy.rustsdk

import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.TestUtilFeature
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
@@ -15,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
@@ -46,6 +48,10 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
                }
            },
        )

    override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) {
        rustCrate.mergeFeature(TestUtilFeature.copy(deps = listOf("aws-credential-types/test-util")))
    }
}

/**
@@ -54,13 +60,19 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
class CredentialProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomization() {
    private val codegenScope = arrayOf(
        "provider" to AwsRuntimeType.awsCredentialTypes(runtimeConfig).resolve("provider"),
        "Credentials" to AwsRuntimeType.awsCredentialTypes(runtimeConfig).resolve("Credentials"),
        "TestCredentials" to AwsRuntimeType.awsCredentialTypesTestUtil(runtimeConfig).resolve("Credentials"),
        "DefaultProvider" to defaultProvider(),
    )

    override fun section(section: ServiceConfig) = writable {
        when (section) {
            ServiceConfig.BuilderStruct ->
                rustTemplate("credentials_provider: Option<std::sync::Arc<dyn #{provider}::ProvideCredentials>>,", *codegenScope)
                rustTemplate(
                    "credentials_provider: Option<std::sync::Arc<dyn #{provider}::ProvideCredentials>>,",
                    *codegenScope,
                )

            ServiceConfig.BuilderImpl -> {
                rustTemplate(
                    """
@@ -80,6 +92,11 @@ class CredentialProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomizati
                )
            }

            is ServiceConfig.DefaultForTests -> rustTemplate(
                "${section.configBuilderRef}.set_credentials_provider(Some(std::sync::Arc::new(#{TestCredentials}::for_tests())));",
                *codegenScope,
            )

            else -> emptySection
        }
    }
+179 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rulesengine.language.EndpointRuleSet
import software.amazon.smithy.rulesengine.language.syntax.parameters.Builtins
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter
import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointRulesetIndex
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigParam
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.standardConfigParam
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.extendIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import java.util.Optional

/** load a builtIn parameter from a ruleset by name */
fun EndpointRuleSet.getBuiltIn(builtIn: String) = parameters.toList().find { it.builtIn == Optional.of(builtIn) }

/** load a builtIn parameter from a ruleset. The returned builtIn is the one defined in the ruleset (including latest docs, etc.) */
fun EndpointRuleSet.getBuiltIn(builtIn: Parameter) = getBuiltIn(builtIn.builtIn.orNull()!!)
fun ClientCodegenContext.getBuiltIn(builtIn: Parameter): Parameter? = getBuiltIn(builtIn.builtIn.orNull()!!)
fun ClientCodegenContext.getBuiltIn(builtIn: String): Parameter? {
    val idx = EndpointRulesetIndex.of(model)
    val rules = idx.endpointRulesForService(serviceShape) ?: return null
    return rules.getBuiltIn(builtIn)
}

private fun toConfigParam(parameter: Parameter): ConfigParam = ConfigParam(
    parameter.name.rustName(),
    when (parameter.type!!) {
        ParameterType.STRING -> RuntimeType.String.toSymbol()
        ParameterType.BOOLEAN -> RuntimeType.Bool.toSymbol()
    },
    parameter.documentation.orNull()?.let { writable { docs(it) } },
)

fun Model.loadBuiltIn(serviceId: ShapeId, builtInSrc: Parameter): Parameter? {
    val model = this
    val idx = EndpointRulesetIndex.of(model)
    val service = model.expectShape(serviceId, ServiceShape::class.java)
    val rules = idx.endpointRulesForService(service) ?: return null
    // load the builtIn with a matching name from the ruleset allowing for any docs updates
    return rules.getBuiltIn(builtInSrc)
}

fun Model.sdkConfigSetter(
    serviceId: ShapeId,
    builtInSrc: Parameter,
    configParameterNameOverride: String?,
): Pair<AdHocSection<*>, (Section) -> Writable>? {
    val builtIn = loadBuiltIn(serviceId, builtInSrc) ?: return null
    val fieldName = configParameterNameOverride ?: builtIn.name.rustName()

    val map = when (builtIn.type!!) {
        ParameterType.STRING -> writable { rust("|s|s.to_string()") }
        ParameterType.BOOLEAN -> null
    }
    return SdkConfigSection.copyField(fieldName, map)
}

/**
 * Create a client codegen decorator that creates bindings for a builtIn parameter. Optionally, you can provide [clientParam]
 * which allows control over the config parameter that will be generated.
 */
fun decoratorForBuiltIn(
    builtIn: Parameter,
    clientParam: ConfigParam? = null,
): ClientCodegenDecorator {
    val nameOverride = clientParam?.name
    val name = nameOverride ?: builtIn.name.rustName()
    return object : ClientCodegenDecorator {
        override val name: String = "Auto${builtIn.builtIn.get()}"
        override val order: Byte = 0

        private fun rulesetContainsBuiltIn(codegenContext: ClientCodegenContext) =
            codegenContext.getBuiltIn(builtIn) != null

        override fun extraSections(codegenContext: ClientCodegenContext): List<Pair<AdHocSection<*>, (Section) -> Writable>> {
            return listOfNotNull(
                codegenContext.model.sdkConfigSetter(codegenContext.serviceShape.id, builtIn, clientParam?.name),
            )
        }

        override fun configCustomizations(
            codegenContext: ClientCodegenContext,
            baseCustomizations: List<ConfigCustomization>,
        ): List<ConfigCustomization> {
            return baseCustomizations.extendIf(rulesetContainsBuiltIn(codegenContext)) {
                standardConfigParam(
                    clientParam ?: toConfigParam(builtIn),
                )
            }
        }

        override fun endpointCustomizations(codegenContext: ClientCodegenContext): List<EndpointCustomization> = listOf(
            object : EndpointCustomization {
                override fun loadBuiltInFromServiceConfig(parameter: Parameter, configRef: String): Writable? =
                    when (parameter.builtIn) {
                        builtIn.builtIn -> writable {
                            rust("$configRef.$name")
                            if (parameter.type == ParameterType.STRING) {
                                rust(".clone()")
                            }
                        }
                        else -> null
                    }

                override fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? {
                    if (name != builtIn.builtIn.get()) {
                        return null
                    }
                    return writable {
                        rustTemplate(
                            "let $configBuilderRef = $configBuilderRef.${nameOverride ?: builtIn.name.rustName()}(#{value});",
                            "value" to value.toWritable(),
                        )
                    }
                }
            },
        )
    }
}

private val endpointUrlDocs = writable {
    rust(
        """
        /// Sets the endpoint url used to communicate with this service

        /// Note: this is used in combination with other endpoint rules, e.g. an API that applies a host-label prefix
        /// will be prefixed onto this URL. To fully override the endpoint resolver, use
        /// [`Builder::endpoint_resolver`].
        """.trimIndent(),
    )
}

fun Node.toWritable(): Writable {
    val node = this
    return writable {
        when (node) {
            is StringNode -> rust(node.value.dq())
            is BooleanNode -> rust("${node.value}")
            else -> PANIC("unsupported value for a default: $node")
        }
    }
}

val PromotedBuiltInsDecorators =
    listOf(
        decoratorForBuiltIn(Builtins.FIPS),
        decoratorForBuiltIn(Builtins.DUALSTACK),
        decoratorForBuiltIn(
            Builtins.SDK_ENDPOINT,
            ConfigParam("endpoint_url", RuntimeType.String.toSymbol(), endpointUrlDocs),
        ),
    ).toTypedArray()
Loading