Unverified Commit 48b4ef6a authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Generated Combined Error Shapes (#71)

* Add combined error generator

* Deterministic operation generation

* Test generator refactoring

* Fix test failure in inlineable

* Back out needless changes to Rust.kt and merge error types

* Add some docs
parent 075d4094
Loading
Loading
Loading
Loading
+22 −4
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ data class Local(val basePath: String) : DependencyLocation()

sealed class RustDependency(open val name: String) : SymbolDependencyContainer {
    abstract fun version(): String
    open fun dependencies(): List<RustDependency> = listOf()
    override fun getDependencies(): List<SymbolDependency> {
        return listOf(
            SymbolDependency
@@ -27,7 +28,7 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer {
                .packageName(name).version(version())
                // We rely on retrieving the structured dependency from the symbol later
                .putProperty(PropertyKey, this).build()
        )
        ) + dependencies().flatMap { it.dependencies }
    }

    companion object {
@@ -48,24 +49,41 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer {
 *
 * CodegenVisitor deduplicates inline dependencies by (module, name) during code generation.
 */
class InlineDependency(name: String, val module: String, val renderer: (RustWriter) -> Unit) : RustDependency(name) {
class InlineDependency(
    name: String,
    val module: String,
    val extraDependencies: List<RustDependency> = listOf(),
    val renderer: (RustWriter) -> Unit
) : RustDependency(name) {
    override fun version(): String {
        return renderer(RustWriter.forModule("_")).hashCode().toString()
    }

    override fun dependencies(): List<RustDependency> {
        return extraDependencies
    }

    fun key() = "$module::$name"

    companion object {
        fun forRustFile(name: String, module: String, filename: String): InlineDependency {
        fun forRustFile(
            name: String,
            module: String,
            filename: String,
            vararg additionalDepencies: RustDependency
        ): InlineDependency {
            // The inline crate is loaded as a dependency on the runtime classpath
            val rustFile = this::class.java.getResource("/inlineable/src/$filename")
            check(rustFile != null)
            return InlineDependency(name, module) { writer ->
            return InlineDependency(name, module, additionalDepencies.toList()) { writer ->
                writer.raw(rustFile.readText())
            }
        }

        fun uuid() = forRustFile("v4", "uuid", "uuid.rs")

        // TODO: putting this in the "error" module risks conflicting with a modeled error named "GenericError"
        fun genericError() = forRustFile("GenericError", "types", "generic_error.rs", CargoDependency.Serde)
    }
}

+1 −0
Original line number Diff line number Diff line
@@ -131,6 +131,7 @@ class RustWriter private constructor(
) :
    CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(namespace)) {
    companion object {
        fun root() = forModule(null)
        fun forModule(module: String?): RustWriter = if (module == null) {
            RustWriter("lib.rs", "crate")
        } else {
+3 −1
Original line number Diff line number Diff line
@@ -122,9 +122,11 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        fun SerdeJson(path: String) =
            RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json")

        val GenericError = RuntimeType("GenericError", InlineDependency.genericError(), "crate::types")

        fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType(
            name = name,
            dependency = InlineDependency(name, module, func),
            dependency = InlineDependency(name, module, listOf(), func),
            namespace = "crate::$module"
        )
    }
+87 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.lang.Attribute
import software.amazon.smithy.rust.codegen.lang.Derives
import software.amazon.smithy.rust.codegen.lang.RustMetadata
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.rust
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType

/**
 * For a given Operation ([this]), return the symbol referring to the unified error? This can be used
 * if you, eg. want to return a unfied error from a function:
 *
 * ```kotlin
 * rustWriter.rustBlock("fn get_error() -> #T", operation.errorSymbol(symbolProvider)) {
 *   write("todo!() // function body")
 * }
 * ```
 */
fun OperationShape.errorSymbol(symbolProvider: SymbolProvider): RuntimeType {
    val symbol = symbolProvider.toSymbol(this)
    return RuntimeType("${symbol.name}Error", null, "crate::error")
}

/**
 * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants,
 * but we must still combine those variants into an enum covering all possible errors for a given operation.
 */
class CombinedErrorGenerator(
    model: Model,
    private val symbolProvider: SymbolProvider,
    private val operation: OperationShape
) {

    private val operationIndex = OperationIndex.of(model)
    fun render(writer: RustWriter) {
        val errors = operationIndex.getErrors(operation)
        val symbol = operation.errorSymbol(symbolProvider)
        val meta = RustMetadata(
            derives = Derives(setOf(RuntimeType.StdFmt("Debug"))),
            additionalAttributes = listOf(Attribute.NonExhaustive),
            public = true
        )
        meta.render(writer)
        writer.rustBlock("enum ${symbol.name}") {
            errors.forEach { errorVariant ->
                val errorSymbol = symbolProvider.toSymbol(errorVariant)
                write("${errorSymbol.name}(#T),", errorSymbol)
            }
            rust(
                """
                /// An unexpected error, eg. invalid JSON returned by the service
                Unhandled(Box<dyn #T>),
            """,
                RuntimeType.StdError
            )
        }
        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdFmt("Display")) {
            rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
                rustBlock("match self") {
                    errors.forEach {
                        val errorSymbol = symbolProvider.toSymbol(it)
                        rust("""${symbol.name}::${errorSymbol.name}(inner) => inner.fmt(f),""")
                    }
                    rust("${symbol.name}::Unhandled(inner) => inner.fmt(f)")
                }
            }
        }

        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) {
            rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
                rustBlock("match self") {
                    errors.forEach {
                        val errorSymbol = symbolProvider.toSymbol(it)
                        rust("""${symbol.name}::${errorSymbol.name}(inner) => Some(inner),""")
                    }
                    rust("${symbol.name}::Unhandled(inner) => Some(inner.as_ref())")
                }
            }
        }
    }
}
+13 −7
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.RetryableTrait
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.rust
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.StdError
import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.StdFmt
@@ -37,19 +38,24 @@ class ErrorGenerator(
            error.isServerError -> "ErrorCause::Server"
            else -> "ErrorCause::Unknown(${error.value.dq()})"
        }
        val messageShape = shape.getMember("message")
        val message = messageShape.map { "self.message.as_deref()" }.orElse("None")
        writer.rustBlock("impl ${symbol.name}") {
            write("// TODO: create shared runtime crate")
            write("// fn at_fault(&self) -> ErrorCause { $errorCause }")
            write("pub fn retryable(&self) -> bool { $retryable }")
            write("pub fn throttling(&self) -> bool { $throttling }")
            rust(
                """
            pub fn retryable(&self) -> bool { $retryable }
            pub fn throttling(&self) -> bool { $throttling }
            pub fn code(&self) -> &str { ${shape.id.name.dq()} }
            pub fn message(&self) -> Option<&str> { $message }
                """
            )
        }

        writer.rustBlock("impl #T for ${symbol.name}", StdFmt("Display")) {
            rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
                val message = shape.getMember("message")
                write("write!(f, ${symbol.name.dq()})?;")
                if (message.isPresent) {
                    OptionForEach(symbolProvider.toSymbol(message.get()), "&self.message") { field ->
                messageShape.map {
                    OptionForEach(symbolProvider.toSymbol(it), "&self.message") { field ->
                        write("""write!(f, ": {}", $field)?;""")
                    }
                }
Loading