Unverified Commit ae2a74b5 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Non-functional miscellaneous improvements (#1445)

This commit contains non-functional (i.e. mostly stylistic, small
refactors, documentation and consistency) improvements that I've batched
and rolled up while working on #1342.

* Make `extraDependencies` from `InlineDependency` class private
* Miscellaneous improvements to `RustTypes.kt`
* Remove dead code from `RuntimeTypes.kt`
* Style change in `RuntimeTypes.kt`
* Miscellaneous improvements to `SymbolVisitor.kt`
* Miscellaneous improvements to `BuilderGenerator.kt`
    - Make it consistent with the `ServerBuilderGenerator` from #1342.
* Miscellaneous improvements to `TestHelpers.kt`
* Add docs to `headers_for_prefix` function
    - I keep rereading the implementation to see what this returns; it's
      not intuitive from the function name.
parent a89a0832
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -55,7 +55,7 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer {
class InlineDependency(
    name: String,
    val module: RustModule,
    val extraDependencies: List<RustDependency> = listOf(),
    private val extraDependencies: List<RustDependency> = listOf(),
    val renderer: (RustWriter) -> Unit
) : RustDependency(name) {
    override fun version(): String {
+14 −18
Original line number Diff line number Diff line
@@ -91,11 +91,11 @@ sealed class RustType {
    }

    data class Reference(val lifetime: kotlin.String?, override val member: RustType) : RustType(), Container {
        override val name: kotlin.String = member.name
        override val name = member.name
    }

    data class Option(override val member: RustType) : RustType(), Container {
        override val name: kotlin.String = "Option"
        override val name = "Option"
        override val namespace = "std::option"

        /** Convert `Option<T>` to `Option<&T>` **/
@@ -105,7 +105,7 @@ sealed class RustType {
    }

    data class Box(override val member: RustType) : RustType(), Container {
        override val name: kotlin.String = "Box"
        override val name = "Box"
        override val namespace = "std::boxed"
    }

@@ -115,7 +115,7 @@ sealed class RustType {
    }

    data class Vec(override val member: RustType) : RustType(), Container {
        override val name: kotlin.String = "Vec"
        override val name = "Vec"
        override val namespace = "std::vec"
    }

@@ -148,25 +148,21 @@ fun RustType.asArgumentType(fullyQualified: Boolean = true): String {
}

/** Format this Rust type so that it may be used as an argument type in a function definition */
fun RustType.asArgumentValue(name: String): String {
    return when (this) {
        is RustType.String,
        is RustType.Box -> "$name.into()"
fun RustType.asArgumentValue(name: String) =
    when (this) {
        is RustType.String, is RustType.Box -> "$name.into()"
        else -> name
    }
}

/**
 * For a given name, generate an `Argument` data class containing pre-formatted strings for using this type when
 * writing a Rust function
 * writing a Rust function.
 */
fun RustType.asArgument(name: String): Argument {
    return Argument(
fun RustType.asArgument(name: String) = Argument(
    "$name: ${this.asArgumentType()}",
    this.asArgumentValue(name),
    this.render(),
)
}

/**
 * Render this type, including references and generic parameters.
@@ -298,7 +294,7 @@ data class RustMetadata(
        return this
    }

    fun renderVisibility(writer: RustWriter): RustMetadata {
    private fun renderVisibility(writer: RustWriter): RustMetadata {
        writer.writeInline(
            when (visibility) {
                Visibility.PRIVATE -> ""
+1 −5
Original line number Diff line number Diff line
@@ -172,7 +172,6 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
        val Display = stdfmt.member("Display")
        val From = RuntimeType("From", dependency = null, namespace = "std::convert")
        val TryFrom = RuntimeType("TryFrom", dependency = null, namespace = "std::convert")
        val Infallible = RuntimeType("Infallible", dependency = null, namespace = "std::convert")
        val PartialEq = std.member("cmp::PartialEq")
        val StdError = RuntimeType("Error", dependency = null, namespace = "std::error")
        val String = RuntimeType("String", dependency = null, namespace = "std::string")
@@ -235,8 +234,6 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
        val HttpRequestBuilder = Http("request::Builder")
        val HttpResponseBuilder = Http("response::Builder")

        val Hyper = CargoDependency.Hyper.asType()

        fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType =
            RuntimeType(
                "Receiver",
@@ -244,8 +241,7 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
                "aws_smithy_http::event_stream"
            )

        fun jsonErrors(runtimeConfig: RuntimeConfig) =
            forInlineDependency(InlineDependency.jsonErrors(runtimeConfig))
        fun jsonErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.jsonErrors(runtimeConfig))

        val IdempotencyToken by lazy { forInlineDependency(InlineDependency.idempotencyToken()) }

+23 −24
Original line number Diff line number Diff line
@@ -113,7 +113,9 @@ fun Symbol.mapRustType(f: (RustType) -> RustType): Symbol {
/** Set the symbolLocation for this symbol builder */
fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder {
    val currentRustType = this.build().rustType()
    check(currentRustType is RustType.Opaque) { "Only Opaque can have their namespace updated" }
    check(currentRustType is RustType.Opaque) {
        "Only `Opaque` can have their namespace updated"
    }
    val newRustType = currentRustType.copy(namespace = "crate::${symbolLocation.namespace}")
    return this.definitionFile("src/${symbolLocation.filename}")
        .namespace("crate::${symbolLocation.namespace}", "::")
@@ -148,6 +150,18 @@ fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = va
 */
fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) { "Some($value)" }

/**
 * Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service
 * specifically, `rename`
 */
fun Shape.contextName(serviceShape: ServiceShape?): String {
    return if (serviceShape != null) {
        id.getName(serviceShape)
    } else {
        id.name
    }
}

/**
 * Base converter from `Shape` to `Symbol`. Shapes are the direct contents of the `Smithy` model. `Symbols` carry information
 * about Rust types, namespaces, dependencies, metadata as well as other information required to render a symbol.
@@ -168,18 +182,6 @@ class SymbolVisitor(
        return shape.accept(this)
    }

    /**
     * Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service
     * specifically, `rename`
     */
    private fun Shape.contextName(): String {
        return if (serviceShape != null) {
            id.getName(serviceShape)
        } else {
            id.name
        }
    }

    /**
     * Return the name of a given `enum` variant. Note that this refers to `enum` in the Smithy context
     * where enum is a trait that can be applied to [StringShape] and not in the Rust context of an algebraic data type.
@@ -238,7 +240,8 @@ class SymbolVisitor(
    override fun doubleShape(shape: DoubleShape): Symbol = simpleShape(shape)
    override fun stringShape(shape: StringShape): Symbol {
        return if (shape.hasTrait<EnumTrait>()) {
            symbolBuilder(shape, RustType.Opaque(shape.contextName().toPascalCase())).locatedIn(Models).build()
            val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase())
            symbolBuilder(shape, rustType).locatedIn(Models).build()
        } else {
            simpleShape(shape)
        }
@@ -285,7 +288,7 @@ class SymbolVisitor(
        return symbolBuilder(
            shape,
            RustType.Opaque(
                shape.contextName()
                shape.contextName(serviceShape)
                    .replaceFirstChar { it.uppercase() }
            )
        )
@@ -305,7 +308,7 @@ class SymbolVisitor(
        val isError = shape.hasTrait<ErrorTrait>()
        val isInput = shape.hasTrait<SyntheticInputTrait>()
        val isOutput = shape.hasTrait<SyntheticOutputTrait>()
        val name = shape.contextName().toPascalCase().letIf(isError && config.renameExceptions) {
        val name = shape.contextName(serviceShape).toPascalCase().letIf(isError && config.renameExceptions) {
            it.replace("Exception", "Error")
        }
        val builder = symbolBuilder(shape, RustType.Opaque(name))
@@ -318,7 +321,7 @@ class SymbolVisitor(
    }

    override fun unionShape(shape: UnionShape): Symbol {
        val name = shape.contextName().toPascalCase()
        val name = shape.contextName(serviceShape).toPascalCase()
        val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(Models)

        return builder.build()
@@ -355,9 +358,7 @@ private const val SHAPE_KEY = "shape"
private const val SYMBOL_DEFAULT = "symboldefault"
private const val RENAMED_FROM_KEY = "renamedfrom"

fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder {
    return this.putProperty(RUST_TYPE_KEY, rustType)
}
fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder = this.putProperty(RUST_TYPE_KEY, rustType)

fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder {
    return this.putProperty(RENAMED_FROM_KEY, name)
@@ -366,9 +367,7 @@ fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder {
fun Symbol.renamedFrom(): String? = this.getProperty(RENAMED_FROM_KEY, String::class.java).orNull()

fun Symbol.defaultValue(): Default = this.getProperty(SYMBOL_DEFAULT, Default::class.java).orElse(Default.NoDefault)
fun Symbol.Builder.setDefault(default: Default): Symbol.Builder {
    return this.putProperty(SYMBOL_DEFAULT, default)
}
fun Symbol.Builder.setDefault(default: Default): Symbol.Builder = this.putProperty(SYMBOL_DEFAULT, default)

/**
 * Type representing the default value for a given type. (eg. for Strings, this is `""`)
@@ -406,7 +405,7 @@ fun Symbol.extractSymbolFromOption(): Symbol = this.mapRustType { it.stripOuter<
fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter<RustType.Option>() is RustType.Box

// Symbols should _always_ be created with a Rust type & shape attached
fun Symbol.rustType(): RustType = this.getProperty(RUST_TYPE_KEY, RustType::class.java).get()
fun Symbol.rustType(): RustType = this.expectProperty(RUST_TYPE_KEY, RustType::class.java)
fun Symbol.shape(): Shape = this.expectProperty(SHAPE_KEY, Shape::class.java)

/**
+33 −36
Original line number Diff line number Diff line
@@ -9,7 +9,6 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
@@ -56,8 +55,8 @@ class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
    fun serializationError(w: RustWriter, error: String) = "${w.format(runtimeConfig.operationBuildError())}::SerializationError($error.into())"
}

/** setter names will never hit a reserved word and therefore never need escaping */
fun MemberShape.setterName(): String = "set_${this.memberName.toSnakeCase()}"
// Setter names will never hit a reserved word and therefore never need escaping.
fun MemberShape.setterName() = "set_${this.memberName.toSnakeCase()}"

class BuilderGenerator(
    private val model: Model,
@@ -70,21 +69,21 @@ class BuilderGenerator(

    fun render(writer: RustWriter) {
        val symbol = symbolProvider.toSymbol(shape)
        writer.docs("See #D", symbol)
        writer.docs("See #D.", symbol)
        val segments = shape.builderSymbol(symbolProvider).namespace.split("::")
        writer.withModule(segments.last()) {
            renderBuilder(this)
        }
    }

    private fun buildFn(implBlockWriter: RustWriter) {
    private fun renderBuildFn(implBlockWriter: RustWriter) {
        val fallibleBuilder = StructureGenerator.fallibleBuilder(shape, symbolProvider)
        val outputSymbol = symbolProvider.toSymbol(shape)
        val returnType = when (fallibleBuilder) {
            true -> "std::result::Result<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>"
            true -> "Result<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>"
            false -> implBlockWriter.format(outputSymbol)
        }
        implBlockWriter.docs("Consumes the builder and constructs a #D", outputSymbol)
        implBlockWriter.docs("Consumes the builder and constructs a #D.", outputSymbol)
        implBlockWriter.rustBlock("pub fn build(self) -> $returnType") {
            conditionalBlock("Ok(", ")", conditional = fallibleBuilder) {
                // If a wrapper is specified, use the `::new` associated function to construct the wrapper
@@ -94,20 +93,16 @@ class BuilderGenerator(
    }

    private fun RustWriter.missingRequiredField(field: String) {
        val detailedMessage = "$field was not specified but it is required when building ${
        symbolProvider.toSymbol(
            shape
        ).name
        }"
        val detailedMessage = "$field was not specified but it is required when building ${symbolProvider.toSymbol(shape).name}"
        rust(
            """#T::MissingField { field: ${field.dq()}, details: ${detailedMessage.dq()}}""",
            """#T::MissingField { field: "$field", details: "$detailedMessage" } """,
            runtimeConfig.operationBuildError()
        )
    }

    fun renderConvenienceMethod(implBlock: RustWriter) {
        val builderSymbol = shape.builderSymbol(symbolProvider)
        implBlock.docs("Creates a new builder-style object to manufacture #D", structureSymbol)
        implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol)
        implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) {
            write("#T::default()", builderSymbol)
        }
@@ -115,9 +110,8 @@ class BuilderGenerator(

    // TODO(EventStream): [DX] Consider updating builders to take EventInputStream as Into<EventInputStream>
    private fun renderBuilderMember(writer: RustWriter, memberName: String, memberSymbol: Symbol) {
        // builder members are crate-public to enable using them
        // directly in serializers/deserializers. During XML deserialization, `builder.<field>.take` is used to append to
        // lists and maps
        // Builder members are crate-public to enable using them directly in serializers/deserializers.
        // During XML deserialization, `builder.<field>.take` is used to append to lists and maps.
        writer.write("pub(crate) $memberName: #T,", memberSymbol)
    }

@@ -136,15 +130,21 @@ class BuilderGenerator(
        }
    }

    /**
     * Render a `set_foo` method. This is useful as a target for code generation, because the argument type
     * is the same as the resulting member type, and is always optional.
     */
    private fun renderBuilderMemberSetterFn(
        writer: RustWriter,
        outerType: RustType,
        member: MemberShape,
        memberName: String
    ) {
        // Render a `set_foo` method. This is useful as a target for code generation, because the argument type
        // is the same as the resulting member type, and is always optional.
        // TODO(https://github.com/awslabs/smithy-rs/issues/1302): This `asOptional()` call is superfluous except in
        //  the case where the shape is a `@streaming` blob, because [StreamingTraitSymbolProvider] always generates
        //  a non `Option`al target type: in all other cases the client generates `Option`al types.
        val inputType = outerType.asOptional()

        writer.documentShape(member, model)
        writer.rustBlock("pub fn ${member.setterName()}(mut self, input: ${inputType.render(true)}) -> Self") {
            rust("self.$memberName = input; self")
@@ -154,31 +154,29 @@ class BuilderGenerator(
    private fun renderBuilder(writer: RustWriter) {
        val builderName = "Builder"

        val symbol = structureSymbol
        writer.docs("A builder for #D", symbol)
        Attribute.NonExhaustive.render(writer)
        // Matching derives to the main structure + `Default` since we are a builder and everything is optional
        val baseDerives = symbol.expectRustMetadata().derives
        writer.docs("A builder for #D.", structureSymbol)
        // Matching derives to the main structure + `Default` since we are a builder and everything is optional.
        val baseDerives = structureSymbol.expectRustMetadata().derives
        val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)) + RuntimeType.Default
        baseDerives.copy(derives = derives).render(writer)
        writer.rustBlock("pub struct $builderName") {
            members.forEach { member ->
            for (member in members) {
                val memberName = symbolProvider.toMemberName(member)
                // All fields in the builder are optional
                // All fields in the builder are optional.
                val memberSymbol = symbolProvider.toSymbol(member).makeOptional()
                renderBuilderMember(this, memberName, memberSymbol)
            }
        }

        writer.rustBlock("impl $builderName") {
            members.forEach { member ->
                // All fields in the builder are optional
            for (member in members) {
                // All fields in the builder are optional.
                val memberSymbol = symbolProvider.toSymbol(member)
                val outerType = memberSymbol.rustType()
                val coreType = outerType.stripOuter<RustType.Option>()
                val memberName = symbolProvider.toMemberName(member)
                // Render a context-aware builder method for certain types, e.g. a method for vectors that automatically
                // appends
                // appends.
                when (coreType) {
                    is RustType.Vec -> renderVecHelper(member, memberName, coreType)
                    is RustType.HashMap -> renderMapHelper(member, memberName, coreType)
@@ -187,7 +185,7 @@ class BuilderGenerator(

                renderBuilderMemberSetterFn(this, outerType, member, memberName)
            }
            buildFn(this)
            renderBuildFn(this)
        }
    }

@@ -236,17 +234,16 @@ class BuilderGenerator(

    /**
     * The core builder of the inner type. If the structure requires a fallible builder, this may use `?` to return
     * errors
     * errors.
     * ```rust
     * SomeStruct {
     *    field: builder.field,
     *    field2: builder.field2,
     *    field3: builder.field3.unwrap_or_default()
     *    field4: builder.field4.ok_or("field4 is required when building SomeStruct")?
     *    field1: builder.field1,
     *    field2: builder.field2.unwrap_or_default()
     *    field3: builder.field3.ok_or("field3 is required when building SomeStruct")?
     * }
     * ```
     */
    protected fun coreBuilder(writer: RustWriter) {
    private fun coreBuilder(writer: RustWriter) {
        writer.rustBlock("#T", structureSymbol) {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
Loading