Unverified Commit dc1ffb8a authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Refactor `UnconstrainedUnionGenerator` to use `ValidationExceptionConversionGenerator` (#3733)



`UnconstrainedUnionGenerator` should use
`ValidationExceptionConversionGenerator` to generate the
`as_validation_exception` method.

---------

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 9af72f5f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -212,6 +212,7 @@ class PythonServerCodegenVisitor(
                    rustCrate.createInlineModuleCreator(),
                    this@modelsModuleWriter,
                    shape,
                    validationExceptionConversionGenerator,
                ).render()
            }
        }
+1 −0
Original line number Diff line number Diff line
@@ -596,6 +596,7 @@ open class ServerCodegenVisitor(
                    rustCrate.createInlineModuleCreator(),
                    this@modelsModuleWriter,
                    shape,
                    validationExceptionConversionGenerator,
                ).render()
            }
        }
+17 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
@@ -31,6 +32,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Length
import software.amazon.smithy.rust.codegen.server.smithy.generators.Pattern
import software.amazon.smithy.rust.codegen.server.smithy.generators.Range
import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnionConstraintTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
@@ -320,4 +322,19 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
            "AsValidationExceptionFields" to validationExceptionFields.join("\n"),
        )
    }

    override fun unionShapeConstraintViolationImplBlock(
        unionConstraintTraitInfo: Collection<UnionConstraintTraitInfo>,
    ) = writable {
        rustBlockTemplate(
            "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField",
            "String" to RuntimeType.String,
        ) {
            withBlock("match self {", "}") {
                for (constraintViolation in unionConstraintTraitInfo) {
                    rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""")
                }
            }
        }
    }
}
+17 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
@@ -30,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstraintVi
import software.amazon.smithy.rust.codegen.server.smithy.generators.Range
import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnionConstraintTraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
@@ -244,4 +246,19 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
            "AsValidationExceptionFields" to validationExceptionFields.join(""),
        )
    }

    override fun unionShapeConstraintViolationImplBlock(
        unionConstraintTraitInfo: Collection<UnionConstraintTraitInfo>,
    ) = writable {
        rustBlockTemplate(
            "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField",
            "String" to RuntimeType.String,
        ) {
            withBlock("match self {", "}") {
                for (constraintViolation in unionConstraintTraitInfo) {
                    rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""")
                }
            }
        }
    }
}
+20 −23
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
@@ -53,6 +52,7 @@ class UnconstrainedUnionGenerator(
    private val inlineModuleCreator: InlineModuleCreator,
    private val modelsModuleWriter: RustWriter,
    val shape: UnionShape,
    private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
) {
    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
@@ -172,18 +172,15 @@ class UnconstrainedUnionGenerator(
            )

            if (shape.isReachableFromOperationInput()) {
                rustBlock("impl $constraintViolationName") {
                    rustBlockTemplate(
                        "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField",
                        "String" to RuntimeType.String,
                    ) {
                        withBlock("match self {", "}") {
                            for (constraintViolation in constraintViolations()) {
                                rust("""Self::${constraintViolation.name()}(inner) => inner.as_validation_exception_field(path + "/${constraintViolation.forMember.memberName}"),""")
                            }
                        }
                    }
                rustTemplate(
                    """
                    impl $constraintViolationName {
                        #{UnionShapeConstraintViolationImplBlock:W}
                    }
                    """,
                    "UnionShapeConstraintViolationImplBlock" to
                        validationExceptionConversionGenerator.unionShapeConstraintViolationImplBlock(constraintViolations()),
                )
            }
        }
    }
@@ -199,30 +196,26 @@ class UnconstrainedUnionGenerator(
            }
        }

    data class ConstraintViolation(val forMember: MemberShape) {
        fun name() = forMember.memberName.toPascalCase()
    }

    private fun constraintViolations() =
        sortedMembers
            .filter { it.targetCanReachConstrainedShape(model, symbolProvider) }
            .map { ConstraintViolation(it) }
            .map { UnionConstraintTraitInfo(it) }

    private fun renderConstraintViolation(
        writer: RustWriter,
        constraintViolation: ConstraintViolation,
        unionConstraintTraitInfo: UnionConstraintTraitInfo,
    ) {
        val targetShape = model.expectShape(constraintViolation.forMember.target)
        val targetShape = model.expectShape(unionConstraintTraitInfo.forMember.target)

        val constraintViolationSymbol =
            constraintViolationSymbolProvider.toSymbol(targetShape)
                // Box this constraint violation symbol if necessary.
                .letIf(constraintViolation.forMember.hasTrait<ConstraintViolationRustBoxTrait>()) {
                .letIf(unionConstraintTraitInfo.forMember.hasTrait<ConstraintViolationRustBoxTrait>()) {
                    it.makeRustBoxed()
                }

        writer.rust(
            "${constraintViolation.name()}(#T),",
            "${unionConstraintTraitInfo.name()}(#T),",
            constraintViolationSymbol,
        )
    }
@@ -291,7 +284,7 @@ class UnconstrainedUnionGenerator(
                    {
                        let constrained: #{ConstrainedSymbol} = $unconstrainedVar
                            .try_into()$boxIt$boxErr
                            .map_err(Self::Error::${ConstraintViolation(member).name()})?;
                            .map_err(Self::Error::${UnionConstraintTraitInfo(member).name()})?;
                        constrained.into()
                    }
                    """,
@@ -304,9 +297,13 @@ class UnconstrainedUnionGenerator(
                        .try_into()
                        $boxIt
                        $boxErr
                        .map_err(Self::Error::${ConstraintViolation(member).name()})?
                        .map_err(Self::Error::${UnionConstraintTraitInfo(member).name()})?
                    """,
                )
            }
        }
}

data class UnionConstraintTraitInfo(val forMember: MemberShape) {
    fun name() = forMember.memberName.toPascalCase()
}
Loading