Unverified Commit 8c045d27 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support for TimeStreamWrite and TimeStreamQuery (#2707)

TODO:
- [x] docs
- [x] integration test (canary even?)
- [x] customize README for timestream
## Motivation and Context
- #613 
- https://github.com/awslabs/aws-sdk-rust/issues/114


## Description
This adds support for TSW and TSQ by adding endpoint discovery as a
customization. This is made much simpler by the fact that endpoint
discovery for these services **has no parameters** which means that
there is no complexity from caching the returned endpoint.

Customers call `.enable_endpoint_discovery()` on the client to create a
version of the client with endpoint discovery enabled. This returns a
new client and a Reloader from which customers must spawn the reload
task if they want endpoint discovery to rerun.

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 7347c584
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -115,3 +115,9 @@ no longer requires `Service::Error = OperationError<Op::Error, PollError>`, inst
references = ["smithy-rs#2457"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" }
author = "hlbarber"

[[aws-sdk-rust]]
message = "The SDK has added support for timestreamwrite and timestreamquery. Support for these services is considered experimental at this time. In order to use these services, you MUST call `.enable_endpoint_discovery()` on the `Client` after construction."
meta = { "breaking" = false, "tada" = true, "bug" = false }
references = ["smithy-rs#2707", "aws-sdk-rust#114"]
author = "rcoh"
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" }
aws-smithy-http-tower = { path = "../../../rust-runtime/aws-smithy-http-tower" }
aws-smithy-runtime-api = { path = "../../../rust-runtime/aws-smithy-runtime-api" }
aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types" }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" }
aws-types = { path = "../aws-types" }
bytes = "1"
bytes-utils = "0.1.1"
+287 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Maintain a cache of discovered endpoints

use aws_smithy_async::rt::sleep::AsyncSleep;
use aws_smithy_async::time::TimeSource;
use aws_smithy_client::erase::boxclone::BoxFuture;
use aws_smithy_http::endpoint::{ResolveEndpoint, ResolveEndpointError};
use aws_smithy_types::endpoint::Endpoint;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::oneshot::{Receiver, Sender};

/// Endpoint reloader
#[must_use]
pub struct ReloadEndpoint {
    loader: Box<dyn Fn() -> BoxFuture<(Endpoint, SystemTime), ResolveEndpointError> + Send + Sync>,
    endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
    error: Arc<Mutex<Option<ResolveEndpointError>>>,
    rx: Receiver<()>,
    sleep: Arc<dyn AsyncSleep>,
    time: Arc<dyn TimeSource>,
}

impl Debug for ReloadEndpoint {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ReloadEndpoint").finish()
    }
}

impl ReloadEndpoint {
    /// Reload the endpoint once
    pub async fn reload_once(&self) {
        match (self.loader)().await {
            Ok((endpoint, expiry)) => {
                *self.endpoint.lock().unwrap() = Some(ExpiringEndpoint { endpoint, expiry })
            }
            Err(err) => *self.error.lock().unwrap() = Some(err),
        }
    }

    /// An infinite loop task that will reload the endpoint
    ///
    /// This task will terminate when the corresponding [`Client`](crate::Client) is dropped.
    pub async fn reload_task(mut self) {
        loop {
            match self.rx.try_recv() {
                Ok(_) | Err(TryRecvError::Closed) => break,
                _ => {}
            }
            self.reload_increment(self.time.now()).await;
            self.sleep.sleep(Duration::from_secs(60)).await;
        }
    }

    async fn reload_increment(&self, now: SystemTime) {
        let should_reload = self
            .endpoint
            .lock()
            .unwrap()
            .as_ref()
            .map(|e| e.is_expired(now))
            .unwrap_or(true);
        if should_reload {
            tracing::debug!("reloading endpoint, previous endpoint was expired");
            self.reload_once().await;
        }
    }
}

#[derive(Debug, Clone)]
pub(crate) struct EndpointCache {
    error: Arc<Mutex<Option<ResolveEndpointError>>>,
    endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
    // When the sender is dropped, this allows the reload loop to stop
    _drop_guard: Arc<Sender<()>>,
}

impl<T> ResolveEndpoint<T> for EndpointCache {
    fn resolve_endpoint(&self, _params: &T) -> aws_smithy_http::endpoint::Result {
        self.resolve_endpoint()
    }
}

#[derive(Debug)]
struct ExpiringEndpoint {
    endpoint: Endpoint,
    expiry: SystemTime,
}

impl ExpiringEndpoint {
    fn is_expired(&self, now: SystemTime) -> bool {
        tracing::debug!(expiry = ?self.expiry, now = ?now, delta = ?self.expiry.duration_since(now), "checking expiry status of endpoint");
        match self.expiry.duration_since(now) {
            Err(_) => true,
            Ok(t) => t < Duration::from_secs(120),
        }
    }
}

pub(crate) async fn create_cache<F>(
    loader_fn: impl Fn() -> F + Send + Sync + 'static,
    sleep: Arc<dyn AsyncSleep>,
    time: Arc<dyn TimeSource>,
) -> Result<(EndpointCache, ReloadEndpoint), ResolveEndpointError>
where
    F: Future<Output = Result<(Endpoint, SystemTime), ResolveEndpointError>> + Send + 'static,
{
    let error_holder = Arc::new(Mutex::new(None));
    let endpoint_holder = Arc::new(Mutex::new(None));
    let (tx, rx) = tokio::sync::oneshot::channel();
    let cache = EndpointCache {
        error: error_holder.clone(),
        endpoint: endpoint_holder.clone(),
        _drop_guard: Arc::new(tx),
    };
    let reloader = ReloadEndpoint {
        loader: Box::new(move || Box::pin((loader_fn)()) as _),
        endpoint: endpoint_holder,
        error: error_holder,
        rx,
        sleep,
        time,
    };
    reloader.reload_once().await;
    // if we didn't successfully get an endpoint, bail out so the client knows
    // configuration failed to work
    cache.resolve_endpoint()?;
    Ok((cache, reloader))
}

impl EndpointCache {
    fn resolve_endpoint(&self) -> aws_smithy_http::endpoint::Result {
        self.endpoint
            .lock()
            .unwrap()
            .as_ref()
            .map(|e| e.endpoint.clone())
            .ok_or_else(|| {
                self.error
                    .lock()
                    .unwrap()
                    .take()
                    .unwrap_or_else(|| ResolveEndpointError::message("no endpoint loaded"))
            })
    }
}

#[cfg(test)]
mod test {
    use crate::endpoint_discovery::create_cache;
    use aws_smithy_async::rt::sleep::TokioSleep;
    use aws_smithy_async::test_util::controlled_time_and_sleep;
    use aws_smithy_async::time::SystemTimeSource;
    use aws_smithy_types::endpoint::Endpoint;
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::sync::Arc;
    use std::time::{Duration, SystemTime, UNIX_EPOCH};
    use tokio::time::timeout;

    fn check_send_v<T: Send>(t: T) -> T {
        t
    }

    #[tokio::test]
    #[allow(unused_must_use)]
    async fn check_traits() {
        let (cache, reloader) = create_cache(
            || async {
                Ok((
                    Endpoint::builder().url("http://foo.com").build(),
                    SystemTime::now(),
                ))
            },
            Arc::new(TokioSleep::new()),
            Arc::new(SystemTimeSource::new()),
        )
        .await
        .unwrap();
        check_send_v(reloader.reload_task());
        check_send_v(cache);
    }

    #[tokio::test]
    async fn erroring_endpoint_always_reloaded() {
        let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
        let ct = Arc::new(AtomicUsize::new(0));
        let (cache, reloader) = create_cache(
            move || {
                let shared_ct = ct.clone();
                shared_ct.fetch_add(1, Ordering::AcqRel);
                async move {
                    Ok((
                        Endpoint::builder()
                            .url(format!("http://foo.com/{shared_ct:?}"))
                            .build(),
                        expiry,
                    ))
                }
            },
            Arc::new(TokioSleep::new()),
            Arc::new(SystemTimeSource::new()),
        )
        .await
        .expect("returns an endpoint");
        assert_eq!(
            cache.resolve_endpoint().expect("ok").url(),
            "http://foo.com/1"
        );
        // 120 second buffer
        reloader
            .reload_increment(expiry - Duration::from_secs(240))
            .await;
        assert_eq!(
            cache.resolve_endpoint().expect("ok").url(),
            "http://foo.com/1"
        );

        reloader.reload_increment(expiry).await;
        assert_eq!(
            cache.resolve_endpoint().expect("ok").url(),
            "http://foo.com/2"
        );
    }

    #[tokio::test]
    async fn test_advance_of_task() {
        let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
        // expires in 8 minutes
        let (time, sleep, mut gate) = controlled_time_and_sleep(expiry - Duration::from_secs(239));
        let ct = Arc::new(AtomicUsize::new(0));
        let (cache, reloader) = create_cache(
            move || {
                let shared_ct = ct.clone();
                shared_ct.fetch_add(1, Ordering::AcqRel);
                async move {
                    Ok((
                        Endpoint::builder()
                            .url(format!("http://foo.com/{shared_ct:?}"))
                            .build(),
                        expiry,
                    ))
                }
            },
            Arc::new(sleep.clone()),
            Arc::new(time.clone()),
        )
        .await
        .expect("first load success");
        let reload_task = tokio::spawn(reloader.reload_task());
        assert!(!reload_task.is_finished());
        // expiry occurs after 2 sleeps
        // t = 0
        assert_eq!(
            gate.expect_sleep().await.duration(),
            Duration::from_secs(60)
        );
        assert_eq!(cache.resolve_endpoint().unwrap().url(), "http://foo.com/1");
        // t = 60

        let sleep = gate.expect_sleep().await;
        // we're still holding the drop guard, so we haven't expired yet.
        assert_eq!(cache.resolve_endpoint().unwrap().url(), "http://foo.com/1");
        assert_eq!(sleep.duration(), Duration::from_secs(60));
        sleep.allow_progress();
        // t = 120

        let sleep = gate.expect_sleep().await;
        assert_eq!(cache.resolve_endpoint().unwrap().url(), "http://foo.com/2");
        sleep.allow_progress();

        let sleep = gate.expect_sleep().await;
        drop(cache);
        sleep.allow_progress();

        timeout(Duration::from_secs(1), reload_task)
            .await
            .expect("task finishes successfully")
            .expect("finishes");
    }
}
+9 −0
Original line number Diff line number Diff line
@@ -45,3 +45,12 @@ pub mod route53_resource_id_preprocessor;

/// Convert a streaming `SdkBody` into an aws-chunked streaming body with checksum trailers
pub mod http_body_checksum;

#[allow(dead_code)]
pub mod endpoint_discovery;

// just so docs work
#[allow(dead_code)]
/// allow docs to work
#[derive(Debug)]
pub struct Client;
+3 −0
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ import software.amazon.smithy.rustsdk.customize.s3.S3ExtendedRequestIdDecorator
import software.amazon.smithy.rustsdk.customize.s3control.S3ControlDecorator
import software.amazon.smithy.rustsdk.customize.sso.SSODecorator
import software.amazon.smithy.rustsdk.customize.sts.STSDecorator
import software.amazon.smithy.rustsdk.customize.timestream.TimestreamDecorator
import software.amazon.smithy.rustsdk.endpoints.AwsEndpointsStdLib
import software.amazon.smithy.rustsdk.endpoints.OperationInputTestDecorator
import software.amazon.smithy.rustsdk.endpoints.RequireEndpointRules
@@ -69,6 +70,8 @@ val DECORATORS: List<ClientCodegenDecorator> = listOf(
    S3ControlDecorator().onlyApplyTo("com.amazonaws.s3control#AWSS3ControlServiceV20180820"),
    STSDecorator().onlyApplyTo("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"),
    SSODecorator().onlyApplyTo("com.amazonaws.sso#SWBPortalService"),
    TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamwrite#Timestream_20181101"),
    TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamquery#Timestream_20181101"),

    // Only build docs-rs for linux to reduce load on docs.rs
    listOf(
Loading