Unverified Commit 6beec943 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Endpoints 2.0 Integration pre-work (#2063)

* Split endpoint resolution middleware into two parts & refactor endpoint generation

* Endpoints 2.0 Integration pre-work

This PR does a 3 bits of pre-work ahead of ep2 integration:
1. Split endpoint resolution into two separate middlewares:
  1. A smithy native middleware that applies URI and headers
  2. An AWS middleware that applies the auth schemes
2. Add vendorParams support to the ProtocolTestGenerator so that protocol tests can insert a region.
3. Simplify endpoint resolution logic by allowing `make_operation` to fail when an endpoint cannot be resolved.

* Back out previous change to insert endpoint directly into the bag

* backout changes to property bag

* Update changelog & add more docs

* Fix AWS test

* Fix test
parent 4eccaada
Loading
Loading
Loading
Loading
+17 −5
Original line number Diff line number Diff line
@@ -637,3 +637,15 @@ message = "`SdkBody` callbacks have been removed. If you were using these, pleas
references = ["smithy-rs#2065"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

[[aws-sdk-rust]]
message = "`AwsEndpointStage`, a middleware which set endpoints and auth has been split into `AwsAuthStage` and `SmithyEndpointStage`. Related types have also been renamed."
references = ["smithy-rs#2063"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "rcoh"

[[smithy-rs]]
message = "Added SmithyEndpointStage which can be used to set an endpoint for smithy-native clients"
references = ["smithy-rs#2063"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "rcoh"
+46 −84
Original line number Diff line number Diff line
@@ -13,18 +13,15 @@ pub use partition::PartitionResolver;
use std::collections::HashMap;

use aws_smithy_http::endpoint::error::ResolveEndpointError;
use aws_smithy_http::endpoint::{apply_endpoint, EndpointPrefix, ResolveEndpoint};
use aws_smithy_http::endpoint::ResolveEndpoint;
use aws_smithy_http::middleware::MapRequest;
use aws_smithy_http::operation::Request;
use aws_smithy_types::endpoint::Endpoint as SmithyEndpoint;
use aws_smithy_types::Document;
use aws_types::region::{Region, SigningRegion};
use aws_types::SigningService;
use http::header::HeaderName;
use http::{HeaderValue, Uri};
use std::error::Error;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;

pub use aws_types::endpoint::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint};
@@ -87,41 +84,39 @@ impl ResolveEndpoint<Params> for EndpointShim {
    }
}

/// Middleware Stage to Add an Endpoint to a Request
/// Middleware Stage to add authentication information from a Smithy endpoint into the property bag
///
/// AwsEndpointStage implements [`MapRequest`](aws_smithy_http::middleware::MapRequest). It will:
/// 1. Load an endpoint provider from the property bag.
/// 2. Load an endpoint given the [`Region`](aws_types::region::Region) in the property bag.
/// 3. Apply the endpoint to the URI in the request
/// 4. Set the `SigningRegion` and `SigningService` in the property bag to drive downstream
/// AwsAuthStage implements [`MapRequest`](MapRequest). It will:
/// 1. Load an endpoint from the property bag
/// 2. Set the `SigningRegion` and `SigningService` in the property bag to drive downstream
/// signing middleware.
#[derive(Clone, Debug)]
pub struct AwsEndpointStage;
pub struct AwsAuthStage;

#[derive(Debug)]
enum AwsEndpointStageErrorKind {
enum AwsAuthStageErrorKind {
    NoEndpointResolver,
    EndpointResolutionError(BoxError),
}

#[derive(Debug)]
pub struct AwsEndpointStageError {
    kind: AwsEndpointStageErrorKind,
pub struct AwsAuthStageError {
    kind: AwsAuthStageErrorKind,
}

impl fmt::Display for AwsEndpointStageError {
impl fmt::Display for AwsAuthStageError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        use AwsEndpointStageErrorKind::*;
        use AwsAuthStageErrorKind::*;
        match &self.kind {
            NoEndpointResolver => write!(f, "endpoint resolution failed: no endpoint resolver"),
            NoEndpointResolver => write!(f, "endpoint resolution failed: no endpoint present"),
            EndpointResolutionError(_) => write!(f, "endpoint resolution failed"),
        }
    }
}

impl Error for AwsEndpointStageError {
impl Error for AwsAuthStageError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        use AwsEndpointStageErrorKind::*;
        use AwsAuthStageErrorKind::*;
        match &self.kind {
            EndpointResolutionError(source) => Some(source.as_ref() as _),
            NoEndpointResolver => None,
@@ -129,51 +124,22 @@ impl Error for AwsEndpointStageError {
    }
}

impl From<AwsEndpointStageErrorKind> for AwsEndpointStageError {
    fn from(kind: AwsEndpointStageErrorKind) -> Self {
impl From<AwsAuthStageErrorKind> for AwsAuthStageError {
    fn from(kind: AwsAuthStageErrorKind) -> Self {
        Self { kind }
    }
}

impl MapRequest for AwsEndpointStage {
    type Error = AwsEndpointStageError;
impl MapRequest for AwsAuthStage {
    type Error = AwsAuthStageError;

    fn apply(&self, request: Request) -> Result<Request, Self::Error> {
        request.augment(|mut http_req, props| {
            let endpoint_result = props
                .get_mut::<aws_smithy_http::endpoint::Result>()
                .ok_or(AwsEndpointStageErrorKind::NoEndpointResolver)?;
            let endpoint = match endpoint_result {
                // downgrade the mut ref to a shared ref
                Ok(_endpoint) => props.get::<aws_smithy_http::endpoint::Result>()
                    .expect("unreachable (prevalidated that the endpoint is in the bag)")
                    .as_ref()
                    .expect("unreachable (prevalidated that this is OK)"),
                Err(e) => {
                    // We need to own the error to return it, so take it and leave a stub error in
                    // its place
                    return Err(AwsEndpointStageErrorKind::EndpointResolutionError(std::mem::replace(
                        e,
                        ResolveEndpointError::message("the original error was directly returned")
                    ).into()).into());
                }
            };
            let (uri, signing_scope_override, signing_service_override) = smithy_to_aws(endpoint)
                .map_err(|err| AwsEndpointStageErrorKind::EndpointResolutionError(err))?;
            tracing::debug!(endpoint = ?endpoint, base_region = ?signing_scope_override, "resolved endpoint");
            apply_endpoint(http_req.uri_mut(), &uri, props.get::<EndpointPrefix>())
                .map_err(|err| AwsEndpointStageErrorKind::EndpointResolutionError(err.into()))?;
            for (header_name, header_values) in endpoint.headers() {
                http_req.headers_mut().remove(header_name);
                for value in header_values {
                    http_req.headers_mut().insert(
                        HeaderName::from_str(header_name)
                            .map_err(|err| AwsEndpointStageErrorKind::EndpointResolutionError(err.into()))?,
                        HeaderValue::from_str(value)
                            .map_err(|err| AwsEndpointStageErrorKind::EndpointResolutionError(err.into()))?,
                    );
                }
            }
        request.augment(|http_req, props| {
            let endpoint = props
                .get::<aws_smithy_types::endpoint::Endpoint>()
                .ok_or(AwsAuthStageErrorKind::NoEndpointResolver)?;
            let (signing_scope_override, signing_service_override) = smithy_to_aws(endpoint)
                .map_err(|err| AwsAuthStageErrorKind::EndpointResolutionError(err))?;

            if let Some(signing_scope) = signing_scope_override {
                props.insert(signing_scope);
@@ -186,17 +152,14 @@ impl MapRequest for AwsEndpointStage {
    }
}

type EndpointMetadata = (Uri, Option<SigningRegion>, Option<SigningService>);
type EndpointMetadata = (Option<SigningRegion>, Option<SigningService>);

fn smithy_to_aws(value: &SmithyEndpoint) -> Result<EndpointMetadata, Box<dyn Error + Send + Sync>> {
    let uri: Uri = value.url().parse()?;
    // look for v4 as an auth scheme
    let auth_schemes = match value
        .properties()
        .get("authSchemes")
        .ok_or("no auth schemes in metadata")?
    {
        Document::Array(schemes) => schemes,
    let auth_schemes = match value.properties().get("authSchemes") {
        Some(Document::Array(schemes)) => schemes,
        // no auth schemes:
        None => return Ok((None, None)),
        _other => return Err("expected an array for authSchemes".into()),
    };
    let v4 = auth_schemes
@@ -210,7 +173,7 @@ fn smithy_to_aws(value: &SmithyEndpoint) -> Result<EndpointMetadata, Box<dyn Err
            _ => None,
        })
        .next()
        .ok_or("could not find v4 as an acceptable auth scheme")?;
        .ok_or("could not find v4 as an acceptable auth scheme (the SDK does not support Bearer Auth at this time)")?;

    let signing_scope = match v4.get("signingRegion") {
        Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
@@ -222,7 +185,7 @@ fn smithy_to_aws(value: &SmithyEndpoint) -> Result<EndpointMetadata, Box<dyn Err
        None => None,
        _ => return Err("unexpected type".into()),
    };
    Ok((uri, signing_scope, signing_service))
    Ok((signing_scope, signing_service))
}

#[cfg(test)]
@@ -230,7 +193,6 @@ mod test {
    use std::sync::Arc;

    use http::header::HOST;
    use http::Uri;

    use aws_smithy_http::body::SdkBody;
    use aws_smithy_http::endpoint::ResolveEndpoint;
@@ -241,7 +203,7 @@ mod test {
    use aws_types::SigningService;

    use crate::partition::endpoint::{Metadata, Protocol, SignatureVersion};
    use crate::{AwsEndpointStage, EndpointShim, Params};
    use crate::{AwsAuthStage, EndpointShim, Params};

    #[test]
    fn default_endpoint_updates_request() {
@@ -260,25 +222,21 @@ mod test {
            props.insert(SigningService::from_static("kinesis"));
            props.insert(
                EndpointShim::from_arc(provider)
                    .resolve_endpoint(&Params::new(Some(region.clone()))),
                    .resolve_endpoint(&Params::new(Some(region.clone())))
                    .unwrap(),
            );
        };
        let req = AwsEndpointStage.apply(req).expect("should succeed");
        let req = AwsAuthStage.apply(req).expect("should succeed");
        assert_eq!(req.properties().get(), Some(&SigningRegion::from(region)));
        assert_eq!(
            req.properties().get(),
            Some(&SigningService::from_static("kinesis"))
        );

        let (req, conf) = req.into_parts();
        assert_eq!(
            req.uri(),
            &Uri::from_static("https://kinesis.us-east-1.amazonaws.com")
        );
        assert!(req.headers().get(HOST).is_none());
        assert!(req.http().headers().get(HOST).is_none());
        assert!(
            conf.acquire()
                .get::<aws_smithy_http::endpoint::Result>()
            req.properties()
                .get::<aws_smithy_types::endpoint::Endpoint>()
                .is_some(),
            "Endpoint middleware MUST leave the result in the bag"
        );
@@ -303,10 +261,12 @@ mod test {
            props.insert(region.clone());
            props.insert(SigningService::from_static("qldb"));
            props.insert(
                EndpointShim::from_arc(provider).resolve_endpoint(&Params::new(Some(region))),
                EndpointShim::from_arc(provider)
                    .resolve_endpoint(&Params::new(Some(region)))
                    .unwrap(),
            );
        };
        let req = AwsEndpointStage.apply(req).expect("should succeed");
        let req = AwsAuthStage.apply(req).expect("should succeed");
        assert_eq!(
            req.properties().get(),
            Some(&SigningRegion::from(Region::new("us-east-override")))
@@ -333,10 +293,12 @@ mod test {
            props.insert(region.clone());
            props.insert(SigningService::from_static("qldb"));
            props.insert(
                EndpointShim::from_arc(provider).resolve_endpoint(&Params::new(Some(region))),
                EndpointShim::from_arc(provider)
                    .resolve_endpoint(&Params::new(Some(region)))
                    .unwrap(),
            );
        };
        let req = AwsEndpointStage.apply(req).expect("should succeed");
        let req = AwsAuthStage.apply(req).expect("should succeed");
        assert_eq!(
            req.properties().get(),
            Some(&SigningRegion::from(Region::new("us-east-1")))
+24 −16
Original line number Diff line number Diff line
@@ -5,30 +5,36 @@

//! Base Middleware Stack

use aws_endpoint::AwsEndpointStage;
use aws_endpoint::AwsAuthStage;
use aws_http::auth::CredentialsStage;
use aws_http::recursion_detection::RecursionDetectionStage;
use aws_http::user_agent::UserAgentStage;
use aws_sig_auth::middleware::SigV4SigningStage;
use aws_sig_auth::signer::SigV4Signer;
use aws_smithy_http::endpoint::middleware::SmithyEndpointStage;
use aws_smithy_http_tower::map_request::{AsyncMapRequestLayer, MapRequestLayer};
use std::fmt::Debug;
use tower::layer::util::{Identity, Stack};
use tower::ServiceBuilder;

type DefaultMiddlewareStack = Stack<
/// Macro to generate the tower stack type. Arguments should be in reverse order
macro_rules! stack_type {
    ($first: ty, $($rest:ty),+) => {
        tower::layer::util::Stack<$first, stack_type!($($rest),+)>
    };
    ($only: ty) => {
        tower::layer::util::Stack<$only, tower::layer::util::Identity>
    }
}

// Note: the layers here appear in reverse order
type DefaultMiddlewareStack = stack_type!(
    MapRequestLayer<RecursionDetectionStage>,
    Stack<
    MapRequestLayer<SigV4SigningStage>,
        Stack<
    AsyncMapRequestLayer<CredentialsStage>,
            Stack<
    MapRequestLayer<UserAgentStage>,
                Stack<MapRequestLayer<AwsEndpointStage>, Identity>,
            >,
        >,
    >,
>;
    MapRequestLayer<AwsAuthStage>,
    MapRequestLayer<SmithyEndpointStage>
);

/// AWS Middleware Stack
///
@@ -54,7 +60,8 @@ impl DefaultMiddleware {
fn base() -> ServiceBuilder<DefaultMiddlewareStack> {
    let credential_provider = AsyncMapRequestLayer::for_mapper(CredentialsStage::new());
    let signer = MapRequestLayer::for_mapper(SigV4SigningStage::new(SigV4Signer::new()));
    let endpoint_resolver = MapRequestLayer::for_mapper(AwsEndpointStage);
    let endpoint_stage = MapRequestLayer::for_mapper(SmithyEndpointStage::new());
    let auth_stage = MapRequestLayer::for_mapper(AwsAuthStage);
    let user_agent = MapRequestLayer::for_mapper(UserAgentStage::new());
    let recursion_detection = MapRequestLayer::for_mapper(RecursionDetectionStage::new());
    // These layers can be considered as occurring in order, that is:
@@ -64,7 +71,8 @@ fn base() -> ServiceBuilder<DefaultMiddlewareStack> {
    // 4. Sign with credentials
    // (5. Dispatch over the wire)
    ServiceBuilder::new()
        .layer(endpoint_resolver)
        .layer(endpoint_stage)
        .layer(auth_stage)
        .layer(user_agent)
        .layer(credential_provider)
        .layer(signer)
+11 −4
Original line number Diff line number Diff line
@@ -187,7 +187,7 @@ mod test {
    };
    use crate::signer::{OperationSigningConfig, SigV4Signer};
    use aws_endpoint::partition::endpoint::{Protocol, SignatureVersion};
    use aws_endpoint::{AwsEndpointStage, Params};
    use aws_endpoint::{AwsAuthStage, Params};
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_http::endpoint::ResolveEndpoint;
    use aws_smithy_http::middleware::MapRequest;
@@ -237,19 +237,26 @@ mod test {
                signature_versions: SignatureVersion::V4,
            },
        );
        let req = http::Request::new(SdkBody::from(""));
        let req = http::Request::builder()
            .uri("https://kinesis.us-east-1.amazonaws.com")
            .body(SdkBody::from(""))
            .unwrap();
        let region = Region::new("us-east-1");
        let req = operation::Request::new(req)
            .augment(|req, conf| {
                conf.insert(region.clone());
                conf.insert(UNIX_EPOCH + Duration::new(1611160427, 0));
                conf.insert(SigningService::from_static("kinesis"));
                conf.insert(provider.resolve_endpoint(&Params::new(Some(region.clone()))));
                conf.insert(
                    provider
                        .resolve_endpoint(&Params::new(Some(region.clone())))
                        .unwrap(),
                );
                Result::<_, Infallible>::Ok(req)
            })
            .expect("succeeds");

        let endpoint = AwsEndpointStage;
        let endpoint = AwsAuthStage;
        let signer = SigV4SigningStage::new(SigV4Signer::new());
        let mut req = endpoint.apply(req).expect("add endpoint should succeed");
        let mut errs = vec![signer
+5 −7
Original line number Diff line number Diff line
@@ -74,7 +74,7 @@ class AwsEndpointDecorator : RustCodegenDecorator<ClientProtocolGenerator, Clien
        operation: OperationShape,
        baseCustomizations: List<OperationCustomization>,
    ): List<OperationCustomization> {
        return baseCustomizations + EndpointResolverFeature(codegenContext.runtimeConfig, operation)
        return baseCustomizations + EndpointResolverFeature(codegenContext.runtimeConfig)
    }

    override fun libRsCustomizations(
@@ -179,9 +179,7 @@ class EndpointConfigCustomization(
    }
}

// This is an experiment in a slightly different way to create runtime types. All code MAY be refactored to use this pattern

class EndpointResolverFeature(private val runtimeConfig: RuntimeConfig, private val operationShape: OperationShape) :
class EndpointResolverFeature(runtimeConfig: RuntimeConfig) :
    OperationCustomization() {
    private val placeholderEndpointParams = runtimeConfig.awsEndpoint().toType().member("Params")
    private val codegenScope = arrayOf(
@@ -196,9 +194,9 @@ class EndpointResolverFeature(private val runtimeConfig: RuntimeConfig, private
                    """
                    let endpoint_params = #{PlaceholderParams}::new(${section.config}.region.clone());
                    ${section.request}.properties_mut()
                        .insert::<aws_smithy_http::endpoint::Result>(${section.config}
                            .endpoint_resolver
                            .resolve_endpoint(&endpoint_params));
                        .insert::<aws_smithy_http::endpoint::Result>(
                            ${section.config}.endpoint_resolver.resolve_endpoint(&endpoint_params)
                        );
                    """,
                    *codegenScope,
                )
Loading