Unverified Commit 356444bf authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Fix server SDK bug with directly constrained list/map shapes in operation output (#2761)

Fixes https://github.com/awslabs/smithy-rs/issues/2760.

## Testing

I verified that the updated `constraints.smithy` integration test does
not compile without the fix applied.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 62dca85d
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
@@ -11,6 +11,9 @@ use smithy.framework#ValidationException
service ConstraintsService {
    operations: [
        ConstrainedShapesOperation,
        // See https://github.com/awslabs/smithy-rs/issues/2760 for why testing operations reaching
        // constrained shapes that only lie in the output is important.
        ConstrainedShapesOnlyInOutputOperation,
        ConstrainedHttpBoundShapesOperation,
        ConstrainedHttpPayloadBoundShapeOperation,
        ConstrainedRecursiveShapesOperation,
@@ -51,6 +54,11 @@ operation ConstrainedShapesOperation {
    errors: [ValidationException]
}

@http(uri: "/constrained-shapes-only-in-output-operation", method: "POST")
operation ConstrainedShapesOnlyInOutputOperation {
    output: ConstrainedShapesOnlyInOutputOperationOutput,
}

@http(
    uri: "/constrained-http-bound-shapes-operation/{rangeIntegerLabel}/{rangeShortLabel}/{rangeLongLabel}/{rangeByteLabel}/{lengthStringLabel}/{enumStringLabel}",
    method: "POST"
@@ -935,3 +943,31 @@ map MapOfListOfListOfConB {
    key: String,
    value: ConBList
}

structure ConstrainedShapesOnlyInOutputOperationOutput {
    list: ConstrainedListInOutput
    map: ConstrainedMapInOutput
    // Unions were not affected by
    // https://github.com/awslabs/smithy-rs/issues/2760, but testing anyway for
    // good measure.
    union: ConstrainedUnionInOutput
}

@length(min: 69)
list ConstrainedListInOutput {
    member: ConstrainedUnionInOutput
}

@length(min: 69)
map ConstrainedMapInOutput {
    key: String
    value: TransitivelyConstrainedStructureInOutput
}

union ConstrainedUnionInOutput {
    structure: TransitivelyConstrainedStructureInOutput
}

structure TransitivelyConstrainedStructureInOutput {
    lengthString: LengthString
}
+1 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ class CollectionConstraintViolationGenerator(

        inlineModuleCreator(constraintViolationSymbol) {
            val constraintViolationVariants = constraintsInfo.map { it.constraintViolationVariant }.toMutableList()
            if (isMemberConstrained) {
            if (shape.isReachableFromOperationInput() && isMemberConstrained) {
                constraintViolationVariants += {
                    val memberConstraintViolationSymbol =
                        constraintViolationSymbolProvider.toSymbol(targetShape).letIf(
+6 −4
Original line number Diff line number Diff line
@@ -45,10 +45,12 @@ class MapConstraintViolationGenerator(
        val constraintViolationName = constraintViolationSymbol.name

        val constraintViolationCodegenScopeMutableList: MutableList<Pair<String, Any>> = mutableListOf()
        if (isKeyConstrained(keyShape, symbolProvider)) {
        val keyConstraintViolationExists = shape.isReachableFromOperationInput() && isKeyConstrained(keyShape, symbolProvider)
        val valueConstraintViolationExists = shape.isReachableFromOperationInput() && isValueConstrained(valueShape, model, symbolProvider)
        if (keyConstraintViolationExists) {
            constraintViolationCodegenScopeMutableList.add("KeyConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(keyShape))
        }
        if (isValueConstrained(valueShape, model, symbolProvider)) {
        if (valueConstraintViolationExists) {
            constraintViolationCodegenScopeMutableList.add(
                "ValueConstraintViolationSymbol" to
                    constraintViolationSymbolProvider.toSymbol(valueShape).letIf(
@@ -78,8 +80,8 @@ class MapConstraintViolationGenerator(
                ##[derive(Debug, PartialEq)]
                pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationName {
                    ${if (shape.hasTrait<LengthTrait>()) "Length(usize)," else ""}
                    ${if (isKeyConstrained(keyShape, symbolProvider)) "##[doc(hidden)] Key(#{KeyConstraintViolationSymbol})," else ""}
                    ${if (isValueConstrained(valueShape, model, symbolProvider)) "##[doc(hidden)] Value(#{KeySymbol}, #{ValueConstraintViolationSymbol})," else ""}
                    ${if (keyConstraintViolationExists) "##[doc(hidden)] Key(#{KeyConstraintViolationSymbol})," else ""}
                    ${if (valueConstraintViolationExists) "##[doc(hidden)] Value(#{KeySymbol}, #{ValueConstraintViolationSymbol})," else ""}
                }
                """,
                *constraintViolationCodegenScope,
+49 −80
Original line number Diff line number Diff line
@@ -6,20 +6,11 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule
import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator
import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol
import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest

class UnconstrainedCollectionGeneratorTest {
    @Test
@@ -28,6 +19,25 @@ class UnconstrainedCollectionGeneratorTest {
            """
            namespace test
            
            use aws.protocols#restJson1
            use smithy.framework#ValidationException
            
            @restJson1
            service TestService {
                operations: ["Operation"]
            }
            
            @http(uri: "/operation", method: "POST")
            operation Operation {
                input: OperationInputOutput
                output: OperationInputOutput
                errors: [ValidationException]
            }
            
            structure OperationInputOutput {
                list: ListA
            }

            list ListA {
                member: ListB
            }
@@ -44,58 +54,16 @@ class UnconstrainedCollectionGeneratorTest {
                string: String
            }
            """.asSmithyModel()
        val codegenContext = serverTestCodegenContext(model)
        val symbolProvider = codegenContext.symbolProvider

        val listA = model.lookup<ListShape>("test#ListA")
        val listB = model.lookup<ListShape>("test#ListB")

        val project = TestWorkspace.testProject(symbolProvider)

        project.withModule(ServerRustModule.Model) {
            model.lookup<StructureShape>("test#StructureC").serverRenderWithModelBuilder(
                project,
                model,
                symbolProvider,
                this,
                ServerRestJsonProtocol(codegenContext),
            )
        }

        project.withModule(ServerRustModule.ConstrainedModule) {
            listOf(listA, listB).forEach {
                PubCrateConstrainedCollectionGenerator(
                    codegenContext,
                    this.createTestInlineModuleCreator(),
                    it,
                ).render()
            }
        }
        project.withModule(ServerRustModule.UnconstrainedModule) unconstrainedModuleWriter@{
            project.withModule(ServerRustModule.Model) modelsModuleWriter@{
                listOf(listA, listB).forEach {
                    UnconstrainedCollectionGenerator(
                        codegenContext,
                        this@unconstrainedModuleWriter.createTestInlineModuleCreator(),
                        it,
                    ).render()

                    CollectionConstraintViolationGenerator(
                        codegenContext,
                        this@modelsModuleWriter.createTestInlineModuleCreator(),
                        it,
                        CollectionTraitInfo.fromShape(it, codegenContext.constrainedShapeSymbolProvider),
                        SmithyValidationExceptionConversionGenerator(codegenContext),
                    ).render()
                }

                this@unconstrainedModuleWriter.unitTest(
                    name = "list_a_unconstrained_fail_to_constrain_with_first_error",
                    test = """
        serverIntegrationTest(model) { _, rustCrate ->
            rustCrate.testModule {
                unitTest("list_a_unconstrained_fail_to_constrain_with_first_error") {
                    rust(
                        """
                        let c_builder1 = crate::model::StructureC::builder().int(69);
                        let c_builder2 = crate::model::StructureC::builder().string("david".to_owned());
                        let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder1, c_builder2]);
                        let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]);
                        let list_b_unconstrained = crate::unconstrained::list_b_unconstrained::ListBUnconstrained(vec![c_builder1, c_builder2]);
                        let list_a_unconstrained = crate::unconstrained::list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]);

                        let expected_err =
                            crate::model::list_a::ConstraintViolation::Member(0, crate::model::list_b::ConstraintViolation::Member(
@@ -108,13 +76,14 @@ class UnconstrainedCollectionGeneratorTest {
                        );
                        """,
                    )
                }

                this@unconstrainedModuleWriter.unitTest(
                    name = "list_a_unconstrained_succeed_to_constrain",
                    test = """
                unitTest("list_a_unconstrained_succeed_to_constrain") {
                    rust(
                        """
                        let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david"));
                        let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]);
                        let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]);
                        let list_b_unconstrained = crate::unconstrained::list_b_unconstrained::ListBUnconstrained(vec![c_builder]);
                        let list_a_unconstrained = crate::unconstrained::list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]);

                        let expected: Vec<Vec<crate::model::StructureC>> = vec![vec![crate::model::StructureC {
                            string: "david".to_owned(),
@@ -126,20 +95,20 @@ class UnconstrainedCollectionGeneratorTest {
                        assert_eq!(expected, actual);
                        """,
                    )
                }

                this@unconstrainedModuleWriter.unitTest(
                    name = "list_a_unconstrained_converts_into_constrained",
                    test = """
                unitTest("list_a_unconstrained_converts_into_constrained") {
                    rust(
                        """
                        let c_builder = crate::model::StructureC::builder();
                        let list_b_unconstrained = list_b_unconstrained::ListBUnconstrained(vec![c_builder]);
                        let list_a_unconstrained = list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]);
                        let list_b_unconstrained = crate::unconstrained::list_b_unconstrained::ListBUnconstrained(vec![c_builder]);
                        let list_a_unconstrained = crate::unconstrained::list_a_unconstrained::ListAUnconstrained(vec![list_b_unconstrained]);
  
                        let _list_a: crate::constrained::MaybeConstrained<crate::constrained::list_a_constrained::ListAConstrained> = list_a_unconstrained.into();
                        """,
                    )
                }
            }
        project.renderInlineMemoryModules()
        project.compileAndTest()
        }
    }
}
+48 −69
Original line number Diff line number Diff line
@@ -7,20 +7,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.CoreCodegenConfig
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Model
import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator
import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol
import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext

class UnconstrainedMapGeneratorTest {
@@ -30,6 +24,25 @@ class UnconstrainedMapGeneratorTest {
            """
            namespace test
            
            use aws.protocols#restJson1
            use smithy.framework#ValidationException
            
            @restJson1
            service TestService {
                operations: ["Operation"]
            }
            
            @http(uri: "/operation", method: "POST")
            operation Operation {
                input: OperationInputOutput
                output: OperationInputOutput
                errors: [ValidationException]
            }
            
            structure OperationInputOutput {
                map: MapA
            }

            map MapA {
                key: String,
                value: MapB
@@ -56,53 +69,20 @@ class UnconstrainedMapGeneratorTest {

        val project = TestWorkspace.testProject(symbolProvider, CoreCodegenConfig(debugMode = true))

        project.withModule(Model) {
            model.lookup<StructureShape>("test#StructureC").serverRenderWithModelBuilder(
                project,
                model,
                symbolProvider,
                this,
                ServerRestJsonProtocol(codegenContext),
            )
        }

        project.withModule(ServerRustModule.ConstrainedModule) {
            listOf(mapA, mapB).forEach {
                PubCrateConstrainedMapGenerator(
                    codegenContext,
                    this.createTestInlineModuleCreator(),
                    it,
                ).render()
            }
        }
        project.withModule(ServerRustModule.UnconstrainedModule) unconstrainedModuleWriter@{
            project.withModule(Model) modelsModuleWriter@{
                listOf(mapA, mapB).forEach {
                    UnconstrainedMapGenerator(
                        codegenContext,
                        this@unconstrainedModuleWriter.createTestInlineModuleCreator(), it,
                    ).render()

                    MapConstraintViolationGenerator(
                        codegenContext,
                        this@modelsModuleWriter.createTestInlineModuleCreator(),
                        it,
                        SmithyValidationExceptionConversionGenerator(codegenContext),
                    ).render()
                }

                this@unconstrainedModuleWriter.unitTest(
                    name = "map_a_unconstrained_fail_to_constrain_with_some_error",
                    test = """
        serverIntegrationTest(model) { _, rustCrate ->
            rustCrate.testModule {
                unitTest("map_a_unconstrained_fail_to_constrain_with_some_error") {
                    rust(
                        """
                        let c_builder1 = crate::model::StructureC::builder().int(69);
                        let c_builder2 = crate::model::StructureC::builder().string(String::from("david"));
                        let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained(
                        let map_b_unconstrained = crate::unconstrained::map_b_unconstrained::MapBUnconstrained(
                            std::collections::HashMap::from([
                                (String::from("KeyB1"), c_builder1),
                                (String::from("KeyB2"), c_builder2),
                            ])
                        );
                        let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained(
                        let map_a_unconstrained = crate::unconstrained::map_a_unconstrained::MapAUnconstrained(
                            std::collections::HashMap::from([
                                (String::from("KeyA"), map_b_unconstrained),
                            ])
@@ -129,17 +109,17 @@ class UnconstrainedMapGeneratorTest {
                        assert!(actual_err == missing_string_expected_err || actual_err == missing_int_expected_err);
                        """,
                    )

                this@unconstrainedModuleWriter.unitTest(
                    name = "map_a_unconstrained_succeed_to_constrain",
                    test = """
                }
                unitTest("map_a_unconstrained_succeed_to_constrain") {
                    rust(
                        """
                        let c_builder = crate::model::StructureC::builder().int(69).string(String::from("david"));
                        let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained(
                        let map_b_unconstrained = crate::unconstrained::map_b_unconstrained::MapBUnconstrained(
                            std::collections::HashMap::from([
                                (String::from("KeyB"), c_builder),
                            ])
                        );
                        let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained(
                        let map_a_unconstrained = crate::unconstrained::map_a_unconstrained::MapAUnconstrained(
                            std::collections::HashMap::from([
                                (String::from("KeyA"), map_b_unconstrained),
                            ])
@@ -160,17 +140,17 @@ class UnconstrainedMapGeneratorTest {
                        );
                        """,
                    )

                this@unconstrainedModuleWriter.unitTest(
                    name = "map_a_unconstrained_converts_into_constrained",
                    test = """
                }
                unitTest("map_a_unconstrained_converts_into_constrained") {
                    rust(
                        """
                        let c_builder = crate::model::StructureC::builder();
                        let map_b_unconstrained = map_b_unconstrained::MapBUnconstrained(
                        let map_b_unconstrained = crate::unconstrained::map_b_unconstrained::MapBUnconstrained(
                            std::collections::HashMap::from([
                                (String::from("KeyB"), c_builder),
                            ])
                        );
                        let map_a_unconstrained = map_a_unconstrained::MapAUnconstrained(
                        let map_a_unconstrained = crate::unconstrained::map_a_unconstrained::MapAUnconstrained(
                            std::collections::HashMap::from([
                                (String::from("KeyA"), map_b_unconstrained),
                            ])
@@ -181,7 +161,6 @@ class UnconstrainedMapGeneratorTest {
                    )
                }
            }
        project.renderInlineMemoryModules()
        project.compileAndTest()
        }
    }
}