Unverified Commit c9659093 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Fallback to a generic error representation for unmodeled errors (#171)

* Fallback to a generic error representation for unmodeled errors

* DRY out some dependencies

* Fix name is instant_iso8601 resource
parent edb395f6
Loading
Loading
Loading
Loading
+23 −14
Original line number Diff line number Diff line
@@ -68,28 +68,36 @@ class InlineDependency(
    companion object {
        fun forRustFile(
            name: String,
            module: String,
            filename: String,
            vararg additionalDepencies: RustDependency
            vararg additionalDependencies: RustDependency
        ): InlineDependency {
            val module = name
            val filename = "$name.rs"
            // The inline crate is loaded as a dependency on the runtime classpath
            val rustFile = this::class.java.getResource("/inlineable/src/$filename")
            check(rustFile != null) { "Rust file $filename was missing from the resource bundle!" }
            return InlineDependency(name, module, additionalDepencies.toList()) { writer ->
            return InlineDependency(name, module, additionalDependencies.toList()) { writer ->
                writer.raw(rustFile.readText())
            }
        }

        fun uuid() = forRustFile("v4", "uuid", "uuid.rs")
        fun genericError() = forRustFile("GenericError", "types", "generic_error.rs", CargoDependency.Serde)
        fun errorCode() = forRustFile("error_code", "error_code", "error_code.rs", CargoDependency.Http)
        fun docJson() = forRustFile("doc_json", "doc_json", "doc_json.rs", CargoDependency.Serde)
        fun instantEpoch() = forRustFile("instant_epoch", "instant_epoch", "instant_epoch.rs", CargoDependency.Serde)
        fun instantHttpDate() = forRustFile("instant_httpdate", "instant_httpdate", "instant_httpdate.rs", CargoDependency.Serde)
        fun instant8601() = forRustFile("instant_8601", "instant_8601", "instant_iso8601.rs", CargoDependency.Serde)
        fun awsJsonErrors(runtimeConfig: RuntimeConfig) =
            forRustFile("aws_json_errors", CargoDependency.Http, CargoDependency.SmithyTypes(runtimeConfig))

        fun idempotencyToken() = forRustFile("idempotency_token", "idempotency_token", "idempotency_token.rs", CargoDependency.Rand)
        fun blobSerde(runtimeConfig: RuntimeConfig) = forRustFile("blob_serde", "blob_serde", "blob_serde.rs", CargoDependency.Serde, CargoDependency.SmithyHttp(runtimeConfig))
        fun docJson() = forRustFile("doc_json", CargoDependency.Serde)
        fun instantEpoch() = forRustFile("instant_epoch", CargoDependency.Serde)
        fun instantHttpDate() =
            forRustFile("instant_httpdate", CargoDependency.Serde)

        fun instant8601() = forRustFile("instant_iso8601", CargoDependency.Serde)

        fun idempotencyToken() =
            forRustFile("idempotency_token", CargoDependency.Rand)

        fun blobSerde(runtimeConfig: RuntimeConfig) = forRustFile(
            "blob_serde",
            CargoDependency.Serde,
            CargoDependency.SmithyHttp(runtimeConfig)
        )
    }
}

@@ -162,7 +170,8 @@ data class CargoDependency(
            "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = Dev
        )

        val SerdeJson: CargoDependency = CargoDependency("serde_json", CratesIo("1"), features = listOf("float_roundtrip"))
        val SerdeJson: CargoDependency =
            CargoDependency("serde_json", CratesIo("1"), features = listOf("float_roundtrip"))
        val Serde = CargoDependency("serde", CratesIo("1"), features = listOf("derive"))
    }
}
+2 −2
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ fun <T : CodeWriter> T.conditionalBlock(
 * Convenience wrapper that tells Intellij that the contents of this block are Rust
 */
fun <T : CodeWriter> T.rust(
    @Language("Rust", prefix = "fn foo(&self) {", suffix = "}") contents: String,
    @Language("Rust", prefix = "macro_rules! foo { () =>  {{ ", suffix = "}}}") contents: String,
    vararg args: Any
) {
    this.write(contents, *args)
@@ -80,7 +80,7 @@ fun <T : CodeWriter> T.rust(
 * This enables writing code like:
 *
 * ```kotlin
 * writer.rustCtx("""
 * writer.rustTemplate("""
 * let some_val = #{operation}::from_response(response);
 * let serialized = #{serde_json}::to_json(some_val);
 * """, "operation" to operationSymbol, "serde_json" to RuntimeType.SerdeJson)
+12 −8
Original line number Diff line number Diff line
@@ -120,6 +120,7 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
        fun Serde(path: String) = RuntimeType(
            path, dependency = CargoDependency.Serde, namespace = "serde"
        )

        val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde")
        val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde")
        val Serializer = RuntimeType("Serializer", CargoDependency.Serde, namespace = "serde")
@@ -129,19 +130,22 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n

        val SJ = RuntimeType(null, dependency = CargoDependency.SerdeJson, namespace = "serde_json")

        val GenericError = RuntimeType("GenericError", InlineDependency.genericError(), "crate::types")
        val ErrorCode = RuntimeType("error_code", dependency = InlineDependency.errorCode(), namespace = "crate")
        fun awsJsonErrors(runtimeConfig: RuntimeConfig) =
            forInlineDependency(InlineDependency.awsJsonErrors(runtimeConfig))

        val DocJson = RuntimeType("doc_json", InlineDependency.docJson(), "crate")
        val DocJson = forInlineDependency(InlineDependency.docJson())

        val InstantEpoch = RuntimeType("instant_epoch", InlineDependency.instantEpoch(), "crate")
        val InstantHttpDate = RuntimeType("instant_httpdate", InlineDependency.instantHttpDate(), "crate")
        val Instant8601 = RuntimeType("instant_8601", InlineDependency.instant8601(), "crate")
        val IdempotencyToken = RuntimeType("idempotency_token", InlineDependency.idempotencyToken(), "crate")
        val InstantEpoch = forInlineDependency(InlineDependency.instantEpoch())
        val InstantHttpDate = forInlineDependency(InlineDependency.instantHttpDate())
        val Instant8601 = forInlineDependency(InlineDependency.instant8601())
        val IdempotencyToken = forInlineDependency(InlineDependency.idempotencyToken())

        val Config = RuntimeType("config", null, "crate")

        fun BlobSerde(runtimeConfig: RuntimeConfig) = RuntimeType("blob_serde", InlineDependency.blobSerde(runtimeConfig), "crate")
        fun BlobSerde(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.blobSerde(runtimeConfig))

        private fun forInlineDependency(inlineDependency: InlineDependency) =
            RuntimeType(inlineDependency.name, inlineDependency, namespace = "crate")

        fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType(
            name = name,
+1 −1
Original line number Diff line number Diff line
@@ -230,7 +230,7 @@ class HttpProtocolTestGenerator(
                write("assert_eq!(expected_output, actual_error);")
            }
            rustBlock("else") {
                write("panic!(\"wrong variant: {:?}\", parsed);")
                write("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
            }
        } else {
            write("assert_eq!(parsed.unwrap(), expected_output);")
+16 −10
Original line number Diff line number Diff line
@@ -183,22 +183,29 @@ class BasicAwsJsonGenerator(
        val errorSymbol = operationShape.errorSymbol(symbolProvider)
        val bodyId = outputShape.expectTrait(SyntheticOutputTrait::class.java).body
        val bodyShape = bodyId?.let { model.expectShape(bodyId, StructureShape::class.java) }
        val jsonErrors = RuntimeType.awsJsonErrors(protocolConfig.runtimeConfig)
        fromResponseFun(implBlockWriter, operationShape) {
            rustBlock("if #T::is_error(&response)", RuntimeType.ErrorCode) {
            rustBlock("if #T::is_error(&response)", jsonErrors) {
                // TODO: experiment with refactoring this segment into `error_code.rs`. Currently it isn't
                // to avoid the need to double deserialize the body.
                rustTemplate(
                    """
                    let body: #{sj}::Value = #{sj}::from_slice(response.body().as_ref())
                    let body = #{sj}::from_slice(response.body().as_ref())
                        .unwrap_or_else(|_|#{sj}::json!({}));
                    let error_code = #{error_code}::error_type_from_header(&response).map_err(#{error_symbol}::unhandled)?;
                    let error_code = error_code.or_else(||#{error_code}::error_type_from_body(&body));
                    let error_code = error_code.ok_or_else(||#{error_symbol}::unhandled("no error code".to_string()))?;
                    let error_code = #{error_code}::sanitize_error_code(error_code);
                    let generic = #{aws_json_errors}::parse_generic_error(&response, &body);
                    """,
                    "error_code" to RuntimeType.ErrorCode, "error_symbol" to errorSymbol, "sj" to RuntimeType.SJ
                    "aws_json_errors" to jsonErrors, "sj" to RuntimeType.SJ
                )
                if (operationShape.errors.isNotEmpty()) {
                    rustTemplate(
                        """

                    let error_code = match generic.code() {
                        Some(code) => code,
                        None => return Err(#{error_symbol}::unhandled(generic))
                    };""",
                        "error_symbol" to errorSymbol
                    )
                    withBlock("return Err(match error_code {", "})") {
                        // approx:
                        /*
@@ -210,8 +217,7 @@ class BasicAwsJsonGenerator(
                        parseErrorVariants(operationShape, errorSymbol)
                    }
                } else {
                    // TODO: this should actually be a generic error that tries to parse a message
                    write("return Err(#T::unhandled(error_code))", errorSymbol)
                    write("return Err(#T::unhandled(generic))", errorSymbol)
                }
            }
            // let body: OperationOutputBody = serde_json::from_slice(response.body()...);
@@ -260,6 +266,6 @@ class BasicAwsJsonGenerator(
                write("Err(e) => #T::unhandled(e)", errorSymbol)
            }
        }
        write("unknown => #T::unhandled(unknown)", errorSymbol)
        write("_ => #T::unhandled(generic)", errorSymbol)
    }
}
Loading