Unverified Commit 9f0bc36f authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Make `SdkError::into_service_error` infallible (#1974)

parent 4563849d
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -13,11 +13,13 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.RetryableTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.asType
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
@@ -124,6 +126,7 @@ class CombinedErrorGenerator(
) {
    private val runtimeConfig = symbolProvider.config().runtimeConfig
    private val genericError = RuntimeType.GenericError(symbolProvider.config().runtimeConfig)
    private val createUnhandledError = CargoDependency.SmithyHttp(runtimeConfig).asType().member("result::CreateUnhandledError")

    fun render(writer: RustWriter) {
        val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error")
@@ -154,6 +157,15 @@ class CombinedErrorGenerator(
                RuntimeType.GenericError(runtimeConfig),
            )
        }
        writer.rustBlock("impl #T for ${errorSymbol.name}", createUnhandledError) {
            rustBlock("fn create_unhandled_error(source: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self") {
                rustBlock("Self") {
                    rust("kind: ${errorSymbol.name}Kind::Unhandled(#T::new(source)),", unhandledError())
                    rust("meta: Default::default()")
                }
            }
        }

        writer.rust("/// Types of errors that can occur for the `${operationSymbol.name}` operation.")
        meta.render(writer)
        writer.rustBlock("enum ${errorSymbol.name}Kind") {
+20 −14
Original line number Diff line number Diff line
@@ -95,6 +95,7 @@ class TopLevelErrorGenerator(private val codegenContext: CodegenContext, private

    private fun RustWriter.renderImplFrom(errorSymbol: RuntimeType, errors: List<ShapeId>) {
        if (errors.isNotEmpty() || CodegenTarget.CLIENT == codegenContext.target) {
            val operationErrors = errors.map { model.expectShape(it) }
            rustBlock(
                "impl<R> From<#T<#T, R>> for Error where R: Send + Sync + std::fmt::Debug + 'static",
                sdkError,
@@ -106,8 +107,15 @@ class TopLevelErrorGenerator(private val codegenContext: CodegenContext, private
                    "OpError" to errorSymbol,
                ) {
                    rustBlock("match err") {
                        val operationErrors = errors.map { model.expectShape(it) }
                        rustBlock("#T::ServiceError(context) => match context.into_err().kind", sdkError) {
                        rust("#T::ServiceError(context) => Self::from(context.into_err()),", sdkError)
                        rust("_ => Error::Unhandled(#T::new(err.into())),", unhandledError())
                    }
                }
            }

            rustBlock("impl From<#T> for Error", errorSymbol) {
                rustBlock("fn from(err: #T) -> Self", errorSymbol) {
                    rustBlock("match err.kind") {
                        operationErrors.forEach { errorShape ->
                            val errSymbol = symbolProvider.toSymbol(errorShape)
                            rust(
@@ -121,8 +129,6 @@ class TopLevelErrorGenerator(private val codegenContext: CodegenContext, private
                            "unhandled" to unhandledError(),
                        )
                    }
                        rust("_ => Error::Unhandled(#T::new(err.into())),", unhandledError())
                    }
                }
            }
        }
+27 −11
Original line number Diff line number Diff line
@@ -122,6 +122,14 @@ impl<E, R> ServiceError<E, R> {
    }
}

/// Constructs the unhandled variant of a code generated error.
///
/// This trait exists so that [`SdkError::into_service_error`] can be infallible.
pub trait CreateUnhandledError {
    /// Creates an unhandled error variant with the given `source`.
    fn create_unhandled_error(source: Box<dyn Error + Send + Sync + 'static>) -> Self;
}

/// Failed SDK Result
#[non_exhaustive]
#[derive(Debug)]
@@ -179,34 +187,42 @@ impl<E, R> SdkError<E, R> {

    /// Returns the underlying service error `E` if there is one
    ///
    /// If a service error is not available (for example, the error is a network timeout),
    /// then the full `SdkError` is returned. This makes it easy to match on the service's
    /// error response while simultaneously bubbling up transient failures. For example,
    /// handling the `NoSuchKey` error for S3's `GetObject` operation may look as follows:
    /// If the `SdkError` is not a `ServiceError` (for example, the error is a network timeout),
    /// then it will be converted into an unhandled variant of `E`. This makes it easy to match
    /// on the service's error response while simultaneously bubbling up transient failures.
    /// For example, handling the `NoSuchKey` error for S3's `GetObject` operation may look as
    /// follows:
    ///
    /// ```no_run
    /// # use aws_smithy_http::result::SdkError;
    /// # use aws_smithy_http::result::{SdkError, CreateUnhandledError};
    /// # #[derive(Debug)] enum GetObjectErrorKind { NoSuchKey(()), Other(()) }
    /// # #[derive(Debug)] struct GetObjectError { kind: GetObjectErrorKind }
    /// # impl std::fmt::Display for GetObjectError {
    /// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { unimplemented!() }
    /// # }
    /// # impl std::error::Error for GetObjectError {}
    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
    /// # impl CreateUnhandledError for GetObjectError {
    /// #     fn create_unhandled_error(_: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self { unimplemented!() }
    /// # }
    /// # fn example() -> Result<(), GetObjectError> {
    /// # let sdk_err = SdkError::service_error(GetObjectError { kind: GetObjectErrorKind::NoSuchKey(()) }, ());
    /// match sdk_err.into_service_error()? {
    /// match sdk_err.into_service_error() {
    ///     GetObjectError { kind: GetObjectErrorKind::NoSuchKey(_) } => {
    ///         // handle NoSuchKey
    ///     }
    ///     err @ _ => return Err(err.into()),
    ///     err @ _ => return Err(err),
    /// }
    /// # Ok(())
    /// # }
    /// ```
    pub fn into_service_error(self) -> Result<E, Self> {
    pub fn into_service_error(self) -> E
    where
        E: std::error::Error + Send + Sync + CreateUnhandledError + 'static,
        R: Debug + Send + Sync + 'static,
    {
        match self {
            Self::ServiceError(context) => Ok(context.source),
            _ => Err(self),
            Self::ServiceError(context) => context.source,
            _ => E::create_unhandled_error(self.into()),
        }
    }

+1 −4
Original line number Diff line number Diff line
@@ -34,10 +34,7 @@ pub async fn s3_canary(client: s3::Client, s3_bucket_name: String) -> anyhow::Re
                CanaryError(format!("Expected object {} to not exist in S3", test_key)).into(),
            );
        }
        Err(err) => match err
            .into_service_error()
            .context("unexpected s3::GetObject failure")?
        {
        Err(err) => match err.into_service_error() {
            GetObjectError {
                kind: GetObjectErrorKind::NoSuchKey(..),
                ..