Unverified Commit 1771dbdc authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support for error-correcting builders (#2991)

## Motivation and Context
To implement #1767 we need to support error-correction to default values
when instantiating builders.

-
https://smithy.io/2.0/spec/aggregate-types.html?highlight=error%20correction#client-error-correction
## Description
Adds `pub(crate) correct_errors_<shape>` method that will be used in
deserialization to set default values for required fields when not set
in the serialized response. This only applies to client via
`ClientBuilderInstantiator`

## Testing
- added a new test that's fairly exhaustive

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent a5c1ced0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -39,6 +39,6 @@ data class ClientCodegenContext(
) {
    val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins
    override fun builderInstantiator(): BuilderInstantiator {
        return ClientBuilderInstantiator(symbolProvider)
        return ClientBuilderInstantiator(this)
    }
}
+18 −7
Original line number Diff line number Diff line
@@ -13,21 +13,29 @@ import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

fun ClientCodegenContext.builderInstantiator(): BuilderInstantiator = ClientBuilderInstantiator(symbolProvider)

class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider) : BuilderInstantiator {
class ClientBuilderInstantiator(private val clientCodegenContext: ClientCodegenContext) : BuilderInstantiator {
    override fun setField(builder: String, value: Writable, field: MemberShape): Writable {
        return setFieldWithSetter(builder, value, field)
    }

    /**
     * For the client, we finalize builders with error correction enabled
     */
    override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable {
        if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
        val correctErrors = clientCodegenContext.correctErrors(shape)
        val builderW = writable {
            when {
                correctErrors != null -> rustTemplate("#{correctErrors}($builder)", "correctErrors" to correctErrors)
                else -> rustTemplate(builder)
            }
        }
        if (BuilderGenerator.hasFallibleBuilder(shape, clientCodegenContext.symbolProvider)) {
            rustTemplate(
                "$builder.build()#{mapErr}?",
                "#{builder}.build()#{mapErr}?",
                "builder" to builderW,
                "mapErr" to (
                    mapErr?.map {
                        rust(".map_err(#T)", it)
@@ -35,7 +43,10 @@ class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider)
                    ),
            )
        } else {
            rust("$builder.build()")
            rustTemplate(
                "#{builder}.build()",
                "builder" to builderW,
            )
        }
    }
}
+114 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.isEmpty
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.plus
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.some
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.PrimitiveInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.protocols.shapeFunctionName
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.letIf

/**
 * For AWS-services, the spec defines error correction semantics to recover from missing default values for required members:
 * https://smithy.io/2.0/spec/aggregate-types.html?highlight=error%20correction#client-error-correction
 */

private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Writable? {
    if (!member.isRequired) {
        return null
    }
    val target = model.expectShape(member.target)
    val memberSymbol = symbolProvider.toSymbol(member)
    val targetSymbol = symbolProvider.toSymbol(target)
    if (member.isEventStream(model) || member.isStreaming(model)) {
        return null
    }
    val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)
    return writable {
        when {
            target is EnumShape || target.hasTrait<EnumTrait>() -> rustTemplate(""""no value was set".parse::<#{Shape}>().ok()""", "Shape" to targetSymbol)
            target is BooleanShape || target is NumberShape || target is StringShape || target is DocumentShape || target is ListShape || target is MapShape -> rust("Some(Default::default())")
            target is StructureShape -> rustTemplate(
                "{ let builder = #{Builder}::default(); #{instantiate} }",
                "Builder" to symbolProvider.symbolForBuilder(target),
                "instantiate" to builderInstantiator().finalizeBuilder("builder", target).map {
                    if (BuilderGenerator.hasFallibleBuilder(target, symbolProvider)) {
                        rust("#T.ok()", it)
                    } else {
                        it.some()(this)
                    }
                }.letIf(memberSymbol.isRustBoxed()) {
                    it.plus { rustTemplate(".map(#{Box}::new)", *preludeScope) }
                },
            )
            target is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
            target is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)
            target is UnionShape -> rust("Some(#T::Unknown)", targetSymbol)
        }
    }
}

fun ClientCodegenContext.correctErrors(shape: StructureShape): RuntimeType? {
    val name = symbolProvider.shapeFunctionName(serviceShape, shape) + "_correct_errors"
    val corrections = writable {
        shape.members().forEach { member ->
            val memberName = symbolProvider.toMemberName(member)
            errorCorrectedDefault(member)?.also { default ->
                rustTemplate(
                    """if builder.$memberName.is_none() { builder.$memberName = #{default} }""",
                    "default" to default,
                )
            }
        }
    }

    if (corrections.isEmpty()) {
        return null
    }

    return RuntimeType.forInlineFun(name, RustModule.private("serde_util")) {
        rustTemplate(
            """
            pub(crate) fn $name(mut builder: #{Builder}) -> #{Builder} {
                #{corrections}
                builder
            }

            """,
            "Builder" to symbolProvider.symbolForBuilder(shape),
            "corrections" to corrections,
        )
    }
}
+125 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.rust.codegen.client.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup

class ErrorCorrectionTest {
    private val model = """
        namespace com.example
        use aws.protocols#awsJson1_0

        @awsJson1_0
        service HelloService {
            operations: [SayHello],
            version: "1"
        }

        operation SayHello { input: TestInput }
        structure TestInput { nested: TestStruct }
        structure TestStruct {
           @required
           foo: String,
           @required
           byteValue: Byte,
           @required
           listValue: StringList,
           @required
           mapValue: ListMap,
           @required
           doubleListValue: DoubleList
           @required
           document: Document
           @required
           nested: Nested
           @required
           blob: Blob
           @required
           enum: Enum
           @required
           union: U
           notRequired: String
        }

        enum Enum {
            A,
            B,
            C
        }

        union U {
            A: Integer,
            B: String,
            C: Unit
        }

        structure Nested {
            @required
            a: String
        }

        list StringList {
            member: String
        }

        list DoubleList {
            member: StringList
        }

        map ListMap {
            key: String,
            value: StringList
        }
    """.asSmithyModel(smithyVersion = "2.0")

    @Test
    fun correctMissingFields() {
        val shape = model.lookup<StructureShape>("com.example#TestStruct")
        clientIntegrationTest(model) { ctx, crate ->
            crate.lib {
                val codegenCtx =
                    arrayOf("correct_errors" to ctx.correctErrors(shape)!!, "Shape" to ctx.symbolProvider.toSymbol(shape))
                rustTemplate(
                    """
                    /// avoid unused warnings
                pub fn use_fn_publicly() { #{correct_errors}(#{Shape}::builder()); } """,
                    *codegenCtx,
                )
                unitTest("test_default_builder") {
                    rustTemplate(
                        """
                        let builder = #{correct_errors}(#{Shape}::builder().foo("abcd"));
                        let shape = builder.build();
                        // don't override a field already set
                        assert_eq!(shape.foo(), Some("abcd"));
                        // set nested fields
                        assert_eq!(shape.nested().unwrap().a(), Some(""));
                        // don't default non-required fields
                        assert_eq!(shape.not_required(), None);

                        // set defaults for everything else
                        assert_eq!(shape.blob().unwrap().as_ref(), &[]);

                        assert_eq!(shape.list_value(), Some(&[][..]));
                        assert!(shape.map_value().unwrap().is_empty());
                        assert_eq!(shape.double_list_value(), Some(&[][..]));

                        // enums and unions become unknown variants
                        assert!(matches!(shape.r##enum(), Some(crate::types::Enum::Unknown(_))));
                        assert!(shape.union().unwrap().is_unknown());
                        """,
                        *codegenCtx,
                    )
                }
            }
        }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -138,7 +138,7 @@ internal fun RustSymbolProvider.shapeModuleName(serviceShape: ServiceShape?, sha
    )

/** Creates a unique name for a ser/de function. */
internal fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String {
fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String {
    val containerName = when (shape) {
        is MemberShape -> model.expectShape(shape.container).contextName(serviceShape).toSnakeCase()
        else -> shape.contextName(serviceShape).toSnakeCase()
Loading