From fc1e1799dac11d278720b2f4b923482598698552 Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Tue, 23 Aug 2022 15:19:16 -0400 Subject: [PATCH] Refactor endpoints to be Smithy-native (#1641) * Refactor endpoints to be Smithy-native * Add test to DynamoDb * Fix clippy * Fix some clippy errors and update changelog --- CHANGELOG.next.toml | 6 + aws/rust-runtime/aws-endpoint/Cargo.toml | 1 + aws/rust-runtime/aws-endpoint/src/lib.rs | 225 +++++++++++++++--- .../tests/middleware_e2e_test.rs | 12 +- .../aws-sig-auth/src/middleware.rs | 20 +- .../smithy/rustsdk/AwsEndpointDecorator.kt | 61 +++-- .../smithy/rustsdk/SigV4SigningDecorator.kt | 3 + .../EndpointConfigCustomizationTest.kt | 21 +- .../dynamodb/tests/endpoints.rs | 49 ++++ .../iam/tests/resolve-global-endpoint.rs | 28 +-- .../s3/tests/alternative-async-runtime.rs | 10 +- .../aws-smithy-client/src/dvr/replay.rs | 7 +- rust-runtime/aws-smithy-http/src/endpoint.rs | 136 +++++++++-- .../aws-smithy-types/src/date_time/mod.rs | 2 +- rust-runtime/aws-smithy-types/src/endpoint.rs | 175 ++++++++++++++ rust-runtime/aws-smithy-types/src/lib.rs | 25 ++ 16 files changed, 658 insertions(+), 123 deletions(-) create mode 100644 aws/sdk/integration-tests/dynamodb/tests/endpoints.rs create mode 100644 rust-runtime/aws-smithy-types/src/endpoint.rs diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index a6b2160e1..18ca14c2a 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -11,6 +11,12 @@ # references = ["smithy-rs#920"] # meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"} # author = "rcoh" +[[aws-sdk-rust]] +message = """Refactor endpoint resolution internals to use `aws_smithy_types::Endpoint` internally. The public internal +functions `aws_endpoint::set_endpoint_resolver` and `aws_endpoint::get_endpoint_resolver were removed.""" +references = ["smithy-rs#1641"] +meta = { breaking = true, "tada" = false, "bug" = false } +author = "rcoh" [[aws-sdk-rust]] message = """ diff --git a/aws/rust-runtime/aws-endpoint/Cargo.toml b/aws/rust-runtime/aws-endpoint/Cargo.toml index a080bd7c2..67cb4d4e6 100644 --- a/aws/rust-runtime/aws-endpoint/Cargo.toml +++ b/aws/rust-runtime/aws-endpoint/Cargo.toml @@ -9,6 +9,7 @@ repository = "https://github.com/awslabs/smithy-rs" [dependencies] aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" } +aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types"} aws-types = { path = "../aws-types" } http = "0.2.3" regex = { version = "1.5.5", default-features = false, features = ["std"] } diff --git a/aws/rust-runtime/aws-endpoint/src/lib.rs b/aws/rust-runtime/aws-endpoint/src/lib.rs index 9b77c85a4..ab3e95491 100644 --- a/aws/rust-runtime/aws-endpoint/src/lib.rs +++ b/aws/rust-runtime/aws-endpoint/src/lib.rs @@ -10,27 +10,83 @@ pub mod partition; pub use partition::Partition; #[doc(hidden)] pub use partition::PartitionResolver; +use std::collections::HashMap; -use aws_smithy_http::endpoint::EndpointPrefix; +use aws_smithy_http::endpoint::Error as EndpointError; +use aws_smithy_http::endpoint::{apply_endpoint, EndpointPrefix, ResolveEndpoint}; use aws_smithy_http::middleware::MapRequest; use aws_smithy_http::operation::Request; -use aws_smithy_http::property_bag::PropertyBag; +use aws_smithy_types::endpoint::Endpoint as SmithyEndpoint; +use aws_smithy_types::Document; use aws_types::region::{Region, SigningRegion}; use aws_types::SigningService; +use http::header::HeaderName; +use http::{HeaderValue, Uri}; use std::error::Error; use std::fmt; use std::fmt::{Debug, Display, Formatter}; +use std::str::FromStr; use std::sync::Arc; pub use aws_types::endpoint::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint}; -type AwsEndpointResolver = Arc; -pub fn get_endpoint_resolver(properties: &PropertyBag) -> Option<&AwsEndpointResolver> { - properties.get() +#[doc(hidden)] +pub struct Params { + region: Option, +} + +impl Params { + pub fn new(region: Option) -> Self { + Self { region } + } +} + +#[doc(hidden)] +pub struct EndpointShim(Arc); +impl EndpointShim { + pub fn from_resolver(resolver: impl ResolveAwsEndpoint + 'static) -> Self { + Self(Arc::new(resolver)) + } + + pub fn from_arc(arc: Arc) -> Self { + Self(arc) + } } -pub fn set_endpoint_resolver(properties: &mut PropertyBag, provider: AwsEndpointResolver) { - properties.insert(provider); +impl ResolveEndpoint for EndpointShim { + fn resolve_endpoint( + &self, + params: &Params, + ) -> Result { + let aws_endpoint = self + .0 + .resolve_endpoint( + params + .region + .as_ref() + .ok_or_else(|| EndpointError::message("no region in params"))?, + ) + .map_err(|err| EndpointError::message("failure resolving endpoint").with_cause(err))?; + let uri = aws_endpoint.endpoint().uri(); + let mut auth_scheme = + HashMap::from([("name".to_string(), Document::String("sigv4".into()))]); + if let Some(region) = aws_endpoint.credential_scope().region() { + auth_scheme.insert( + "signingRegion".to_string(), + region.as_ref().to_string().into(), + ); + } + if let Some(service) = aws_endpoint.credential_scope().service() { + auth_scheme.insert( + "signingName".to_string(), + service.as_ref().to_string().into(), + ); + } + Ok(SmithyEndpoint::builder() + .url(uri.to_string()) + .property("authSchemes", vec![Document::Object(auth_scheme)]) + .build()) + } } /// Middleware Stage to Add an Endpoint to a Request @@ -56,6 +112,7 @@ impl Display for AwsEndpointStageError { Debug::fmt(self, f) } } + impl Error for AwsEndpointStageError {} impl MapRequest for AwsEndpointStage { @@ -63,30 +120,91 @@ impl MapRequest for AwsEndpointStage { fn apply(&self, request: Request) -> Result { request.augment(|mut http_req, props| { - let provider = - get_endpoint_resolver(props).ok_or(AwsEndpointStageError::NoEndpointResolver)?; - let region = props - .get::() - .ok_or(AwsEndpointStageError::NoRegion)?; - let endpoint = provider - .resolve_endpoint(region) - .map_err(AwsEndpointStageError::EndpointResolutionError)?; - tracing::debug!(endpoint = ?endpoint, base_region = ?region, "resolved endpoint"); - let signing_region = endpoint - .credential_scope() - .region() - .cloned() - .unwrap_or_else(|| region.clone().into()); - props.insert::(signing_region); - if let Some(signing_service) = endpoint.credential_scope().service() { - props.insert::(signing_service.clone()); + let endpoint_result = props + .get_mut::() + .ok_or(AwsEndpointStageError::NoEndpointResolver)?; + let endpoint = match endpoint_result { + // downgrade the mut ref to a shared ref + Ok(_endpoint) => props.get::() + .expect("unreachable (prevalidated that the endpoint is in the bag)") + .as_ref() + .expect("unreachable (prevalidated that this is OK)"), + Err(e) => { + // We need to own the error to return it, so take it and leave a stub error in + // its place + return Err(AwsEndpointStageError::EndpointResolutionError(std::mem::replace( + e, + aws_smithy_http::endpoint::Error::message("the original error was directly returned") + ).into())); + } + }; + let (uri, signing_scope_override, signing_service_override) = smithy_to_aws(endpoint) + .map_err(|err| AwsEndpointStageError::EndpointResolutionError(err))?; + tracing::debug!(endpoint = ?endpoint, base_region = ?signing_scope_override, "resolved endpoint"); + apply_endpoint(http_req.uri_mut(), &uri, props.get::()) + .map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?; + for (header_name, header_values) in endpoint.headers() { + http_req.headers_mut().remove(header_name); + for value in header_values { + http_req.headers_mut().insert( + HeaderName::from_str(header_name) + .map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?, + HeaderValue::from_str(value) + .map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?, + ); + } + } + + if let Some(signing_scope) = signing_scope_override { + props.insert(signing_scope); + } + if let Some(signing_service) = signing_service_override { + props.insert(signing_service); } - endpoint.set_endpoint(http_req.uri_mut(), props.get::()); Ok(http_req) }) } } +type EndpointMetadata = (Uri, Option, Option); + +fn smithy_to_aws(value: &SmithyEndpoint) -> Result> { + let uri: Uri = value.url().parse()?; + // look for v4 as an auth scheme + let auth_schemes = match value + .properties() + .get("authSchemes") + .ok_or("no auth schemes in metadata")? + { + Document::Array(schemes) => schemes, + _other => return Err("expected an array for authSchemes".into()), + }; + let v4 = auth_schemes + .iter() + .flat_map(|doc| match doc { + Document::Object(map) + if map.get("name") == Some(&Document::String("sigv4".to_string())) => + { + Some(map) + } + _ => None, + }) + .next() + .ok_or("could not find v4 as an acceptable auth scheme")?; + + let signing_scope = match v4.get("signingRegion") { + Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))), + None => None, + _ => return Err("unexpected type".into()), + }; + let signing_service = match v4.get("signingName") { + Some(Document::String(s)) => Some(SigningService::from(s.to_string())), + None => None, + _ => return Err("unexpected type".into()), + }; + Ok((uri, signing_scope, signing_service)) +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -95,13 +213,15 @@ mod test { use http::Uri; use aws_smithy_http::body::SdkBody; + use aws_smithy_http::endpoint::ResolveEndpoint; use aws_smithy_http::middleware::MapRequest; use aws_smithy_http::operation; + use aws_types::endpoint::CredentialScope; use aws_types::region::{Region, SigningRegion}; use aws_types::SigningService; use crate::partition::endpoint::{Metadata, Protocol, SignatureVersion}; - use crate::{set_endpoint_resolver, AwsEndpointStage, CredentialScope}; + use crate::{AwsEndpointStage, EndpointShim, Params}; #[test] fn default_endpoint_updates_request() { @@ -118,7 +238,10 @@ mod test { let mut props = req.properties_mut(); props.insert(region.clone()); props.insert(SigningService::from_static("kinesis")); - set_endpoint_resolver(&mut props, provider); + props.insert( + EndpointShim::from_arc(provider) + .resolve_endpoint(&Params::new(Some(region.clone()))), + ); }; let req = AwsEndpointStage.apply(req).expect("should succeed"); assert_eq!(req.properties().get(), Some(&SigningRegion::from(region))); @@ -127,12 +250,18 @@ mod test { Some(&SigningService::from_static("kinesis")) ); - let (req, _conf) = req.into_parts(); + let (req, conf) = req.into_parts(); assert_eq!( req.uri(), &Uri::from_static("https://kinesis.us-east-1.amazonaws.com") ); assert!(req.headers().get(HOST).is_none()); + assert!( + conf.acquire() + .get::() + .is_some(), + "Endpoint middleware MUST leave the result in the bag" + ); } #[test] @@ -151,9 +280,12 @@ mod test { let mut req = operation::Request::new(req); { let mut props = req.properties_mut(); - props.insert(region); - props.insert(SigningService::from_static("kinesis")); - set_endpoint_resolver(&mut props, provider); + props.insert(region.clone()); + props.insert(SigningService::from_static("qldb")); + props.insert( + EndpointShim::from_arc(provider) + .resolve_endpoint(&Params::new(Some(region.clone()))), + ); }; let req = AwsEndpointStage.apply(req).expect("should succeed"); assert_eq!( @@ -165,4 +297,35 @@ mod test { Some(&SigningService::from_static("qldb-override")) ); } + + #[test] + fn supports_fallback_when_scope_is_unset() { + let provider = Arc::new(Metadata { + uri_template: "www.service.com", + protocol: Protocol::Http, + credential_scope: CredentialScope::builder().build(), + signature_versions: SignatureVersion::V4, + }); + let req = http::Request::new(SdkBody::from("")); + let region = Region::new("us-east-1"); + let mut req = operation::Request::new(req); + { + let mut props = req.properties_mut(); + props.insert(region.clone()); + props.insert(SigningService::from_static("qldb")); + props.insert( + EndpointShim::from_arc(provider) + .resolve_endpoint(&Params::new(Some(region.clone()))), + ); + }; + let req = AwsEndpointStage.apply(req).expect("should succeed"); + assert_eq!( + req.properties().get(), + Some(&SigningRegion::from(Region::new("us-east-1"))) + ); + assert_eq!( + req.properties().get(), + Some(&SigningService::from_static("qldb")) + ); + } } diff --git a/aws/rust-runtime/aws-inlineable/tests/middleware_e2e_test.rs b/aws/rust-runtime/aws-inlineable/tests/middleware_e2e_test.rs index 3b9153fe3..bfec2313d 100644 --- a/aws/rust-runtime/aws-inlineable/tests/middleware_e2e_test.rs +++ b/aws/rust-runtime/aws-inlineable/tests/middleware_e2e_test.rs @@ -7,7 +7,6 @@ use std::convert::Infallible; use std::error::Error; use std::fmt; use std::fmt::{Display, Formatter}; -use std::sync::Arc; use std::time::{Duration, UNIX_EPOCH}; use bytes::Bytes; @@ -15,7 +14,7 @@ use http::header::{AUTHORIZATION, USER_AGENT}; use http::{self, Uri}; use aws_endpoint::partition::endpoint::{Protocol, SignatureVersion}; -use aws_endpoint::set_endpoint_resolver; +use aws_endpoint::{EndpointShim, Params}; use aws_http::retry::AwsErrorRetryPolicy; use aws_http::user_agent::AwsUserAgent; use aws_inlineable::middleware::DefaultMiddleware; @@ -23,6 +22,7 @@ use aws_sig_auth::signer::OperationSigningConfig; use aws_smithy_client::test_connection::TestConnection; use aws_smithy_http::body::SdkBody; +use aws_smithy_http::endpoint::ResolveEndpoint; use aws_smithy_http::operation; use aws_smithy_http::operation::Operation; use aws_smithy_http::response::ParseHttpResponse; @@ -83,14 +83,14 @@ fn test_operation() -> Operation { .unwrap(), ) .augment(|req, mut conf| { - set_endpoint_resolver( - &mut conf, - Arc::new(aws_endpoint::partition::endpoint::Metadata { + conf.insert( + EndpointShim::from_resolver(aws_endpoint::partition::endpoint::Metadata { uri_template: "test-service.{region}.amazonaws.com", protocol: Protocol::Https, credential_scope: Default::default(), signature_versions: SignatureVersion::V4, - }), + }) + .resolve_endpoint(&Params::new(Some(Region::new("test-region")))), ); aws_http::auth::set_provider( &mut conf, diff --git a/aws/rust-runtime/aws-sig-auth/src/middleware.rs b/aws/rust-runtime/aws-sig-auth/src/middleware.rs index 9ef1611fd..ecf0a576e 100644 --- a/aws/rust-runtime/aws-sig-auth/src/middleware.rs +++ b/aws/rust-runtime/aws-sig-auth/src/middleware.rs @@ -172,8 +172,9 @@ mod test { use crate::middleware::{SigV4SigningStage, Signature, SigningStageError}; use crate::signer::{OperationSigningConfig, SigV4Signer}; use aws_endpoint::partition::endpoint::{Protocol, SignatureVersion}; - use aws_endpoint::{set_endpoint_resolver, AwsEndpointStage}; + use aws_endpoint::{AwsEndpointStage, Params}; use aws_smithy_http::body::SdkBody; + use aws_smithy_http::endpoint::ResolveEndpoint; use aws_smithy_http::middleware::MapRequest; use aws_smithy_http::operation; use aws_types::region::{Region, SigningRegion}; @@ -181,7 +182,6 @@ mod test { use aws_types::SigningService; use http::header::AUTHORIZATION; use std::convert::Infallible; - use std::sync::Arc; use std::time::{Duration, UNIX_EPOCH}; #[test] @@ -214,12 +214,14 @@ mod test { // check that the endpoint middleware followed by signing middleware produce the expected result #[test] fn endpoint_plus_signer() { - let provider = Arc::new(aws_endpoint::partition::endpoint::Metadata { - uri_template: "kinesis.{region}.amazonaws.com", - protocol: Protocol::Https, - credential_scope: Default::default(), - signature_versions: SignatureVersion::V4, - }); + let provider = aws_endpoint::EndpointShim::from_resolver( + aws_endpoint::partition::endpoint::Metadata { + uri_template: "kinesis.{region}.amazonaws.com", + protocol: Protocol::Https, + credential_scope: Default::default(), + signature_versions: SignatureVersion::V4, + }, + ); let req = http::Request::new(SdkBody::from("")); let region = Region::new("us-east-1"); let req = operation::Request::new(req) @@ -227,7 +229,7 @@ mod test { conf.insert(region.clone()); conf.insert(UNIX_EPOCH + Duration::new(1611160427, 0)); conf.insert(SigningService::from_static("kinesis")); - set_endpoint_resolver(conf, provider); + conf.insert(provider.resolve_endpoint(&Params::new(Some(region.clone())))); Result::<_, Infallible>::Ok(req) }) .expect("succeeds"); diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt index 11dab4045..e4c5ae5f7 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsEndpointDecorator.kt @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.smithy.generators.LibRsSection import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.orNull @@ -62,7 +63,10 @@ class AwsEndpointDecorator : RustCodegenDecorator { codegenContext: ClientCodegenContext, baseCustomizations: List, ): List { - return baseCustomizations + EndpointConfigCustomization(codegenContext, endpoints(SdkSettings.from(codegenContext.settings))) + return baseCustomizations + EndpointConfigCustomization( + codegenContext, + endpoints(SdkSettings.from(codegenContext.settings)), + ) } override fun operationCustomizations( @@ -84,20 +88,35 @@ class AwsEndpointDecorator : RustCodegenDecorator { clazz.isAssignableFrom(ClientCodegenContext::class.java) } -class EndpointConfigCustomization(private val coreCodegenContext: CoreCodegenContext, private val endpointData: ObjectNode) : +class EndpointConfigCustomization( + private val coreCodegenContext: CoreCodegenContext, + private val endpointData: ObjectNode, +) : ConfigCustomization() { private val runtimeConfig = coreCodegenContext.runtimeConfig private val resolveAwsEndpoint = runtimeConfig.awsEndpoint().asType().copy(name = "ResolveAwsEndpoint") + private val smithyEndpointResolver = + CargoDependency.SmithyHttp(runtimeConfig).asType().member("endpoint::ResolveEndpoint") + private val placeholderEndpointParams = runtimeConfig.awsEndpoint().asType().member("Params") + private val endpointShim = runtimeConfig.awsEndpoint().asType().member("EndpointShim") private val moduleUseName = coreCodegenContext.moduleUseName() + private val codegenScope = arrayOf( + "SmithyResolver" to smithyEndpointResolver, + "PlaceholderParams" to placeholderEndpointParams, + "ResolveAwsEndpoint" to resolveAwsEndpoint, + "EndpointShim" to endpointShim, + "aws_types" to awsTypes(runtimeConfig).asType(), + ) + override fun section(section: ServiceConfig): Writable = writable { when (section) { - is ServiceConfig.ConfigStruct -> rust( - "pub (crate) endpoint_resolver: ::std::sync::Arc,", - resolveAwsEndpoint, + is ServiceConfig.ConfigStruct -> rustTemplate( + "pub (crate) endpoint_resolver: std::sync::Arc>,", + *codegenScope, ) is ServiceConfig.ConfigImpl -> emptySection is ServiceConfig.BuilderStruct -> - rust("endpoint_resolver: Option<::std::sync::Arc>,", resolveAwsEndpoint) + rustTemplate("endpoint_resolver: Option>>,", *codegenScope) ServiceConfig.BuilderImpl -> rustTemplate( """ @@ -118,30 +137,30 @@ class EndpointConfigCustomization(private val coreCodegenContext: CoreCodegenCon /// ).build(); /// ``` pub fn endpoint_resolver(mut self, endpoint_resolver: impl #{ResolveAwsEndpoint} + 'static) -> Self { - self.endpoint_resolver = Some(::std::sync::Arc::new(endpoint_resolver)); + self.endpoint_resolver = Some(std::sync::Arc::new(#{EndpointShim}::from_resolver(endpoint_resolver)) as _); self } /// Sets the endpoint resolver to use when making requests. pub fn set_endpoint_resolver(&mut self, endpoint_resolver: Option>) -> &mut Self { - self.endpoint_resolver = endpoint_resolver; + self.endpoint_resolver = endpoint_resolver.map(|res|std::sync::Arc::new(#{EndpointShim}::from_arc(res) ) as _); self } """, - "ResolveAwsEndpoint" to resolveAwsEndpoint, - "aws_types" to awsTypes(runtimeConfig).asType(), + *codegenScope, ) ServiceConfig.BuilderBuild -> { val resolverGenerator = EndpointResolverGenerator(coreCodegenContext, endpointData) - rust( + rustTemplate( """ endpoint_resolver: self.endpoint_resolver.unwrap_or_else(|| - ::std::sync::Arc::new(#T()) + std::sync::Arc::new(#{EndpointShim}::from_resolver(#{Resolver}())) ), """, - resolverGenerator.resolver(), + *codegenScope, "Resolver" to resolverGenerator.resolver(), ) } + else -> emptySection } } } @@ -150,14 +169,24 @@ class EndpointConfigCustomization(private val coreCodegenContext: CoreCodegenCon class EndpointResolverFeature(private val runtimeConfig: RuntimeConfig, private val operationShape: OperationShape) : OperationCustomization() { + private val placeholderEndpointParams = runtimeConfig.awsEndpoint().asType().member("Params") + private val codegenScope = arrayOf( + "PlaceholderParams" to placeholderEndpointParams, + "BuildError" to runtimeConfig.operationBuildError(), + ) override fun section(section: OperationSection): Writable { return when (section) { is OperationSection.MutateRequest -> writable { - rust( + // insert the endpoint resolution _result_ into the bag (note that this won't bail if endpoint resolution failed) + rustTemplate( """ - #T::set_endpoint_resolver(&mut ${section.request}.properties_mut(), ${section.config}.endpoint_resolver.clone()); + let endpoint_params = #{PlaceholderParams}::new(${section.config}.region.clone()); + ${section.request}.properties_mut() + .insert::(${section.config} + .endpoint_resolver + .resolve_endpoint(&endpoint_params)); """, - runtimeConfig.awsEndpoint().asType(), + *codegenScope, ) } else -> emptySection diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt index 13f517a43..50dd8e181 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt @@ -189,6 +189,9 @@ class SigV4SigningFeature( """ ${section.request}.properties_mut().insert(signing_config); ${section.request}.properties_mut().insert(#{aws_types}::SigningService::from_static(${section.config}.signing_service())); + if let Some(region) = &${section.config}.region { + ${section.request}.properties_mut().insert(#{aws_types}::region::SigningRegion::from(region.clone())); + } """, *codegenScope, ) diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointConfigCustomizationTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointConfigCustomizationTest.kt index 528c0a236..11e5f03dc 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointConfigCustomizationTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointConfigCustomizationTest.kt @@ -27,8 +27,10 @@ import software.amazon.smithy.rust.codegen.testutil.unitTest import software.amazon.smithy.rust.codegen.util.runCommand internal class EndpointConfigCustomizationTest { + private val placeholderEndpointParams = AwsTestRuntimeConfig.awsEndpoint().asType().member("Params") private val codegenScope = arrayOf( "http" to CargoDependency.Http.asType(), + "PlaceholderParams" to placeholderEndpointParams, "aws_types" to awsTypes(AwsTestRuntimeConfig).asType(), ) @@ -154,6 +156,7 @@ internal class EndpointConfigCustomizationTest { } val customization = CombinedCodegenDecorator(listOf(RequiredCustomizations(), codegenDecorator)) CodegenVisitor(context, customization).execute() + println("file:///$testDir") "cargo test".runCommand(testDir) } @@ -176,10 +179,8 @@ internal class EndpointConfigCustomizationTest { """ let conf = crate::config::Config::builder().build(); let endpoint = conf.endpoint_resolver - .resolve_endpoint(&#{aws_types}::region::Region::new("fips-ca-central-1")).expect("default resolver produces a valid endpoint"); - let mut uri = #{http}::Uri::from_static("/?k=v"); - endpoint.set_endpoint(&mut uri, None); - assert_eq!(uri, #{http}::Uri::from_static("https://access-analyzer-fips.ca-central-1.amazonaws.com/?k=v")); + .resolve_endpoint(&::#{PlaceholderParams}::new(Some(#{aws_types}::region::Region::new("fips-ca-central-1")))).expect("default resolver produces a valid endpoint"); + assert_eq!(endpoint.url(), "https://access-analyzer-fips.ca-central-1.amazonaws.com/"); """, *codegenScope, ) @@ -197,16 +198,12 @@ internal class EndpointConfigCustomizationTest { """ let conf = crate::config::Config::builder().build(); let endpoint = conf.endpoint_resolver - .resolve_endpoint(&#{aws_types}::region::Region::new("us-east-1")).expect("default resolver produces a valid endpoint"); - let mut uri = #{http}::Uri::from_static("/?k=v"); - endpoint.set_endpoint(&mut uri, None); - assert_eq!(uri, #{http}::Uri::from_static("https://iam.amazonaws.com/?k=v")); + .resolve_endpoint(&::#{PlaceholderParams}::new(Some(#{aws_types}::region::Region::new("us-east-1")))).expect("default resolver produces a valid endpoint"); + assert_eq!(endpoint.url(), "https://iam.amazonaws.com/"); let endpoint = conf.endpoint_resolver - .resolve_endpoint(&#{aws_types}::region::Region::new("iam-fips")).expect("default resolver produces a valid endpoint"); - let mut uri = #{http}::Uri::from_static("/?k=v"); - endpoint.set_endpoint(&mut uri, None); - assert_eq!(uri, #{http}::Uri::from_static("https://iam-fips.amazonaws.com/?k=v")); + .resolve_endpoint(&::#{PlaceholderParams}::new(Some(#{aws_types}::region::Region::new("iam-fips")))).expect("default resolver produces a valid endpoint"); + assert_eq!(endpoint.url(), "https://iam-fips.amazonaws.com/"); """, *codegenScope, ) diff --git a/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs b/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs new file mode 100644 index 000000000..fe6a99839 --- /dev/null +++ b/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_sdk_dynamodb::{Credentials, Region}; +use aws_smithy_http::endpoint::Endpoint; +use http::Uri; + +/// Iterative test of loading clients from shared configuration +#[tokio::test] +async fn endpoints_can_be_overridden_globally() { + let shared_config = aws_types::SdkConfig::builder() + .region(Region::new("us-east-4")) + .endpoint_resolver(Endpoint::immutable( + "http://localhost:8000".parse().unwrap(), + )) + .build(); + let conf = aws_sdk_dynamodb::config::Builder::from(&shared_config) + .credentials_provider(Credentials::new("asdf", "asdf", None, None, "test")) + .build(); + let (conn, request) = aws_smithy_client::test_connection::capture_request(None); + let svc = aws_sdk_dynamodb::Client::from_conf_conn(conf, conn); + let _ = svc.list_tables().send().await; + assert_eq!( + request.expect_request().uri(), + &Uri::from_static("http://localhost:8000") + ); +} + +#[tokio::test] +async fn endpoints_can_be_overridden_locally() { + let shared_config = aws_types::SdkConfig::builder() + .region(Region::new("us-east-4")) + .build(); + let conf = aws_sdk_dynamodb::config::Builder::from(&shared_config) + .credentials_provider(Credentials::new("asdf", "asdf", None, None, "test")) + .endpoint_resolver(Endpoint::immutable( + "http://localhost:8000".parse().unwrap(), + )) + .build(); + let (conn, request) = aws_smithy_client::test_connection::capture_request(None); + let svc = aws_sdk_dynamodb::Client::from_conf_conn(conf, conn); + let _ = svc.list_tables().send().await; + assert_eq!( + request.expect_request().uri(), + &Uri::from_static("http://localhost:8000") + ); +} diff --git a/aws/sdk/integration-tests/iam/tests/resolve-global-endpoint.rs b/aws/sdk/integration-tests/iam/tests/resolve-global-endpoint.rs index 7a219fab6..eaccafd48 100644 --- a/aws/sdk/integration-tests/iam/tests/resolve-global-endpoint.rs +++ b/aws/sdk/integration-tests/iam/tests/resolve-global-endpoint.rs @@ -3,13 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_endpoint::get_endpoint_resolver; use aws_sdk_iam::Region; -use http::Uri; #[tokio::test] async fn correct_endpoint_resolver() { - let conf = aws_sdk_iam::Config::builder().build(); + let conf = aws_sdk_iam::Config::builder() + .region(Region::from_static("iam-fips")) + .build(); let operation = aws_sdk_iam::operation::ListRoles::builder() .build() .unwrap() @@ -17,23 +17,9 @@ async fn correct_endpoint_resolver() { .await .expect("valid operation"); let props = operation.properties(); - let resolver = get_endpoint_resolver(&props).expect("operation should have endpoint resolver"); - // test regular endpoint - { - let ep = resolver - .resolve_endpoint(&Region::new("us-east-1")) - .expect("valid endpoint"); - let mut uri = Uri::from_static("/"); - ep.set_endpoint(&mut uri, None); - assert_eq!(uri, Uri::from_static("https://iam.amazonaws.com/")); - } + let ep: &aws_smithy_http::endpoint::Result = + props.get().expect("endpoint result was not present"); + let ep = ep.as_ref().expect("ep resolved successfully"); // test fips endpoint - { - let ep = resolver - .resolve_endpoint(&Region::new("iam-fips")) - .expect("valid endpoint"); - let mut uri = Uri::from_static("/"); - ep.set_endpoint(&mut uri, None); - assert_eq!(uri, Uri::from_static("https://iam-fips.amazonaws.com/")); - } + assert_eq!(ep.url(), "https://iam-fips.amazonaws.com/"); } diff --git a/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs b/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs index 9e61c2989..b20718331 100644 --- a/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs +++ b/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs @@ -145,16 +145,16 @@ async fn retry_test(sleep_impl: Arc) -> Result<(), Box Result<(), Box>, ) -> Result<(), Box> { - let mut actual_requests = self.recorded_requests.lock().unwrap(); + let mut actual_requests = + std::mem::take(self.recorded_requests.lock().unwrap().deref_mut()); for conn_id in 0..self.verifiable_events.len() { let conn_id = ConnectionId(conn_id); let expected = self.verifiable_events.get(&conn_id).unwrap(); @@ -98,7 +100,8 @@ impl ReplayingConnection { /// Return all the recorded requests for further analysis pub async fn take_requests(self) -> Vec> { - let mut recorded_requests = self.recorded_requests.lock().unwrap(); + let mut recorded_requests = + std::mem::take(self.recorded_requests.lock().unwrap().deref_mut()); let mut out = Vec::with_capacity(recorded_requests.len()); for conn_id in 0..recorded_requests.len() { out.push( diff --git a/rust-runtime/aws-smithy-http/src/endpoint.rs b/rust-runtime/aws-smithy-http/src/endpoint.rs index 619066789..6eb8bec7c 100644 --- a/rust-runtime/aws-smithy-http/src/endpoint.rs +++ b/rust-runtime/aws-smithy-http/src/endpoint.rs @@ -4,12 +4,56 @@ */ use std::borrow::Cow; +use std::fmt::{Display, Formatter}; use std::str::FromStr; use http::uri::{Authority, Uri}; use crate::operation::BuildError; +pub type Result = std::result::Result; + +pub trait ResolveEndpoint: Send + Sync { + fn resolve_endpoint(&self, params: &Params) -> Result; +} + +/// Endpoint Resolution Error +#[derive(Debug)] +pub struct Error { + message: String, + extra: Option>, +} + +impl Error { + /// Create an [`Error`] with a message + pub fn message(message: impl Into) -> Self { + Self { + message: message.into(), + extra: None, + } + } + + pub fn with_cause(self, cause: impl Into>) -> Self { + Self { + extra: Some(cause.into()), + ..self + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.extra.as_ref().map(|err| err.as_ref() as _) + } +} + +// TODO(endpoints 2.0): when `endpoint_url` is added, deprecate & delete `Endpoint` /// API Endpoint /// /// This implements an API endpoint as specified in the @@ -25,7 +69,7 @@ pub struct Endpoint { #[derive(Clone, Debug, Eq, PartialEq)] pub struct EndpointPrefix(String); impl EndpointPrefix { - pub fn new(prefix: impl Into) -> Result { + pub fn new(prefix: impl Into) -> std::result::Result { let prefix = prefix.into(); match Authority::from_str(&prefix) { Ok(_) => Ok(EndpointPrefix(prefix)), @@ -46,6 +90,68 @@ impl EndpointPrefix { #[derive(Debug, Eq, PartialEq, Clone)] pub enum InvalidEndpoint { EndpointMustHaveAuthority, + EndpointMustHaveScheme, + FailedToConstructAuthority, + FailedToConstructUri, +} + +impl Display for InvalidEndpoint { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + InvalidEndpoint::EndpointMustHaveAuthority => { + write!(f, "Endpoint must contain an authority") + } + InvalidEndpoint::EndpointMustHaveScheme => { + write!(f, "Endpoint must contain a valid scheme") + } + InvalidEndpoint::FailedToConstructAuthority => { + write!( + f, + "Endpoint must contain a valid authority when combined with endpoint prefix" + ) + } + InvalidEndpoint::FailedToConstructUri => write!(f, "Failed to construct URI"), + } + } +} + +impl std::error::Error for InvalidEndpoint {} + +/// Apply `endpoint` to `uri` +/// +/// This method mutates `uri` by setting the `endpoint` on it +/// +/// # Panics +/// This method panics if `uri` does not have a scheme +pub fn apply_endpoint( + uri: &mut Uri, + endpoint: &Uri, + prefix: Option<&EndpointPrefix>, +) -> std::result::Result<(), InvalidEndpoint> { + let prefix = prefix.map(|p| p.0.as_str()).unwrap_or(""); + let authority = endpoint + .authority() + .as_ref() + .map(|auth| auth.as_str()) + .unwrap_or(""); + let authority = if !prefix.is_empty() { + Authority::from_str(&format!("{}{}", prefix, authority)) + } else { + Authority::from_str(authority) + } + .map_err(|_| InvalidEndpoint::FailedToConstructAuthority)?; + let scheme = *endpoint + .scheme() + .as_ref() + .ok_or(InvalidEndpoint::EndpointMustHaveScheme)?; + let new_uri = Uri::builder() + .authority(authority) + .scheme(scheme.clone()) + .path_and_query(Endpoint::merge_paths(endpoint, uri).as_ref()) + .build() + .map_err(|_| InvalidEndpoint::FailedToConstructUri)?; + *uri = new_uri; + Ok(()) } impl Endpoint { @@ -61,6 +167,11 @@ impl Endpoint { } } + /// Returns the URI of this endpoint + pub fn uri(&self) -> &Uri { + &self.uri + } + /// Create a new immutable endpoint from a URI /// /// ```rust @@ -82,26 +193,11 @@ impl Endpoint { /// Sets the endpoint on `uri`, potentially applying the specified `prefix` in the process. pub fn set_endpoint(&self, uri: &mut http::Uri, prefix: Option<&EndpointPrefix>) { - let prefix = prefix.map(|p| p.0.as_str()).unwrap_or(""); - let authority = self - .uri - .authority() - .as_ref() - .map(|auth| auth.as_str()) - .unwrap_or(""); - let authority = if !self.immutable && !prefix.is_empty() { - Authority::from_str(&format!("{}{}", prefix, authority)).expect("parts must be valid") - } else { - Authority::from_str(authority).expect("authority is valid") + let prefix = match self.immutable { + true => None, + false => prefix, }; - let scheme = *self.uri.scheme().as_ref().expect("scheme must be provided"); - let new_uri = Uri::builder() - .authority(authority) - .scheme(scheme.clone()) - .path_and_query(Self::merge_paths(&self.uri, uri).as_ref()) - .build() - .expect("valid uri"); - *uri = new_uri; + apply_endpoint(uri, &self.uri, prefix).expect("failed to set endpoint"); } fn merge_paths<'a>(endpoint: &'a Uri, uri: &'a Uri) -> Cow<'a, str> { diff --git a/rust-runtime/aws-smithy-types/src/date_time/mod.rs b/rust-runtime/aws-smithy-types/src/date_time/mod.rs index ae3686e1d..bb1120c82 100644 --- a/rust-runtime/aws-smithy-types/src/date_time/mod.rs +++ b/rust-runtime/aws-smithy-types/src/date_time/mod.rs @@ -209,7 +209,7 @@ impl DateTime { Format::DateTime => format::rfc3339::read(s)?, Format::HttpDate => format::http_date::read(s)?, Format::EpochSeconds => { - let split_point = s.find(delim).unwrap_or_else(|| s.len()); + let split_point = s.find(delim).unwrap_or(s.len()); let (s, rest) = s.split_at(split_point); (Self::from_str(s, format)?, rest) } diff --git a/rust-runtime/aws-smithy-types/src/endpoint.rs b/rust-runtime/aws-smithy-types/src/endpoint.rs new file mode 100644 index 000000000..9c33efc05 --- /dev/null +++ b/rust-runtime/aws-smithy-types/src/endpoint.rs @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +//! Smithy Endpoint Types + +use crate::Document; +use std::borrow::Cow; +use std::collections::HashMap; + +type MaybeStatic = Cow<'static, str>; + +/// Smithy Endpoint Type +/// +/// Generally, this type should not be used from user code +#[derive(Debug, Clone, PartialEq)] +pub struct Endpoint { + url: MaybeStatic, + headers: HashMap>, + properties: HashMap, +} + +#[allow(unused)] +impl Endpoint { + /// Returns the URL of this endpoint + pub fn url(&self) -> &str { + &self.url + } + + /// Returns the headers associated with this endpoint + pub fn headers(&self) -> impl Iterator)> { + self.headers + .iter() + .map(|(k, v)| (k.as_ref(), v.iter().map(|v| v.as_ref()))) + } + + /// Returns the properties associated with this endpoint + pub fn properties(&self) -> &HashMap, Document> { + &self.properties + } + + /// Converts this endpoint back into a [`Builder`] + pub fn into_builder(self) -> Builder { + Builder { endpoint: self } + } + + /// A builder for [`Endpoint`] + pub fn builder() -> Builder { + Builder::new() + } +} + +#[derive(Debug, Clone)] +/// Builder for [`Endpoint`] +pub struct Builder { + endpoint: Endpoint, +} + +#[allow(unused)] +impl Builder { + pub(crate) fn new() -> Self { + Self { + endpoint: Endpoint { + url: Default::default(), + headers: HashMap::new(), + properties: HashMap::new(), + }, + } + } + + /// Set the URL of the Endpoint + /// + /// # Examples + /// ```rust + /// use aws_smithy_types::endpoint::Endpoint; + /// let endpoint = Endpoint::builder().url("https://www.example.com").build(); + /// ``` + pub fn url(mut self, url: impl Into) -> Self { + self.endpoint.url = url.into(); + self + } + + /// Adds a header to the endpoint + /// + /// If there is already a header for this key, this header will be appended to that key + /// + /// # Examples + /// ```rust + /// use aws_smithy_types::endpoint::Endpoint; + /// let endpoint = Endpoint::builder().url("https://www.example.com").header("x-my-header", "hello").build(); + /// ``` + pub fn header(mut self, name: impl Into, value: impl Into) -> Self { + self.endpoint + .headers + .entry(name.into()) + .or_default() + .push(value.into()); + self + } + + /// Adds a property to the endpoint + /// + /// If there is already a property for this key, the existing property will be overwritten + /// + /// # Examples + /// ```rust + /// use aws_smithy_types::endpoint::Endpoint; + /// let endpoint = Endpoint::builder() + /// .url("https://www.example.com") + /// .property("x-my-header", true) + /// .build(); + /// ``` + pub fn property(mut self, key: impl Into, value: impl Into) -> Self { + self.endpoint.properties.insert(key.into(), value.into()); + self + } + + /// Constructs an [`Endpoint`] from this builder + /// + /// # Panics + /// Panics if URL is unset or empty + pub fn build(self) -> Endpoint { + assert_ne!(self.endpoint.url(), "", "URL was unset"); + self.endpoint + } +} + +#[cfg(test)] +mod test { + use crate::endpoint::Endpoint; + use crate::Document; + use std::borrow::Cow; + use std::collections::HashMap; + + #[test] + fn endpoint_builder() { + let endpoint = Endpoint::builder() + .url("https://www.amazon.com") + .header("x-amz-test", "header-value") + .property("custom", Document::Bool(true)) + .build(); + assert_eq!(endpoint.url, Cow::Borrowed("https://www.amazon.com")); + assert_eq!( + endpoint.headers, + HashMap::from([( + Cow::Borrowed("x-amz-test"), + vec![Cow::Borrowed("header-value")] + )]) + ); + assert_eq!( + endpoint.properties, + HashMap::from([(Cow::Borrowed("custom"), Document::Bool(true))]) + ); + + assert_eq!(endpoint.url(), "https://www.amazon.com"); + assert_eq!( + endpoint + .headers() + .map(|(k, v)| (k, v.collect::>())) + .collect::>(), + vec![("x-amz-test", vec!["header-value"])] + ); + } + + #[test] + fn borrowed_values() { + fn foo(a: &str) { + // borrowed values without a static lifetime need to be converted into owned values + let endpoint = Endpoint::builder().url(a.to_string()).build(); + assert_eq!(endpoint.url(), a); + } + + foo("asdf"); + } +} diff --git a/rust-runtime/aws-smithy-types/src/lib.rs b/rust-runtime/aws-smithy-types/src/lib.rs index 570e4e78f..14705db77 100644 --- a/rust-runtime/aws-smithy-types/src/lib.rs +++ b/rust-runtime/aws-smithy-types/src/lib.rs @@ -17,6 +17,7 @@ use std::collections::HashMap; pub mod base64; pub mod date_time; +pub mod endpoint; pub mod primitive; pub mod retry; pub mod timeout; @@ -76,6 +77,30 @@ pub enum Document { Null, } +impl From for Document { + fn from(value: bool) -> Self { + Document::Bool(value) + } +} + +impl From for Document { + fn from(value: String) -> Self { + Document::String(value) + } +} + +impl From> for Document { + fn from(values: Vec) -> Self { + Document::Array(values) + } +} + +impl From> for Document { + fn from(values: HashMap) -> Self { + Document::Object(values) + } +} + /// A number type that implements Javascript / JSON semantics, modeled on serde_json: /// #[derive(Debug, Clone, Copy, PartialEq)] -- GitLab