Unverified Commit 081387bd authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Breaking: Make it possible to asynchronously provide credentials (#572)

* Make it possible to asynchronously provide credentials

* CR feedback

* Add access key ID to credential debug print
parent 45f4e3d6
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -10,3 +10,7 @@ edition = "2018"
[dependencies]
smithy-http = { path = "../../../rust-runtime/smithy-http" }
zeroize = "1.2.0"

[dev-dependencies]
http = "0.2.3"
tokio = { version = "1.0", features = ["rt", "macros"] }
+95 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::time::SystemTime;
use zeroize::Zeroizing;

/// AWS SDK Credentials
///
/// An opaque struct representing credentials that may be used in an AWS SDK, modeled on
/// the [CRT credentials implementation](https://github.com/awslabs/aws-c-auth/blob/main/source/credentials.c).
///
/// When `Credentials` is dropped, its contents are zeroed in memory. Credentials uses an interior Arc to ensure
/// that even when cloned, credentials don't exist in multiple memory locations.
#[derive(Clone)]
pub struct Credentials(Arc<Inner>);

struct Inner {
    access_key_id: Zeroizing<String>,
    secret_access_key: Zeroizing<String>,
    session_token: Zeroizing<Option<String>>,

    /// Credential Expiry
    ///
    /// A timepoint at which the credentials should no longer
    /// be used because they have expired. The primary purpose of this value is to allow
    /// credentials to communicate to the caching provider when they need to be refreshed.
    ///
    /// If these credentials never expire, this value will be set to `None`
    expires_after: Option<SystemTime>,

    provider_name: &'static str,
}

impl Debug for Credentials {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let mut creds = f.debug_struct("Credentials");
        creds.field("provider_name", &self.0.provider_name);
        creds.field("access_key_id", &self.0.access_key_id);
        creds.finish()
    }
}

const STATIC_CREDENTIALS: &str = "Static";
impl Credentials {
    pub fn new(
        access_key_id: impl Into<String>,
        secret_access_key: impl Into<String>,
        session_token: Option<String>,
        expires_after: Option<SystemTime>,
        provider_name: &'static str,
    ) -> Self {
        Credentials(Arc::new(Inner {
            access_key_id: Zeroizing::new(access_key_id.into()),
            secret_access_key: Zeroizing::new(secret_access_key.into()),
            session_token: Zeroizing::new(session_token),
            expires_after,
            provider_name,
        }))
    }

    pub fn from_keys(
        access_key_id: impl Into<String>,
        secret_access_key: impl Into<String>,
        session_token: Option<String>,
    ) -> Self {
        Self::new(
            access_key_id,
            secret_access_key,
            session_token,
            None,
            STATIC_CREDENTIALS,
        )
    }

    pub fn access_key_id(&self) -> &str {
        &self.0.access_key_id
    }

    pub fn secret_access_key(&self) -> &str {
        &self.0.secret_access_key
    }

    pub fn expiry(&self) -> Option<SystemTime> {
        self.0.expires_after
    }

    pub fn session_token(&self) -> Option<&str> {
        self.0.session_token.as_deref()
    }
}
+3 −155
Original line number Diff line number Diff line
@@ -3,160 +3,8 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

mod credentials;
pub mod middleware;
pub mod provider;

use smithy_http::property_bag::PropertyBag;
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use std::time::SystemTime;
use zeroize::Zeroizing;

/// AWS SDK Credentials
///
/// An opaque struct representing credentials that may be used in an AWS SDK, modeled on
/// the [CRT credentials implementation](https://github.com/awslabs/aws-c-auth/blob/main/source/credentials.c).
///
/// When `Credentials` is dropped, its contents are zeroed in memory. Credentials uses an interior Arc to ensure
/// that even when cloned, credentials don't exist in multiple memory locations.
#[derive(Clone)]
pub struct Credentials(Arc<Inner>);

struct Inner {
    access_key_id: Zeroizing<String>,
    secret_access_key: Zeroizing<String>,
    session_token: Zeroizing<Option<String>>,

    /// Credential Expiry
    ///
    /// A timepoint at which the credentials should no longer
    /// be used because they have expired. The primary purpose of this value is to allow
    /// credentials to communicate to the caching provider when they need to be refreshed.
    ///
    /// If these credentials never expire, this value will be set to `None`
    expires_after: Option<SystemTime>,

    provider_name: &'static str,
}

impl Debug for Credentials {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let mut creds = f.debug_struct("Credentials");
        creds.field("provider_name", &self.0.provider_name);
        creds.finish()
    }
}

const STATIC_CREDENTIALS: &str = "Static";
impl Credentials {
    pub fn new(
        access_key_id: impl Into<String>,
        secret_access_key: impl Into<String>,
        session_token: Option<String>,
        expires_after: Option<SystemTime>,
        provider_name: &'static str,
    ) -> Self {
        Credentials(Arc::new(Inner {
            access_key_id: Zeroizing::new(access_key_id.into()),
            secret_access_key: Zeroizing::new(secret_access_key.into()),
            session_token: Zeroizing::new(session_token),
            expires_after,
            provider_name,
        }))
    }

    pub fn from_keys(
        access_key_id: impl Into<String>,
        secret_access_key: impl Into<String>,
        session_token: Option<String>,
    ) -> Self {
        Self::new(
            access_key_id,
            secret_access_key,
            session_token,
            None,
            STATIC_CREDENTIALS,
        )
    }

    pub fn access_key_id(&self) -> &str {
        &self.0.access_key_id
    }

    pub fn secret_access_key(&self) -> &str {
        &self.0.secret_access_key
    }

    pub fn expiry(&self) -> Option<SystemTime> {
        self.0.expires_after
    }

    pub fn session_token(&self) -> Option<&str> {
        self.0.session_token.as_deref()
    }
}

#[derive(Debug)]
#[non_exhaustive]
pub enum CredentialsError {
    CredentialsNotLoaded,
    Unhandled(Box<dyn Error + Send + Sync + 'static>),
}

impl Display for CredentialsError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            CredentialsError::CredentialsNotLoaded => write!(f, "CredentialsNotLoaded"),
            CredentialsError::Unhandled(err) => write!(f, "{}", err),
        }
    }
}

impl Error for CredentialsError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            CredentialsError::Unhandled(e) => Some(e.as_ref() as _),
            _ => None,
        }
    }
}

pub type CredentialsProvider = Arc<dyn ProvideCredentials>;

/// A Credentials Provider
///
/// This interface is intentionally NOT async. Credential providers should provide a separate
/// async method to drive refresh (eg. in a background task).
///
/// Pending future design iteration, an async credentials provider may be introduced.
pub trait ProvideCredentials: Send + Sync {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError>;
}

pub fn default_provider() -> impl ProvideCredentials {
    // TODO: this should be a chain based on the CRT
    provider::EnvironmentVariableCredentialsProvider::new()
}

impl ProvideCredentials for Credentials {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
        Ok(self.clone())
    }
}

pub fn set_provider(config: &mut PropertyBag, provider: Arc<dyn ProvideCredentials>) {
    config.insert(provider);
}

#[cfg(test)]
mod test {
    use crate::Credentials;

    fn assert_send_sync<T: Send + Sync>() {}

    #[test]
    fn creds_are_send_sync() {
        assert_send_sync::<Credentials>()
    }
}
pub use credentials::Credentials;
+124 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

use crate::provider::CredentialsProvider;
use smithy_http::middleware::AsyncMapRequest;
use smithy_http::operation::Request;
use std::future::Future;
use std::pin::Pin;

/// Middleware stage that requests credentials from a [CredentialsProvider] and places them in
/// the property bag of the request.
///
/// [CredentialsStage] implements [`AsyncMapRequest`](smithy_http::middleware::AsyncMapRequest), and:
/// 1. Retrieves a `CredentialsProvider` from the property bag.
/// 2. Calls the credential provider's `provide_credentials` and awaits its result.
/// 3. Places returned `Credentials` into the property bad to drive downstream signing middleware.
#[derive(Clone, Default)]
#[non_exhaustive]
pub struct CredentialsStage;

impl CredentialsStage {
    pub fn new() -> Self {
        CredentialsStage
    }
}

mod error {
    use crate::provider::CredentialsError;
    use std::error::Error as StdError;
    use std::fmt;

    #[derive(Debug)]
    pub enum CredentialsStageError {
        MissingCredentialsProvider,
        CredentialsLoadingError(CredentialsError),
    }

    impl StdError for CredentialsStageError {}

    impl fmt::Display for CredentialsStageError {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            use CredentialsStageError::*;
            match self {
                MissingCredentialsProvider => {
                    write!(f, "No credentials provider in the property bag")
                }
                CredentialsLoadingError(err) => write!(
                    f,
                    "Failed to load credentials from the credentials provider: {}",
                    err
                ),
            }
        }
    }

    impl From<CredentialsError> for CredentialsStageError {
        fn from(err: CredentialsError) -> Self {
            CredentialsStageError::CredentialsLoadingError(err)
        }
    }
}

pub use error::*;

type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;

impl AsyncMapRequest for CredentialsStage {
    type Error = CredentialsStageError;
    type Future = Pin<Box<dyn Future<Output = Result<Request, Self::Error>> + Send + 'static>>;

    fn apply(&self, mut request: Request) -> BoxFuture<Result<Request, Self::Error>> {
        Box::pin(async move {
            let cred_future = {
                let config = request.config();
                let credential_provider = config
                    .get::<CredentialsProvider>()
                    .ok_or(CredentialsStageError::MissingCredentialsProvider)?;
                credential_provider.provide_credentials()
            };
            let credentials = cred_future.await?;
            request.config_mut().insert(credentials);
            Ok(request)
        })
    }
}

#[cfg(test)]
mod tests {
    use super::CredentialsStage;
    use crate::provider::set_provider;
    use crate::Credentials;
    use smithy_http::body::SdkBody;
    use smithy_http::middleware::AsyncMapRequest;
    use smithy_http::operation;
    use std::sync::Arc;

    #[tokio::test]
    async fn async_map_request_apply_requires_credential_provider() {
        let req = operation::Request::new(http::Request::new(SdkBody::from("some body")));
        CredentialsStage::new()
            .apply(req)
            .await
            .expect_err("should fail if there's no credential provider in the bag");
    }

    #[tokio::test]
    async fn async_map_request_apply_populates_credentials() {
        let mut req = operation::Request::new(http::Request::new(SdkBody::from("some body")));
        set_provider(
            &mut req.config_mut(),
            Arc::new(Credentials::from_keys("test", "test", None)),
        );
        let req = CredentialsStage::new()
            .apply(req)
            .await
            .expect("credential provider is in the bag; should succeed");
        assert!(
            req.config().get::<Credentials>().is_some(),
            "it should set credentials on the request config"
        );
    }
}
+66 −103
Original line number Diff line number Diff line
@@ -3,132 +3,95 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

use crate::{Credentials, CredentialsError, ProvideCredentials};
use std::collections::HashMap;
use std::env::VarError;
pub mod env;

/// Load Credentials from Environment Variables
pub struct EnvironmentVariableCredentialsProvider {
    env: Box<dyn Fn(&str) -> Result<String, VarError> + Send + Sync>,
}
use crate::Credentials;
use smithy_http::property_bag::PropertyBag;
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::future::{self, Future};
use std::pin::Pin;
use std::sync::Arc;

impl EnvironmentVariableCredentialsProvider {
    pub fn new() -> Self {
        EnvironmentVariableCredentialsProvider { env: Box::new(var) }
#[derive(Debug)]
#[non_exhaustive]
pub enum CredentialsError {
    CredentialsNotLoaded,
    Unhandled(Box<dyn Error + Send + Sync + 'static>),
}

    /// Create a EnvironmentVariable provider from a HashMap for testing
    pub fn for_map(env: HashMap<String, String>) -> Self {
        EnvironmentVariableCredentialsProvider {
            env: Box::new(move |key: &str| {
                env.get(key)
                    .ok_or(VarError::NotPresent)
                    .map(|k| k.to_string())
            }),
impl Display for CredentialsError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            CredentialsError::CredentialsNotLoaded => write!(f, "CredentialsNotLoaded"),
            CredentialsError::Unhandled(err) => write!(f, "{}", err),
        }
    }
}

impl Default for EnvironmentVariableCredentialsProvider {
    fn default() -> Self {
        Self::new()
impl Error for CredentialsError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            CredentialsError::Unhandled(e) => Some(e.as_ref() as _),
            _ => None,
        }
    }

fn var(key: &str) -> Result<String, VarError> {
    std::env::var(key)
}

const ENV_PROVIDER: &str = "EnvironmentVariable";
pub type CredentialsResult = Result<Credentials, CredentialsError>;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;

impl ProvideCredentials for EnvironmentVariableCredentialsProvider {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
        let access_key = (self.env)("AWS_ACCESS_KEY_ID").map_err(to_cred_error)?;
        let secret_key = (self.env)("AWS_SECRET_ACCESS_KEY")
            .or_else(|_| (self.env)("SECRET_ACCESS_KEY"))
            .map_err(to_cred_error)?;
        let session_token = (self.env)("AWS_SESSION_TOKEN").ok();
        Ok(Credentials::new(
            access_key,
            secret_key,
            session_token,
            None,
            ENV_PROVIDER,
        ))
/// An asynchronous credentials provider
///
/// If your use-case is synchronous, you should implement [ProvideCredentials] instead.
pub trait AsyncProvideCredentials: Send + Sync {
    fn provide_credentials(&self) -> BoxFuture<CredentialsResult>;
}

pub type CredentialsProvider = Arc<dyn AsyncProvideCredentials>;

/// A synchronous credentials provider
///
/// This is offered as a convenience for credential provider implementations that don't
/// need to be async. Otherwise, implement [AsyncProvideCredentials].
pub trait ProvideCredentials: Send + Sync {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError>;
}

fn to_cred_error(err: VarError) -> CredentialsError {
    match err {
        VarError::NotPresent => CredentialsError::CredentialsNotLoaded,
        e @ VarError::NotUnicode(_) => CredentialsError::Unhandled(Box::new(e)),
impl<T> AsyncProvideCredentials for T
where
    T: ProvideCredentials,
{
    fn provide_credentials(&self) -> BoxFuture<CredentialsResult> {
        let result = self.provide_credentials();
        Box::pin(future::ready(result))
    }
}

#[cfg(test)]
mod test {
    use crate::provider::EnvironmentVariableCredentialsProvider;
    use crate::{CredentialsError, ProvideCredentials};
    use std::collections::HashMap;

    #[test]
    fn valid_no_token() {
        let mut env = HashMap::new();
        env.insert("AWS_ACCESS_KEY_ID".to_owned(), "access".to_owned());
        env.insert("AWS_SECRET_ACCESS_KEY".to_owned(), "secret".to_owned());

        let provider = EnvironmentVariableCredentialsProvider::for_map(env);
        let creds = provider.provide_credentials().expect("valid credentials");
        assert_eq!(creds.session_token(), None);
        assert_eq!(creds.access_key_id(), "access");
        assert_eq!(creds.secret_access_key(), "secret");
pub fn default_provider() -> impl AsyncProvideCredentials {
    // TODO: this should be a chain based on the CRT
    env::EnvironmentVariableCredentialsProvider::new()
}

    #[test]
    fn valid_with_token() {
        let mut env = HashMap::new();
        env.insert("AWS_ACCESS_KEY_ID".to_owned(), "access".to_owned());
        env.insert("AWS_SECRET_ACCESS_KEY".to_owned(), "secret".to_owned());
        env.insert("AWS_SESSION_TOKEN".to_owned(), "token".to_owned());

        let provider = EnvironmentVariableCredentialsProvider::for_map(env);
        let creds = provider.provide_credentials().expect("valid credentials");
        assert_eq!(creds.session_token().unwrap(), "token");
        assert_eq!(creds.access_key_id(), "access");
        assert_eq!(creds.secret_access_key(), "secret");
impl ProvideCredentials for Credentials {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
        Ok(self.clone())
    }

    #[test]
    fn secret_key_fallback() {
        let mut env = HashMap::new();
        env.insert("AWS_ACCESS_KEY_ID".to_owned(), "access".to_owned());
        env.insert("SECRET_ACCESS_KEY".to_owned(), "secret".to_owned());
        env.insert("AWS_SESSION_TOKEN".to_owned(), "token".to_owned());

        let provider = EnvironmentVariableCredentialsProvider::for_map(env);
        let creds = provider.provide_credentials().expect("valid credentials");
        assert_eq!(creds.session_token().unwrap(), "token");
        assert_eq!(creds.access_key_id(), "access");
        assert_eq!(creds.secret_access_key(), "secret");
}

    #[test]
    fn missing() {
        let env = HashMap::new();
        let provider = EnvironmentVariableCredentialsProvider::for_map(env);
        let err = provider
            .provide_credentials()
            .expect_err("no credentials defined");
        match err {
            CredentialsError::Unhandled(_) => panic!("wrong error type"),
            _ => (),
        };
pub fn set_provider(config: &mut PropertyBag, provider: Arc<dyn AsyncProvideCredentials>) {
    config.insert(provider);
}

#[cfg(test)]
mod test {
    use crate::Credentials;

    fn assert_send_sync<T: Send + Sync>() {}

    #[test]
    fn real_environment() {
        let provider = EnvironmentVariableCredentialsProvider::new();
        // we don't know what's in the env, just make sure it doesn't crash.
        let _ = provider.provide_credentials();
    fn creds_are_send_sync() {
        assert_send_sync::<Credentials>()
    }
}
Loading