Unverified Commit fe6a56fc authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Deserialize responses (#81)

* Deserialize responses

* Fix http version

* Fix test failures after merge
parent a3d5f10e
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -81,9 +81,8 @@ class InlineDependency(
        }

        fun uuid() = forRustFile("v4", "uuid", "uuid.rs")

        fun genericError() = forRustFile("GenericError", "types", "generic_error.rs", CargoDependency.Serde)

        fun errorCode() = forRustFile("error_code", "error_code", "error_code.rs", CargoDependency.Http)
        fun docJson() = forRustFile("doc_json", "doc_json", "doc_json.rs", CargoDependency.Serde)
    }
}
@@ -156,7 +155,7 @@ data class CargoDependency(
            "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = Dev
        )

        val SerdeJson: CargoDependency = CargoDependency("serde_json", CratesIo("1"))
        val SerdeJson: CargoDependency = CargoDependency("serde_json", CratesIo("1"), features = listOf("float_roundtrip"))
        val Serde = CargoDependency("serde", CratesIo("1"), features = listOf("derive"))
    }
}
+29 −1
Original line number Diff line number Diff line
@@ -65,10 +65,38 @@ fun <T : CodeWriter> T.conditionalBlock(
/**
 * Convenience wrapper that tells Intellij that the contents of this block are Rust
 */
fun <T : CodeWriter> T.rust(@Language("Rust", prefix = "fn foo(&self) {", suffix = "}") contents: String, vararg args: Any) {
fun <T : CodeWriter> T.rust(
    @Language("Rust", prefix = "fn foo(&self) {", suffix = "}") contents: String,
    vararg args: Any
) {
    this.write(contents, *args)
}

/**
 * Experimental API for templating long blocks of Rust
 *
 * This enables writing code like:
 *
 * ```kotlin
 * writer.rustCtx("""
 * let some_val = #{operation}::from_response(response);
 * let serialized = #{serde_json}::to_json(some_val);
 * """, "operation" to operationSymbol, "serde_json" to RuntimeType.SerdeJson)
 * ```
 *
 * Before being passed to the underlying code writer, this syntax is rewritten to match the slightly more verbose
 * `CodeWriter` formatting: `#{name:T}`
 */
fun <T : CodeWriter> T.rustTemplate(
    @Language("Rust", prefix = "fn foo(&self) {", suffix = "}") contents: String,
    vararg ctx: Pair<String, Any>
) {
    this.pushState()
    this.putContext(ctx.toMap())
    this.write(contents.replace(Regex("""#\{([a-zA-Z_0-9]+)\}""")) { matchResult -> "#{${matchResult.groupValues[1]}:T}" })
    this.popState()
}

/*
 * Writes a Rust-style block, demarcated by curly braces
 */
+10 −3
Original line number Diff line number Diff line
@@ -32,10 +32,10 @@ data class RuntimeConfig(val cratePrefix: String = "smithy", val relativePath: S
    }
}

data class RuntimeType(val name: String, val dependency: RustDependency?, val namespace: String) {
data class RuntimeType(val name: String?, val dependency: RustDependency?, val namespace: String) {
    fun toSymbol(): Symbol {
        val builder = Symbol.builder().name(name).namespace(namespace, "::")
            .rustType(RustType.Opaque(name))
            .rustType(RustType.Opaque(name ?: ""))

        dependency?.run { builder.addDependency(this) }
        return builder.build()
@@ -47,7 +47,8 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        } else {
            "::"
        }
        return "$prefix$namespace::$name"
        val postFix = name?.let { "::$name" } ?: ""
        return "$prefix$namespace$postFix"
    }

    // TODO: refactor to be RuntimeTypeProvider a la Symbol provider that packages the `RuntimeConfig` state.
@@ -117,6 +118,9 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        val HttpRequestBuilder = Http("request::Builder")
        val HttpResponseBuilder = Http("response::Builder")

        fun Serde(path: String) = RuntimeType(
            path, dependency = CargoDependency.Serde, namespace = "serde"
        )
        val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde")
        val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde")
        val Serializer = RuntimeType("Serializer", CargoDependency.Serde, namespace = "serde")
@@ -124,7 +128,10 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        fun SerdeJson(path: String) =
            RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json")

        val SJ = RuntimeType(null, dependency = CargoDependency.SerdeJson, namespace = "serde_json")

        val GenericError = RuntimeType("GenericError", InlineDependency.genericError(), "crate::types")
        val ErrorCode = RuntimeType("error_code", dependency = InlineDependency.errorCode(), namespace = "crate")

        val DocJson = RuntimeType("doc_json", InlineDependency.docJson(), "crate")

+6 −0
Original line number Diff line number Diff line
@@ -72,6 +72,12 @@ class CombinedErrorGenerator(
            }
        }

        writer.rustBlock("impl ${symbol.name}") {
            writer.rustBlock("pub fn unhandled<E: Into<Box<dyn #T>>>(err: E) -> Self", RuntimeType.StdError) {
                write("${symbol.name}::Unhandled(err.into())")
            }
        }

        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) {
            rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
                rustBlock("match self") {
+25 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape

/**
 * Configuration needed to generate the client for a given Service<->Protocol pair
@@ -85,6 +86,17 @@ abstract class HttpProtocolGenerator(
                write("#T::assemble(self.input.request_builder_base(), self.input.build_body())", inputSymbol)
            }

            fromResponseImpl(this, operationShape)

            rustBlock(
                "pub fn parse_response(&self, response: #T<impl AsRef<[u8]>>) -> Result<#T, #T>",
                RuntimeType.Http("response::Response"),
                symbolProvider.toSymbol(operationShape.outputShape(model)),
                operationShape.errorSymbol(symbolProvider)
            ) {
                write("Self::from_response(response)")
            }

            rustBlock("pub fn new(input: #T) -> Self", inputSymbol) {
                write("Self { input }")
            }
@@ -108,6 +120,19 @@ abstract class HttpProtocolGenerator(
        }
    }

    protected fun fromResponseFun(implBlockWriter: RustWriter, operationShape: OperationShape, f: RustWriter.() -> Unit) {
        implBlockWriter.rustBlock(
            "fn from_response(response: #T<impl AsRef<[u8]>>) -> Result<#T, #T>",
            RuntimeType.Http("response::Response"),
            symbolProvider.toSymbol(operationShape.outputShape(model)),
            operationShape.errorSymbol(symbolProvider)
        ) {
            f(this)
        }
    }

    abstract fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape)

    /**
     * Add necessary methods to the impl block to generate the request body
     *
Loading