Unverified Commit 9b0e3cfd authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Source defaults from the model instead of implicitly (#2985)

## Motivation and Context
We weren't correctly computing defaults which lead to incorrect behavior
when coupled with nullability.

## Description
Minimal changeset to source defaults from the model. Other changes:
- Unify enum parsing across client and server to always use `from_str`
in protocol tests
- Extract `PrimitiveInstantiator` from `Instantiator` so it can be used
to instantiate defaults

## Testing
- regular codegen tests

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 96aa1fcc
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -150,3 +150,9 @@ message = "Fix regression with redacting sensitive HTTP response bodies."
references = ["smithy-rs#2926", "smithy-rs#2972"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "ysaito1001"

[[smithy-rs]]
message = "Source defaults from the default trait instead of implicitly based on type. This has minimal changes in the generated code."
references = ["smithy-rs#2985"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "rcoh"
+0 −8
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
@@ -14,18 +13,12 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName

private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable {
    rust("#T::from($data)", enumSymbol)
}

class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior {
    override fun hasFallibleBuilder(shape: StructureShape): Boolean =
        BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider)
@@ -40,7 +33,6 @@ class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Ins
    codegenContext.model,
    codegenContext.runtimeConfig,
    ClientBuilderKindBehavior(codegenContext),
    ::enumFromStringFn,
) {
    fun renderFluentCall(
        writer: RustWriter,
+5 −0
Original line number Diff line number Diff line
@@ -30,6 +30,11 @@ fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable {
    return writable { f(self) }
}

/** Returns Some(..arg) */
fun Writable.some(): Writable {
    return this.map { rust("Some(#T)", it) }
}

fun Writable.isNotEmpty(): Boolean = !this.isEmpty()

operator fun Writable.plus(other: Writable): Writable {
+6 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
@@ -102,6 +103,11 @@ sealed class Default {
     * This symbol should use the Rust `std::default::Default` when unset
     */
    object RustDefault : Default()

    /**
     * This symbol has a custom default value different from `Default::default`
     */
    data class NonZeroDefault(val value: Node) : Default()
}

/**
+27 −10
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BigDecimalShape
import software.amazon.smithy.model.shapes.BigIntegerShape
import software.amazon.smithy.model.shapes.BlobShape
@@ -37,6 +38,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
@@ -48,6 +50,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import kotlin.reflect.KClass
@@ -79,14 +82,16 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?)
/**
 * Make the return [value] optional if the [member] symbol is as well optional.
 */
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) {
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String =
    value.letIf(toSymbol(member).isOptional()) {
        "Some($value)"
    }

/**
 * Make the return [value] optional if the [member] symbol is not optional.
 */
fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) {
fun SymbolProvider.toOptional(member: MemberShape, value: String): String =
    value.letIf(!toSymbol(member).isOptional()) {
        "Some($value)"
    }

@@ -111,7 +116,7 @@ fun Shape.contextName(serviceShape: ServiceShape?): String {
 */
open class SymbolVisitor(
    settings: CoreRustSettings,
    override val model: Model,
    final override val model: Model,
    private val serviceShape: ServiceShape?,
    override val config: RustSymbolProviderConfig,
) : RustSymbolProvider, ShapeVisitor<Symbol> {
@@ -170,7 +175,7 @@ open class SymbolVisitor(
    }

    private fun simpleShape(shape: SimpleShape): Symbol {
        return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build()
        return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build()
    }

    override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape)
@@ -263,13 +268,20 @@ open class SymbolVisitor(

    override fun memberShape(shape: MemberShape): Symbol {
        val target = model.expectShape(shape.target)
        val defaultValue = shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait ->
            when (val value = trait.toNode()) {
                Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault
                Node.nullNode() -> Default.NoDefault
                else -> Default.NonZeroDefault(value)
            }
        } ?: Default.NoDefault
        // Handle boxing first, so we end up with Option<Box<_>>, not Box<Option<_>>.
        return handleOptionality(
            handleRustBoxing(toSymbol(target), shape),
            shape,
            nullableIndex,
            config.nullabilityCheckMode,
        )
        ).toBuilder().setDefault(defaultValue).build()
    }

    override fun timestampShape(shape: TimestampShape?): Symbol {
@@ -297,7 +309,12 @@ fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder =
        // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation
        .definitionFile("thisisabug.rs")

fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol =
fun handleOptionality(
    symbol: Symbol,
    member: MemberShape,
    nullableIndex: NullableIndex,
    nullabilityCheckMode: CheckMode,
): Symbol =
    symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() }

/**
Loading