Unverified Commit 70a3526b authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add Support for Endpoint Prefix (#420)

* Add Support for Endpoint Prefix

I decided to allow the httpLabel fields to remain optional for now (in the future, hopefully we'll be able to code generate on their required status and delete some code!).

In other exciting news, there are now no failing protocol tests!

* Simplify code and add test for unset

* Update codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EndpointTraitBindingGenerator.kt

* Fix clippy errors
parent 7844b56e
Loading
Loading
Loading
Loading
+45 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.customizations

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.smithy.generators.OperationBuildError
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig

class EndpointPrefixGenerator(private val protocolConfig: ProtocolConfig, private val shape: OperationShape) :
    OperationCustomization() {
    override fun section(section: OperationSection): Writable = when (section) {
        is OperationSection.MutateRequest -> writable {
            shape.getTrait(EndpointTrait::class.java).map { epTrait ->
                val endpointTraitBindings = EndpointTraitBindings(
                    protocolConfig.model,
                    protocolConfig.symbolProvider,
                    protocolConfig.runtimeConfig,
                    shape,
                    epTrait
                )
                val buildError = OperationBuildError(protocolConfig.runtimeConfig)
                withBlock("let endpoint_prefix = ", ";") {
                    endpointTraitBindings.render(this, "self")
                }
                rustBlock("match endpoint_prefix") {
                    rust("Ok(prefix) => { request.config_mut().insert(prefix); },")
                    rust("Err(err) => return Err(${buildError.serializationError(this, "err")})")
                }
            }
        }
        else -> emptySection
    }
}
+5 −1
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.smithy.customize
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.smithy.customizations.AllowClippyLints
import software.amazon.smithy.rust.codegen.smithy.customizations.CrateVersionGenerator
import software.amazon.smithy.rust.codegen.smithy.customizations.EndpointPrefixGenerator
import software.amazon.smithy.rust.codegen.smithy.customizations.IdempotencyTokenGenerator
import software.amazon.smithy.rust.codegen.smithy.customizations.SmithyTypesPubUseGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization
@@ -26,7 +27,10 @@ class RequiredCustomizations : RustCodegenDecorator {
        operation: OperationShape,
        baseCustomizations: List<OperationCustomization>
    ): List<OperationCustomization> {
        return baseCustomizations + IdempotencyTokenGenerator(protocolConfig, operation)
        return baseCustomizations + IdempotencyTokenGenerator(protocolConfig, operation) + EndpointPrefixGenerator(
            protocolConfig,
            operation
        )
    }

    override fun libRsCustomizations(
+6 −0
Original line number Diff line number Diff line
@@ -39,6 +39,12 @@ fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): RuntimeTyp

fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).member("BuildError")

class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
    fun missingField(w: RustWriter, field: String, details: String) = "${w.format(runtimeConfig.operationBuildError())}::MissingField { field: ${field.dq()}, details: ${details.dq()} }"
    fun invalidField(w: RustWriter, field: String, details: String) = "${w.format(runtimeConfig.operationBuildError())}::InvalidField { field: ${field.dq()}, details: ${details.dq()}.to_string() }"
    fun serializationError(w: RustWriter, error: String) = "${w.format(runtimeConfig.operationBuildError())}::SerializationError($error.into())"
}

/** setter names will never hit a reserved word and therefore never need escaping */
fun MemberShape.setterName(): String = "set_${this.memberName.toSnakeCase()}"

+95 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.http.rustFormatString
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.util.inputShape

fun EndpointTrait.prefixFormatString(): String {
    return this.hostPrefix.rustFormatString("", "")
}

fun RuntimeConfig.smithyHttp() = CargoDependency.SmithyHttp(this).asType()

class EndpointTraitBindings(
    model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val runtimeConfig: RuntimeConfig,
    operationShape: OperationShape,
    private val endpointTrait: EndpointTrait
) {
    private val inputShape = operationShape.inputShape(model)
    private val smithyHttp = runtimeConfig.smithyHttp()
    private val endpointPrefix = smithyHttp.member("endpoint::EndpointPrefix")

    /**
     * Render the `EndpointPrefix` struct. [input] refers to the symbol referring to the input of this operation.
     *
     * Generates code like:
     * ```rust
     * EndpointPrefix::new(format!("{}.aws.com", input.bucket));
     * ```
     *
     * The returned expression is a `Result<EndpointPrefix, UriError>`
     */
    fun render(writer: RustWriter, input: String) {
        // the Rust format pattern to make the endpoint prefix eg. "{}.foo"
        val formatLiteral = endpointTrait.prefixFormatString()
        if (endpointTrait.hostPrefix.labels.isEmpty()) {
            // if there are no labels, we don't need string formatting
            writer.rustTemplate(
                "#{EndpointPrefix}::new($formatLiteral)",
                "EndpointPrefix" to endpointPrefix
            )
        } else {
            val operationBuildError = OperationBuildError(runtimeConfig)
            writer.rustBlock("") {
                // build a list of args: `labelname = "field"`
                // these eventually end up in the format! macro invocation:
                // ```format!("some.{endpoint}", endpoint = endpoint);```
                val args = endpointTrait.hostPrefix.labels.map { label ->
                    val memberShape = inputShape.getMember(label.content).get()
                    val field = symbolProvider.toMemberName(memberShape)
                    val invalidFieldError = operationBuildError.invalidField(
                        writer,
                        field,
                        "$field was unset or empty but must be set as part of the endpoint prefix"
                    )
                    if (symbolProvider.toSymbol(memberShape).isOptional()) {
                        rust("let $field = $input.$field.as_deref().unwrap_or_default();")
                    } else {
                        // NOTE: this is dead code until we start respecting @required
                        rust("let $field = &$input.$field;")
                    }
                    rust(
                        """
                    if $field.is_empty() {
                        return Err($invalidFieldError)
                    }
                    """
                    )
                    "${label.content} = $field"
                }
                writer.rustTemplate(
                    "#{EndpointPrefix}::new(format!($formatLiteral, ${args.joinToString()}))",
                    "EndpointPrefix" to endpointPrefix
                )
            }
        }
    }
}
+1 −3
Original line number Diff line number Diff line
@@ -100,7 +100,7 @@ abstract class HttpProtocolGenerator(
        }
        val operationName = symbolProvider.toSymbol(operationShape).name
        operationWriter.documentShape(operationShape, model)
        Attribute.Derives(setOf(RuntimeType.Clone, RuntimeType.Default)).render(operationWriter)
        Attribute.Derives(setOf(RuntimeType.Clone, RuntimeType.Default, RuntimeType.Debug)).render(operationWriter)
        operationWriter.rustBlock("pub struct $operationName") {
            write("_private: ()")
        }
@@ -118,8 +118,6 @@ abstract class HttpProtocolGenerator(
        traitImplementations(operationWriter, operationShape)
    }

    data class ResponseBody(val type: String, val mutability: String)

    protected fun httpBuilderFun(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) {
        Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter)
        implBlockWriter.rustBlock(
Loading