Unverified Commit a9ef40b9 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Lifetimes in builders (#3249)



## Motivation and Context
We're not handling lifetimes in builders of structures.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
parent 8754c99b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ class RustReservedWordSymbolProvider(
            "RustReservedWordSymbolProvider should only run once"
        }

        var renamedSymbol = internal.toSymbol(shape)
        val renamedSymbol = internal.toSymbol(shape)
        return when (shape) {
            is MemberShape -> {
                val container = model.expectShape(shape.container)
+30 −24
Original line number Diff line number Diff line
@@ -88,32 +88,16 @@ open class StructureGenerator(
        renderStructure()
    }

    /**
     * Search for lifetimes used by the members of the struct and generate a declaration.
     * e.g. `<'a, 'b>`
     */
    private fun lifetimeDeclaration(): String {
        val lifetimes = members
            .map { symbolProvider.toSymbol(it).rustType().innerReference() }
            .mapNotNull {
                when (it) {
                    is RustType.Reference -> it.lifetime
                    else -> null
                }
            }.toSet().sorted()
        return if (lifetimes.isNotEmpty()) {
            "<${lifetimes.joinToString { "'$it" }}>"
        } else {
            ""
        }
    }

    /**
     * Render a custom debug implementation
     * When [SensitiveTrait] support is required, render a custom debug implementation to redact sensitive data
     */
    private fun renderDebugImpl() {
        writer.rustBlock("impl ${lifetimeDeclaration()} #T for $name ${lifetimeDeclaration()}", RuntimeType.Debug) {
        val lifetime = shape.lifetimeDeclaration(symbolProvider)
        writer.rustBlock(
            "impl ${shape.lifetimeDeclaration(symbolProvider)} #T for $name $lifetime",
            RuntimeType.Debug,
        ) {
            writer.rustBlock("fn fmt(&self, f: &mut #1T::Formatter<'_>) -> #1T::Result", RuntimeType.stdFmt) {
                rust("""let mut formatter = f.debug_struct(${name.dq()});""")
                members.forEach { member ->
@@ -134,8 +118,13 @@ open class StructureGenerator(
        if (accessorMembers.isEmpty()) {
            return
        }
        val lifetimes = lifetimeDeclaration()
        writer.rustBlock("impl $lifetimes $name $lifetimes") {
        writer.rustBlock(
            "impl ${shape.lifetimeDeclaration(symbolProvider)} $name ${
                shape.lifetimeDeclaration(
                    symbolProvider,
                )
            }",
        ) {
            // Render field accessor methods
            forEachMember(accessorMembers) { member, memberName, memberSymbol ->
                val memberType = memberSymbol.rustType()
@@ -146,6 +135,7 @@ open class StructureGenerator(
                        unwrapOrDefault = true
                        memberType.stripOuter<RustType.Option>().asDeref().asRef()
                    }

                    memberType.isCopy() -> memberType
                    memberType is RustType.Option && memberType.member.isDeref() -> memberType.asDeref()
                    memberType.isDeref() -> memberType.asDeref().asRef()
@@ -188,7 +178,7 @@ open class StructureGenerator(
        writer.deprecatedShape(shape)
        containerMeta.render(writer)

        writer.rustBlock("struct $name ${lifetimeDeclaration()}") {
        writer.rustBlock("struct $name ${shape.lifetimeDeclaration(symbolProvider)}") {
            writer.forEachMember(members) { member, memberName, memberSymbol ->
                renderStructureMember(writer, member, memberName, memberSymbol)
            }
@@ -223,3 +213,19 @@ open class StructureGenerator(
        )
    }
}

/**
 * Search for lifetimes used by the members of the struct and generate a declaration.
 * e.g. `<'a, 'b>`
 */
fun StructureShape.lifetimeDeclaration(symbolProvider: RustSymbolProvider): String {
    val lifetimes = this.members()
        .mapNotNull { symbolProvider.toSymbol(it).rustType().innerReference()?.let { it as RustType.Reference } }
        .mapNotNull { it.lifetime }
        .toSet().sorted()
    return if (lifetimes.isNotEmpty()) {
        "<${lifetimes.joinToString { "'$it" }}>"
    } else {
        ""
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -62,7 +62,7 @@ object OperationNormalizer {
        check(
            shapeConflict == null,
        ) {
            "shape $shapeConflict conflicted with an existing shape in the model (${model.getShape(shapeConflict!!.id)}. This is a bug."
            "shape $shapeConflict conflicted with an existing shape in the model (${model.expectShape(shapeConflict!!.id)}). This is a bug."
        }
        val modelWithOperationInputs = model.toBuilder().addShapes(newShapes).build()
        return transformer.mapShapes(modelWithOperationInputs) {
+31 −7
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
@@ -41,6 +42,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.lifetimeDeclaration
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
@@ -243,7 +245,10 @@ open class ServerCodegenVisitor(
                logger.log(logMessage.level, logMessage.message)
            }
            if (validationResult.shouldAbort) {
                throw CodegenException("Unsupported constraints feature used; see error messages above for resolution", validationResult)
                throw CodegenException(
                    "Unsupported constraints feature used; see error messages above for resolution",
                    validationResult,
                )
            }
        }

@@ -328,7 +333,8 @@ open class ServerCodegenVisitor(
            serverBuilderGenerator.render(rustCrate, writer)

            if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
                writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
                val lifetimes = shape.lifetimeDeclaration(codegenContext.symbolProvider)
                writer.rustBlock("impl $lifetimes ${codegenContext.symbolProvider.toSymbol(shape).name} $lifetimes") {
                    serverBuilderGenerator.renderConvenienceMethod(this)
                }
            }
@@ -372,7 +378,11 @@ open class ServerCodegenVisitor(

        if (renderUnconstrainedList) {
            logger.info("[rust-server-codegen] Generating an unconstrained type for collection shape $shape")
            rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.UnconstrainedModule, shape, codegenContext) {
            rustCrate.withModuleOrWithStructureBuilderModule(
                ServerRustModule.UnconstrainedModule,
                shape,
                codegenContext,
            ) {
                UnconstrainedCollectionGenerator(
                    codegenContext,
                    rustCrate.createInlineModuleCreator(),
@@ -382,7 +392,11 @@ open class ServerCodegenVisitor(

            if (!isDirectlyConstrained) {
                logger.info("[rust-server-codegen] Generating a constrained type for collection shape $shape")
                rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.ConstrainedModule, shape, codegenContext) {
                rustCrate.withModuleOrWithStructureBuilderModule(
                    ServerRustModule.ConstrainedModule,
                    shape,
                    codegenContext,
                ) {
                    PubCrateConstrainedCollectionGenerator(
                        codegenContext,
                        rustCrate.createInlineModuleCreator(),
@@ -427,7 +441,11 @@ open class ServerCodegenVisitor(

        if (renderUnconstrainedMap) {
            logger.info("[rust-server-codegen] Generating an unconstrained type for map $shape")
            rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.UnconstrainedModule, shape, codegenContext) {
            rustCrate.withModuleOrWithStructureBuilderModule(
                ServerRustModule.UnconstrainedModule,
                shape,
                codegenContext,
            ) {
                UnconstrainedMapGenerator(
                    codegenContext,
                    rustCrate.createInlineModuleCreator(),
@@ -437,7 +455,11 @@ open class ServerCodegenVisitor(

            if (!isDirectlyConstrained) {
                logger.info("[rust-server-codegen] Generating a constrained type for map $shape")
                rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.ConstrainedModule, shape, codegenContext) {
                rustCrate.withModuleOrWithStructureBuilderModule(
                    ServerRustModule.ConstrainedModule,
                    shape,
                    codegenContext,
                ) {
                    PubCrateConstrainedMapGenerator(
                        codegenContext,
                        rustCrate.createInlineModuleCreator(),
@@ -575,7 +597,9 @@ open class ServerCodegenVisitor(
     */
    open fun protocolTests() {
        rustCrate.withModule(ServerRustModule.Operation) {
            ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), protocolGenerator).render(this)
            ServerProtocolTestGenerator(codegenContext, protocolGeneratorFactory.support(), protocolGenerator).render(
                this,
            )
        }
    }

+10 −8
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.generators.lifetimeDeclaration
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
@@ -147,6 +148,7 @@ class ServerBuilderGenerator(
    private val isBuilderFallible = hasFallibleBuilder(shape, model, symbolProvider, takeInUnconstrainedTypes)
    private val serverBuilderConstraintViolations =
        ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator)
    private val lifetime = shape.lifetimeDeclaration(symbolProvider)

    private val codegenScope = arrayOf(
        "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
@@ -196,11 +198,11 @@ class ServerBuilderGenerator(
            it == RuntimeType.Debug || it == RuntimeType.Clone
        } + RuntimeType.Default
        Attribute(derive(builderDerives)).render(writer)
        writer.rustBlock("${visibility.toRustQualifier()} struct Builder") {
        writer.rustBlock("${visibility.toRustQualifier()} struct Builder$lifetime") {
            members.forEach { renderBuilderMember(this, it) }
        }

        writer.rustBlock("impl Builder") {
        writer.rustBlock("impl $lifetime Builder $lifetime") {
            for (member in members) {
                if (publicConstrainedTypes) {
                    renderBuilderMemberFn(this, member)
@@ -262,7 +264,7 @@ class ServerBuilderGenerator(
                self.build_enforcing_all_constraints()
            }
            """,
            "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol),
            "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol, lifetime),
        )
        renderBuildEnforcingAllConstraintsFn(implBlockWriter)
    }
@@ -270,7 +272,7 @@ class ServerBuilderGenerator(
    private fun renderBuildEnforcingAllConstraintsFn(implBlockWriter: RustWriter) {
        implBlockWriter.rustBlockTemplate(
            "fn build_enforcing_all_constraints(self) -> #{ReturnType:W}",
            "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol),
            "ReturnType" to buildFnReturnType(isBuilderFallible, structureSymbol, lifetime),
        ) {
            conditionalBlock("Ok(", ")", conditional = isBuilderFallible) {
                coreBuilder(this)
@@ -280,7 +282,7 @@ class ServerBuilderGenerator(

    fun renderConvenienceMethod(implBlock: RustWriter) {
        implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol)
        implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) {
        implBlock.rustBlock("pub fn builder() -> #T $lifetime", builderSymbol) {
            write("#T::default()", builderSymbol)
        }
    }
@@ -413,10 +415,10 @@ class ServerBuilderGenerator(
    private fun renderTryFromBuilderImpl(writer: RustWriter) {
        writer.rustTemplate(
            """
            impl #{TryFrom}<Builder> for #{Structure} {
            impl #{TryFrom}<Builder $lifetime> for #{Structure}$lifetime {
                type Error = ConstraintViolation;

                fn try_from(builder: Builder) -> Result<Self, Self::Error> {
                fn try_from(builder: Builder $lifetime) -> Result<Self, Self::Error> {
                    builder.build()
                }
            }
@@ -428,7 +430,7 @@ class ServerBuilderGenerator(
    private fun renderFromBuilderImpl(writer: RustWriter) {
        writer.rustTemplate(
            """
            impl #{From}<Builder> for #{Structure} {
            impl #{From}<Builder $lifetime> for #{Structure} $lifetime {
                fn from(builder: Builder) -> Self {
                    builder.build()
                }
Loading