Unverified Commit 9609cc3a authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Support the remaining fields on request protocol tests (#50)

* Support the remaining fields on request protocol tests

Add support for remaining requirements and & make some small changes to support them.

* Satisfy rustfmt
parent 6011a0fd
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -106,7 +106,7 @@ class RustWriter private constructor(private val filename: String, val namespace
        filename.removeSuffix(".rs").split('/').last()
    } else null

    private fun safeName(prefix: String = "var"): String {
    fun safeName(prefix: String = "var"): String {
        n += 1
        return "${prefix}_$n"
    }
+6 −1
Original line number Diff line number Diff line
@@ -52,12 +52,17 @@ abstract class HttpProtocolGenerator(
            val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body
            val body = shapeId?.let { model.expectShape(it, StructureShape::class.java) }
            toBodyImpl(this, inputShape, body)
            // TODO: streaming shapes need special support
            rustBlock("pub fn assemble(builder: \$T, body: Vec<u8>) -> \$T<Vec<u8>>", RuntimeType.HttpRequestBuilder, RuntimeType.Http("request::Request")) {
                write("builder.header(\$T, body.len()).body(body)", RuntimeType.Http("header::CONTENT_LENGTH"))
                write(""".expect("http request should be valid")""")
            }
        }
    }

    protected fun httpBuilderFun(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) {
        implBlockWriter.rustBlock(
            "pub fn build_http_request(&self) -> \$T",
            "pub fn request_builder_base(&self) -> \$T",
            RuntimeType.HttpRequestBuilder
        ) {
            f(this)
+27 −1
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.inputShape
import java.util.logging.Logger

data class ProtocolSupport(
    val requestBodySerialization: Boolean
@@ -24,6 +25,7 @@ class HttpProtocolTestGenerator(
    private val operationShape: OperationShape,
    private val writer: RustWriter
) {
    private val logger = Logger.getLogger(javaClass.name)
    // TODO: remove these once Smithy publishes fixes.
    // These tests are not even attempted to be compiled
    val DisableTests = setOf(
@@ -102,7 +104,11 @@ class HttpProtocolTestGenerator(
            writeInline("let input =")
            instantiator.render(this, inputShape, httpRequestTestCase.params)
            write(";")
            write("let http_request = input.build_http_request().body(()).unwrap();")
            if (protocolSupport.requestBodySerialization) {
                write("let http_request = ${protocolConfig.symbolProvider.toSymbol(inputShape).name}::assemble(input.request_builder_base(), input.build_body());")
            } else {
                write("let http_request = ${protocolConfig.symbolProvider.toSymbol(inputShape).name}::assemble(input.request_builder_base(), vec![]);")
            }
            with(httpRequestTestCase) {
                write(
                    """
@@ -115,11 +121,31 @@ class HttpProtocolTestGenerator(
            checkForbidQueryParams(this, httpRequestTestCase.forbidQueryParams)
            checkRequiredQueryParams(this, httpRequestTestCase.requireQueryParams)
            checkHeaders(this, httpRequestTestCase.headers)
            checkForbidHeaders(this, httpRequestTestCase.forbidHeaders)
            checkRequiredHeaders(this, httpRequestTestCase.requireHeaders)
            if (protocolSupport.requestBodySerialization) {
                checkBody(this, httpRequestTestCase.body.orElse(""), httpRequestTestCase.bodyMediaType.orElse(null))
            }

            // Explicitly warn if the test case defined parameters that we aren't doing anything with
            with(httpRequestTestCase) {
                if (authScheme.isPresent) {
                    logger.warning("Test case provided authScheme but this was ignored")
                }
                if (!httpRequestTestCase.vendorParams.isEmpty) {
                    logger.warning("Test case provided vendorParams but these were ignored")
                }
            }
        }
    }

    private fun checkRequiredHeaders(rustWriter: RustWriter, requireHeaders: List<String>) {
        basicCheck(requireHeaders, rustWriter, "required_headers", "require_headers")
    }

    private fun checkForbidHeaders(rustWriter: RustWriter, forbidHeaders: List<String>) {
        basicCheck(forbidHeaders, rustWriter, "forbidden_headers", "forbid_headers")
    }

    private fun checkBody(rustWriter: RustWriter, body: String, mediaType: String?) {
        if (body == "") {
+5 −1
Original line number Diff line number Diff line
@@ -121,10 +121,14 @@ class HttpTraitBindingGenerator(
                    ListForEach(memberType, field) { innerField, targetId ->
                        val innerMemberType = model.expectShape(targetId)
                        val formatted = headerFmtFun(innerMemberType, memberShape, innerField)
                        val safeName = safeName("formatted")
                        write("let $safeName = $formatted;")
                        rustBlock("if !$safeName.is_empty()") {
                            write("builder = builder.header(${httpBinding.locationName.dq()}, $formatted);")
                        }
                    }
                }
            }
            write("builder")
        }
        return true
+1 −1
Original line number Diff line number Diff line
@@ -108,7 +108,7 @@ class BasicAwsJsonGenerator(
        operationShape: OperationShape,
        inputShape: StructureShape
    ) {
        implBlockWriter.rustBlock("pub fn build_http_request(&self) -> \$T", RuntimeType.HttpRequestBuilder) {
        httpBuilderFun(implBlockWriter) {
            write("let builder = \$T::new();", RuntimeType.HttpRequestBuilder)
            write(
                """
Loading