Unverified Commit b7a8eb58 authored by Nikhil Benesch's avatar Nikhil Benesch Committed by GitHub
Browse files

Add endpoint resolver to SdkConfig (#1300)



* Add endpoint resolver to SdkConfig

This commit adds an `endpoint_resolver` field to SdkConfig, which flows
through to the `endpoint_resolver` field on each individual service
configuration. To avoid a circular dependency between the aws-types
crate and the aws-endpoint crate, the `AwsEndpoint` type and its
dependencies are moved from the aws-endpoint crate to the aws-types
crate, then rexported in aws-endpoint.

Fix aws-labs/aws-sdk-rust#396.

* impl From<'static str> for signing service and signing region

* Code generate for the static string implementation

* Add EndpointResolver to aws_config

* Add docs and adapt test

* Fix test

* Fix endpoint config customizations test to remove dead code

* Re-add `.set_endpoint`

* Remove print and update changelog

Co-authored-by: default avatarRussell Cohen <rcoh@amazon.com>
parent 3c256647
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -58,3 +58,9 @@ message = "Update urlencoding crate to v2.1.0"
references = ["smithy-rs#1301"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "benesch"

[[aws-sdk-rust]]
message = "Add endpoint resolver to SdkConfig. This enables overriding the endpoint resolver for all services build from a single SdkConfig."
references = ["smithy-rs#1300"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "benesch"
+29 −0
Original line number Diff line number Diff line
@@ -165,6 +165,7 @@ mod loader {
    use aws_smithy_types::timeout;
    use aws_types::app_name::AppName;
    use aws_types::credentials::{ProvideCredentials, SharedCredentialsProvider};
    use aws_types::endpoint::ResolveAwsEndpoint;
    use aws_types::SdkConfig;

    use crate::default_provider::{app_name, credentials, region, retry_config, timeout_config};
@@ -181,6 +182,7 @@ mod loader {
    pub struct ConfigLoader {
        app_name: Option<AppName>,
        credentials_provider: Option<SharedCredentialsProvider>,
        endpoint_resolver: Option<Arc<dyn ResolveAwsEndpoint>>,
        region: Option<Box<dyn ProvideRegion>>,
        retry_config: Option<RetryConfig>,
        sleep: Option<Arc<dyn AsyncSleep>>,
@@ -285,6 +287,30 @@ mod loader {
            self
        }

        /// Override the endpoint resolver used for **all** AWS Services
        ///
        /// This method will override the endpoint resolver used for **all** AWS services. This mainly
        /// exists to set a static endpoint for tools like `LocalStack`. For live traffic, AWS services
        /// require the service-specific endpoint resolver they load by default.
        ///
        /// # Examples
        ///
        /// Use a static endpoint for all services
        /// ```no_run
        /// # async fn doc() {
        /// use aws_smithy_http::endpoint::Endpoint;
        /// let sdk_config = aws_config::from_env()
        ///   .endpoint_resolver(Endpoint::immutable("http://localhost:1234".parse().expect("valid URI")))
        ///   .load().await;
        /// # }
        pub fn endpoint_resolver(
            mut self,
            endpoint_resolver: impl ResolveAwsEndpoint + 'static,
        ) -> Self {
            self.endpoint_resolver = Some(Arc::new(endpoint_resolver));
            self
        }

        /// Set configuration for all sub-loaders (credentials, region etc.)
        ///
        /// Update the `ProviderConfig` used for all nested loaders. This can be used to override
@@ -390,6 +416,8 @@ mod loader {
                SharedCredentialsProvider::new(builder.build().await)
            };

            let endpoint_resolver = self.endpoint_resolver;

            let mut builder = SdkConfig::builder()
                .region(region)
                .retry_config(retry_config)
@@ -397,6 +425,7 @@ mod loader {
                .credentials_provider(credentials_provider)
                .http_connector(http_connector);

            builder.set_endpoint_resolver(endpoint_resolver);
            builder.set_app_name(app_name);
            builder.set_sleep_impl(sleep_impl);
            builder.build()
+10 −126
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ pub use partition::Partition;
#[doc(hidden)]
pub use partition::PartitionResolver;

use aws_smithy_http::endpoint::{Endpoint, EndpointPrefix};
use aws_smithy_http::endpoint::EndpointPrefix;
use aws_smithy_http::middleware::MapRequest;
use aws_smithy_http::operation::Request;
use aws_smithy_http::property_bag::PropertyBag;
@@ -22,122 +22,7 @@ use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

/// Endpoint to connect to an AWS Service
///
/// An `AwsEndpoint` captures all necessary information needed to connect to an AWS service, including:
/// - The URI of the endpoint (needed to actually send the request)
/// - The name of the service (needed downstream for signing)
/// - The signing region (which may differ from the actual region)
#[derive(Clone, Debug)]
pub struct AwsEndpoint {
    endpoint: Endpoint,
    credential_scope: CredentialScope,
}

impl AwsEndpoint {
    pub fn set_endpoint(&self, mut uri: &mut http::Uri, endpoint_prefix: Option<&EndpointPrefix>) {
        self.endpoint.set_endpoint(&mut uri, endpoint_prefix);
    }
}

pub type BoxError = Box<dyn Error + Send + Sync + 'static>;

/// Resolve the AWS Endpoint for a given region
///
/// To provide a static endpoint, [`Endpoint`](aws_smithy_http::endpoint::Endpoint) implements this trait.
/// Example usage:
/// ```rust
/// # mod dynamodb {
/// # use aws_endpoint::ResolveAwsEndpoint;
/// # pub struct ConfigBuilder;
/// # impl ConfigBuilder {
/// #     pub fn endpoint(&mut self, resolver: impl ResolveAwsEndpoint + 'static) {
/// #         // ...
/// #     }
/// # }
/// # pub struct Config;
/// # impl Config {
/// #     pub fn builder() -> ConfigBuilder {
/// #         ConfigBuilder
/// #     }
/// # }
/// # }
/// use aws_smithy_http::endpoint::Endpoint;
/// use http::Uri;
/// let config = dynamodb::Config::builder()
///     .endpoint(
///         Endpoint::immutable(Uri::from_static("http://localhost:8080"))
///     );
/// ```
/// In the future, each AWS service will generate their own implementation of `ResolveAwsEndpoint`. This implementation
/// may use endpoint discovery. The list of supported regions for a given service
/// will be codegenerated from `endpoints.json`.
pub trait ResolveAwsEndpoint: Send + Sync {
    // TODO(https://github.com/awslabs/smithy-rs/issues/866): Create `ResolveEndpointError`
    fn resolve_endpoint(&self, region: &Region) -> Result<AwsEndpoint, BoxError>;
}

#[derive(Clone, Default, Debug)]
pub struct CredentialScope {
    region: Option<SigningRegion>,
    service: Option<SigningService>,
}

impl CredentialScope {
    pub fn builder() -> credential_scope::Builder {
        credential_scope::Builder::default()
    }
}

pub mod credential_scope {
    use crate::CredentialScope;
    use aws_types::region::SigningRegion;
    use aws_types::SigningService;

    #[derive(Debug, Default)]
    pub struct Builder {
        region: Option<SigningRegion>,
        service: Option<SigningService>,
    }

    impl Builder {
        pub fn region(mut self, region: &'static str) -> Self {
            self.region = Some(SigningRegion::from_static(region));
            self
        }

        pub fn service(mut self, service: &'static str) -> Self {
            self.service = Some(SigningService::from_static(service));
            self
        }

        pub fn build(self) -> CredentialScope {
            CredentialScope {
                region: self.region,
                service: self.service,
            }
        }
    }
}

impl CredentialScope {
    pub fn merge(&self, other: &CredentialScope) -> CredentialScope {
        CredentialScope {
            region: self.region.clone().or_else(|| other.region.clone()),
            service: self.service.clone().or_else(|| other.service.clone()),
        }
    }
}

/// An `Endpoint` can be its own resolver to support static endpoints
impl ResolveAwsEndpoint for Endpoint {
    fn resolve_endpoint(&self, _region: &Region) -> Result<AwsEndpoint, BoxError> {
        Ok(AwsEndpoint {
            endpoint: self.clone(),
            credential_scope: Default::default(),
        })
    }
}
pub use aws_types::endpoint::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint};

type AwsEndpointResolver = Arc<dyn ResolveAwsEndpoint>;
pub fn get_endpoint_resolver(properties: &PropertyBag) -> Option<&AwsEndpointResolver> {
@@ -188,16 +73,15 @@ impl MapRequest for AwsEndpointStage {
                .map_err(AwsEndpointStageError::EndpointResolutionError)?;
            tracing::debug!(endpoint = ?endpoint, base_region = ?region, "resolved endpoint");
            let signing_region = endpoint
                .credential_scope
                .region
                .credential_scope()
                .region()
                .cloned()
                .unwrap_or_else(|| region.clone().into());
            props.insert::<SigningRegion>(signing_region);
            if let Some(signing_service) = endpoint.credential_scope.service {
                props.insert::<SigningService>(signing_service);
            if let Some(signing_service) = endpoint.credential_scope().service() {
                props.insert::<SigningService>(signing_service.clone());
            }
            endpoint
                .endpoint
                .set_endpoint(http_req.uri_mut(), props.get::<EndpointPrefix>());
            endpoint.set_endpoint(http_req.uri_mut(), props.get::<EndpointPrefix>());
            Ok(http_req)
        })
    }
@@ -257,8 +141,8 @@ mod test {
            uri_template: "www.service.com",
            protocol: Protocol::Http,
            credential_scope: CredentialScope::builder()
                .service("qldb-override")
                .region("us-east-override")
                .service(SigningService::from_static("qldb-override"))
                .region(SigningRegion::from_static("us-east-override"))
                .build(),
            signature_versions: SignatureVersion::V4,
        });
+11 −13
Original line number Diff line number Diff line
@@ -3,8 +3,8 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

use crate::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint};
use aws_smithy_http::endpoint::Endpoint;
use aws_types::endpoint::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint};
use aws_types::region::Region;

/// Endpoint metadata
@@ -55,17 +55,15 @@ impl ResolveAwsEndpoint for Metadata {
        let uri = self.uri_template.replace("{region}", region.as_ref());
        let uri = format!("{}://{}", self.protocol.as_str(), uri);
        let endpoint = Endpoint::mutable(uri.parse()?);
        let ep = AwsEndpoint {
            endpoint,
            credential_scope: CredentialScope {
                service: self.credential_scope.service.clone(),
                region: self
                    .credential_scope
                    .region
                    .clone()
                    .or_else(|| Some(region.clone().into())),
            },
        };
        Ok(ep)
        let mut credential_scope = CredentialScope::builder().region(
            self.credential_scope
                .region()
                .cloned()
                .unwrap_or_else(|| region.clone().into()),
        );
        if let Some(service) = self.credential_scope.service() {
            credential_scope = credential_scope.service(service.clone());
        }
        Ok(AwsEndpoint::new(endpoint, credential_scope.build()))
    }
}
+26 −25
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@

pub mod endpoint;

use crate::{AwsEndpoint, BoxError, ResolveAwsEndpoint};
use aws_types::endpoint::{AwsEndpoint, BoxError, ResolveAwsEndpoint};
use aws_types::region::Region;
use regex::Regex;
use std::collections::HashMap;
@@ -19,6 +19,7 @@ use std::iter;
///
/// Once a partition has been identified, endpoint resolution is delegated to the underlying
/// partition.
#[derive(Debug)]
pub struct PartitionResolver {
    /// Base partition used if no partitions match the region regex
    base: Partition,
@@ -200,10 +201,10 @@ mod test {
                Metadata {
                    uri_template: "service-alt.us-west-1.amazonaws.com",
                    protocol: Http,
                    credential_scope: CredentialScope {
                        region: Some(SigningRegion::from_static("us-west-1")),
                        service: Some(SigningService::from_static("foo")),
                    },
                    credential_scope: CredentialScope::builder()
                        .region(SigningRegion::from_static("us-west-1"))
                        .service(SigningService::from_static("foo"))
                        .build(),
                    signature_versions: V4,
                },
            )
@@ -218,10 +219,9 @@ mod test {
            .default_endpoint(Metadata {
                uri_template: "service.{region}.amazonaws.com",
                protocol: Https,
                credential_scope: CredentialScope {
                    service: Some(SigningService::from_static("foo")),
                    ..Default::default()
                },
                credential_scope: CredentialScope::builder()
                    .service(SigningService::from_static("foo"))
                    .build(),
                signature_versions: SignatureVersion::V4,
            })
            .partition_endpoint("partition")
@@ -231,10 +231,10 @@ mod test {
                Metadata {
                    uri_template: "some-global-thing.amazonaws.cn",
                    protocol: Https,
                    credential_scope: CredentialScope {
                        region: Some(SigningRegion::from_static("cn-east-1")),
                        service: Some(SigningService::from_static("foo")),
                    },
                    credential_scope: CredentialScope::builder()
                        .region(SigningRegion::from_static("cn-east-1"))
                        .service(SigningService::from_static("foo"))
                        .build(),
                    signature_versions: SignatureVersion::V4,
                },
            )
@@ -243,10 +243,9 @@ mod test {
                Metadata {
                    uri_template: "fips.amazonaws.cn",
                    protocol: Https,
                    credential_scope: CredentialScope {
                        region: Some(SigningRegion::from_static("cn-fips")),
                        service: None,
                    },
                    credential_scope: CredentialScope::builder()
                        .region(SigningRegion::from_static("cn-fips"))
                        .build(),
                    signature_versions: SignatureVersion::V4,
                },
            )
@@ -269,10 +268,9 @@ mod test {
                uri_template: "service.{region}.amazonaws.com",
                protocol: Https,
                signature_versions: V4,
                credential_scope: CredentialScope {
                    service: Some(SigningService::from_static("foo")),
                    ..Default::default()
                },
                credential_scope: CredentialScope::builder()
                    .service(SigningService::from_static("foo"))
                    .build(),
            })
            .build()
            .expect("valid partition")
@@ -378,12 +376,15 @@ mod test {
        endpoint.set_endpoint(&mut test_uri, None);
        assert_eq!(test_uri, Uri::from_static(test_case.uri));
        assert_eq!(
            endpoint.credential_scope.region,
            Some(SigningRegion::from_static(test_case.signing_region))
            endpoint.credential_scope().region(),
            Some(&SigningRegion::from_static(test_case.signing_region))
        );
        assert_eq!(
            endpoint.credential_scope.service,
            test_case.signing_service.map(SigningService::from_static)
            endpoint.credential_scope().service(),
            test_case
                .signing_service
                .map(SigningService::from_static)
                .as_ref()
        )
    }
}
Loading