Unverified Commit cdb63348 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Fix generated query param serializers to omit 0-values (#156)

parent f2f1e358
Loading
Loading
Loading
Loading
+33 −14
Original line number Diff line number Diff line
@@ -11,7 +11,9 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.writer.CodegenWriter
import software.amazon.smithy.codegen.core.writer.CodegenWriterFactory
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.DocumentationTrait
@@ -243,21 +245,37 @@ class RustWriter private constructor(
        return this
    }

    // TODO: refactor both of these methods & add a parent method to for_each across any field type
    // generically
    fun OptionForEach(member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) {
        if (member.isOptional()) {
    /**
     * Generate a wrapping if statement around a field.
     *
     * - If the field is optional, it will only be called if the field is present
     * - If the field is an unboxed primitive, it will only be called if the field is non-zero
     *
     */
    fun ifSet(shape: Shape, member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) {
        // TODO: this API should be refactored so that we don't need to strip `&` to get reference comparisons to work.
        when {
            member.isOptional() -> {
                val derefName = safeName("inner")
            // TODO: `inner` should be custom codegenned to avoid shadowing
                rustBlock("if let Some($derefName) = $outerField") {
                    block(derefName)
                }
        } else {
            this.block(outerField)
            }
            shape is NumberShape -> rustBlock("if ${outerField.removePrefix("&")} != 0") {
                block(outerField)
            }
            shape is BooleanShape -> rustBlock("if ${outerField.removePrefix("&")}") {
                block(outerField)
            }
            else -> this.block(outerField)
        }
    }

    fun ListForEach(target: Shape, outerField: String, block: CodeWriter.(field: String, target: ShapeId) -> Unit) {
    fun ListForEach(
        target: Shape,
        outerField: String,
        block: CodeWriter.(field: String, target: ShapeId) -> Unit
    ) {
        if (target is CollectionShape) {
            val derefName = safeName("inner")
            rustBlock("for $derefName in $outerField") {
@@ -279,7 +297,8 @@ class RustWriter private constructor(
        return "${headerDocs ?: ""}\n$header\n$useDecls\n$contents\n"
    }

    fun format(r: Any): String {
    fun format(r: Any):
        String {
            return formatter.apply(r, "")
        }

+1 −1
Original line number Diff line number Diff line
@@ -55,7 +55,7 @@ class ErrorGenerator(
            rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
                write("write!(f, ${symbol.name.dq()})?;")
                messageShape.map {
                    OptionForEach(symbolProvider.toSymbol(it), "&self.message") { field ->
                    ifSet(it, symbolProvider.toSymbol(it), "&self.message") { field ->
                        write("""write!(f, ": {}", $field)?;""")
                    }
                }
+2 −2
Original line number Diff line number Diff line
@@ -105,7 +105,7 @@ class HttpTraitBindingGenerator(
                val memberType = model.expectShape(memberShape.target)
                val memberSymbol = symbolProvider.toSymbol(memberShape)
                val memberName = symbolProvider.toMemberName(memberShape)
                OptionForEach(memberSymbol, "&self.$memberName") { field ->
                ifSet(memberType, memberSymbol, "&self.$memberName") { field ->
                    ListForEach(memberType, field) { innerField, targetId ->
                        val innerMemberType = model.expectShape(targetId)
                        val formatted = headerFmtFun(innerMemberType, memberShape, innerField)
@@ -186,7 +186,7 @@ class HttpTraitBindingGenerator(
                val memberSymbol = symbolProvider.toSymbol(memberShape)
                val memberName = symbolProvider.toMemberName(memberShape)
                val outerTarget = model.expectShape(memberShape.target)
                OptionForEach(memberSymbol, "&self.$memberName") { field ->
                ifSet(outerTarget, memberSymbol, "&self.$memberName") { field ->
                    ListForEach(outerTarget, field) { innerField, targetId ->
                        val target = model.expectShape(targetId)
                        write(
+28 −0
Original line number Diff line number Diff line
@@ -70,6 +70,12 @@ class HttpTraitBindingGeneratorTest {
                @httpQuery("paramName")
                someValue: String,

                @httpQuery("primitive")
                primitive: PrimitiveInteger,

                @httpQuery("enabled")
                enabled: PrimitiveBoolean,

                @httpQuery("hello")
                extras: Extras,

@@ -135,6 +141,28 @@ class HttpTraitBindingGeneratorTest {
        )
    }

    @Test
    fun `generate serialize non-zero values`() {
        val writer = RustWriter.forModule("input")
        // currently rendering the operation renders the protocols—I want to separate that at some point.
        renderOperation(writer)
        writer.compileAndTest(
            """
            let ts = Instant::from_epoch_seconds(10123125);
            let inp = PutObjectInput::builder()
                .bucket_name("somebucket/ok")
                .key(ts.clone())
                .primitive(1)
                .enabled(true)
                .build().expect("build should succeed");
            let mut o = String::new();
            inp.uri_query(&mut o);
            assert_eq!(o.as_str(), "?primitive=1&enabled=true")
            """,
            clippy = true
        )
    }

    @Test
    fun `build http requests`() {
        val writer = RustWriter.forModule("input")