Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/CargoDependency.kt +22 −4 Original line number Diff line number Diff line Loading @@ -20,6 +20,7 @@ data class Local(val basePath: String) : DependencyLocation() sealed class RustDependency(open val name: String) : SymbolDependencyContainer { abstract fun version(): String open fun dependencies(): List<RustDependency> = listOf() override fun getDependencies(): List<SymbolDependency> { return listOf( SymbolDependency Loading @@ -27,7 +28,7 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { .packageName(name).version(version()) // We rely on retrieving the structured dependency from the symbol later .putProperty(PropertyKey, this).build() ) ) + dependencies().flatMap { it.dependencies } } companion object { Loading @@ -48,24 +49,41 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { * * CodegenVisitor deduplicates inline dependencies by (module, name) during code generation. */ class InlineDependency(name: String, val module: String, val renderer: (RustWriter) -> Unit) : RustDependency(name) { class InlineDependency( name: String, val module: String, val extraDependencies: List<RustDependency> = listOf(), val renderer: (RustWriter) -> Unit ) : RustDependency(name) { override fun version(): String { return renderer(RustWriter.forModule("_")).hashCode().toString() } override fun dependencies(): List<RustDependency> { return extraDependencies } fun key() = "$module::$name" companion object { fun forRustFile(name: String, module: String, filename: String): InlineDependency { fun forRustFile( name: String, module: String, filename: String, vararg additionalDepencies: RustDependency ): InlineDependency { // The inline crate is loaded as a dependency on the runtime classpath val rustFile = this::class.java.getResource("/inlineable/src/$filename") check(rustFile != null) return InlineDependency(name, module) { writer -> return InlineDependency(name, module, additionalDepencies.toList()) { writer -> writer.raw(rustFile.readText()) } } fun uuid() = forRustFile("v4", "uuid", "uuid.rs") // TODO: putting this in the "error" module risks conflicting with a modeled error named "GenericError" fun genericError() = forRustFile("GenericError", "types", "generic_error.rs", CargoDependency.Serde) } } Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt +1 −0 Original line number Diff line number Diff line Loading @@ -131,6 +131,7 @@ class RustWriter private constructor( ) : CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(namespace)) { companion object { fun root() = forModule(null) fun forModule(module: String?): RustWriter = if (module == null) { RustWriter("lib.rs", "crate") } else { Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +3 −1 Original line number Diff line number Diff line Loading @@ -122,9 +122,11 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na fun SerdeJson(path: String) = RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json") val GenericError = RuntimeType("GenericError", InlineDependency.genericError(), "crate::types") fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType( name = name, dependency = InlineDependency(name, module, func), dependency = InlineDependency(name, module, listOf(), func), namespace = "crate::$module" ) } Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CombinedErrorGenerator.kt 0 → 100644 +87 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.generators import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.lang.Attribute import software.amazon.smithy.rust.codegen.lang.Derives import software.amazon.smithy.rust.codegen.lang.RustMetadata import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType /** * For a given Operation ([this]), return the symbol referring to the unified error? This can be used * if you, eg. want to return a unfied error from a function: * * ```kotlin * rustWriter.rustBlock("fn get_error() -> #T", operation.errorSymbol(symbolProvider)) { * write("todo!() // function body") * } * ``` */ fun OperationShape.errorSymbol(symbolProvider: SymbolProvider): RuntimeType { val symbol = symbolProvider.toSymbol(this) return RuntimeType("${symbol.name}Error", null, "crate::error") } /** * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants, * but we must still combine those variants into an enum covering all possible errors for a given operation. */ class CombinedErrorGenerator( model: Model, private val symbolProvider: SymbolProvider, private val operation: OperationShape ) { private val operationIndex = OperationIndex.of(model) fun render(writer: RustWriter) { val errors = operationIndex.getErrors(operation) val symbol = operation.errorSymbol(symbolProvider) val meta = RustMetadata( derives = Derives(setOf(RuntimeType.StdFmt("Debug"))), additionalAttributes = listOf(Attribute.NonExhaustive), public = true ) meta.render(writer) writer.rustBlock("enum ${symbol.name}") { errors.forEach { errorVariant -> val errorSymbol = symbolProvider.toSymbol(errorVariant) write("${errorSymbol.name}(#T),", errorSymbol) } rust( """ /// An unexpected error, eg. invalid JSON returned by the service Unhandled(Box<dyn #T>), """, RuntimeType.StdError ) } writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdFmt("Display")) { rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { rustBlock("match self") { errors.forEach { val errorSymbol = symbolProvider.toSymbol(it) rust("""${symbol.name}::${errorSymbol.name}(inner) => inner.fmt(f),""") } rust("${symbol.name}::Unhandled(inner) => inner.fmt(f)") } } } writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) { rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { rustBlock("match self") { errors.forEach { val errorSymbol = symbolProvider.toSymbol(it) rust("""${symbol.name}::${errorSymbol.name}(inner) => Some(inner),""") } rust("${symbol.name}::Unhandled(inner) => Some(inner.as_ref())") } } } } } codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ErrorGenerator.kt +13 −7 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.RetryableTrait import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.StdError import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.StdFmt Loading @@ -37,19 +38,24 @@ class ErrorGenerator( error.isServerError -> "ErrorCause::Server" else -> "ErrorCause::Unknown(${error.value.dq()})" } val messageShape = shape.getMember("message") val message = messageShape.map { "self.message.as_deref()" }.orElse("None") writer.rustBlock("impl ${symbol.name}") { write("// TODO: create shared runtime crate") write("// fn at_fault(&self) -> ErrorCause { $errorCause }") write("pub fn retryable(&self) -> bool { $retryable }") write("pub fn throttling(&self) -> bool { $throttling }") rust( """ pub fn retryable(&self) -> bool { $retryable } pub fn throttling(&self) -> bool { $throttling } pub fn code(&self) -> &str { ${shape.id.name.dq()} } pub fn message(&self) -> Option<&str> { $message } """ ) } writer.rustBlock("impl #T for ${symbol.name}", StdFmt("Display")) { rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { val message = shape.getMember("message") write("write!(f, ${symbol.name.dq()})?;") if (message.isPresent) { OptionForEach(symbolProvider.toSymbol(message.get()), "&self.message") { field -> messageShape.map { OptionForEach(symbolProvider.toSymbol(it), "&self.message") { field -> write("""write!(f, ": {}", $field)?;""") } } Loading Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/CargoDependency.kt +22 −4 Original line number Diff line number Diff line Loading @@ -20,6 +20,7 @@ data class Local(val basePath: String) : DependencyLocation() sealed class RustDependency(open val name: String) : SymbolDependencyContainer { abstract fun version(): String open fun dependencies(): List<RustDependency> = listOf() override fun getDependencies(): List<SymbolDependency> { return listOf( SymbolDependency Loading @@ -27,7 +28,7 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { .packageName(name).version(version()) // We rely on retrieving the structured dependency from the symbol later .putProperty(PropertyKey, this).build() ) ) + dependencies().flatMap { it.dependencies } } companion object { Loading @@ -48,24 +49,41 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { * * CodegenVisitor deduplicates inline dependencies by (module, name) during code generation. */ class InlineDependency(name: String, val module: String, val renderer: (RustWriter) -> Unit) : RustDependency(name) { class InlineDependency( name: String, val module: String, val extraDependencies: List<RustDependency> = listOf(), val renderer: (RustWriter) -> Unit ) : RustDependency(name) { override fun version(): String { return renderer(RustWriter.forModule("_")).hashCode().toString() } override fun dependencies(): List<RustDependency> { return extraDependencies } fun key() = "$module::$name" companion object { fun forRustFile(name: String, module: String, filename: String): InlineDependency { fun forRustFile( name: String, module: String, filename: String, vararg additionalDepencies: RustDependency ): InlineDependency { // The inline crate is loaded as a dependency on the runtime classpath val rustFile = this::class.java.getResource("/inlineable/src/$filename") check(rustFile != null) return InlineDependency(name, module) { writer -> return InlineDependency(name, module, additionalDepencies.toList()) { writer -> writer.raw(rustFile.readText()) } } fun uuid() = forRustFile("v4", "uuid", "uuid.rs") // TODO: putting this in the "error" module risks conflicting with a modeled error named "GenericError" fun genericError() = forRustFile("GenericError", "types", "generic_error.rs", CargoDependency.Serde) } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt +1 −0 Original line number Diff line number Diff line Loading @@ -131,6 +131,7 @@ class RustWriter private constructor( ) : CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(namespace)) { companion object { fun root() = forModule(null) fun forModule(module: String?): RustWriter = if (module == null) { RustWriter("lib.rs", "crate") } else { Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +3 −1 Original line number Diff line number Diff line Loading @@ -122,9 +122,11 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na fun SerdeJson(path: String) = RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json") val GenericError = RuntimeType("GenericError", InlineDependency.genericError(), "crate::types") fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType( name = name, dependency = InlineDependency(name, module, func), dependency = InlineDependency(name, module, listOf(), func), namespace = "crate::$module" ) } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CombinedErrorGenerator.kt 0 → 100644 +87 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.generators import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.lang.Attribute import software.amazon.smithy.rust.codegen.lang.Derives import software.amazon.smithy.rust.codegen.lang.RustMetadata import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType /** * For a given Operation ([this]), return the symbol referring to the unified error? This can be used * if you, eg. want to return a unfied error from a function: * * ```kotlin * rustWriter.rustBlock("fn get_error() -> #T", operation.errorSymbol(symbolProvider)) { * write("todo!() // function body") * } * ``` */ fun OperationShape.errorSymbol(symbolProvider: SymbolProvider): RuntimeType { val symbol = symbolProvider.toSymbol(this) return RuntimeType("${symbol.name}Error", null, "crate::error") } /** * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants, * but we must still combine those variants into an enum covering all possible errors for a given operation. */ class CombinedErrorGenerator( model: Model, private val symbolProvider: SymbolProvider, private val operation: OperationShape ) { private val operationIndex = OperationIndex.of(model) fun render(writer: RustWriter) { val errors = operationIndex.getErrors(operation) val symbol = operation.errorSymbol(symbolProvider) val meta = RustMetadata( derives = Derives(setOf(RuntimeType.StdFmt("Debug"))), additionalAttributes = listOf(Attribute.NonExhaustive), public = true ) meta.render(writer) writer.rustBlock("enum ${symbol.name}") { errors.forEach { errorVariant -> val errorSymbol = symbolProvider.toSymbol(errorVariant) write("${errorSymbol.name}(#T),", errorSymbol) } rust( """ /// An unexpected error, eg. invalid JSON returned by the service Unhandled(Box<dyn #T>), """, RuntimeType.StdError ) } writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdFmt("Display")) { rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { rustBlock("match self") { errors.forEach { val errorSymbol = symbolProvider.toSymbol(it) rust("""${symbol.name}::${errorSymbol.name}(inner) => inner.fmt(f),""") } rust("${symbol.name}::Unhandled(inner) => inner.fmt(f)") } } } writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) { rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { rustBlock("match self") { errors.forEach { val errorSymbol = symbolProvider.toSymbol(it) rust("""${symbol.name}::${errorSymbol.name}(inner) => Some(inner),""") } rust("${symbol.name}::Unhandled(inner) => Some(inner.as_ref())") } } } } }
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ErrorGenerator.kt +13 −7 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.RetryableTrait import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.StdError import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.StdFmt Loading @@ -37,19 +38,24 @@ class ErrorGenerator( error.isServerError -> "ErrorCause::Server" else -> "ErrorCause::Unknown(${error.value.dq()})" } val messageShape = shape.getMember("message") val message = messageShape.map { "self.message.as_deref()" }.orElse("None") writer.rustBlock("impl ${symbol.name}") { write("// TODO: create shared runtime crate") write("// fn at_fault(&self) -> ErrorCause { $errorCause }") write("pub fn retryable(&self) -> bool { $retryable }") write("pub fn throttling(&self) -> bool { $throttling }") rust( """ pub fn retryable(&self) -> bool { $retryable } pub fn throttling(&self) -> bool { $throttling } pub fn code(&self) -> &str { ${shape.id.name.dq()} } pub fn message(&self) -> Option<&str> { $message } """ ) } writer.rustBlock("impl #T for ${symbol.name}", StdFmt("Display")) { rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { val message = shape.getMember("message") write("write!(f, ${symbol.name.dq()})?;") if (message.isPresent) { OptionForEach(symbolProvider.toSymbol(message.get()), "&self.message") { field -> messageShape.map { OptionForEach(symbolProvider.toSymbol(it), "&self.message") { field -> write("""write!(f, ": {}", $field)?;""") } } Loading