Unverified Commit 206ef913 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

breaking change: improve `as_`/`is_` helpers (#527)

* breaking change: improve `as_`/`is_` helpers

- `as_` helpers for unions now return `Result` to clearly indicate what the actual value was
- Union variants now implement `TryFrom` for the parent union
- Error types now generate `is_` helpers

* bump version to v0.0.9-alpha

* Remove unsound TryFrom

* fix unit test and cargo fmt failure

* Delete unused import
parent cc13edf9
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -109,7 +109,7 @@ fun generateSmithyBuild(tests: List<AwsService>): String {
                      },
                      "service": "${it.service}",
                      "module": "aws-sdk-${it.module}",
                      "moduleVersion": "0.0.8-alpha",
                      "moduleVersion": "0.0.9-alpha",
                      "moduleAuthors": ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Russell Cohen <rcoh@amazon.com>"],
                      "license": "Apache-2.0"
                      ${it.extraConfig ?: ""}
+2 −4
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ use structopt::StructOpt;

// types from the AWS SDK for Rust
use aws_types::region::ProvideRegion;
use lambda::{error::InvokeErrorKind, Client, Config, Region, SdkError};
use lambda::{Client, Config, Region, SdkError};

// types from other third-party crates
use tracing_subscriber::fmt::format::FmtSpan;
@@ -137,9 +137,7 @@ async fn main() {
        //
        // For our example, we will simply print that the function doesn't
        // exist and return a non-zero exit code to indicate the failure.
        Err(SdkError::ServiceError { err, .. })
            if matches!(err.kind, InvokeErrorKind::ResourceNotFoundError(_)) =>
        {
        Err(SdkError::ServiceError { err, .. }) if err.is_resource_not_found_error() => {
            println!("This lambda function does not exist: {}", err);
            process::exit(1);
        }
+8 −8
Original line number Diff line number Diff line
@@ -30,10 +30,10 @@ class UnionGenerator(

    private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
    private fun renderUnion() {
        val symbol = symbolProvider.toSymbol(shape)
        val containerMeta = symbol.expectRustMetadata()
        val unionSymbol = symbolProvider.toSymbol(shape)
        val containerMeta = unionSymbol.expectRustMetadata()
        containerMeta.render(writer)
        writer.rustBlock("enum ${symbol.name}") {
        writer.rustBlock("enum ${unionSymbol.name}") {
            sortedMembers.forEach { member ->
                val memberSymbol = symbolProvider.toSymbol(member)
                documentShape(member, model)
@@ -41,17 +41,17 @@ class UnionGenerator(
                write("${member.memberName.toPascalCase()}(#T),", symbolProvider.toSymbol(member))
            }
        }
        writer.rustBlock("impl ${symbol.name}") {
        writer.rustBlock("impl ${unionSymbol.name}") {
            sortedMembers.forEach { member ->
                val memberSymbol = symbolProvider.toSymbol(member)
                val funcNamePart = member.memberName.toSnakeCase()
                val variantName = member.memberName.toPascalCase()

                writer.rustBlock("pub fn as_$funcNamePart(&self) -> Option<&#T>", memberSymbol) {
                    rust("if let ${symbol.name}::$variantName(val) = &self { Some(&val) } else { None }")
                rustBlock("pub fn as_$funcNamePart(&self) -> Result<&#T, &Self>", memberSymbol) {
                    rust("if let ${unionSymbol.name}::$variantName(val) = &self { Ok(&val) } else { Err(&self) }")
                }
                writer.rustBlock("pub fn is_$funcNamePart(&self) -> bool") {
                    rust("self.as_$funcNamePart().is_some()")
                rustBlock("pub fn is_$funcNamePart(&self) -> bool") {
                    rust("self.as_$funcNamePart().is_ok()")
                }
            }
        }
+14 −7
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.customize.Section
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
 * For a given Operation ([this]), return the symbol referring to the unified error? This can be used
@@ -119,14 +120,13 @@ class CombinedErrorGenerator(
            }
        }

        writer.rustBlock("impl ${symbol.name}") {
            writer.rustTemplate(
                """
        impl ${symbol.name} {
            pub fn new(kind: ${symbol.name}Kind, meta: #{generic_error}) -> Self {
                Self { kind, meta }
            }


            pub fn unhandled(err: impl Into<Box<dyn #{std_error} + Send + Sync + 'static>>) -> Self {
                Self {
                    kind: ${symbol.name}Kind::Unhandled(err.into()),
@@ -158,10 +158,17 @@ class CombinedErrorGenerator(
            pub fn code(&self) -> Option<&str> {
                self.meta.code()
            }
        }
        """,
                "generic_error" to genericError, "std_error" to RuntimeType.StdError
            )
            errors.forEach { error ->
                val errorSymbol = symbolProvider.toSymbol(error)
                val fnName = errorSymbol.name.toSnakeCase()
                writer.rustBlock("pub fn is_$fnName(&self) -> bool") {
                    rust("matches!(&self.kind, ${symbol.name}Kind::${errorSymbol.name}(_))")
                }
            }
        }

        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) {
            rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
+1 −1
Original line number Diff line number Diff line
@@ -389,7 +389,7 @@ class XmlBindingTraitParserGenerator(
                                    })
                                """
                            withBlock("let tmp =", ";") {
                                parseMember(member, ctx.copy(accum = current))
                                parseMember(member, ctx.copy(accum = current.trim()))
                            }
                            rust("base = Some(#T::$variantName(tmp));", symbol)
                        }
Loading