Commit 91d80a42 authored by Fahad Zubair's avatar Fahad Zubair
Browse files

Implement `parseEventStreamErrorMetadata`, and change client test case for event stream

parent b961f18e
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.generateRustPayloadInitializer
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.testModule
@@ -46,7 +47,7 @@ class ClientEventStreamUnmarshallerGeneratorTest {
                        "exception",
                        "UnmodeledError",
                        "${testCase.responseContentType}",
                        br#"${testCase.validUnmodeledError}"#
                        ${testCase.generateRustPayloadInitializer(testCase.validUnmodeledError)}
                    );
                    let result = $generator::new().unmarshall(&message);
                    assert!(result.is_ok(), "expected ok, got: {:?}", result);
+18 −2
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
@@ -140,9 +141,24 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
    override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
        RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")

    // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
    override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
        TODO("rpcv2Cbor event streams have not yet been implemented")
        ProtocolFunctions.crossOperationFn("parse_event_stream_error_metadata") { fnName ->
            // `HeaderMap::new()` doesn't allocate.
            rustTemplate(
                """
                pub fn $fnName(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{DeserializeError}> {
                    #{cbor_errors}::parse_error_metadata(0, &#{Headers}::new(), payload)
                }
                """,
                "cbor_errors" to RuntimeType.cborErrors(runtimeConfig),
                "Bytes" to RuntimeType.Bytes,
                "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
                "DeserializeError" to
                    CargoDependency.smithyCbor(runtimeConfig).toType()
                        .resolve("decode::DeserializeError"),
                "Headers" to RuntimeType.headers(runtimeConfig),
            )
        }

    // Unlike other protocols, the `rpcv2Cbor` protocol requires that `Content-Length` is always set
    // unless there is no input or if the operation is an event stream, see
+35 −5
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.util.lookup
import java.util.Base64

object EventStreamUnmarshallTestCases {
    fun RustWriter.writeUnmarshallTestCases(
@@ -109,7 +110,7 @@ object EventStreamUnmarshallTestCases {
                    "event",
                    "MessageWithStruct",
                    "${testCase.responseContentType}",
                    br##"${testCase.validTestStruct}"##
                    ${testCase.generateRustPayloadInitializer(testCase.validTestStruct)}
                );
                let result = $generator::new().unmarshall(&message);
                assert!(result.is_ok(), "expected ok, got: {:?}", result);
@@ -140,7 +141,7 @@ object EventStreamUnmarshallTestCases {
                    "event",
                    "MessageWithUnion",
                    "${testCase.responseContentType}",
                    br##"${testCase.validTestUnion}"##
                    ${testCase.generateRustPayloadInitializer(testCase.validTestUnion)}
                );
                let result = $generator::new().unmarshall(&message);
                assert!(result.is_ok(), "expected ok, got: {:?}", result);
@@ -221,7 +222,7 @@ object EventStreamUnmarshallTestCases {
                    "event",
                    "MessageWithNoHeaderPayloadTraits",
                    "${testCase.responseContentType}",
                    br##"${testCase.validMessageWithNoHeaderPayloadTraits}"##
                    ${testCase.generateRustPayloadInitializer(testCase.validMessageWithNoHeaderPayloadTraits)}
                );
                let result = $generator::new().unmarshall(&message);
                assert!(result.is_ok(), "expected ok, got: {:?}", result);
@@ -246,7 +247,7 @@ object EventStreamUnmarshallTestCases {
                    "exception",
                    "SomeError",
                    "${testCase.responseContentType}",
                    br##"${testCase.validSomeError}"##
                    ${testCase.generateRustPayloadInitializer(testCase.validSomeError)}
                );
                let result = $generator::new().unmarshall(&message);
                assert!(result.is_ok(), "expected ok, got: {:?}", result);
@@ -267,7 +268,7 @@ object EventStreamUnmarshallTestCases {
                "event",
                "MessageWithBlob",
                "wrong-content-type",
                br#"${testCase.validTestStruct}"#
                ${testCase.generateRustPayloadInitializer(testCase.validTestStruct)}
            );
            let result = $generator::new().unmarshall(&message);
            assert!(result.is_err(), "expected error, got: {:?}", result);
@@ -275,6 +276,35 @@ object EventStreamUnmarshallTestCases {
            """,
        )
    }

    /**
     * Generates a Rust-compatible initializer string for a given payload.
     *
     * This function handles two different scenarios based on the event stream message content type:
     *
     * 1. For CBOR payloads (content type "application/cbor"):
     *    - The input payload is expected to be a base64 encoded CBOR value.
     *    - It decodes the base64 string and generates a Rust byte array initializer.
     *    - The output format is: &[0xFFu8, 0xFFu8, ...] where FF are hexadecimal values.
     *
     * 2. For all other content types:
     *    - It returns a Rust raw string literal initializer.
     *    - The output format is: br##"original_payload"##
     */
    fun EventStreamTestModels.TestCase.generateRustPayloadInitializer(payload: String): String {
        return if (this.eventStreamMessageContentType == "application/cbor") {
            Base64.getDecoder().decode(payload)
                .joinToString(
                    prefix = "&[",
                    postfix = "]",
                    transform = { "0x${it.toUByte().toString(16).padStart(2, '0')}u8" },
                )
        } else {
            """
            br##"$payload"##
            """
        }
    }
}

internal fun conditionalBuilderInput(
+11 −3
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ pub fn parse_error_metadata(
    _response_headers: &Headers,
    response_body: &[u8],
) -> Result<ErrorMetadataBuilder, DeserializeError> {
    fn error_code(
    fn error_code_and_message(
        mut builder: ErrorMetadataBuilder,
        decoder: &mut Decoder,
    ) -> Result<ErrorMetadataBuilder, DeserializeError> {
@@ -41,6 +41,14 @@ pub fn parse_error_metadata(
                let code = decoder.str()?;
                builder.code(sanitize_error_code(&code))
            }
            "message" | "Message" | "errorMessage" => {
                // Silently skip if `message` is not a string. This allows for custom error
                // structures that might use different types for the message field.
                match decoder.str() {
                    Ok(message) => builder.message(message),
                    Err(_) => builder
                }
            }
            _ => {
                decoder.skip()?;
                builder
@@ -60,13 +68,13 @@ pub fn parse_error_metadata(
                    break;
                }
                _ => {
                    builder = error_code(builder, decoder)?;
                    builder = error_code_and_message(builder, decoder)?;
                }
            };
        },
        Some(n) => {
            for _ in 0..n {
                builder = error_code(builder, decoder)?;
                builder = error_code_and_message(builder, decoder)?;
            }
        }
    };