Commit 31dafaf7 authored by John DiSanti's avatar John DiSanti
Browse files

Fix re-export of `SdkError`

parent aca4d963
Loading
Loading
Loading
Loading
+19 −2
Original line number Diff line number Diff line
@@ -24,10 +24,11 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes
import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.util.letIf
@@ -79,6 +80,8 @@ class RequiredCustomizations : ClientCodegenDecorator {
        baseCustomizations + AllowLintsCustomization()

    override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) {
        val rc = codegenContext.runtimeConfig

        // Add rt-tokio feature for `ByteStream::from_path`
        rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio")))

@@ -91,7 +94,21 @@ class RequiredCustomizations : ClientCodegenDecorator {
            pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this)
        }
        rustCrate.withModule(ClientRustModule.Error) {
            pubUseSmithyErrorTypes(codegenContext)(this)
            rustTemplate(
                """
                pub type SdkError<E> = #{SdkError}<E, #{SdkErrorResponse}>;
                pub use #{DisplayErrorContext};
                pub use #{ProvideErrorMetadata};
                """,
                "SdkError" to RuntimeType.smithyHttp(rc).resolve("result::SdkError"),
                "SdkErrorResponse" to if (codegenContext.smithyRuntimeMode.generateOrchestrator) {
                    RuntimeType.smithyRuntimeApi(rc).resolve("client::orchestrator::HttpResponse")
                } else {
                    RuntimeType.HttpResponse
                },
                "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"),
                "ProvideErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ProvideErrorMetadata"),
            )
        }

        ClientRustModule.Meta.also { metaModule ->
+27 −63
Original line number Diff line number Diff line
@@ -9,20 +9,12 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember
import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.letIf

private data class PubUseType(
    val type: RuntimeType,
    val shouldExport: (Model) -> Boolean,
    val alias: String? = null,
)

/** Returns true if the model has normal streaming operations (excluding event streams) */
private fun hasStreamingOperations(model: Model): Boolean {
@@ -48,62 +40,34 @@ private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(m
/** Returns true if the model uses any timestamp shapes */
private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape)

/** Returns a list of types that should be re-exported for the given model */
internal fun pubUseTypes(codegenContext: CodegenContext, model: Model): List<RuntimeType> =
    pubUseTypesThatShouldBeExported(codegenContext, model).map { it.type }

private fun pubUseTypesThatShouldBeExported(codegenContext: CodegenContext, model: Model): List<PubUseType> {
    val runtimeConfig = codegenContext.runtimeConfig
    return (
        listOf(
            PubUseType(RuntimeType.blob(runtimeConfig), ::hasBlobs),
            PubUseType(RuntimeType.dateTime(runtimeConfig), ::hasDateTimes),
            PubUseType(RuntimeType.format(runtimeConfig), ::hasDateTimes, "DateTimeFormat"),
        ) + RuntimeType.smithyHttp(runtimeConfig).let { http ->
            listOf(
                PubUseType(http.resolve("byte_stream::ByteStream"), ::hasStreamingOperations),
                PubUseType(http.resolve("byte_stream::AggregatedBytes"), ::hasStreamingOperations),
                PubUseType(http.resolve("byte_stream::error::Error"), ::hasStreamingOperations, "ByteStreamError"),
                PubUseType(http.resolve("body::SdkBody"), ::hasStreamingOperations),
            )
        }
        ).filter { pubUseType -> pubUseType.shouldExport(model) }
}

/** Adds re-export statements for Smithy primitives */
fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writable = writable {
    val types = pubUseTypesThatShouldBeExported(codegenContext, model)
    if (types.isNotEmpty()) {
        types.forEach {
            val useStatement = if (it.alias == null) {
                "pub use #T;"
            } else {
                "pub use #T as ${it.alias};"
    val rc = codegenContext.runtimeConfig
    if (hasBlobs(model)) {
        rustTemplate("pub use #{Blob};", "Blob" to RuntimeType.blob(rc))
    }
            rust(useStatement, it.type)
        }
    }
}

/** Adds re-export statements for error types */
fun pubUseSmithyErrorTypes(codegenContext: CodegenContext): Writable = writable {
    val runtimeConfig = codegenContext.runtimeConfig
    val reexports = listOf(
        listOf(
            RuntimeType.smithyHttp(runtimeConfig).let { http ->
                PubUseType(http.resolve("result::SdkError"), { _ -> true })
            },
        ),
        RuntimeType.smithyTypes(runtimeConfig).let { types ->
            listOf(PubUseType(types.resolve("error::display::DisplayErrorContext"), { _ -> true }))
                // Only re-export `ProvideErrorMetadata` for clients
                .letIf(codegenContext.target == CodegenTarget.CLIENT) { list ->
                    list +
                        listOf(PubUseType(types.resolve("error::metadata::ProvideErrorMetadata"), { _ -> true }))
    if (hasDateTimes(model)) {
        rustTemplate(
            """
            pub use #{DateTime};
            pub use #{Format} as DateTimeFormat;
            """,
            "DateTime" to RuntimeType.dateTime(rc),
            "Format" to RuntimeType.format(rc),
        )
    }
        },
    ).flatten()
    reexports.forEach { reexport ->
        rust("pub use #T;", reexport.type)
    if (hasStreamingOperations(model)) {
        rustTemplate(
            """
            pub use #{ByteStream};
            pub use #{AggregatedBytes};
            pub use #{Error} as ByteStreamError;
            pub use #{SdkBody};
            """,
            "ByteStream" to RuntimeType.smithyHttp(rc).resolve("byte_stream::ByteStream"),
            "AggregatedBytes" to RuntimeType.smithyHttp(rc).resolve("byte_stream::AggregatedBytes"),
            "Error" to RuntimeType.smithyHttp(rc).resolve("byte_stream::error::Error"),
            "SdkBody" to RuntimeType.smithyHttp(rc).resolve("body::SdkBody"),
        )
    }
}
+37 −37
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.customizations

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGeneratorTest.Companion.model
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext
@@ -43,61 +43,61 @@ class SmithyTypesPubUseExtraTest {
        """.asSmithyModel()
    }

    private fun typesWithEmptyModel() = typesWithMember()
    private fun typesWithMember(
    private fun reexportsWithEmptyModel() = reexportsWithMember()
    private fun reexportsWithMember(
        inputMember: String = "",
        outputMember: String = "",
        unionMember: String = "",
        additionalShape: String = "",
    ) = pubUseTypes(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape))
    ) = RustWriter.root().let { writer ->
        pubUseSmithyPrimitives(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape))(writer)
        writer.toString()
    }

    private fun assertDoesntHaveTypes(types: List<RuntimeType>, expectedTypes: List<String>) =
        expectedTypes.forEach { assertDoesntHaveType(types, it) }
    private fun assertDoesntHaveReexports(reexports: String, expectedTypes: List<String>) =
        expectedTypes.forEach { assertDoesntHaveReexports(reexports, it) }

    private fun assertDoesntHaveType(types: List<RuntimeType>, type: String) {
        if (types.any { t -> t.fullyQualifiedName() == type }) {
    private fun assertDoesntHaveReexports(reexports: String, type: String) {
        if (reexports.contains(type)) {
            throw AssertionError("Expected $type to NOT be re-exported, but it was.")
        }
    }

    private fun assertHasTypes(types: List<RuntimeType>, expectedTypes: List<String>) =
        expectedTypes.forEach { assertHasType(types, it) }
    private fun assertHasReexports(reexports: String, expectedTypes: List<String>) =
        expectedTypes.forEach { assertHasReexport(reexports, it) }

    private fun assertHasType(types: List<RuntimeType>, type: String) {
        if (types.none { t -> t.fullyQualifiedName() == type }) {
            throw AssertionError(
                "Expected $type to be re-exported. Re-exported types: " +
                    types.joinToString { it.fullyQualifiedName() },
            )
    private fun assertHasReexport(reexports: String, type: String) {
        if (!reexports.contains(type)) {
            throw AssertionError("Expected $type to be re-exported. Re-exported types:\n$reexports")
        }
    }

    @Test
    fun `it re-exports Blob when a model uses blobs`() {
        assertDoesntHaveType(typesWithEmptyModel(), "::aws_smithy_types::Blob")
        assertHasType(typesWithMember(inputMember = "foo: Blob"), "::aws_smithy_types::Blob")
        assertHasType(typesWithMember(outputMember = "foo: Blob"), "::aws_smithy_types::Blob")
        assertHasType(
            typesWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Blob"),
        this.assertDoesntHaveReexports(reexportsWithEmptyModel(), "::aws_smithy_types::Blob")
        assertHasReexport(reexportsWithMember(inputMember = "foo: Blob"), "::aws_smithy_types::Blob")
        assertHasReexport(reexportsWithMember(outputMember = "foo: Blob"), "::aws_smithy_types::Blob")
        assertHasReexport(
            reexportsWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Blob"),
            "::aws_smithy_types::Blob",
        )
        assertHasType(
            typesWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Blob"),
        assertHasReexport(
            reexportsWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Blob"),
            "::aws_smithy_types::Blob",
        )
    }

    @Test
    fun `it re-exports DateTime when a model uses timestamps`() {
        assertDoesntHaveType(typesWithEmptyModel(), "aws_smithy_types::DateTime")
        assertHasType(typesWithMember(inputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime")
        assertHasType(typesWithMember(outputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime")
        assertHasType(
            typesWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"),
        this.assertDoesntHaveReexports(reexportsWithEmptyModel(), "aws_smithy_types::DateTime")
        assertHasReexport(reexportsWithMember(inputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime")
        assertHasReexport(reexportsWithMember(outputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime")
        assertHasReexport(
            reexportsWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"),
            "::aws_smithy_types::DateTime",
        )
        assertHasType(
            typesWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"),
        assertHasReexport(
            reexportsWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"),
            "::aws_smithy_types::DateTime",
        )
    }
@@ -108,20 +108,20 @@ class SmithyTypesPubUseExtraTest {
            listOf("::aws_smithy_http::byte_stream::ByteStream", "::aws_smithy_http::byte_stream::AggregatedBytes")
        val streamingShape = "@streaming blob Streaming"

        assertDoesntHaveTypes(typesWithEmptyModel(), streamingTypes)
        assertHasTypes(typesWithMember(additionalShape = streamingShape, inputMember = "m: Streaming"), streamingTypes)
        assertHasTypes(typesWithMember(additionalShape = streamingShape, outputMember = "m: Streaming"), streamingTypes)
        this.assertDoesntHaveReexports(reexportsWithEmptyModel(), streamingTypes)
        assertHasReexports(reexportsWithMember(additionalShape = streamingShape, inputMember = "m: Streaming"), streamingTypes)
        assertHasReexports(reexportsWithMember(additionalShape = streamingShape, outputMember = "m: Streaming"), streamingTypes)

        // Event streams don't re-export the normal streaming types
        assertDoesntHaveTypes(
            typesWithMember(
        this.assertDoesntHaveReexports(
            reexportsWithMember(
                additionalShape = "@streaming union EventStream { foo: SomeStruct }",
                inputMember = "m: EventStream",
            ),
            streamingTypes,
        )
        assertDoesntHaveTypes(
            typesWithMember(
        this.assertDoesntHaveReexports(
            reexportsWithMember(
                additionalShape = "@streaming union EventStream { foo: SomeStruct }",
                outputMember = "m: EventStream",
            ),
+13 −2
Original line number Diff line number Diff line
@@ -6,10 +6,11 @@
package software.amazon.smithy.rust.codegen.server.smithy.customizations

import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes
import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
@@ -34,12 +35,22 @@ class ServerRequiredCustomizations : ServerCodegenDecorator {
        baseCustomizations + AllowLintsCustomization()

    override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) {
        val rc = codegenContext.runtimeConfig

        // Add rt-tokio feature for `ByteStream::from_path`
        rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio")))

        rustCrate.withModule(ServerRustModule.Types) {
            pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this)
            pubUseSmithyErrorTypes(codegenContext)(this)
            rustTemplate(
                """
                pub type SdkError<E> = #{SdkError}<E, #{SdkErrorResponse}>;
                pub use #{DisplayErrorContext};
                """,
                "SdkError" to RuntimeType.smithyHttp(rc).resolve("result::SdkError"),
                "SdkErrorResponse" to RuntimeType.HttpResponse,
                "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"),
            )
        }

        rustCrate.withModule(ServerRustModule.root) {