Loading codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +19 −2 Original line number Diff line number Diff line Loading @@ -24,10 +24,11 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.util.letIf Loading Loading @@ -79,6 +80,8 @@ class RequiredCustomizations : ClientCodegenDecorator { baseCustomizations + AllowLintsCustomization() override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { val rc = codegenContext.runtimeConfig // Add rt-tokio feature for `ByteStream::from_path` rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio"))) Loading @@ -91,7 +94,21 @@ class RequiredCustomizations : ClientCodegenDecorator { pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this) } rustCrate.withModule(ClientRustModule.Error) { pubUseSmithyErrorTypes(codegenContext)(this) rustTemplate( """ pub type SdkError<E> = #{SdkError}<E, #{SdkErrorResponse}>; pub use #{DisplayErrorContext}; pub use #{ProvideErrorMetadata}; """, "SdkError" to RuntimeType.smithyHttp(rc).resolve("result::SdkError"), "SdkErrorResponse" to if (codegenContext.smithyRuntimeMode.generateOrchestrator) { RuntimeType.smithyRuntimeApi(rc).resolve("client::orchestrator::HttpResponse") } else { RuntimeType.HttpResponse }, "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"), "ProvideErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ProvideErrorMetadata"), ) } ClientRustModule.Meta.also { metaModule -> Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +27 −63 Original line number Diff line number Diff line Loading @@ -9,20 +9,12 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.letIf private data class PubUseType( val type: RuntimeType, val shouldExport: (Model) -> Boolean, val alias: String? = null, ) /** Returns true if the model has normal streaming operations (excluding event streams) */ private fun hasStreamingOperations(model: Model): Boolean { Loading @@ -48,62 +40,34 @@ private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(m /** Returns true if the model uses any timestamp shapes */ private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape) /** Returns a list of types that should be re-exported for the given model */ internal fun pubUseTypes(codegenContext: CodegenContext, model: Model): List<RuntimeType> = pubUseTypesThatShouldBeExported(codegenContext, model).map { it.type } private fun pubUseTypesThatShouldBeExported(codegenContext: CodegenContext, model: Model): List<PubUseType> { val runtimeConfig = codegenContext.runtimeConfig return ( listOf( PubUseType(RuntimeType.blob(runtimeConfig), ::hasBlobs), PubUseType(RuntimeType.dateTime(runtimeConfig), ::hasDateTimes), PubUseType(RuntimeType.format(runtimeConfig), ::hasDateTimes, "DateTimeFormat"), ) + RuntimeType.smithyHttp(runtimeConfig).let { http -> listOf( PubUseType(http.resolve("byte_stream::ByteStream"), ::hasStreamingOperations), PubUseType(http.resolve("byte_stream::AggregatedBytes"), ::hasStreamingOperations), PubUseType(http.resolve("byte_stream::error::Error"), ::hasStreamingOperations, "ByteStreamError"), PubUseType(http.resolve("body::SdkBody"), ::hasStreamingOperations), ) } ).filter { pubUseType -> pubUseType.shouldExport(model) } } /** Adds re-export statements for Smithy primitives */ fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writable = writable { val types = pubUseTypesThatShouldBeExported(codegenContext, model) if (types.isNotEmpty()) { types.forEach { val useStatement = if (it.alias == null) { "pub use #T;" } else { "pub use #T as ${it.alias};" val rc = codegenContext.runtimeConfig if (hasBlobs(model)) { rustTemplate("pub use #{Blob};", "Blob" to RuntimeType.blob(rc)) } rust(useStatement, it.type) } } } /** Adds re-export statements for error types */ fun pubUseSmithyErrorTypes(codegenContext: CodegenContext): Writable = writable { val runtimeConfig = codegenContext.runtimeConfig val reexports = listOf( listOf( RuntimeType.smithyHttp(runtimeConfig).let { http -> PubUseType(http.resolve("result::SdkError"), { _ -> true }) }, ), RuntimeType.smithyTypes(runtimeConfig).let { types -> listOf(PubUseType(types.resolve("error::display::DisplayErrorContext"), { _ -> true })) // Only re-export `ProvideErrorMetadata` for clients .letIf(codegenContext.target == CodegenTarget.CLIENT) { list -> list + listOf(PubUseType(types.resolve("error::metadata::ProvideErrorMetadata"), { _ -> true })) if (hasDateTimes(model)) { rustTemplate( """ pub use #{DateTime}; pub use #{Format} as DateTimeFormat; """, "DateTime" to RuntimeType.dateTime(rc), "Format" to RuntimeType.format(rc), ) } }, ).flatten() reexports.forEach { reexport -> rust("pub use #T;", reexport.type) if (hasStreamingOperations(model)) { rustTemplate( """ pub use #{ByteStream}; pub use #{AggregatedBytes}; pub use #{Error} as ByteStreamError; pub use #{SdkBody}; """, "ByteStream" to RuntimeType.smithyHttp(rc).resolve("byte_stream::ByteStream"), "AggregatedBytes" to RuntimeType.smithyHttp(rc).resolve("byte_stream::AggregatedBytes"), "Error" to RuntimeType.smithyHttp(rc).resolve("byte_stream::error::Error"), "SdkBody" to RuntimeType.smithyHttp(rc).resolve("body::SdkBody"), ) } } codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt +37 −37 Original line number Diff line number Diff line Loading @@ -7,7 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.customizations import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGeneratorTest.Companion.model import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext Loading Loading @@ -43,61 +43,61 @@ class SmithyTypesPubUseExtraTest { """.asSmithyModel() } private fun typesWithEmptyModel() = typesWithMember() private fun typesWithMember( private fun reexportsWithEmptyModel() = reexportsWithMember() private fun reexportsWithMember( inputMember: String = "", outputMember: String = "", unionMember: String = "", additionalShape: String = "", ) = pubUseTypes(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape)) ) = RustWriter.root().let { writer -> pubUseSmithyPrimitives(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape))(writer) writer.toString() } private fun assertDoesntHaveTypes(types: List<RuntimeType>, expectedTypes: List<String>) = expectedTypes.forEach { assertDoesntHaveType(types, it) } private fun assertDoesntHaveReexports(reexports: String, expectedTypes: List<String>) = expectedTypes.forEach { assertDoesntHaveReexports(reexports, it) } private fun assertDoesntHaveType(types: List<RuntimeType>, type: String) { if (types.any { t -> t.fullyQualifiedName() == type }) { private fun assertDoesntHaveReexports(reexports: String, type: String) { if (reexports.contains(type)) { throw AssertionError("Expected $type to NOT be re-exported, but it was.") } } private fun assertHasTypes(types: List<RuntimeType>, expectedTypes: List<String>) = expectedTypes.forEach { assertHasType(types, it) } private fun assertHasReexports(reexports: String, expectedTypes: List<String>) = expectedTypes.forEach { assertHasReexport(reexports, it) } private fun assertHasType(types: List<RuntimeType>, type: String) { if (types.none { t -> t.fullyQualifiedName() == type }) { throw AssertionError( "Expected $type to be re-exported. Re-exported types: " + types.joinToString { it.fullyQualifiedName() }, ) private fun assertHasReexport(reexports: String, type: String) { if (!reexports.contains(type)) { throw AssertionError("Expected $type to be re-exported. Re-exported types:\n$reexports") } } @Test fun `it re-exports Blob when a model uses blobs`() { assertDoesntHaveType(typesWithEmptyModel(), "::aws_smithy_types::Blob") assertHasType(typesWithMember(inputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasType(typesWithMember(outputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasType( typesWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Blob"), this.assertDoesntHaveReexports(reexportsWithEmptyModel(), "::aws_smithy_types::Blob") assertHasReexport(reexportsWithMember(inputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasReexport(reexportsWithMember(outputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasReexport( reexportsWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Blob"), "::aws_smithy_types::Blob", ) assertHasType( typesWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Blob"), assertHasReexport( reexportsWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Blob"), "::aws_smithy_types::Blob", ) } @Test fun `it re-exports DateTime when a model uses timestamps`() { assertDoesntHaveType(typesWithEmptyModel(), "aws_smithy_types::DateTime") assertHasType(typesWithMember(inputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasType(typesWithMember(outputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasType( typesWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), this.assertDoesntHaveReexports(reexportsWithEmptyModel(), "aws_smithy_types::DateTime") assertHasReexport(reexportsWithMember(inputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasReexport(reexportsWithMember(outputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasReexport( reexportsWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), "::aws_smithy_types::DateTime", ) assertHasType( typesWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), assertHasReexport( reexportsWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), "::aws_smithy_types::DateTime", ) } Loading @@ -108,20 +108,20 @@ class SmithyTypesPubUseExtraTest { listOf("::aws_smithy_http::byte_stream::ByteStream", "::aws_smithy_http::byte_stream::AggregatedBytes") val streamingShape = "@streaming blob Streaming" assertDoesntHaveTypes(typesWithEmptyModel(), streamingTypes) assertHasTypes(typesWithMember(additionalShape = streamingShape, inputMember = "m: Streaming"), streamingTypes) assertHasTypes(typesWithMember(additionalShape = streamingShape, outputMember = "m: Streaming"), streamingTypes) this.assertDoesntHaveReexports(reexportsWithEmptyModel(), streamingTypes) assertHasReexports(reexportsWithMember(additionalShape = streamingShape, inputMember = "m: Streaming"), streamingTypes) assertHasReexports(reexportsWithMember(additionalShape = streamingShape, outputMember = "m: Streaming"), streamingTypes) // Event streams don't re-export the normal streaming types assertDoesntHaveTypes( typesWithMember( this.assertDoesntHaveReexports( reexportsWithMember( additionalShape = "@streaming union EventStream { foo: SomeStruct }", inputMember = "m: EventStream", ), streamingTypes, ) assertDoesntHaveTypes( typesWithMember( this.assertDoesntHaveReexports( reexportsWithMember( additionalShape = "@streaming union EventStream { foo: SomeStruct }", outputMember = "m: EventStream", ), Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt +13 −2 Original line number Diff line number Diff line Loading @@ -6,10 +6,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.customizations import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext Loading @@ -34,12 +35,22 @@ class ServerRequiredCustomizations : ServerCodegenDecorator { baseCustomizations + AllowLintsCustomization() override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { val rc = codegenContext.runtimeConfig // Add rt-tokio feature for `ByteStream::from_path` rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio"))) rustCrate.withModule(ServerRustModule.Types) { pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this) pubUseSmithyErrorTypes(codegenContext)(this) rustTemplate( """ pub type SdkError<E> = #{SdkError}<E, #{SdkErrorResponse}>; pub use #{DisplayErrorContext}; """, "SdkError" to RuntimeType.smithyHttp(rc).resolve("result::SdkError"), "SdkErrorResponse" to RuntimeType.HttpResponse, "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"), ) } rustCrate.withModule(ServerRustModule.root) { Loading Loading
codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +19 −2 Original line number Diff line number Diff line Loading @@ -24,10 +24,11 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.util.letIf Loading Loading @@ -79,6 +80,8 @@ class RequiredCustomizations : ClientCodegenDecorator { baseCustomizations + AllowLintsCustomization() override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { val rc = codegenContext.runtimeConfig // Add rt-tokio feature for `ByteStream::from_path` rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio"))) Loading @@ -91,7 +94,21 @@ class RequiredCustomizations : ClientCodegenDecorator { pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this) } rustCrate.withModule(ClientRustModule.Error) { pubUseSmithyErrorTypes(codegenContext)(this) rustTemplate( """ pub type SdkError<E> = #{SdkError}<E, #{SdkErrorResponse}>; pub use #{DisplayErrorContext}; pub use #{ProvideErrorMetadata}; """, "SdkError" to RuntimeType.smithyHttp(rc).resolve("result::SdkError"), "SdkErrorResponse" to if (codegenContext.smithyRuntimeMode.generateOrchestrator) { RuntimeType.smithyRuntimeApi(rc).resolve("client::orchestrator::HttpResponse") } else { RuntimeType.HttpResponse }, "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"), "ProvideErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ProvideErrorMetadata"), ) } ClientRustModule.Meta.also { metaModule -> Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +27 −63 Original line number Diff line number Diff line Loading @@ -9,20 +9,12 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.letIf private data class PubUseType( val type: RuntimeType, val shouldExport: (Model) -> Boolean, val alias: String? = null, ) /** Returns true if the model has normal streaming operations (excluding event streams) */ private fun hasStreamingOperations(model: Model): Boolean { Loading @@ -48,62 +40,34 @@ private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(m /** Returns true if the model uses any timestamp shapes */ private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape) /** Returns a list of types that should be re-exported for the given model */ internal fun pubUseTypes(codegenContext: CodegenContext, model: Model): List<RuntimeType> = pubUseTypesThatShouldBeExported(codegenContext, model).map { it.type } private fun pubUseTypesThatShouldBeExported(codegenContext: CodegenContext, model: Model): List<PubUseType> { val runtimeConfig = codegenContext.runtimeConfig return ( listOf( PubUseType(RuntimeType.blob(runtimeConfig), ::hasBlobs), PubUseType(RuntimeType.dateTime(runtimeConfig), ::hasDateTimes), PubUseType(RuntimeType.format(runtimeConfig), ::hasDateTimes, "DateTimeFormat"), ) + RuntimeType.smithyHttp(runtimeConfig).let { http -> listOf( PubUseType(http.resolve("byte_stream::ByteStream"), ::hasStreamingOperations), PubUseType(http.resolve("byte_stream::AggregatedBytes"), ::hasStreamingOperations), PubUseType(http.resolve("byte_stream::error::Error"), ::hasStreamingOperations, "ByteStreamError"), PubUseType(http.resolve("body::SdkBody"), ::hasStreamingOperations), ) } ).filter { pubUseType -> pubUseType.shouldExport(model) } } /** Adds re-export statements for Smithy primitives */ fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writable = writable { val types = pubUseTypesThatShouldBeExported(codegenContext, model) if (types.isNotEmpty()) { types.forEach { val useStatement = if (it.alias == null) { "pub use #T;" } else { "pub use #T as ${it.alias};" val rc = codegenContext.runtimeConfig if (hasBlobs(model)) { rustTemplate("pub use #{Blob};", "Blob" to RuntimeType.blob(rc)) } rust(useStatement, it.type) } } } /** Adds re-export statements for error types */ fun pubUseSmithyErrorTypes(codegenContext: CodegenContext): Writable = writable { val runtimeConfig = codegenContext.runtimeConfig val reexports = listOf( listOf( RuntimeType.smithyHttp(runtimeConfig).let { http -> PubUseType(http.resolve("result::SdkError"), { _ -> true }) }, ), RuntimeType.smithyTypes(runtimeConfig).let { types -> listOf(PubUseType(types.resolve("error::display::DisplayErrorContext"), { _ -> true })) // Only re-export `ProvideErrorMetadata` for clients .letIf(codegenContext.target == CodegenTarget.CLIENT) { list -> list + listOf(PubUseType(types.resolve("error::metadata::ProvideErrorMetadata"), { _ -> true })) if (hasDateTimes(model)) { rustTemplate( """ pub use #{DateTime}; pub use #{Format} as DateTimeFormat; """, "DateTime" to RuntimeType.dateTime(rc), "Format" to RuntimeType.format(rc), ) } }, ).flatten() reexports.forEach { reexport -> rust("pub use #T;", reexport.type) if (hasStreamingOperations(model)) { rustTemplate( """ pub use #{ByteStream}; pub use #{AggregatedBytes}; pub use #{Error} as ByteStreamError; pub use #{SdkBody}; """, "ByteStream" to RuntimeType.smithyHttp(rc).resolve("byte_stream::ByteStream"), "AggregatedBytes" to RuntimeType.smithyHttp(rc).resolve("byte_stream::AggregatedBytes"), "Error" to RuntimeType.smithyHttp(rc).resolve("byte_stream::error::Error"), "SdkBody" to RuntimeType.smithyHttp(rc).resolve("body::SdkBody"), ) } }
codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt +37 −37 Original line number Diff line number Diff line Loading @@ -7,7 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.customizations import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGeneratorTest.Companion.model import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext Loading Loading @@ -43,61 +43,61 @@ class SmithyTypesPubUseExtraTest { """.asSmithyModel() } private fun typesWithEmptyModel() = typesWithMember() private fun typesWithMember( private fun reexportsWithEmptyModel() = reexportsWithMember() private fun reexportsWithMember( inputMember: String = "", outputMember: String = "", unionMember: String = "", additionalShape: String = "", ) = pubUseTypes(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape)) ) = RustWriter.root().let { writer -> pubUseSmithyPrimitives(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape))(writer) writer.toString() } private fun assertDoesntHaveTypes(types: List<RuntimeType>, expectedTypes: List<String>) = expectedTypes.forEach { assertDoesntHaveType(types, it) } private fun assertDoesntHaveReexports(reexports: String, expectedTypes: List<String>) = expectedTypes.forEach { assertDoesntHaveReexports(reexports, it) } private fun assertDoesntHaveType(types: List<RuntimeType>, type: String) { if (types.any { t -> t.fullyQualifiedName() == type }) { private fun assertDoesntHaveReexports(reexports: String, type: String) { if (reexports.contains(type)) { throw AssertionError("Expected $type to NOT be re-exported, but it was.") } } private fun assertHasTypes(types: List<RuntimeType>, expectedTypes: List<String>) = expectedTypes.forEach { assertHasType(types, it) } private fun assertHasReexports(reexports: String, expectedTypes: List<String>) = expectedTypes.forEach { assertHasReexport(reexports, it) } private fun assertHasType(types: List<RuntimeType>, type: String) { if (types.none { t -> t.fullyQualifiedName() == type }) { throw AssertionError( "Expected $type to be re-exported. Re-exported types: " + types.joinToString { it.fullyQualifiedName() }, ) private fun assertHasReexport(reexports: String, type: String) { if (!reexports.contains(type)) { throw AssertionError("Expected $type to be re-exported. Re-exported types:\n$reexports") } } @Test fun `it re-exports Blob when a model uses blobs`() { assertDoesntHaveType(typesWithEmptyModel(), "::aws_smithy_types::Blob") assertHasType(typesWithMember(inputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasType(typesWithMember(outputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasType( typesWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Blob"), this.assertDoesntHaveReexports(reexportsWithEmptyModel(), "::aws_smithy_types::Blob") assertHasReexport(reexportsWithMember(inputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasReexport(reexportsWithMember(outputMember = "foo: Blob"), "::aws_smithy_types::Blob") assertHasReexport( reexportsWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Blob"), "::aws_smithy_types::Blob", ) assertHasType( typesWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Blob"), assertHasReexport( reexportsWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Blob"), "::aws_smithy_types::Blob", ) } @Test fun `it re-exports DateTime when a model uses timestamps`() { assertDoesntHaveType(typesWithEmptyModel(), "aws_smithy_types::DateTime") assertHasType(typesWithMember(inputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasType(typesWithMember(outputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasType( typesWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), this.assertDoesntHaveReexports(reexportsWithEmptyModel(), "aws_smithy_types::DateTime") assertHasReexport(reexportsWithMember(inputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasReexport(reexportsWithMember(outputMember = "foo: Timestamp"), "::aws_smithy_types::DateTime") assertHasReexport( reexportsWithMember(inputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), "::aws_smithy_types::DateTime", ) assertHasType( typesWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), assertHasReexport( reexportsWithMember(outputMember = "foo: SomeUnion", unionMember = "foo: Timestamp"), "::aws_smithy_types::DateTime", ) } Loading @@ -108,20 +108,20 @@ class SmithyTypesPubUseExtraTest { listOf("::aws_smithy_http::byte_stream::ByteStream", "::aws_smithy_http::byte_stream::AggregatedBytes") val streamingShape = "@streaming blob Streaming" assertDoesntHaveTypes(typesWithEmptyModel(), streamingTypes) assertHasTypes(typesWithMember(additionalShape = streamingShape, inputMember = "m: Streaming"), streamingTypes) assertHasTypes(typesWithMember(additionalShape = streamingShape, outputMember = "m: Streaming"), streamingTypes) this.assertDoesntHaveReexports(reexportsWithEmptyModel(), streamingTypes) assertHasReexports(reexportsWithMember(additionalShape = streamingShape, inputMember = "m: Streaming"), streamingTypes) assertHasReexports(reexportsWithMember(additionalShape = streamingShape, outputMember = "m: Streaming"), streamingTypes) // Event streams don't re-export the normal streaming types assertDoesntHaveTypes( typesWithMember( this.assertDoesntHaveReexports( reexportsWithMember( additionalShape = "@streaming union EventStream { foo: SomeStruct }", inputMember = "m: EventStream", ), streamingTypes, ) assertDoesntHaveTypes( typesWithMember( this.assertDoesntHaveReexports( reexportsWithMember( additionalShape = "@streaming union EventStream { foo: SomeStruct }", outputMember = "m: EventStream", ), Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt +13 −2 Original line number Diff line number Diff line Loading @@ -6,10 +6,11 @@ package software.amazon.smithy.rust.codegen.server.smithy.customizations import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext Loading @@ -34,12 +35,22 @@ class ServerRequiredCustomizations : ServerCodegenDecorator { baseCustomizations + AllowLintsCustomization() override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { val rc = codegenContext.runtimeConfig // Add rt-tokio feature for `ByteStream::from_path` rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio"))) rustCrate.withModule(ServerRustModule.Types) { pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this) pubUseSmithyErrorTypes(codegenContext)(this) rustTemplate( """ pub type SdkError<E> = #{SdkError}<E, #{SdkErrorResponse}>; pub use #{DisplayErrorContext}; """, "SdkError" to RuntimeType.smithyHttp(rc).resolve("result::SdkError"), "SdkErrorResponse" to RuntimeType.HttpResponse, "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"), ) } rustCrate.withModule(ServerRustModule.root) { Loading