Unverified Commit 3d8c52a6 authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Follow up on DefaultEndpointResolver in the orchestrator (#2592)



## Motivation and Context
This PR incorporates post-merge feedback left in #2577.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarYuki Saito <awsaito@amazon.com>
parent dd517bc8
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ use aws_sdk_s3::primitives::SdkBody;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::test_connection::TestConnection;
use aws_smithy_runtime::client::connections::adapter::DynConnectorAdapter;
use aws_smithy_runtime_api::client::endpoints::DefaultEndpointResolver;
use aws_smithy_runtime::client::orchestrator::endpoints::DefaultEndpointResolver;
use aws_smithy_runtime_api::client::interceptors::{
    Interceptor, InterceptorContext, InterceptorError, Interceptors,
};
+6 −5
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ class EndpointParamsInterceptorGenerator(
        arrayOf(
            "BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
            "ConfigBag" to runtimeApi.resolve("config_bag::ConfigBag"),
            "ContextAttachedError" to interceptors.resolve("error::ContextAttachedError"),
            "EndpointResolverParams" to orchestrator.resolve("EndpointResolverParams"),
            "HttpResponse" to orchestrator.resolve("HttpResponse"),
            "HttpRequest" to orchestrator.resolve("HttpRequest"),
@@ -83,10 +84,10 @@ class EndpointParamsInterceptorGenerator(
            let input = context.input()?;
            let _input = input
                .downcast_ref::<${operationInput.name}>()
                .ok_or_else(|| #{InterceptorError}::invalid_input_access())?;
                .ok_or_else(|| "failed to downcast to ${operationInput.name}")?;
            let params_builder = cfg
                .get::<#{ParamsBuilder}>()
                .ok_or(#{InterceptorError}::read_before_execution("missing endpoint params builder"))?
                .ok_or_else(|| "missing endpoint params builder")?
                .clone();
            ${"" /* TODO(EndpointResolver): Call setters on `params_builder` to update its fields by using values from `_input` */}
            cfg.put(params_builder);
@@ -111,7 +112,7 @@ class EndpointParamsInterceptorGenerator(
            )
            withBlockTemplate(
                "let endpoint_prefix = ",
                ".map_err(#{InterceptorError}::read_before_execution)?;",
                """.map_err(|err| #{ContextAttachedError}::new("endpoint prefix could not be built", err))?;""",
                *codegenScope,
            ) {
                endpointTraitBindings.render(
@@ -130,11 +131,11 @@ class EndpointParamsInterceptorGenerator(
            let _ = context;
            let params_builder = cfg
                .get::<#{ParamsBuilder}>()
                .ok_or(#{InterceptorError}::read_before_execution("missing endpoint params builder"))?
                .ok_or_else(|| "missing endpoint params builder")?
                .clone();
            let params = params_builder
                .build()
                .map_err(#{InterceptorError}::read_before_execution)?;
                .map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
            cfg.put(
                #{EndpointResolverParams}::new(params)
            );
+1 −1
Original line number Diff line number Diff line
@@ -86,7 +86,7 @@ class ServiceRuntimePluginGenerator(
            "ConfigBagAccessors" to runtimeApi.resolve("client::orchestrator::ConfigBagAccessors"),
            "Connection" to runtimeApi.resolve("client::orchestrator::Connection"),
            "ConnectorSettings" to RuntimeType.smithyClient(rc).resolve("http_connector::ConnectorSettings"),
            "DefaultEndpointResolver" to runtimeApi.resolve("client::endpoints::DefaultEndpointResolver"),
            "DefaultEndpointResolver" to runtime.resolve("client::orchestrator::endpoints::DefaultEndpointResolver"),
            "DynConnectorAdapter" to runtime.resolve("client::connections::adapter::DynConnectorAdapter"),
            "HttpAuthSchemes" to runtimeApi.resolve("client::orchestrator::HttpAuthSchemes"),
            "IdentityResolvers" to runtimeApi.resolve("client::orchestrator::IdentityResolvers"),
+0 −3
Original line number Diff line number Diff line
@@ -21,8 +21,5 @@ pub mod retries;
/// Runtime plugin type definitions.
pub mod runtime_plugin;

/// Smithy endpoint resolution runtime plugins
pub mod endpoints;

/// Smithy auth runtime plugins
pub mod auth;
+0 −110
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use crate::client::orchestrator::{
    BoxError, EndpointResolver, EndpointResolverParams, HttpRequest,
};
use aws_smithy_http::endpoint::error::ResolveEndpointError;
use aws_smithy_http::endpoint::{
    apply_endpoint, EndpointPrefix, ResolveEndpoint, SharedEndpointResolver,
};
use http::header::HeaderName;
use http::{HeaderValue, Uri};
use std::fmt::Debug;
use std::str::FromStr;

#[derive(Debug, Clone)]
pub struct StaticUriEndpointResolver {
    endpoint: Uri,
}

impl StaticUriEndpointResolver {
    pub fn http_localhost(port: u16) -> Self {
        Self {
            endpoint: Uri::from_str(&format!("http://localhost:{port}"))
                .expect("all u16 values are valid ports"),
        }
    }

    pub fn uri(endpoint: Uri) -> Self {
        Self { endpoint }
    }
}

impl EndpointResolver for StaticUriEndpointResolver {
    fn resolve_and_apply_endpoint(
        &self,
        _params: &EndpointResolverParams,
        _endpoint_prefix: Option<&EndpointPrefix>,
        request: &mut HttpRequest,
    ) -> Result<(), BoxError> {
        apply_endpoint(request.uri_mut(), &self.endpoint, None)?;
        Ok(())
    }
}

#[derive(Debug, Clone)]
pub struct DefaultEndpointResolver<Params> {
    inner: SharedEndpointResolver<Params>,
}

impl<Params> DefaultEndpointResolver<Params> {
    pub fn new(resolve_endpoint: SharedEndpointResolver<Params>) -> Self {
        Self {
            inner: resolve_endpoint,
        }
    }
}

impl<Params> EndpointResolver for DefaultEndpointResolver<Params>
where
    Params: Debug + Send + Sync + 'static,
{
    fn resolve_and_apply_endpoint(
        &self,
        params: &EndpointResolverParams,
        endpoint_prefix: Option<&EndpointPrefix>,
        request: &mut HttpRequest,
    ) -> Result<(), BoxError> {
        let endpoint = match params.get::<Params>() {
            Some(params) => self.inner.resolve_endpoint(params)?,
            None => {
                return Err(Box::new(ResolveEndpointError::message(
                    "params of expected type was not present",
                )));
            }
        };

        let uri: Uri = endpoint.url().parse().map_err(|err| {
            ResolveEndpointError::from_source("endpoint did not have a valid uri", err)
        })?;

        apply_endpoint(request.uri_mut(), &uri, endpoint_prefix).map_err(|err| {
            ResolveEndpointError::message(format!(
                "failed to apply endpoint `{:?}` to request `{:?}`",
                uri, request,
            ))
            .with_source(Some(err.into()))
        })?;

        for (header_name, header_values) in endpoint.headers() {
            request.headers_mut().remove(header_name);
            for value in header_values {
                request.headers_mut().insert(
                    HeaderName::from_str(header_name).map_err(|err| {
                        ResolveEndpointError::message("invalid header name")
                            .with_source(Some(err.into()))
                    })?,
                    HeaderValue::from_str(value).map_err(|err| {
                        ResolveEndpointError::message("invalid header value")
                            .with_source(Some(err.into()))
                    })?,
                );
            }
        }

        Ok(())
    }
}
Loading