Unverified Commit 53fab07a authored by Jesse Szwedko's avatar Jesse Szwedko Committed by GitHub
Browse files

Add support for `credential_process` from profiles (#1356)



Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>
parent fe582a8d
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -46,3 +46,9 @@ message = "Switch to [RustCrypto](https://github.com/RustCrypto)'s implementatio
references = ["smithy-rs#1404"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "petrosagg"

[[aws-sdk-rust]]
message = "Add support for `credential_process` in AWS configs for fetching credentials from an external process."
references = ["smithy-rs#1356"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jszwedko"
+292 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Credentials Provider for external process

use crate::json_credentials::{json_parse_loop, InvalidJsonCredentials, RefreshableCredentials};
use aws_smithy_json::deserialize::Token;
use aws_smithy_types::date_time::Format;
use aws_smithy_types::DateTime;
use aws_types::credentials::{future, CredentialsError, ProvideCredentials};
use aws_types::{credentials, Credentials};
use std::fmt;
use std::process::Command;
use std::time::SystemTime;

pub(crate) struct CommandWithSensitiveArgs<T>(T);

impl<T> CommandWithSensitiveArgs<T>
where
    T: AsRef<str>,
{
    pub(crate) fn new(value: T) -> Self {
        Self(value)
    }

    pub(crate) fn unredacted(&self) -> &str {
        self.0.as_ref()
    }
}

impl<T> fmt::Display for CommandWithSensitiveArgs<T>
where
    T: AsRef<str>,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        // Security: The arguments for command must be redacted since they can be sensitive
        let command = self.0.as_ref();
        match command.find(char::is_whitespace) {
            Some(index) => write!(f, "{} ** arguments redacted **", &command[0..index]),
            None => write!(f, "{}", command),
        }
    }
}

impl<T> fmt::Debug for CommandWithSensitiveArgs<T>
where
    T: AsRef<str>,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{:?}", format!("{}", self))
    }
}

impl<T> Clone for CommandWithSensitiveArgs<T>
where
    T: Clone,
{
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

/// External process credentials provider
///
/// This credentials provider runs a configured external process and parses
/// its output to retrieve credentials.
///
/// The external process must exit with status 0 and output the following
/// JSON format to `stdout` to provide credentials:
///
/// ```json
/// {
///     "Version:" 1,
///     "AccessKeyId": "access key id",
///     "SecretAccessKey": "secret access key",
///     "SessionToken": "session token",
///     "Expiration": "time that the expiration will expire"
/// }
/// ```
///
/// The `Version` must be set to 1. `AccessKeyId` and `SecretAccessKey` are always required.
/// `SessionToken` must be set if a session token is associated with the `AccessKeyId`.
/// The `Expiration` is optional, and must be given in the RFC 3339 date time format (e.g.,
/// `2022-05-26T12:34:56.789Z`).
///
/// If the external process exits with a non-zero status, then the contents of `stderr`
/// will be output as part of the credentials provider error message.
///
/// This credentials provider is included in the profile credentials provider, and can be
/// configured using the `credential_process` attribute. For example:
///
/// ```plain
/// [profile example]
/// credential_process = /path/to/my/process --some --arguments
/// ```
#[derive(Debug)]
pub struct CredentialProcessProvider {
    command: CommandWithSensitiveArgs<String>,
}

impl ProvideCredentials for CredentialProcessProvider {
    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
    where
        Self: 'a,
    {
        future::ProvideCredentials::new(self.credentials())
    }
}

impl CredentialProcessProvider {
    /// Create new [`CredentialProcessProvider`] with the `command` needed to execute the external process.
    pub fn new(command: String) -> Self {
        Self {
            command: CommandWithSensitiveArgs::new(command),
        }
    }

    async fn credentials(&self) -> credentials::Result {
        // Security: command arguments must be redacted at debug level
        tracing::debug!(command = %self.command, "loading credentials from external process");

        let mut command = if cfg!(windows) {
            let mut command = Command::new("cmd.exe");
            command.args(&["/C", self.command.unredacted()]);
            command
        } else {
            let mut command = Command::new("sh");
            command.args(&["-c", self.command.unredacted()]);
            command
        };

        let output = command.output().map_err(|e| {
            CredentialsError::provider_error(format!(
                "Error retrieving credentials from external process: {}",
                e
            ))
        })?;

        // Security: command arguments can be logged at trace level
        tracing::trace!(command = ?command, status = ?output.status, "executed command (unredacted)");

        if !output.status.success() {
            let reason =
                std::str::from_utf8(&output.stderr).unwrap_or("could not decode stderr as UTF-8");
            return Err(CredentialsError::provider_error(format!(
                "Error retrieving credentials: external process exited with code {}. Stderr: {}",
                output.status, reason
            )));
        }

        let output = std::str::from_utf8(&output.stdout).map_err(|e| {
            CredentialsError::provider_error(format!(
                "Error retrieving credentials from external process: could not decode output as UTF-8: {}",
                e
            ))
        })?;

        match parse_credential_process_json_credentials(output) {
            Ok(RefreshableCredentials {
                access_key_id,
                secret_access_key,
                session_token,
                expiration,
                ..
            }) => Ok(Credentials::new(
                access_key_id,
                secret_access_key,
                Some(session_token.to_string()),
                expiration.into(),
                "CredentialProcess",
            )),
            Err(invalid) => Err(CredentialsError::provider_error(format!(
                "Error retrieving credentials from external process, could not parse response: {}",
                invalid
            ))),
        }
    }
}

/// Deserialize a credential_process response from a string
///
/// Returns an error if the response cannot be successfully parsed or is missing keys.
///
/// Keys are case insensitive.
pub(crate) fn parse_credential_process_json_credentials(
    credentials_response: &str,
) -> Result<RefreshableCredentials, InvalidJsonCredentials> {
    let mut version = None;
    let mut access_key_id = None;
    let mut secret_access_key = None;
    let mut session_token = None;
    let mut expiration = None;
    json_parse_loop(credentials_response.as_bytes(), |key, value| {
        match (key, value) {
            /*
             "Version": 1,
             "AccessKeyId": "ASIARTESTID",
             "SecretAccessKey": "TESTSECRETKEY",
             "SessionToken": "TESTSESSIONTOKEN",
             "Expiration": "2022-05-02T18:36:00+00:00"
            */
            (key, Token::ValueNumber { value, .. }) if key.eq_ignore_ascii_case("Version") => {
                version = Some(value.to_i32())
            }
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("AccessKeyId") => {
                access_key_id = Some(value.to_unescaped()?)
            }
            (key, Token::ValueString { value, .. })
                if key.eq_ignore_ascii_case("SecretAccessKey") =>
            {
                secret_access_key = Some(value.to_unescaped()?)
            }
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("SessionToken") => {
                session_token = Some(value.to_unescaped()?)
            }
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("Expiration") => {
                expiration = Some(value.to_unescaped()?)
            }

            _ => {}
        };
        Ok(())
    })?;

    match version {
        Some(1) => { /* continue */ }
        None => return Err(InvalidJsonCredentials::MissingField("Version")),
        Some(version) => {
            return Err(InvalidJsonCredentials::InvalidField {
                field: "version",
                err: format!("unknown version number: {}", version).into(),
            })
        }
    }

    let access_key_id = access_key_id.ok_or(InvalidJsonCredentials::MissingField("AccessKeyId"))?;
    let secret_access_key =
        secret_access_key.ok_or(InvalidJsonCredentials::MissingField("SecretAccessKey"))?;
    let session_token = session_token.ok_or(InvalidJsonCredentials::MissingField("Token"))?;
    let expiration = expiration.ok_or(InvalidJsonCredentials::MissingField("Expiration"))?;
    let expiration = SystemTime::try_from(
        DateTime::from_str(expiration.as_ref(), Format::DateTime).map_err(|err| {
            InvalidJsonCredentials::InvalidField {
                field: "Expiration",
                err: err.into(),
            }
        })?,
    )
    .map_err(|_| {
        InvalidJsonCredentials::Other(
            "credential expiration time cannot be represented by a DateTime".into(),
        )
    })?;
    Ok(RefreshableCredentials {
        access_key_id,
        secret_access_key,
        session_token,
        expiration,
    })
}

#[cfg(test)]
mod test {
    use crate::credential_process::CredentialProcessProvider;
    use aws_smithy_types::date_time::Format;
    use aws_smithy_types::DateTime;
    use aws_types::credentials::ProvideCredentials;
    use std::time::SystemTime;

    #[tokio::test]
    async fn test_credential_process() {
        let provider = CredentialProcessProvider::new(String::from(
            r#"echo '{ "Version": 1, "AccessKeyId": "ASIARTESTID", "SecretAccessKey": "TESTSECRETKEY", "SessionToken": "TESTSESSIONTOKEN", "Expiration": "2022-05-02T18:36:00+00:00" }'"#,
        ));
        let creds = provider.provide_credentials().await.expect("valid creds");
        assert_eq!(creds.access_key_id(), "ASIARTESTID");
        assert_eq!(creds.secret_access_key(), "TESTSECRETKEY");
        assert_eq!(creds.session_token(), Some("TESTSESSIONTOKEN"));
        assert_eq!(
            creds.expiry(),
            Some(
                SystemTime::try_from(
                    DateTime::from_str("2022-05-02T18:36:00+00:00", Format::DateTime)
                        .expect("static datetime")
                )
                .expect("static datetime")
            )
        );
    }
}
+3 −3
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ use aws_types::credentials::CredentialsError;
use aws_types::{credentials, Credentials};

use crate::connector::expect_connector;
use crate::json_credentials::{parse_json_credentials, JsonCredentials};
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;

use bytes::Bytes;
@@ -145,12 +145,12 @@ impl ParseStrictResponse for CredentialsResponseParser {
            std::str::from_utf8(response.body().as_ref()).map_err(CredentialsError::unhandled)?;
        let json_creds = parse_json_credentials(str_resp).map_err(CredentialsError::unhandled)?;
        match json_creds {
            JsonCredentials::RefreshableCredentials {
            JsonCredentials::RefreshableCredentials(RefreshableCredentials {
                access_key_id,
                secret_access_key,
                session_token,
                expiration,
            } => Ok(Credentials::new(
            }) => Ok(Credentials::new(
                access_key_id,
                secret_access_key,
                Some(session_token.to_string()),
+3 −3
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@

use crate::imds;
use crate::imds::client::{ImdsError, LazyClient};
use crate::json_credentials::{parse_json_credentials, JsonCredentials};
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;
use aws_smithy_client::SdkError;
use aws_types::credentials::{future, CredentialsError, ProvideCredentials};
@@ -170,13 +170,13 @@ impl ImdsCredentialsProvider {
            .await
            .map_err(CredentialsError::provider_error)?;
        match parse_json_credentials(&credentials) {
            Ok(JsonCredentials::RefreshableCredentials {
            Ok(JsonCredentials::RefreshableCredentials(RefreshableCredentials {
                access_key_id,
                secret_access_key,
                session_token,
                expiration,
                ..
            }) => Ok(Credentials::new(
            })) => Ok(Credentials::new(
                access_key_id,
                secret_access_key,
                Some(session_token.to_string()),
+62 −32
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ use aws_smithy_types::DateTime;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::fmt::{self, Display, Formatter};
use std::time::SystemTime;

#[derive(Debug)]
@@ -63,15 +63,29 @@ impl Display for InvalidJsonCredentials {

impl Error for InvalidJsonCredentials {}

#[derive(PartialEq, Eq)]
pub(crate) struct RefreshableCredentials<'a> {
    pub(crate) access_key_id: Cow<'a, str>,
    pub(crate) secret_access_key: Cow<'a, str>,
    pub(crate) session_token: Cow<'a, str>,
    pub(crate) expiration: SystemTime,
}

impl<'a> fmt::Debug for RefreshableCredentials<'a> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("RefreshableCredentials")
            .field("access_key_id", &self.access_key_id)
            .field("secret_access_key", &"** redacted **")
            .field("session_token", &"** redacted **")
            .field("expiration", &self.expiration)
            .finish()
    }
}

#[non_exhaustive]
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum JsonCredentials<'a> {
    RefreshableCredentials {
        access_key_id: Cow<'a, str>,
        secret_access_key: Cow<'a, str>,
        session_token: Cow<'a, str>,
        expiration: SystemTime,
    },
    RefreshableCredentials(RefreshableCredentials<'a>),
    Error {
        code: Cow<'a, str>,
        message: Cow<'a, str>,
@@ -115,7 +129,7 @@ pub(crate) fn parse_json_credentials(
    let mut expiration = None;
    let mut message = None;
    json_parse_loop(credentials_response.as_bytes(), |key, value| {
        match key {
        match (key, value) {
            /*
             "Code": "Success",
             "Type": "AWS-HMAC",
@@ -125,16 +139,31 @@ pub(crate) fn parse_json_credentials(
             "Expiration" : "....",
             "LastUpdated" : "2009-11-23T0:00:00Z"
            */
            c if c.eq_ignore_ascii_case("Code") => code = Some(value),
            c if c.eq_ignore_ascii_case("AccessKeyId") => access_key_id = Some(value),
            c if c.eq_ignore_ascii_case("SecretAccessKey") => secret_access_key = Some(value),
            c if c.eq_ignore_ascii_case("Token") => session_token = Some(value),
            c if c.eq_ignore_ascii_case("Expiration") => expiration = Some(value),
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("Code") => {
                code = Some(value.to_unescaped()?);
            }
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("AccessKeyId") => {
                access_key_id = Some(value.to_unescaped()?);
            }
            (key, Token::ValueString { value, .. })
                if key.eq_ignore_ascii_case("SecretAccessKey") =>
            {
                secret_access_key = Some(value.to_unescaped()?);
            }
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("Token") => {
                session_token = Some(value.to_unescaped()?);
            }
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("Expiration") => {
                expiration = Some(value.to_unescaped()?);
            }

            // Error case handling: message will be set
            c if c.eq_ignore_ascii_case("Message") => message = Some(value),
            _ => {}
            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("Message") => {
                message = Some(value.to_unescaped()?);
            }
            _ => {}
        };
        Ok(())
    })?;
    match code {
        // IMDS does not appear to reply with a `Code` missing, but documentation indicates it
@@ -161,12 +190,14 @@ pub(crate) fn parse_json_credentials(
                    "credential expiration time cannot be represented by a SystemTime".into(),
                )
            })?;
            Ok(JsonCredentials::RefreshableCredentials {
            Ok(JsonCredentials::RefreshableCredentials(
                RefreshableCredentials {
                    access_key_id,
                    secret_access_key,
                    session_token,
                    expiration,
            })
                },
            ))
        }
        Some(other) => Ok(JsonCredentials::Error {
            code: other,
@@ -177,7 +208,7 @@ pub(crate) fn parse_json_credentials(

pub(crate) fn json_parse_loop<'a>(
    input: &'a [u8],
    mut f: impl FnMut(Cow<'a, str>, Cow<'a, str>),
    mut f: impl FnMut(Cow<'a, str>, &Token<'a>) -> Result<(), InvalidJsonCredentials>,
) -> Result<(), InvalidJsonCredentials> {
    let mut tokens = json_token_iter(input).peekable();
    if !matches!(tokens.next().transpose()?, Some(Token::StartObject { .. })) {
@@ -189,10 +220,9 @@ pub(crate) fn json_parse_loop<'a>(
        match tokens.next().transpose()? {
            Some(Token::EndObject { .. }) => break,
            Some(Token::ObjectKey { key, .. }) => {
                if let Some(Ok(Token::ValueString { value, .. })) = tokens.peek() {
                if let Some(Ok(token)) = tokens.peek() {
                    let key = key.to_unescaped()?;
                    let value = value.to_unescaped()?;
                    f(key, value)
                    f(key, token)?
                }
                skip_value(&mut tokens)?;
            }
@@ -214,7 +244,7 @@ pub(crate) fn json_parse_loop<'a>(
#[cfg(test)]
mod test {
    use crate::json_credentials::{
        parse_json_credentials, InvalidJsonCredentials, JsonCredentials,
        parse_json_credentials, InvalidJsonCredentials, JsonCredentials, RefreshableCredentials,
    };
    use std::time::{Duration, UNIX_EPOCH};

@@ -233,12 +263,12 @@ mod test {
        let parsed = parse_json_credentials(response).expect("valid JSON");
        assert_eq!(
            parsed,
            JsonCredentials::RefreshableCredentials {
            JsonCredentials::RefreshableCredentials(RefreshableCredentials {
                access_key_id: "ASIARTEST".into(),
                secret_access_key: "xjtest".into(),
                session_token: "IQote///test".into(),
                expiration: UNIX_EPOCH + Duration::from_secs(1631935916),
            }
            })
        )
    }

@@ -273,12 +303,12 @@ mod test {
        let parsed = parse_json_credentials(resp).expect("code not required");
        assert_eq!(
            parsed,
            JsonCredentials::RefreshableCredentials {
            JsonCredentials::RefreshableCredentials(RefreshableCredentials {
                access_key_id: "ASIARTEST".into(),
                secret_access_key: "xjtest".into(),
                session_token: "IQote///test".into(),
                expiration: UNIX_EPOCH + Duration::from_secs(1631935916),
            }
            })
        )
    }

@@ -347,12 +377,12 @@ mod test {
        assert!(
            matches!(
                &parsed,
                JsonCredentials::RefreshableCredentials {
                JsonCredentials::RefreshableCredentials(RefreshableCredentials{
                    access_key_id: Cow::Borrowed("ASIARTEST"),
                    secret_access_key: Cow::Borrowed("SECRETTEST"),
                    session_token,
                    expiration
                } if session_token.starts_with("token") && *expiration == UNIX_EPOCH + Duration::from_secs(1234567890)
                }) if session_token.starts_with("token") && *expiration == UNIX_EPOCH + Duration::from_secs(1234567890)
            ),
            "{:?}",
            parsed
Loading