Unverified Commit d293d1f7 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Add lifetimes to async traits that take args by reference (#3061)

This PR adds lifetimes to the `IdentityResolver`, `DnsResolver` (renamed
to `ResolveDns`), and `EndpointResolver` traits so that lifetime
gymnastics aren't needed when implementing those traits. For example,
`IdentityResolver::resolve_identity` takes `&ConfigBag` as an argument,
which means you have to pull things out of the ConfigBag outside of any
returned async block in order for the compiler to be satisfied. This
change removes that consideration and makes implementing these traits a
lot easier.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 8439f2ae
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -343,3 +343,15 @@ message = """
references = ["smithy-rs#3032"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "Lifetimes have been added to `EndpointResolver` and `IdentityResolver` traits."
references = ["smithy-rs#3061"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "Lifetimes have been added to the `EndpointResolver` trait."
references = ["smithy-rs#3061"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
+1 −1
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ allowed_external_types = [
   "aws_smithy_http::endpoint",
   "aws_smithy_http::endpoint::error::InvalidEndpointError",
   "aws_smithy_http::result::SdkError",
   "aws_smithy_runtime_api::client::dns::DnsResolver",
   "aws_smithy_runtime_api::client::dns::ResolveDns",
   "aws_smithy_runtime_api::client::dns::SharedDnsResolver",
   "aws_smithy_runtime_api::client::http::HttpClient",
   "aws_smithy_runtime_api::client::http::SharedHttpClient",
+10 −10
Original line number Diff line number Diff line
@@ -50,7 +50,7 @@ use crate::http_credential_provider::HttpCredentialProvider;
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_runtime_api::client::dns::{DnsResolver, ResolveDnsError, SharedDnsResolver};
use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_types::error::display::DisplayErrorContext;
@@ -272,9 +272,9 @@ impl Builder {

    /// Override the DNS resolver used to validate URIs
    ///
    /// URIs must refer to loopback addresses. The [`DnsResolver`](aws_smithy_runtime_api::client::dns::DnsResolver)
    /// is used to retrieve IP addresses for a given domain.
    pub fn dns(mut self, dns: impl DnsResolver + 'static) -> Self {
    /// URIs must refer to loopback addresses. The [`ResolveDns`](aws_smithy_runtime_api::client::dns::ResolveDns)
    /// implementation is used to retrieve IP addresses for a given domain.
    pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
        self.dns = Some(dns.into_shared());
        self
    }
@@ -399,7 +399,7 @@ async fn validate_full_uri(
        Ok(addr) => addr.is_loopback(),
        Err(_domain_name) => {
            let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
            dns.resolve_dns(host.to_owned())
            dns.resolve_dns(host)
                .await
                .map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
                .iter()
@@ -751,16 +751,16 @@ mod test {
        }
    }

    impl DnsResolver for TestDns {
        fn resolve_dns(&self, name: String) -> DnsFuture {
            DnsFuture::ready(Ok(self.addrs.get(&name).unwrap_or(&self.fallback).clone()))
    impl ResolveDns for TestDns {
        fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
            DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
        }
    }

    #[derive(Debug)]
    struct NeverDns;
    impl DnsResolver for NeverDns {
        fn resolve_dns(&self, _name: String) -> DnsFuture {
    impl ResolveDns for NeverDns {
        fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
            DnsFuture::new(async {
                Never::new().await;
                unreachable!()
+3 −4
Original line number Diff line number Diff line
@@ -524,11 +524,10 @@ struct ImdsEndpointResolver {
}

impl EndpointResolver for ImdsEndpointResolver {
    fn resolve_endpoint(&self, _: &EndpointResolverParams) -> EndpointFuture {
        let this = self.clone();
    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
        EndpointFuture::new(async move {
            this.endpoint_source
                .endpoint(this.mode_override)
            self.endpoint_source
                .endpoint(self.mode_override.clone())
                .await
                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
                .map_err(|err| err.into())
+6 −10
Original line number Diff line number Diff line
@@ -192,23 +192,19 @@ fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Toke
}

impl IdentityResolver for TokenResolver {
    fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
        let this = self.clone();
        IdentityFuture::new(async move {
            let preloaded_token = this
    fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a> {
        IdentityFuture::new(async {
            let preloaded_token = self
                .inner
                .cache
                .yield_or_clear_if_expired(this.inner.time_source.now())
                .yield_or_clear_if_expired(self.inner.time_source.now())
                .await;
            let token = match preloaded_token {
                Some(token) => Ok(token),
                None => {
                    this.inner
                    self.inner
                        .cache
                        .get_or_load(|| {
                            let this = this.clone();
                            async move { this.get_token().await }
                        })
                        .get_or_load(|| async { self.get_token().await })
                        .await
                }
            }?;
Loading