Unverified Commit 64b9b91f authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Fix errors for unions with unit target membershape (#3547)

## Motivation and Context

Unions that have a [unit target member
shape](https://smithy.io/2.0/spec/model.html#unit-type) do not have an
associated data in the generated Rust enum.

Closes Issue:
[2546](https://github.com/smithy-lang/smithy-rs/issues/2546

)

## Description

On the **server** side, when the union has constrained members, the code
generated for the conversion from the `Unconstrained` type to the
`Constrained` type incorrectly assumed that each Rust enum would have
associated data.

```
rust-server-codegen/src/unconstrained.rs:31:129
  |
  |               crate::unconstrained::some_union_with_unit_unconstrained::SomeUnionWithUnitUnconstrained::Option1(unconstrained) => Self::Option1(
    |                                                                                                                                   -^^^^^^^^^^^^- help: consider using a semicolon here to finish the statement: `;`
    |  _________________________________________________________________________________________________________________________________|
  | |
  | |                 unconstrained
  | |             ),
  | |_____________- call expression requires function
    |
   ::: rust-server-codegen/src/model.rs:152:5
    |
    |       Option1,
    |       ------- `SomeUnionWithUnit::Option1` defined here
```

The marshaling code for event streams with unit target types incorrectly
assumed that the variant would have associated data.

```
rust-server-codegen/src/event_stream_serde.rs

impl ::aws_smithy_eventstream::frame::MarshallMessage for TestEventMarshaller {
        fn marshal() {
            let payload = match input {
                Self::Input::KeepAlive(inner) => {
```

On the **client** side, the `event_stream_serde` code incorrectly
assumes that a union member, which has the `@streaming` trait applied to
it, takes a `model::Unit` type.

```
rust-client-codegen/src/event_stream_serde.rs:
    |
    |                     crate::types::TestEvent::KeepAlive(crate::types::Unit::builder().build()),
    |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^---------------------------------------
    |                     |
    |                     call expression requires function
    |
   ::: rust-client-codegen/src/types/_test_event.rs
    |
    |     KeepAlive,
    |     --------- `TestEvent::KeepAlive` defined here
```

## Testing

A unit test has been added that tests the following model:

```
            $version: "2"
            namespace com.example
            use aws.protocols#restJson1
            use smithy.framework#ValidationException
            
            @restJson1 @title("Test Service") 
            service TestService { 
                version: "0.1", 
                operations: [ 
                    TestOperation
                    TestSimpleUnionWithUnit
                ] 
            }
            
            @http(uri: "/testunit", method: "POST")
            operation TestSimpleUnionWithUnit {
                input := {
                    @required
                    request: SomeUnionWithUnit
                }
                output := {
                    result : SomeUnionWithUnit
                }
                errors: [
                    ValidationException
                ]
            }
            
            @length(min: 13)
            string StringRestricted
            
            union SomeUnionWithUnit {
                Option1: Unit
                Option2: StringRestricted
            }

            @http(uri: "/test", method: "POST")
            operation TestOperation {
                input := { payload: String }
                output := {
                    @httpPayload
                    events: TestEvent
                },
                errors: [ValidationException]
            }
            
            @streaming
            union TestEvent {
                KeepAlive: Unit,
                Response: TestResponseEvent,
            }
            
            structure TestResponseEvent { 
                data: String 
            }            

```

---------

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 69efe3a7
Loading
Loading
Loading
Loading
+6 −1
Original line number Original line Diff line number Diff line
@@ -10,7 +10,6 @@
# references = ["smithy-rs#920"]
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[smithy-rs]]
[[smithy-rs]]
message = """
message = """
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:
@@ -52,3 +51,9 @@ message = "Stalled stream protection on downloads will now only trigger if the u
references = ["smithy-rs#3485"]
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "jdisanti"
author = "jdisanti"

[[smithy-rs]]
message = "Unions with unit target member shape are now fully supported"
references = ["smithy-rs#2546"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all"}
author = "drganjoo"
+15 −6
Original line number Original line Diff line number Diff line
@@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toPascalCase


fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = private("event_stream_serde")
fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = private("event_stream_serde")
@@ -189,6 +190,13 @@ class EventStreamUnmarshallerGenerator(
            // Don't attempt to parse the payload for an empty struct. The payload can be empty, or if the model was
            // Don't attempt to parse the payload for an empty struct. The payload can be empty, or if the model was
            // updated since the code was generated, it can have content that would not be understood.
            // updated since the code was generated, it can have content that would not be understood.
            empty -> {
            empty -> {
                if (unionMember.isTargetUnit()) {
                    rustTemplate(
                        "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName))",
                        "Output" to unionSymbol,
                        *codegenScope,
                    )
                } else {
                    rustTemplate(
                    rustTemplate(
                        "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(#{UnionStruct}::builder().build())))",
                        "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(#{UnionStruct}::builder().build())))",
                        "Output" to unionSymbol,
                        "Output" to unionSymbol,
@@ -196,6 +204,7 @@ class EventStreamUnmarshallerGenerator(
                        *codegenScope,
                        *codegenScope,
                    )
                    )
                }
                }
            }


            payloadOnly -> {
            payloadOnly -> {
                withBlock("let parsed = ", ";") {
                withBlock("let parsed = ", ";") {
+19 −2
Original line number Original line Diff line number Diff line
@@ -42,6 +42,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStre
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toPascalCase


open class EventStreamMarshallerGenerator(
open class EventStreamMarshallerGenerator(
@@ -107,7 +108,15 @@ open class EventStreamMarshallerGenerator(
                rustBlock("let payload = match input") {
                rustBlock("let payload = match input") {
                    for (member in unionShape.members()) {
                    for (member in unionShape.members()) {
                        val eventType = member.memberName // must be the original name, not the Rust-safe name
                        val eventType = member.memberName // must be the original name, not the Rust-safe name
                        rustBlock("Self::Input::${symbolProvider.toMemberName(member)}(inner) => ") {
                        // Union members targeting the Smithy `Unit` type do not have associated data in the
                        // Rust enum generated for the type.
                        val mayHaveInner =
                            if (!member.isTargetUnit()) {
                                "(inner)"
                            } else {
                                ""
                            }
                        rustBlock("Self::Input::${symbolProvider.toMemberName(member)}$mayHaveInner => ") {
                            addStringHeader(":event-type", "${eventType.dq()}.into()")
                            addStringHeader(":event-type", "${eventType.dq()}.into()")
                            val target = model.expectShape(member.target, StructureShape::class.java)
                            val target = model.expectShape(member.target, StructureShape::class.java)
                            renderMarshallEvent(member, target)
                            renderMarshallEvent(member, target)
@@ -147,7 +156,15 @@ open class EventStreamMarshallerGenerator(
            renderMarshallEventPayload("inner.$memberName", payloadMember, target, serializerFn)
            renderMarshallEventPayload("inner.$memberName", payloadMember, target, serializerFn)
        } else if (headerMembers.isEmpty()) {
        } else if (headerMembers.isEmpty()) {
            val serializerFn = serializerGenerator.payloadSerializer(unionMember)
            val serializerFn = serializerGenerator.payloadSerializer(unionMember)
            renderMarshallEventPayload("inner", unionMember, eventStruct, serializerFn)
            // Union members targeting the Smithy `Unit` type do not have associated data in the
            // Rust enum generated for the type. For these, we need to pass the `crate::model::Unit` data type.
            val inner =
                if (unionMember.isTargetUnit()) {
                    "crate::model::Unit::builder().build()"
                } else {
                    "inner"
                }
            renderMarshallEventPayload(inner, unionMember, eventStruct, serializerFn)
        } else {
        } else {
            rust("Vec::new()")
            rust("Vec::new()")
        }
        }
+79 −57
Original line number Original line Diff line number Diff line
@@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator
import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator
@@ -86,12 +87,18 @@ class UnconstrainedUnionGenerator(
                """,
                """,
            ) {
            ) {
                sortedMembers.forEach { member ->
                sortedMembers.forEach { member ->
                    if (member.isTargetUnit()) {
                        rust(
                            "${unconstrainedShapeSymbolProvider.toMemberName(member)},",
                        )
                    } else {
                        rust(
                        rust(
                            "${unconstrainedShapeSymbolProvider.toMemberName(member)}(#T),",
                            "${unconstrainedShapeSymbolProvider.toMemberName(member)}(#T),",
                            unconstrainedShapeSymbolProvider.toSymbol(member),
                            unconstrainedShapeSymbolProvider.toSymbol(member),
                        )
                        )
                    }
                    }
                }
                }
            }


            rustTemplate(
            rustTemplate(
                """
                """
@@ -198,6 +205,15 @@ class UnconstrainedUnionGenerator(
                withBlock("match value {", "}") {
                withBlock("match value {", "}") {
                    sortedMembers.forEach { member ->
                    sortedMembers.forEach { member ->
                        val memberName = unconstrainedShapeSymbolProvider.toMemberName(member)
                        val memberName = unconstrainedShapeSymbolProvider.toMemberName(member)
                        if (member.isTargetUnit()) {
                            // Unit type within Unions do not have associated data.
                            rustTemplate(
                                """
                                #{UnconstrainedUnion}::$memberName => Self::$memberName,
                                """,
                                "UnconstrainedUnion" to symbol,
                            )
                        } else {
                            withBlockTemplate(
                            withBlockTemplate(
                                "#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(",
                                "#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(",
                                "),",
                                "),",
@@ -206,6 +222,17 @@ class UnconstrainedUnionGenerator(
                                if (!member.canReachConstrainedShape(model, symbolProvider)) {
                                if (!member.canReachConstrainedShape(model, symbolProvider)) {
                                    rust("unconstrained")
                                    rust("unconstrained")
                                } else {
                                } else {
                                    generateTryFromImplForReachableConstrainedShape(member).invoke(this)
                                }
                            }
                        }
                    }
                }
            }
        }

    private fun generateTryFromImplForReachableConstrainedShape(member: MemberShape) =
        writable {
            val targetShape = model.expectShape(member.target)
            val targetShape = model.expectShape(member.target)
            val resolveToNonPublicConstrainedType =
            val resolveToNonPublicConstrainedType =
                targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait<EnumTrait>() &&
                targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait<EnumTrait>() &&
@@ -255,8 +282,3 @@ class UnconstrainedUnionGenerator(
            }
            }
        }
        }
}
}
                    }
                }
            }
        }
}
+77 −0
Original line number Original line Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.server.smithy

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest

class UnionWithUnitTest {
    @Test
    fun `a constrained union that has a unit member should compile`() {
        val model =
            """
            ${'$'}version: "2"
            namespace com.example
            use aws.protocols#restJson1
            use smithy.framework#ValidationException
            
            @restJson1 @title("Test Service") 
            service TestService { 
                version: "0.1", 
                operations: [ 
                    TestOperation
                    TestSimpleUnionWithUnit
                ] 
            }
            
            @http(uri: "/testunit", method: "POST")
            operation TestSimpleUnionWithUnit {
                input := {
                    @required
                    request: SomeUnionWithUnit
                }
                output := {
                    result : SomeUnionWithUnit
                }
                errors: [
                    ValidationException
                ]
            }
            
            @length(min: 13)
            string StringRestricted
            
            union SomeUnionWithUnit {
                Option1: Unit
                Option2: StringRestricted
            }

            @http(uri: "/test", method: "POST")
            operation TestOperation {
                input := { payload: String }
                output := {
                    @httpPayload
                    events: TestEvent
                },
                errors: [ValidationException]
            }
            
            @streaming
            union TestEvent {
                KeepAlive: Unit,
                Response: TestResponseEvent,
            }
            
            structure TestResponseEvent { 
                data: String 
            }            
            """.asSmithyModel()

        // Ensure the generated SDK compiles.
        serverIntegrationTest(model) { _, _ -> }
    }
}