Unverified Commit ba0d182e authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Add convenience to async provide credentials from a closure (#577)

parent 081387bd
Loading
Loading
Loading
Loading
+46 −2
Original line number Diff line number Diff line
@@ -44,17 +44,61 @@ type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;

/// An asynchronous credentials provider
///
/// If your use-case is synchronous, you should implement [ProvideCredentials] instead.
/// If your use-case is synchronous, you should implement [`ProvideCredentials`] instead. Otherwise,
/// consider using [`async_provide_credentials_fn`] with a closure rather than directly implementing
/// this trait.
pub trait AsyncProvideCredentials: Send + Sync {
    fn provide_credentials(&self) -> BoxFuture<CredentialsResult>;
}

pub type CredentialsProvider = Arc<dyn AsyncProvideCredentials>;

/// A [`AsyncProvideCredentials`] implemented by a closure.
///
/// See [`async_provide_credentials_fn`] for more details.
#[derive(Copy, Clone)]
pub struct AsyncProvideCredentialsFn<T: Send + Sync> {
    f: T,
}

impl<T, F> AsyncProvideCredentials for AsyncProvideCredentialsFn<T>
where
    T: Fn() -> F + Send + Sync,
    F: Future<Output = CredentialsResult> + Send + 'static,
{
    fn provide_credentials(&self) -> BoxFuture<CredentialsResult> {
        Box::pin((self.f)())
    }
}

/// Returns a new [`AsyncProvideCredentialsFn`] with the given closure. This allows you
/// to create an [`AsyncProvideCredentials`] implementation from an async block that returns
/// a [`CredentialsResult`].
///
/// # Example
///
/// ```
/// use aws_auth::Credentials;
/// use aws_auth::provider::async_provide_credentials_fn;
///
/// async_provide_credentials_fn(|| async {
///     // Async process to retrieve credentials goes here
///     let credentials: Credentials = todo!().await?;
///     Ok(credentials)
/// });
/// ```
pub fn async_provide_credentials_fn<T, F>(f: T) -> AsyncProvideCredentialsFn<T>
where
    T: Fn() -> F + Send + Sync,
    F: Future<Output = CredentialsResult> + Send + 'static,
{
    AsyncProvideCredentialsFn { f }
}

/// A synchronous credentials provider
///
/// This is offered as a convenience for credential provider implementations that don't
/// need to be async. Otherwise, implement [AsyncProvideCredentials].
/// need to be async. Otherwise, implement [`AsyncProvideCredentials`].
pub trait ProvideCredentials: Send + Sync {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError>;
}
+26 −73
Original line number Diff line number Diff line
@@ -3,9 +3,7 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

use aws_auth::provider::{CredentialsError, ProvideCredentials};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use aws_auth::provider::{async_provide_credentials_fn, CredentialsError};
use sts::Credentials;

/// Implements a basic version of ProvideCredentials with AWS STS
@@ -14,72 +12,22 @@ use sts::Credentials;
async fn main() -> Result<(), dynamodb::Error> {
    tracing_subscriber::fmt::init();
    let client = sts::Client::from_env();
    let sts_provider = StsCredentialsProvider {
        client,
        credentials: Arc::new(Mutex::new(None)),
    };
    sts_provider.spawn_refresh_loop().await;

    // NOTE: Do not use this in production! This will grab new credentials for every request.
    // A high quality caching credential provider implementation is in the roadmap.
    let dynamodb_conf = dynamodb::Config::builder()
        .credentials_provider(sts_provider)
        .build();
    let client = dynamodb::Client::from_conf(dynamodb_conf);
    println!("tables: {:?}", client.list_tables().send().await?);
    Ok(())
}

/// This is a rough example of how you could implement ProvideCredentials with Amazon STS.
///
/// Do not use this in production! A high quality implementation is in the roadmap.
#[derive(Clone)]
struct StsCredentialsProvider {
    client: sts::Client,
    credentials: Arc<Mutex<Option<Credentials>>>,
}

impl ProvideCredentials for StsCredentialsProvider {
    fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
        let inner = self.credentials.lock().unwrap().clone();
        inner.ok_or(CredentialsError::CredentialsNotLoaded)
    }
}

impl StsCredentialsProvider {
    pub async fn spawn_refresh_loop(&self) {
        let _ = self
            .refresh()
        .credentials_provider(async_provide_credentials_fn(move || {
            let client = client.clone();
            async move {
                let session_token = client
                    .get_session_token()
                    .send()
                    .await
            .map_err(|e| eprintln!("failed to load credentials! {}", e));
        let this = self.clone();
        tokio::spawn(async move {
            loop {
                let needs_refresh = {
                    let creds = this.credentials.lock().unwrap();
                    let expiry = creds.as_ref().and_then(|creds| creds.expiry());
                    if creds.is_none() {
                        true
                    } else {
                        expiry
                            .map(|expiry| SystemTime::now() > expiry)
                            .unwrap_or(false)
                    }
                };
                if needs_refresh {
                    let _ = this
                        .refresh()
                        .await
                        .map_err(|e| eprintln!("failed to load credentials! {}", e));
                }
                tokio::time::sleep(Duration::from_secs(5)).await;
            }
        });
    }
    pub async fn refresh(&self) -> Result<(), sts::Error> {
        let session_token = self.client.get_session_token().send().await?;
                    .map_err(|err| CredentialsError::Unhandled(Box::new(err)))?;
                let sts_credentials = session_token
                    .credentials
                    .expect("should include credentials");
        *self.credentials.lock().unwrap() = Some(Credentials::new(
                Ok(Credentials::new(
                    sts_credentials.access_key_id.unwrap(),
                    sts_credentials.secret_access_key.unwrap(),
                    sts_credentials.session_token,
@@ -87,7 +35,12 @@ impl StsCredentialsProvider {
                        .expiration
                        .map(|expiry| expiry.to_system_time().expect("sts sent a time < 0")),
                    "Sts",
        ));
        Ok(())
                ))
            }
        }))
        .build();

    let client = dynamodb::Client::from_conf(dynamodb_conf);
    println!("tables: {:?}", client.list_tables().send().await?);
    Ok(())
}