Unverified Commit e1ba6562 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Fix 3 bugs in `@length`-constrained collection and map shapes (#2085)



* Fix 3 bugs in `@length`-constrained collection and map shapes

The 3 bugs are related, hence why a single commit to address them all.

1. #2028 added support for the `@length` constraint trait on collection
   shapes, but the code enforcing the trait was not being exercised upon
   receiving a request, specifically when converting the input
   unconstrained list into the constrained one. Note that #2028 did not
   result in the removal of the relevant protocol tests from
   `ServerProtocolTestGenerator`'s list of known failing tests.
2. Fixes code generation of `@length`-constrained list shapes whose
   members are not constrained: the converter being generated only
   worked for the case where the member was (transitively) constrained.
   The `constraints.smithy` model has been expanded to cover this case.
3. Fixes bug in code generation, when the
   `codegenConfig.publicConstrainedTypes` setting is set to `false`, of
   `@length`-constrained map shapes and collection shapes whose values
   or members (respectively) are constrained, but result in Rust types
   that would have been public regardless of the setting's value. This
   is the case only when they are modeled as structure shapes or union
   shapes. In these cases, two converters from the constrained type to
   the inner type were being generated, resulting in two `From` trait
   implementations. The `constraints.smithy` model has been expanded to
   cover this case.

* Use correct symbol provider

* ./gradlew ktlintFormat

Co-authored-by: default avatarHarry Barber <106155934+hlbarber@users.noreply.github.com>
parent f3e05625
Loading
Loading
Loading
Loading
+11 −7
Original line number Diff line number Diff line
@@ -467,13 +467,14 @@ structure ConA {
    fixedValueByte: FixedValueByte,

    conBList: ConBList,
    conBList2: ConBList2,
    lengthList: LengthList,

    // TODO(https://github.com/awslabs/smithy-rs/issues/1401): a `set` shape is
    //  just a `list` shape with `uniqueItems`, which hasn't been implemented yet.
    // conBSet: ConBSet,

    conBMap: ConBMap,
    lengthMap: LengthMap,

    mapOfMapOfListOfListOfConB: MapOfMapOfListOfListOfConB,

@@ -837,14 +838,11 @@ list RecursiveList {
}

list ConBList {
    member: NestedList
    member: LengthList
}

list ConBList2 {
    member: ConB
}

list NestedList {
@length(max: 69)
list LengthList {
    member: ConB
}

@@ -874,6 +872,12 @@ map ConBMap {
    value: LengthString
}

@length(min: 1, max: 69)
map LengthMap {
    key: String,
    value: String
}

@error("client")
structure ErrorWithLengthStringMessage {
    // TODO Doesn't work yet because constrained string types don't implement
+9 −1
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
@@ -23,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.supportedCollectionConstraintTraits
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

@@ -132,7 +135,12 @@ class ConstrainedCollectionGenerator(
            "ValidationFunctions" to constraintsInfo.map { it.validationFunctionDefinition(constraintViolation, inner) }.join("\n"),
        )

        if (!publicConstrainedTypes && isValueConstrained(shape, model, symbolProvider)) {
        val innerShape = model.expectShape(shape.member.target)
        if (!publicConstrainedTypes &&
            innerShape.canReachConstrainedShape(model, symbolProvider) &&
            innerShape !is StructureShape &&
            innerShape !is UnionShape
        ) {
            writer.rustTemplate(
                """
                impl #{From}<$name> for #{FullyUnconstrainedSymbol} {
+8 −1
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
@@ -115,7 +117,12 @@ class ConstrainedMapGenerator(
            *codegenScope,
        )

        if (!publicConstrainedTypes && isValueConstrained(shape, model, symbolProvider)) {
        val valueShape = model.expectShape(shape.value.target)
        if (!publicConstrainedTypes &&
            isValueConstrained(valueShape, model, symbolProvider) &&
            valueShape !is StructureShape &&
            valueShape !is UnionShape
        ) {
            writer.rustTemplate(
                """
                impl #{From}<$name> for #{FullyUnconstrainedSymbol} {
+52 −19
Original line number Diff line number Diff line
@@ -6,8 +6,12 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
@@ -38,6 +42,8 @@ class UnconstrainedCollectionGenerator(
    private val symbolProvider = codegenContext.symbolProvider
    private val unconstrainedShapeSymbolProvider = codegenContext.unconstrainedShapeSymbolProvider
    private val pubCrateConstrainedShapeSymbolProvider = codegenContext.pubCrateConstrainedShapeSymbolProvider
    private val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape)
    private val name = symbol.name
    private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes
    private val constraintViolationSymbolProvider =
        with(codegenContext.constraintViolationSymbolProvider) {
@@ -47,22 +53,20 @@ class UnconstrainedCollectionGenerator(
                PubCrateConstraintViolationSymbolProvider(this)
            }
        }
    private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
    private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider
    private val constrainedSymbol = if (shape.isDirectlyConstrained(symbolProvider)) {
        constrainedShapeSymbolProvider.toSymbol(shape)
    } else {
        pubCrateConstrainedShapeSymbolProvider.toSymbol(shape)
    }
    private val innerShape = model.expectShape(shape.member.target)

    fun render() {
        check(shape.canReachConstrainedShape(model, symbolProvider))

        val symbol = unconstrainedShapeSymbolProvider.toSymbol(shape)
        val name = symbol.name
        val innerShape = model.expectShape(shape.member.target)
        val innerUnconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(innerShape)
        val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
        val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape)

        unconstrainedModuleWriter.withInlineModule(symbol.module()) {
            rustTemplate(
@@ -75,29 +79,58 @@ class UnconstrainedCollectionGenerator(
                        Self::Unconstrained(value)
                    }
                }
                """,
                "InnerUnconstrainedSymbol" to innerUnconstrainedSymbol,
                "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(),
            )

                impl #{TryFrom}<$name> for #{ConstrainedSymbol} {
                    type Error = #{ConstraintViolationSymbol};
            renderTryFromUnconstrainedForConstrained(this)
        }
    }

    private fun renderTryFromUnconstrainedForConstrained(writer: RustWriter) {
        val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape)

        writer.rustBlock("impl std::convert::TryFrom<$name> for #{T}", constrainedSymbol) {
            rust("type Error = #T;", constraintViolationSymbol)

            rustBlock("fn try_from(value: $name) -> Result<Self, Self::Error>") {
                if (innerShape.canReachConstrainedShape(model, symbolProvider)) {
                    val resolvesToNonPublicConstrainedValueType =
                        innerShape.canReachConstrainedShape(model, symbolProvider) &&
                            !innerShape.isDirectlyConstrained(symbolProvider) &&
                            innerShape !is StructureShape &&
                            innerShape !is UnionShape
                    val innerConstrainedSymbol = if (resolvesToNonPublicConstrainedValueType) {
                        pubCrateConstrainedShapeSymbolProvider.toSymbol(innerShape)
                    } else {
                        constrainedShapeSymbolProvider.toSymbol(innerShape)
                    }

                    fn try_from(value: $name) -> Result<Self, Self::Error> {
                        let res: Result<_, (usize, #{InnerConstraintViolationSymbol})> = value
                    rustTemplate(
                        """
                        let res: Result<std::vec::Vec<#{InnerConstrainedSymbol}>, (usize, #{InnerConstraintViolationSymbol})> = value
                            .0
                            .into_iter()
                            .enumerate()
                            .map(|(idx, inner)| inner.try_into().map_err(|inner_violation| (idx, inner_violation)))
                            .collect();
                        res.map(Self)
                           .map_err(|(idx, inner_violation)| #{ConstraintViolationSymbol}::Member(idx, inner_violation))
                    }
                }
                        let inner = res.map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?;
                        """,
                "InnerUnconstrainedSymbol" to innerUnconstrainedSymbol,
                        "InnerConstrainedSymbol" to innerConstrainedSymbol,
                        "InnerConstraintViolationSymbol" to innerConstraintViolationSymbol,
                "ConstrainedSymbol" to constrainedSymbol,
                "ConstraintViolationSymbol" to constraintViolationSymbol,
                "MaybeConstrained" to constrainedSymbol.makeMaybeConstrained(),
                        "TryFrom" to RuntimeType.TryFrom,
                    )
                } else {
                    rust("let inner = value.0;")
                }

                if (shape.isDirectlyConstrained(symbolProvider)) {
                    rust("Self::try_from(inner)")
                } else {
                    rust("Ok(Self(inner))")
                }
            }
        }
    }
}
+7 −4
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
@@ -95,11 +97,12 @@ class UnconstrainedMapGenerator(

            rustBlock("fn try_from(value: $name) -> Result<Self, Self::Error>") {
                if (isKeyConstrained(keyShape, symbolProvider) || isValueConstrained(valueShape, model, symbolProvider)) {
                    val resolveToNonPublicConstrainedValueType =
                    val resolvesToNonPublicConstrainedValueType =
                        isValueConstrained(valueShape, model, symbolProvider) &&
                            !valueShape.isDirectlyConstrained(symbolProvider) &&
                            !valueShape.isStructureShape
                    val constrainedValueSymbol = if (resolveToNonPublicConstrainedValueType) {
                            valueShape !is StructureShape &&
                            valueShape !is UnionShape
                    val constrainedValueSymbol = if (resolvesToNonPublicConstrainedValueType) {
                        pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape)
                    } else {
                        constrainedShapeSymbolProvider.toSymbol(valueShape)
@@ -154,7 +157,7 @@ class UnconstrainedMapGenerator(
                    )

                    val constrainedValueTypeIsNotFinalType =
                        resolveToNonPublicConstrainedValueType && shape.isDirectlyConstrained(symbolProvider)
                        resolvesToNonPublicConstrainedValueType && shape.isDirectlyConstrained(symbolProvider)
                    if (constrainedValueTypeIsNotFinalType) {
                        // The map is constrained. Its value shape reaches a constrained shape, but the value shape itself
                        // is not directly constrained. The value shape must be an aggregate shape. But it is not a
Loading