Unverified Commit b11cede4 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

*breaking change*: Refactor httpLabel behavior around optionality (#537)

* *breaking change*: Refactor httpLabel behavior around optionality

**This change is breaking because fields with the `httpLabel` go from `T` to `Option<T>` in generated models.**

Previously, existence of `httpLabel` would cause a non-optional field to be generated. But:
1. This is wrong. Protocols should not impact models.
2. This causes issues when generating the transcribe service because the `httpLabel` trait is attached to the model.
3. This leads to a bad user experience of the field is unset—A default value is inserted but that leads to a signing error down the line.

This change will cause a failure during request construction if fields targetted with `httpLabel` are either unset or empty.

An integration test validating this behavior for S3 was also added.

* Remove unused container parameter

* Ignore clippy another clippy lint

* Fix clippy lint name
parent 3d61226b
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -33,10 +33,16 @@ fun <T : CodeWriter> T.withBlock(
    return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block, args = args)
}

fun <T : CodeWriter> T.assignment(variableName: String, vararg ctx: Pair<String, Any>, block: T.() -> Unit) {
    withBlockTemplate("let $variableName =", ";", *ctx) {
        block()
    }
}

fun <T : CodeWriter> T.withBlockTemplate(
    textBeforeNewLine: String,
    textAfterNewLine: String,
    vararg ctx: Pair<String, RuntimeType>,
    vararg ctx: Pair<String, Any>,
    block: T.() -> Unit
): T {
    return withTemplate(textBeforeNewLine, ctx) { header ->
+3 −6
Original line number Diff line number Diff line
@@ -37,7 +37,6 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
@@ -161,12 +160,10 @@ class SymbolVisitor(
        return RuntimeType.Blob(config.runtimeConfig).toSymbol()
    }

    private fun handleOptionality(symbol: Symbol, member: MemberShape, container: Shape): Symbol {
    private fun handleOptionality(symbol: Symbol, member: MemberShape): Symbol {
        // If a field has the httpLabel trait and we are generating
        // an Input shape, then the field is _not optional_.
        val httpLabeledInput =
            container.hasTrait<SyntheticInputTrait>() && member.hasTrait<HttpLabelTrait>()
        return if (nullableIndex.isNullable(member) && !httpLabeledInput) {
        return if (nullableIndex.isNullable(member)) {
            symbol.makeOptional()
        } else symbol
    }
@@ -283,7 +280,7 @@ class SymbolVisitor(
        return targetSymbol.letIf(config.handleRustBoxing) {
            handleRustBoxing(it, shape)
        }.letIf(config.handleOptionality) {
            handleOptionality(it, shape, model.expectShape(shape.container))
            handleOptionality(it, shape)
        }
    }

+10 −3
Original line number Diff line number Diff line
@@ -11,19 +11,26 @@ import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection

val ClippyAllowLints = listOf(
    // Sometimes a operation be named the same as our module eg. output leading to `output::output`
    // Sometimes operations are named the same as our module eg. output leading to `output::output`
    "module_inception",

    // Currently, we don't recase acronyms in models, eg. SSEVersion
    "upper_case_acronyms",

    // Large errors trigger this warning, we are unlikely to optimize this case currently
    "large_enum_variant",

    // Some models have members with `is` in the name, which leads to builder functions with the wrong self convention
    "wrong_self_convention",

    // models like ecs use method names like "add()" which confuses clippy
    "should_implement_trait"
    "should_implement_trait",

    // protocol tests use silly names like `baz`, don't flag that
    "blacklisted_name"
)

class AllowClippyLints() : LibRsCustomization() {
class AllowClippyLints : LibRsCustomization() {
    override fun section(section: LibRsSection) = when (section) {
        is LibRsSection.Attributes -> writable {
            ClippyAllowLints.forEach {
+54 −24
Original line number Diff line number Diff line
@@ -20,15 +20,18 @@ import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.assignment
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.OperationBuildError
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError
import software.amazon.smithy.rust.codegen.smithy.generators.redactIfNecessary
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
@@ -100,7 +103,7 @@ class RequestBindingGenerator(
            buildError
        ) {
            write("let mut uri = String::new();")
            write("self.uri_base(&mut uri);")
            write("self.uri_base(&mut uri)?;")
            if (hasQuery) {
                write("self.uri_query(&mut uri);")
            }
@@ -252,14 +255,23 @@ class RequestBindingGenerator(
     */
    private fun uriBase(writer: RustWriter) {
        val formatString = httpTrait.uriFormatString()
        // name of a local variable containing this member's component of the URI
        val local = { member: MemberShape -> symbolProvider.toMemberName(member) }
        val args = httpTrait.uri.labels.map { label ->
            val member = inputShape.expectMember(label.content)
            "${label.content} = ${labelFmtFun(writer, model.expectShape(member.target), member, label)}"
            "${label.content} = ${local(member)}"
        }
        val combinedArgs = listOf(formatString, *args.toTypedArray())
        writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null)
        writer.rustBlock("fn uri_base(&self, output: &mut String)") {
            write("write!(output, ${combinedArgs.joinToString(", ")}).expect(\"formatting should succeed\")")
        writer.rustBlock("fn uri_base(&self, output: &mut String) -> Result<(), #T>", runtimeConfig.operationBuildError()) {
            httpTrait.uri.labels.map { label ->
                val member = inputShape.expectMember(label.content)
                assignment(local(member)) {
                    serializeLabel(member, label)
                }
            }
            rust("""write!(output, ${combinedArgs.joinToString(", ")}).expect("formatting should succeed");""")
            rust("Ok(())")
        }
    }

@@ -368,29 +380,47 @@ class RequestBindingGenerator(
        }
    }

    /**
     * Format [member] when used as an HTTP Label (`/bucket/{key}`)
     */
    private fun labelFmtFun(writer: RustWriter, target: Shape, member: MemberShape, label: SmithyPattern.Segment): String {
        val memberName = symbolProvider.toMemberName(member)
        return when {
    private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment) {
        val target = model.expectShape(member.target)
        val symbol = symbolProvider.toSymbol(member)
        val buildError = {
            OperationBuildError(runtimeConfig).missingField(
                this,
                symbolProvider.toMemberName(member),
                "cannot be empty or unset"
            )
        }
        rustBlock("") {
            rust("let input = &self.${symbolProvider.toMemberName(member)};")
            if (symbol.isOptional()) {
                rust("let input = input.as_ref().ok_or(${buildError()})?;")
            }
            when {
                target.isStringShape -> {
                val func = writer.format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string"))
                "$func(&self.$memberName, ${label.isGreedyLabel})"
                    val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string"))
                    rust("let formatted = $func(input, ${label.isGreedyLabel});")
                }
                target.isTimestampShape -> {
                    val timestampFormat =
                        index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat)
                    val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
                val func = writer.format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
                "$func(&self.$memberName, ${writer.format(timestampFormatType)})"
                    val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
                    rust("let formatted = $func(&input, ${format(timestampFormatType)});")
                }
                else -> {
                val func = writer.format(RuntimeType.LabelFormat(runtimeConfig, "fmt_default"))
                "$func(&self.$memberName)"
                    val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_default"))
                    rust("let formatted = $func(input);")
                }
            }
            rust(
                """
                if formatted.is_empty() {
                    return Err(${buildError()})
                }
                formatted
            """
            )
        }
    }

    /** End URI generation **/
}
+53 −0
Original line number Diff line number Diff line
@@ -276,4 +276,57 @@ class RequestBindingGeneratorTest {
        """
        )
    }

    @Test
    fun `missing uri label produces an error`() {
        val writer = RustWriter.forModule("input")
        renderOperation(writer)
        writer.compileAndTest(
            """
        let ts = smithy_types::Instant::from_epoch_seconds(10123125);
        let inp = PutObjectInput::builder()
            // don't set bucket
            // .bucket_name("buk")
            .key(ts.clone())
            .build().unwrap();
        let err = inp.request_builder_base().expect_err("can't build request with bucket unset");
        assert!(matches!(err, ${writer.format(TestRuntimeConfig.operationBuildError())}::MissingField { .. }))
        """
        )
    }

    @Test
    fun `missing timestamp uri label produces an error`() {
        val writer = RustWriter.forModule("input")
        renderOperation(writer)
        writer.compileAndTest(
            """
        let ts = smithy_types::Instant::from_epoch_seconds(10123125);
        let inp = PutObjectInput::builder()
            .bucket_name("buk")
            // don't set key
            // .key(ts.clone())
            .build().unwrap();
        let err = inp.request_builder_base().expect_err("can't build request with bucket unset");
        assert!(matches!(err, ${writer.format(TestRuntimeConfig.operationBuildError())}::MissingField { .. }))
        """
        )
    }

    @Test
    fun `empty uri label produces an error`() {
        val writer = RustWriter.forModule("input")
        renderOperation(writer)
        writer.compileAndTest(
            """
        let ts = smithy_types::Instant::from_epoch_seconds(10123125);
        let inp = PutObjectInput::builder()
            .bucket_name("")
            .key(ts.clone())
            .build().unwrap();
        let err = inp.request_builder_base().expect_err("can't build request with bucket unset");
        assert!(matches!(err, ${writer.format(TestRuntimeConfig.operationBuildError())}::MissingField { .. }))
        """
        )
    }
}