Unverified Commit adf30a8e authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Derive `Eq` and `Hash` wherever possible (#2223)

In server SDKs, these traits can be implemented by any shape _except_ if
the shape's closure contains:

1. A `float`, `double`, or `document` shape: floating point types in
   Rust do not implement `Eq`. Similarly, [`document` shapes] may
   contain arbitrary JSON-like data containing floating point values.
2. A [@streaming] shape: all the streaming data would need to be
   buffered first to compare it.

Additionally, the `Hash` trait cannot be implemented by shapes whose
closure contains:

1. A `map` shape: we render `map` shapes as `std::collections::HashMap`,
   which _do not_ implement `Hash`. See
   https://github.com/awslabs/smithy/issues/1567.

In **client SDKs, these traits cannot be derived on any code-generated
Rust types corresponding to Smithy shapes**, since e.g. adding new
optional members to a structure [is a backwards-compatible change], and
doing so alters the semantics of these traits.

However, this commit does implement these traits for the
`aws_smithy_types::date_time::DateTime` and `aws_smithy_types::Blob`
runtime types.

This change is necessary to efficiently implement the `@uniqueItems`
constraint trait in server SDKs.

[`Eq`]: https://doc.rust-lang.org/std/cmp/trait.Eq.html
[`Hash`]: https://doc.rust-lang.org/std/hash/trait.Hash.html
[`document` shapes]: https://smithy.io/2.0/spec/simple-types.html#document
[@streaming]: https://smithy.io/2.0/spec/streaming.html
[is a backwards-compatible change]: https://smithy.io/2.0/guides/evolving-models.html#updating-structures
parent 1d8934f6
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -370,3 +370,15 @@ set_credentials_cache(
references = ["smithy-rs#2122"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "ysaito1001"

[[smithy-rs]]
message = "`aws_smithy_types::date_time::DateTime`, `aws_smithy_types::Blob` now implement the `Eq` and `Hash` traits"
references = ["smithy-rs#2223"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "all"}
author = "david-perez"

[[smithy-rs]]
message = "Code-generated types for server SDKs now implement the `Eq` and `Hash` traits when possible"
references = ["smithy-rs#2223"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server"}
author = "david-perez"
+1 −1
Original line number Diff line number Diff line
@@ -80,7 +80,7 @@ class RustClientCodegenPlugin : DecoratableBuildPlugin() {
                .let { StreamingShapeSymbolProvider(it, model) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(NonExhaustive)) }
                // Streaming shapes need different derives (e.g. they cannot derive Eq)
                // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`)
                .let { StreamingShapeMetadataProvider(it, model) }
                // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot
                // be the name of an operation input
+13 −9
Original line number Diff line number Diff line
@@ -8,7 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
@@ -27,7 +30,7 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private
    WrappingSymbolProvider(base) {
    override fun toSymbol(shape: Shape): Symbol {
        val initial = base.toSymbol(shape)
        // We are only targetting member shapes
        // We are only targeting member shapes
        if (shape !is MemberShape) {
            return initial
        }
@@ -49,7 +52,7 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private
}

/**
 * SymbolProvider to drop the clone and PartialEq bounds in streaming shapes
 * SymbolProvider to drop the `Clone` and `PartialEq` bounds in streaming shapes.
 *
 * Streaming shapes cannot be cloned and equality cannot be checked without reading the body. Because of this, these shapes
 * do not implement `Clone` or `PartialEq`.
@@ -60,10 +63,6 @@ class StreamingShapeMetadataProvider(
    private val base: RustSymbolProvider,
    private val model: Model,
) : SymbolMetadataProvider(base) {
    override fun memberMeta(memberShape: MemberShape): RustMetadata {
        return base.toSymbol(memberShape).expectRustMetadata()
    }

    override fun structureMeta(structureShape: StructureShape): RustMetadata {
        val baseMetadata = base.toSymbol(structureShape).expectRustMetadata()
        return if (structureShape.hasStreamingMember(model)) {
@@ -78,7 +77,12 @@ class StreamingShapeMetadataProvider(
        } else baseMetadata
    }

    override fun enumMeta(stringShape: StringShape): RustMetadata {
        return base.toSymbol(stringShape).expectRustMetadata()
    }
    override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata()
    override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata()

    override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata()
    override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata()
    override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata()
    override fun numberMeta(numberShape: NumberShape) = base.toSymbol(numberShape).expectRustMetadata()
    override fun blobMeta(blobShape: BlobShape) = base.toSymbol(blobShape).expectRustMetadata()
}
+71 −57
Original line number Diff line number Diff line
@@ -8,7 +8,12 @@ package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
@@ -55,9 +60,15 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr
            is MemberShape -> memberMeta(shape)
            is StructureShape -> structureMeta(shape)
            is UnionShape -> unionMeta(shape)
            is ListShape -> listMeta(shape)
            is MapShape -> mapMeta(shape)
            is NumberShape -> numberMeta(shape)
            is BlobShape -> blobMeta(shape)
            is StringShape -> if (shape.hasTrait<EnumTrait>()) {
                enumMeta(shape)
            } else null
            } else {
                stringMeta(shape)
            }

            else -> null
        }
@@ -68,98 +79,101 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr
    abstract fun structureMeta(structureShape: StructureShape): RustMetadata
    abstract fun unionMeta(unionShape: UnionShape): RustMetadata
    abstract fun enumMeta(stringShape: StringShape): RustMetadata

    abstract fun listMeta(listShape: ListShape): RustMetadata
    abstract fun mapMeta(mapShape: MapShape): RustMetadata
    abstract fun stringMeta(stringShape: StringShape): RustMetadata
    abstract fun numberMeta(numberShape: NumberShape): RustMetadata
    abstract fun blobMeta(blobShape: BlobShape): RustMetadata
}

/**
 * The base metadata supports a list of attributes that are used by generators to decorate code.
 * By default, we apply ```#[non_exhaustive]``` only to client structures since model changes should
 * be considered as breaking only when generating server code.
 */
class BaseSymbolMetadataProvider(
    base: RustSymbolProvider,
    private val model: Model,
    private val additionalAttributes: List<Attribute>,
) : SymbolMetadataProvider(base) {
    private fun containerDefault(shape: Shape): RustMetadata {
fun containerDefaultMetadata(
    shape: Shape,
    model: Model,
    additionalAttributes: List<Attribute> = emptyList(),
): RustMetadata {
    val defaultDerives = setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)

    val isSensitive = shape.hasTrait<SensitiveTrait>() ||
        // Checking the shape's direct members for the sensitive trait should suffice.
        // Whether their descendants, i.e. a member's member, is sensitive does not
            // affect the inclusion/exclusion of the derived Debug trait of _this_ container
        // affect the inclusion/exclusion of the derived `Debug` trait of _this_ container
        // shape; any sensitive descendant should still be printed as redacted.
        shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent }

        val derives = if (isSensitive) {
    val setOfDerives = if (isSensitive) {
        defaultDerives - RuntimeType.Debug
    } else {
        defaultDerives
    }
    return RustMetadata(
            derives,
        setOfDerives,
        additionalAttributes,
        Visibility.PUBLIC,
    )
}

    override fun memberMeta(memberShape: MemberShape): RustMetadata {
        val container = model.expectShape(memberShape.container)
        return when {
            container.isStructureShape -> {
/**
 * The base metadata supports a set of attributes that are used by generators to decorate code.
 *
 * By default we apply `#[non_exhaustive]` in [additionalAttributes] only to client structures since breaking model
 * changes are fine when generating server code.
 */
class BaseSymbolMetadataProvider(
    base: RustSymbolProvider,
    private val model: Model,
    private val additionalAttributes: List<Attribute>,
) : SymbolMetadataProvider(base) {

    override fun memberMeta(memberShape: MemberShape): RustMetadata =
        when (val container = model.expectShape(memberShape.container)) {
            is StructureShape -> {
                // TODO(https://github.com/awslabs/smithy-rs/issues/943): Once streaming accessors are usable,
                // then also make streaming members `#[doc(hidden)]`
                if (memberShape.getMemberTrait(model, StreamingTrait::class.java).isPresent) {
                    RustMetadata(visibility = Visibility.PUBLIC)
                } else {
                    RustMetadata(
                        // At some point, visibility will be made PRIVATE, so make these `#[doc(hidden)]` for now
                        // At some point, visibility _may_ be made `PRIVATE`, so make these `#[doc(hidden)]` for now.
                        visibility = Visibility.PUBLIC,
                        additionalAttributes = listOf(Attribute.DocHidden),
                    )
                }
            }

            container.isUnionShape ||
                container.isListShape ||
                container.isSetShape ||
                container.isMapShape
            -> RustMetadata(visibility = Visibility.PUBLIC)

            is UnionShape, is CollectionShape, is MapShape -> RustMetadata(visibility = Visibility.PUBLIC)
            else -> TODO("Unrecognized container type: $container")
        }
    }

    override fun structureMeta(structureShape: StructureShape): RustMetadata {
        return containerDefault(structureShape)
    }

    override fun unionMeta(unionShape: UnionShape): RustMetadata {
        return containerDefault(unionShape)
    }
    override fun structureMeta(structureShape: StructureShape) = containerDefaultMetadata(structureShape, model, additionalAttributes)
    override fun unionMeta(unionShape: UnionShape) = containerDefaultMetadata(unionShape, model, additionalAttributes)

    override fun enumMeta(stringShape: StringShape): RustMetadata {
        return containerDefault(stringShape).withDerives(
            RuntimeType.Hash,
            // enums can be Eq because they can only contain ints and strings
    override fun enumMeta(stringShape: StringShape): RustMetadata =
        containerDefaultMetadata(stringShape, model, additionalAttributes).withDerives(
            // Smithy's `enum` shapes can additionally be `Eq`, `PartialOrd`, `Ord`, and `Hash` because they can
            // only contain strings.
            RuntimeType.Eq,
            // enums can be PartialOrd/Ord because they can only contain ints and strings
            RuntimeType.PartialOrd,
            RuntimeType.Ord,
            RuntimeType.Hash,
        )
    }

    companion object {
        private val defaultDerives by lazy {
            setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)
        }
    }
    // Only the server subproject uses these, so we provide a sane and conservative default implementation here so that
    // the rest of symbol metadata providers can just delegate to it.
    private val defaultRustMetadata = RustMetadata(visibility = Visibility.PRIVATE)

    override fun listMeta(listShape: ListShape) = defaultRustMetadata
    override fun mapMeta(mapShape: MapShape) = defaultRustMetadata
    override fun stringMeta(stringShape: StringShape) = defaultRustMetadata
    override fun numberMeta(numberShape: NumberShape) = defaultRustMetadata
    override fun blobMeta(blobShape: BlobShape) = defaultRustMetadata
}

private const val META_KEY = "meta"
fun Symbol.Builder.meta(rustMetadata: RustMetadata?): Symbol.Builder {
    return this.putProperty(META_KEY, rustMetadata)
}
fun Symbol.Builder.meta(rustMetadata: RustMetadata?): Symbol.Builder = this.putProperty(META_KEY, rustMetadata)

fun Symbol.expectRustMetadata(): RustMetadata = this.getProperty(META_KEY, RustMetadata::class.java).orElseThrow {
    CodegenException(
        "Expected $this to have metadata attached but it did not. ",
        "Expected `$this` to have metadata attached but it did not.",
    )
}
+7 −1
Original line number Diff line number Diff line
@@ -17,7 +17,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.DeriveEqAndHashSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations
import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator
import java.util.logging.Level
@@ -54,7 +56,7 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin {
    }

    companion object {
        /** SymbolProvider
        /**
         * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider
         *
         * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered
@@ -77,8 +79,12 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin {
                .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) }
                // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes.
                .let { ConstrainedShapeSymbolMetadataProvider(it, model, constrainedTypes) }
                // Streaming shapes need different derives (e.g. they cannot derive Eq)
                .let { PythonStreamingShapeMetadataProvider(it, model) }
                // Derive `Eq` and `Hash` if possible.
                .let { DeriveEqAndHashSymbolMetadataProvider(it, model) }
                // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot
                // be the name of an operation input
                .let { RustReservedWordSymbolProvider(it, model) }
Loading