Unverified Commit 653e4cd7 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Revamp errors for `aws-sigv4` (#1937)

parent 12f217b3
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
allowed_external_types = [
    "aws_sigv4::http_request::sign::SignableBody",
    "aws_sigv4::http_request::error::SigningError",
    "aws_smithy_http::*",
    "aws_types::*",
    "http::request::Request",
+2 −2
Original line number Diff line number Diff line
@@ -103,7 +103,7 @@ impl From<SigningError> for SigningStageError {
impl Error for SigningStageError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match &self {
            SigningStageError::SigningFailure(err) => Some(err.as_ref()),
            SigningStageError::SigningFailure(err) => Some(err),
            _ => None,
        }
    }
@@ -160,7 +160,7 @@ impl MapRequest for SigV4SigningStage {
            let signature = self
                .signer
                .sign(operation_config, &request_config, &creds, &mut req)
                .map_err(|err| SigningStageError::SigningFailure(err))?;
                .map_err(SigningStageError::SigningFailure)?;
            config.insert(signature);
            Ok(req)
        })
+1 −3
Original line number Diff line number Diff line
@@ -12,11 +12,11 @@ use aws_smithy_http::body::SdkBody;
use aws_types::region::SigningRegion;
use aws_types::Credentials;
use aws_types::SigningService;
use std::error::Error;
use std::fmt;
use std::time::{Duration, SystemTime};

pub use aws_sigv4::http_request::SignableBody;
pub type SigningError = aws_sigv4::http_request::SigningError;

const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \
    `expires_in` duration because the credentials used to sign it will expire first.";
@@ -121,8 +121,6 @@ impl fmt::Debug for SigV4Signer {
    }
}

pub type SigningError = Box<dyn Error + Send + Sync>;

impl SigV4Signer {
    pub fn new() -> Self {
        SigV4Signer { _private: () }
+13 −6
Original line number Diff line number Diff line
@@ -3,12 +3,13 @@
 * SPDX-License-Identifier: Apache-2.0
 */

use super::query_writer::QueryWriter;
use super::{Error, PayloadChecksumKind, SignableBody, SignatureLocation, SigningParams};
use crate::date_time::{format_date, format_date_time};
use crate::http_request::error::CanonicalRequestError;
use crate::http_request::query_writer::QueryWriter;
use crate::http_request::sign::SignableRequest;
use crate::http_request::url_escape::percent_encode_path;
use crate::http_request::PercentEncodingMode;
use crate::http_request::{PayloadChecksumKind, SignableBody, SignatureLocation, SigningParams};
use crate::sign::sha256_hex_string;
use http::header::{HeaderName, HOST};
use http::{HeaderMap, HeaderValue, Method, Uri};
@@ -124,7 +125,7 @@ impl<'a> CanonicalRequest<'a> {
    pub(super) fn from<'b>(
        req: &'b SignableRequest<'b>,
        params: &'b SigningParams<'b>,
    ) -> Result<CanonicalRequest<'b>, Error> {
    ) -> Result<CanonicalRequest<'b>, CanonicalRequestError> {
        // Path encoding: if specified, re-encode % as %25
        // Set method and path into CanonicalRequest
        let path = req.uri().path();
@@ -182,7 +183,7 @@ impl<'a> CanonicalRequest<'a> {
        params: &SigningParams<'_>,
        payload_hash: &str,
        date_time: &str,
    ) -> Result<(Vec<CanonicalHeaderName>, HeaderMap), Error> {
    ) -> Result<(Vec<CanonicalHeaderName>, HeaderMap), CanonicalRequestError> {
        // Header computation:
        // The canonical request will include headers not present in the input. We need to clone and
        // normalize the headers from the original request and add:
@@ -375,9 +376,15 @@ fn trim_spaces_from_byte_string(bytes: &[u8]) -> &[u8] {

/// Works just like [trim_all] but acts on HeaderValues instead of bytes.
/// Will ensure that the underlying bytes are valid UTF-8.
fn normalize_header_value(header_value: &HeaderValue) -> Result<HeaderValue, Error> {
fn normalize_header_value(
    header_value: &HeaderValue,
) -> Result<HeaderValue, CanonicalRequestError> {
    let trimmed_value = trim_all(header_value.as_bytes());
    HeaderValue::from_str(std::str::from_utf8(&trimmed_value)?).map_err(Error::from)
    HeaderValue::from_str(
        std::str::from_utf8(&trimmed_value)
            .map_err(CanonicalRequestError::invalid_utf8_in_header_value)?,
    )
    .map_err(CanonicalRequestError::from)
}

#[derive(Debug, PartialEq, Default)]
+104 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use http::header::{InvalidHeaderName, InvalidHeaderValue};
use std::error::Error;
use std::fmt;
use std::str::Utf8Error;

#[derive(Debug)]
enum SigningErrorKind {
    FailedToCreateCanonicalRequest { source: CanonicalRequestError },
}

/// Error signing request
#[derive(Debug)]
pub struct SigningError {
    kind: SigningErrorKind,
}

impl fmt::Display for SigningError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self.kind {
            SigningErrorKind::FailedToCreateCanonicalRequest { .. } => {
                write!(f, "failed to create canonical request")
            }
        }
    }
}

impl Error for SigningError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match &self.kind {
            SigningErrorKind::FailedToCreateCanonicalRequest { source } => Some(source),
        }
    }
}

impl From<CanonicalRequestError> for SigningError {
    fn from(source: CanonicalRequestError) -> Self {
        Self {
            kind: SigningErrorKind::FailedToCreateCanonicalRequest { source },
        }
    }
}

#[derive(Debug)]
enum CanonicalRequestErrorKind {
    InvalidHeaderName { source: InvalidHeaderName },
    InvalidHeaderValue { source: InvalidHeaderValue },
    InvalidUtf8InHeaderValue { source: Utf8Error },
}

#[derive(Debug)]
pub(crate) struct CanonicalRequestError {
    kind: CanonicalRequestErrorKind,
}

impl fmt::Display for CanonicalRequestError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        use CanonicalRequestErrorKind::*;
        match self.kind {
            InvalidHeaderName { .. } => write!(f, "invalid header name"),
            InvalidHeaderValue { .. } => write!(f, "invalid header value"),
            InvalidUtf8InHeaderValue { .. } => write!(f, "invalid UTF-8 in header value"),
        }
    }
}

impl Error for CanonicalRequestError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        use CanonicalRequestErrorKind::*;
        match &self.kind {
            InvalidHeaderName { source } => Some(source),
            InvalidHeaderValue { source } => Some(source),
            InvalidUtf8InHeaderValue { source } => Some(source),
        }
    }
}

impl CanonicalRequestError {
    pub(crate) fn invalid_utf8_in_header_value(source: Utf8Error) -> Self {
        Self {
            kind: CanonicalRequestErrorKind::InvalidUtf8InHeaderValue { source },
        }
    }
}

impl From<InvalidHeaderName> for CanonicalRequestError {
    fn from(source: InvalidHeaderName) -> Self {
        Self {
            kind: CanonicalRequestErrorKind::InvalidHeaderName { source },
        }
    }
}

impl From<InvalidHeaderValue> for CanonicalRequestError {
    fn from(source: InvalidHeaderValue) -> Self {
        Self {
            kind: CanonicalRequestErrorKind::InvalidHeaderValue { source },
        }
    }
}
Loading