Unverified Commit fc1e1799 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Refactor endpoints to be Smithy-native (#1641)

* Refactor endpoints to be Smithy-native

* Add test to DynamoDb

* Fix clippy

* Fix some clippy errors and update changelog
parent 1c02c08b
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -11,6 +11,12 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
[[aws-sdk-rust]]
message = """Refactor endpoint resolution internals to use `aws_smithy_types::Endpoint` internally. The public internal
functions `aws_endpoint::set_endpoint_resolver` and `aws_endpoint::get_endpoint_resolver were removed."""
references = ["smithy-rs#1641"]
meta = { breaking = true, "tada" = false, "bug" = false }
author = "rcoh"

[[aws-sdk-rust]]
message = """
+1 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ repository = "https://github.com/awslabs/smithy-rs"

[dependencies]
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" }
aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types"}
aws-types = { path = "../aws-types" }
http = "0.2.3"
regex = { version = "1.5.5", default-features = false, features = ["std"] }
+194 −31
Original line number Diff line number Diff line
@@ -10,27 +10,83 @@ pub mod partition;
pub use partition::Partition;
#[doc(hidden)]
pub use partition::PartitionResolver;
use std::collections::HashMap;

use aws_smithy_http::endpoint::EndpointPrefix;
use aws_smithy_http::endpoint::Error as EndpointError;
use aws_smithy_http::endpoint::{apply_endpoint, EndpointPrefix, ResolveEndpoint};
use aws_smithy_http::middleware::MapRequest;
use aws_smithy_http::operation::Request;
use aws_smithy_http::property_bag::PropertyBag;
use aws_smithy_types::endpoint::Endpoint as SmithyEndpoint;
use aws_smithy_types::Document;
use aws_types::region::{Region, SigningRegion};
use aws_types::SigningService;
use http::header::HeaderName;
use http::{HeaderValue, Uri};
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::str::FromStr;
use std::sync::Arc;

pub use aws_types::endpoint::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint};

type AwsEndpointResolver = Arc<dyn ResolveAwsEndpoint>;
pub fn get_endpoint_resolver(properties: &PropertyBag) -> Option<&AwsEndpointResolver> {
    properties.get()
#[doc(hidden)]
pub struct Params {
    region: Option<Region>,
}

impl Params {
    pub fn new(region: Option<Region>) -> Self {
        Self { region }
    }
}

#[doc(hidden)]
pub struct EndpointShim(Arc<dyn ResolveAwsEndpoint>);
impl EndpointShim {
    pub fn from_resolver(resolver: impl ResolveAwsEndpoint + 'static) -> Self {
        Self(Arc::new(resolver))
    }

    pub fn from_arc(arc: Arc<dyn ResolveAwsEndpoint>) -> Self {
        Self(arc)
    }
}

pub fn set_endpoint_resolver(properties: &mut PropertyBag, provider: AwsEndpointResolver) {
    properties.insert(provider);
impl ResolveEndpoint<Params> for EndpointShim {
    fn resolve_endpoint(
        &self,
        params: &Params,
    ) -> Result<SmithyEndpoint, aws_smithy_http::endpoint::Error> {
        let aws_endpoint = self
            .0
            .resolve_endpoint(
                params
                    .region
                    .as_ref()
                    .ok_or_else(|| EndpointError::message("no region in params"))?,
            )
            .map_err(|err| EndpointError::message("failure resolving endpoint").with_cause(err))?;
        let uri = aws_endpoint.endpoint().uri();
        let mut auth_scheme =
            HashMap::from([("name".to_string(), Document::String("sigv4".into()))]);
        if let Some(region) = aws_endpoint.credential_scope().region() {
            auth_scheme.insert(
                "signingRegion".to_string(),
                region.as_ref().to_string().into(),
            );
        }
        if let Some(service) = aws_endpoint.credential_scope().service() {
            auth_scheme.insert(
                "signingName".to_string(),
                service.as_ref().to_string().into(),
            );
        }
        Ok(SmithyEndpoint::builder()
            .url(uri.to_string())
            .property("authSchemes", vec![Document::Object(auth_scheme)])
            .build())
    }
}

/// Middleware Stage to Add an Endpoint to a Request
@@ -56,6 +112,7 @@ impl Display for AwsEndpointStageError {
        Debug::fmt(self, f)
    }
}

impl Error for AwsEndpointStageError {}

impl MapRequest for AwsEndpointStage {
@@ -63,30 +120,91 @@ impl MapRequest for AwsEndpointStage {

    fn apply(&self, request: Request) -> Result<Request, Self::Error> {
        request.augment(|mut http_req, props| {
            let provider =
                get_endpoint_resolver(props).ok_or(AwsEndpointStageError::NoEndpointResolver)?;
            let region = props
                .get::<Region>()
                .ok_or(AwsEndpointStageError::NoRegion)?;
            let endpoint = provider
                .resolve_endpoint(region)
                .map_err(AwsEndpointStageError::EndpointResolutionError)?;
            tracing::debug!(endpoint = ?endpoint, base_region = ?region, "resolved endpoint");
            let signing_region = endpoint
                .credential_scope()
                .region()
                .cloned()
                .unwrap_or_else(|| region.clone().into());
            props.insert::<SigningRegion>(signing_region);
            if let Some(signing_service) = endpoint.credential_scope().service() {
                props.insert::<SigningService>(signing_service.clone());
            }
            endpoint.set_endpoint(http_req.uri_mut(), props.get::<EndpointPrefix>());
            let endpoint_result = props
                .get_mut::<aws_smithy_http::endpoint::Result>()
                .ok_or(AwsEndpointStageError::NoEndpointResolver)?;
            let endpoint = match endpoint_result {
                // downgrade the mut ref to a shared ref
                Ok(_endpoint) => props.get::<aws_smithy_http::endpoint::Result>()
                    .expect("unreachable (prevalidated that the endpoint is in the bag)")
                    .as_ref()
                    .expect("unreachable (prevalidated that this is OK)"),
                Err(e) => {
                    // We need to own the error to return it, so take it and leave a stub error in
                    // its place
                    return Err(AwsEndpointStageError::EndpointResolutionError(std::mem::replace(
                        e,
                        aws_smithy_http::endpoint::Error::message("the original error was directly returned")
                    ).into()));
                }
            };
            let (uri, signing_scope_override, signing_service_override) = smithy_to_aws(endpoint)
                .map_err(|err| AwsEndpointStageError::EndpointResolutionError(err))?;
            tracing::debug!(endpoint = ?endpoint, base_region = ?signing_scope_override, "resolved endpoint");
            apply_endpoint(http_req.uri_mut(), &uri, props.get::<EndpointPrefix>())
                .map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?;
            for (header_name, header_values) in endpoint.headers() {
                http_req.headers_mut().remove(header_name);
                for value in header_values {
                    http_req.headers_mut().insert(
                        HeaderName::from_str(header_name)
                            .map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?,
                        HeaderValue::from_str(value)
                            .map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?,
                    );
                }
            }

            if let Some(signing_scope) = signing_scope_override {
                props.insert(signing_scope);
            }
            if let Some(signing_service) = signing_service_override {
                props.insert(signing_service);
            }
            Ok(http_req)
        })
    }
}

type EndpointMetadata = (Uri, Option<SigningRegion>, Option<SigningService>);

fn smithy_to_aws(value: &SmithyEndpoint) -> Result<EndpointMetadata, Box<dyn Error + Send + Sync>> {
    let uri: Uri = value.url().parse()?;
    // look for v4 as an auth scheme
    let auth_schemes = match value
        .properties()
        .get("authSchemes")
        .ok_or("no auth schemes in metadata")?
    {
        Document::Array(schemes) => schemes,
        _other => return Err("expected an array for authSchemes".into()),
    };
    let v4 = auth_schemes
        .iter()
        .flat_map(|doc| match doc {
            Document::Object(map)
                if map.get("name") == Some(&Document::String("sigv4".to_string())) =>
            {
                Some(map)
            }
            _ => None,
        })
        .next()
        .ok_or("could not find v4 as an acceptable auth scheme")?;

    let signing_scope = match v4.get("signingRegion") {
        Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
        None => None,
        _ => return Err("unexpected type".into()),
    };
    let signing_service = match v4.get("signingName") {
        Some(Document::String(s)) => Some(SigningService::from(s.to_string())),
        None => None,
        _ => return Err("unexpected type".into()),
    };
    Ok((uri, signing_scope, signing_service))
}

#[cfg(test)]
mod test {
    use std::sync::Arc;
@@ -95,13 +213,15 @@ mod test {
    use http::Uri;

    use aws_smithy_http::body::SdkBody;
    use aws_smithy_http::endpoint::ResolveEndpoint;
    use aws_smithy_http::middleware::MapRequest;
    use aws_smithy_http::operation;
    use aws_types::endpoint::CredentialScope;
    use aws_types::region::{Region, SigningRegion};
    use aws_types::SigningService;

    use crate::partition::endpoint::{Metadata, Protocol, SignatureVersion};
    use crate::{set_endpoint_resolver, AwsEndpointStage, CredentialScope};
    use crate::{AwsEndpointStage, EndpointShim, Params};

    #[test]
    fn default_endpoint_updates_request() {
@@ -118,7 +238,10 @@ mod test {
            let mut props = req.properties_mut();
            props.insert(region.clone());
            props.insert(SigningService::from_static("kinesis"));
            set_endpoint_resolver(&mut props, provider);
            props.insert(
                EndpointShim::from_arc(provider)
                    .resolve_endpoint(&Params::new(Some(region.clone()))),
            );
        };
        let req = AwsEndpointStage.apply(req).expect("should succeed");
        assert_eq!(req.properties().get(), Some(&SigningRegion::from(region)));
@@ -127,12 +250,18 @@ mod test {
            Some(&SigningService::from_static("kinesis"))
        );

        let (req, _conf) = req.into_parts();
        let (req, conf) = req.into_parts();
        assert_eq!(
            req.uri(),
            &Uri::from_static("https://kinesis.us-east-1.amazonaws.com")
        );
        assert!(req.headers().get(HOST).is_none());
        assert!(
            conf.acquire()
                .get::<aws_smithy_http::endpoint::Result>()
                .is_some(),
            "Endpoint middleware MUST leave the result in the bag"
        );
    }

    #[test]
@@ -151,9 +280,12 @@ mod test {
        let mut req = operation::Request::new(req);
        {
            let mut props = req.properties_mut();
            props.insert(region);
            props.insert(SigningService::from_static("kinesis"));
            set_endpoint_resolver(&mut props, provider);
            props.insert(region.clone());
            props.insert(SigningService::from_static("qldb"));
            props.insert(
                EndpointShim::from_arc(provider)
                    .resolve_endpoint(&Params::new(Some(region.clone()))),
            );
        };
        let req = AwsEndpointStage.apply(req).expect("should succeed");
        assert_eq!(
@@ -165,4 +297,35 @@ mod test {
            Some(&SigningService::from_static("qldb-override"))
        );
    }

    #[test]
    fn supports_fallback_when_scope_is_unset() {
        let provider = Arc::new(Metadata {
            uri_template: "www.service.com",
            protocol: Protocol::Http,
            credential_scope: CredentialScope::builder().build(),
            signature_versions: SignatureVersion::V4,
        });
        let req = http::Request::new(SdkBody::from(""));
        let region = Region::new("us-east-1");
        let mut req = operation::Request::new(req);
        {
            let mut props = req.properties_mut();
            props.insert(region.clone());
            props.insert(SigningService::from_static("qldb"));
            props.insert(
                EndpointShim::from_arc(provider)
                    .resolve_endpoint(&Params::new(Some(region.clone()))),
            );
        };
        let req = AwsEndpointStage.apply(req).expect("should succeed");
        assert_eq!(
            req.properties().get(),
            Some(&SigningRegion::from(Region::new("us-east-1")))
        );
        assert_eq!(
            req.properties().get(),
            Some(&SigningService::from_static("qldb"))
        );
    }
}
+6 −6
Original line number Diff line number Diff line
@@ -7,7 +7,6 @@ use std::convert::Infallible;
use std::error::Error;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use std::time::{Duration, UNIX_EPOCH};

use bytes::Bytes;
@@ -15,7 +14,7 @@ use http::header::{AUTHORIZATION, USER_AGENT};
use http::{self, Uri};

use aws_endpoint::partition::endpoint::{Protocol, SignatureVersion};
use aws_endpoint::set_endpoint_resolver;
use aws_endpoint::{EndpointShim, Params};
use aws_http::retry::AwsErrorRetryPolicy;
use aws_http::user_agent::AwsUserAgent;
use aws_inlineable::middleware::DefaultMiddleware;
@@ -23,6 +22,7 @@ use aws_sig_auth::signer::OperationSigningConfig;

use aws_smithy_client::test_connection::TestConnection;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::ResolveEndpoint;
use aws_smithy_http::operation;
use aws_smithy_http::operation::Operation;
use aws_smithy_http::response::ParseHttpResponse;
@@ -83,14 +83,14 @@ fn test_operation() -> Operation<TestOperationParser, AwsErrorRetryPolicy> {
            .unwrap(),
    )
    .augment(|req, mut conf| {
        set_endpoint_resolver(
            &mut conf,
            Arc::new(aws_endpoint::partition::endpoint::Metadata {
        conf.insert(
            EndpointShim::from_resolver(aws_endpoint::partition::endpoint::Metadata {
                uri_template: "test-service.{region}.amazonaws.com",
                protocol: Protocol::Https,
                credential_scope: Default::default(),
                signature_versions: SignatureVersion::V4,
            }),
            })
            .resolve_endpoint(&Params::new(Some(Region::new("test-region")))),
        );
        aws_http::auth::set_provider(
            &mut conf,
+11 −9
Original line number Diff line number Diff line
@@ -172,8 +172,9 @@ mod test {
    use crate::middleware::{SigV4SigningStage, Signature, SigningStageError};
    use crate::signer::{OperationSigningConfig, SigV4Signer};
    use aws_endpoint::partition::endpoint::{Protocol, SignatureVersion};
    use aws_endpoint::{set_endpoint_resolver, AwsEndpointStage};
    use aws_endpoint::{AwsEndpointStage, Params};
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_http::endpoint::ResolveEndpoint;
    use aws_smithy_http::middleware::MapRequest;
    use aws_smithy_http::operation;
    use aws_types::region::{Region, SigningRegion};
@@ -181,7 +182,6 @@ mod test {
    use aws_types::SigningService;
    use http::header::AUTHORIZATION;
    use std::convert::Infallible;
    use std::sync::Arc;
    use std::time::{Duration, UNIX_EPOCH};

    #[test]
@@ -214,12 +214,14 @@ mod test {
    // check that the endpoint middleware followed by signing middleware produce the expected result
    #[test]
    fn endpoint_plus_signer() {
        let provider = Arc::new(aws_endpoint::partition::endpoint::Metadata {
        let provider = aws_endpoint::EndpointShim::from_resolver(
            aws_endpoint::partition::endpoint::Metadata {
                uri_template: "kinesis.{region}.amazonaws.com",
                protocol: Protocol::Https,
                credential_scope: Default::default(),
                signature_versions: SignatureVersion::V4,
        });
            },
        );
        let req = http::Request::new(SdkBody::from(""));
        let region = Region::new("us-east-1");
        let req = operation::Request::new(req)
@@ -227,7 +229,7 @@ mod test {
                conf.insert(region.clone());
                conf.insert(UNIX_EPOCH + Duration::new(1611160427, 0));
                conf.insert(SigningService::from_static("kinesis"));
                set_endpoint_resolver(conf, provider);
                conf.insert(provider.resolve_endpoint(&Params::new(Some(region.clone()))));
                Result::<_, Infallible>::Ok(req)
            })
            .expect("succeeds");
Loading