Unverified Commit df77d5ff authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Enforce constraints for unnamed enums (#3884)

### Enforces Constraints for Unnamed Enums

This PR addresses the issue where, on the server side, unnamed enums
were incorrectly treated as infallible during deserialization, allowing
any string value to be converted without validation. The solution
introduces a `ConstraintViolation` and `TryFrom` implementation for
unnamed enums, ensuring that deserialized values conform to the enum
variants defined in the Smithy model.

The following is an example of an unnamed enum:

```smithy
@enum([
    { value: "MONDAY" },
    { value: "TUESDAY" }
])
string UnnamedDayOfWeek
```

On the server side the following type is generated for the Smithy shape:

```rust
pub struct UnnamedDayOfWeek(String);

impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek {
    type Error = crate::model::unnamed_day_of_week::ConstraintViolation;

    fn try_from(
        s: ::std::string::String,
    ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<::std::string::String>>::Error>
    {
        match s.as_str() {
            "MONDAY" | "TUESDAY" => Ok(Self(s)),
            _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)),
        }
    }
}
```

This change prevents invalid values from being deserialized into unnamed
enums and raises appropriate constraint violations when necessary.

There is one difference between the Rust code generated for
`TryFrom<String>` for named enums versus unnamed enums. The
implementation for unnamed enums passes the ownership of the `String`
parameter to the generated structure, and the implementation for
`TryFrom<&str>` delegates to `TryFrom<String>`.

```rust
impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek {
    type Error = crate::model::unnamed_day_of_week::ConstraintViolation;
    fn try_from(
        s: ::std::string::String,
    ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<::std::string::String>>::Error>
    {
        match s.as_str() {
            "MONDAY" | "TUESDAY" => Ok(Self(s)),
            _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)),
        }
    }
}

impl ::std::convert::TryFrom<&str> for UnnamedDayOfWeek {
    type Error = crate::model::unnamed_day_of_week::ConstraintViolation;
    fn try_from(
        s: &str,
    ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<&str>>::Error> {
        s.to_owned().try_into()
    }
}
```

On the client side, the behaviour is unchanged, and the client does not
validate for backward compatibility reasons. An [existing
test](https://github.com/smithy-lang/smithy-rs/pull/3884/files#diff-021ec60146cfe231105d21a7389f2dffcd546595964fbb3f0684ebf068325e48R82

)
has been modified to ensure this.

```rust
#[test]
fn generate_unnamed_enums() {
    let result = "t2.nano"
        .parse::<crate::types::UnnamedEnum>()
        .expect("static value validated to member");
    assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));
    let result = "not-a-valid-variant"
        .parse::<crate::types::UnnamedEnum>()
        .expect("static value validated to member");
    assert_eq!(result, UnnamedEnum("not-a-valid-variant".to_owned()));
}
```

Fixes issue #3880

---------

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 8cf9ebdd
Loading
Loading
Loading
Loading

.changelog/4329788.md

0 → 100644
+18 −0
Original line number Original line Diff line number Diff line
---
applies_to: ["server"]
authors: ["drganjoo"]
references: ["smithy-rs#3880"]
breaking: true
new_feature: false
bug_fix: true
---
Unnamed enums now validate assigned values and will raise a `ConstraintViolation` if an unknown variant is set.

The following is an example of an unnamed enum:
```smithy
@enum([
    { value: "MONDAY" },
    { value: "TUESDAY" }
])
string UnnamedDayOfWeek
```
+31 −0
Original line number Original line Diff line number Diff line
@@ -79,6 +79,37 @@ data class InfallibleEnumType(
            )
            )
        }
        }


    override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
        writable {
            rustTemplate(
                """
                impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
                    fn from(s: T) -> Self {
                        ${context.enumName}(s.as_ref().to_owned())
                    }
                }
                """,
                *preludeScope,
            )
        }

    override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
        writable {
            // Add an infallible FromStr implementation for uniformity
            rustTemplate(
                """
                impl ::std::str::FromStr for ${context.enumName} {
                    type Err = ::std::convert::Infallible;
    
                    fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
                        #{Ok}(${context.enumName}::from(s))
                    }
                }
                """,
                *preludeScope,
            )
        }

    override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
    override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
        writable {
        writable {
            // `try_parse` isn't needed for unnamed enums
            // `try_parse` isn't needed for unnamed enums
+7 −1
Original line number Original line Diff line number Diff line
@@ -69,6 +69,8 @@ internal class ClientInstantiatorTest {
        val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
        val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
        val sut = ClientInstantiator(codegenContext)
        val sut = ClientInstantiator(codegenContext)
        val data = Node.parse("t2.nano".dq())
        val data = Node.parse("t2.nano".dq())
        // The client SDK should accept unknown variants as valid.
        val notValidVariant = Node.parse("not-a-valid-variant".dq())


        val project = TestWorkspace.testProject(symbolProvider)
        val project = TestWorkspace.testProject(symbolProvider)
        project.moduleFor(shape) {
        project.moduleFor(shape) {
@@ -77,7 +79,11 @@ internal class ClientInstantiatorTest {
                withBlock("let result = ", ";") {
                withBlock("let result = ", ";") {
                    sut.render(this, shape, data)
                    sut.render(this, shape, data)
                }
                }
                rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""")
                rust("""assert_eq!(result, UnnamedEnum("$data".to_owned()));""")
                withBlock("let result = ", ";") {
                    sut.render(this, shape, notValidVariant)
                }
                rust("""assert_eq!(result, UnnamedEnum("$notValidVariant".to_owned()));""")
            }
            }
        }
        }
        project.compileAndTest()
        project.compileAndTest()
+10 −26
Original line number Original line Diff line number Diff line
@@ -59,6 +59,12 @@ abstract class EnumType {
    /** Returns a writable that implements `FromStr` for the enum */
    /** Returns a writable that implements `FromStr` for the enum */
    abstract fun implFromStr(context: EnumGeneratorContext): Writable
    abstract fun implFromStr(context: EnumGeneratorContext): Writable


    /** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */
    abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable

    /** Returns a writable that implements `FromStr` for the unnamed enum */
    abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable

    /** Optionally adds additional documentation to the `enum` docs */
    /** Optionally adds additional documentation to the `enum` docs */
    open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}
    open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}


@@ -237,32 +243,10 @@ open class EnumGenerator(
                    rust("&self.0")
                    rust("&self.0")
                },
                },
        )
        )

        // impl From<str> for Blah { ... }
        // Add an infallible FromStr implementation for uniformity
        enumType.implFromForStrForUnnamedEnum(context)(this)
        rustTemplate(
        // impl FromStr for Blah { ... }
            """
        enumType.implFromStrForUnnamedEnum(context)(this)
            impl ::std::str::FromStr for ${context.enumName} {
                type Err = ::std::convert::Infallible;

                fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
                    #{Ok}(${context.enumName}::from(s))
                }
            }
            """,
            *preludeScope,
        )

        rustTemplate(
            """
            impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
                fn from(s: T) -> Self {
                    ${context.enumName}(s.as_ref().to_owned())
                }
            }

            """,
            *preludeScope,
        )
    }
    }


    private fun RustWriter.renderEnum() {
    private fun RustWriter.renderEnum() {
+10 −0
Original line number Original line Diff line number Diff line
@@ -494,6 +494,16 @@ class EnumGeneratorTest {
                        // intentional no-op
                        // intentional no-op
                    }
                    }


                override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
                    writable {
                        // intentional no-op
                    }

                override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
                    writable {
                        // intentional no-op
                    }

                override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
                override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
                    writable {
                    writable {
                        rust("// additional enum members")
                        rust("// additional enum members")
Loading