Unverified Commit 45f27111 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Fix protocol tests against the orchestrator (#2768)



This PR fixes the protocol tests in orchestrator mode, and adds
`--all-targets` to the orchestrator CI checks.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarZelda Hessler <zhessler@amazon.com>
parent 988eb617
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Compani
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.FuturesUtil
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.HdrHistogram
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.Hound
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.HttpBody
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.SerdeJson
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.Smol
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.TempFile
@@ -122,6 +123,7 @@ class S3TestDependencies(private val codegenContext: ClientCodegenContext) : Lib
            addDependency(BytesUtils.toDevDependency())
            addDependency(FastRand.toDevDependency())
            addDependency(HdrHistogram)
            addDependency(HttpBody.toDevDependency())
            addDependency(Smol)
            addDependency(TempFile)
            addDependency(TracingAppender)
+19 −3
Original line number Diff line number Diff line
@@ -321,7 +321,7 @@ class DefaultProtocolTestGenerator(
        writeInline("let expected_output =")
        instantiator.render(this, expectedShape, testCase.params)
        write(";")
        write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder)
        write("let mut http_response = #T::new()", RuntimeType.HttpResponseBuilder)
        testCase.headers.forEach { (key, value) ->
            writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})")
        }
@@ -360,7 +360,9 @@ class DefaultProtocolTestGenerator(
                let de = #{OperationDeserializer};
                let parsed = de.deserialize_streaming(&mut http_response);
                let parsed = parsed.unwrap_or_else(|| {
                    let http_response = http_response.map(|body|#{copy_from_slice}(body.bytes().unwrap()));
                    let http_response = http_response.map(|body| {
                        #{SdkBody}::from(#{copy_from_slice}(body.bytes().unwrap()))
                    });
                    de.deserialize_nonstreaming(&http_response)
                });
                """,
@@ -369,12 +371,19 @@ class DefaultProtocolTestGenerator(
                "copy_from_slice" to RuntimeType.Bytes.resolve("copy_from_slice"),
                "ResponseDeserializer" to CargoDependency.smithyRuntimeApi(codegenContext.runtimeConfig).toType()
                    .resolve("client::orchestrator::ResponseDeserializer"),
                "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig),
            )
        }
        if (expectedShape.hasTrait<ErrorTrait>()) {
            val errorSymbol = codegenContext.symbolProvider.symbolForOperationError(operationShape)
            val errorVariant = codegenContext.symbolProvider.toSymbol(expectedShape).name
            rust("""let parsed = parsed.expect_err("should be error response");""")
            if (codegenContext.smithyRuntimeMode.defaultToOrchestrator) {
                rustTemplate(
                    """let parsed: &#{Error} = parsed.as_operation_error().expect("operation error").downcast_ref().unwrap();""",
                    "Error" to codegenContext.symbolProvider.symbolForOperationError(operationShape),
                )
            }
            rustBlock("if let #T::$errorVariant(parsed) = parsed", errorSymbol) {
                compareMembers(expectedShape)
            }
@@ -382,7 +391,14 @@ class DefaultProtocolTestGenerator(
                rust("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
            }
        } else {
            if (codegenContext.smithyRuntimeMode.defaultToMiddleware) {
                rust("let parsed = parsed.unwrap();")
            } else {
                rustTemplate(
                    """let parsed: #{Output} = *parsed.expect("should be successful response").downcast().unwrap();""",
                    "Output" to codegenContext.symbolProvider.toSymbol(expectedShape),
                )
            }
            compareMembers(outputShape)
        }
    }
+2 −2
Original line number Diff line number Diff line
@@ -42,14 +42,14 @@ cd aws/sdk/build/aws-sdk/sdk
for service in "${services_that_compile[@]}"; do
    pushd "${service}"
    echo -e "${C_YELLOW}# Running 'cargo check --all-features' on '${service}'${C_RESET}"
    RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo check --all-features
    RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo check --all-features --all-targets
    popd
done

for service in "${services_that_pass_tests[@]}"; do
    pushd "${service}"
    echo -e "${C_YELLOW}# Running 'cargo test --all-features' on '${service}'${C_RESET}"
    RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo test --all-features --no-fail-fast
    RUSTFLAGS="${RUSTFLAGS:-} --cfg aws_sdk_orchestrator_mode" cargo test --all-features --all-targets --no-fail-fast
    popd
done