Unverified Commit 5b93fd2f authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Improve client init time by switching to regex-lite (#3269)

Each client initialization was taking between 1 and 2 milliseconds,
regardless if the client had been constructed before or not. For
example, if a customer wants five clients with different credentials
providers, that could be 10 milliseconds of time spent in
`Client::from_conf`. Approximately 98% of this time was spent compiling
regular expressions for the endpoint partition resolver.

This change switches everything over to the regex-lite crate, which has
faster regex compile times, and shouldn't have much of an impact on
performance for our specific use-cases (small strings, only evaluated at
client initialization).

The use of regex was entirely removed in aws-sigv4 since it was overkill
for what it was being used for.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 6420816b
Loading
Loading
Loading
Loading
+11 −1
Original line number Diff line number Diff line
@@ -10,3 +10,13 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[aws-sdk-rust]]
message = """Client creation now takes microseconds instead of milliseconds.
Previously, it would take 2-3 milliseconds for each client instantiation due to time spent compiling regexes.
For applications that used several clients, this would increase start-up time in cases where it really matters,
such as for AWS Lambda cold starts. This time was improved by both changing regex implementation and caching the
result of the compilation."""
references = ["aws-sdk-rust#975", "smithy-rs#3269"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"
+0 −1
Original line number Diff line number Diff line
@@ -30,7 +30,6 @@ num-bigint = { version = "0.4", optional = true }
once_cell = "1.8"
p256 = { version = "0.11", features = ["ecdsa"], optional = true }
percent-encoding = { version = "2.1", optional = true }
regex = "1.5"
ring = { version = "0.17.5", optional = true }
sha2 = "0.10"
crypto-bigint = { version = "0.5.4", optional = true }
+38 −37
Original line number Diff line number Diff line
@@ -425,39 +425,37 @@ impl<'a> fmt::Display for CanonicalRequest<'a> {
    }
}

/// A regex for matching on 2 or more spaces that acts on bytes.
static MULTIPLE_SPACES: once_cell::sync::Lazy<regex::bytes::Regex> =
    once_cell::sync::Lazy::new(|| regex::bytes::Regex::new(r" {2,}").unwrap());

/// Removes excess spaces before and after a given byte string, and converts multiple sequential
/// spaces to a single space e.g. "  Some  example   text  " -> "Some example text".
///
/// This function ONLY affects spaces and not other kinds of whitespace.
fn trim_all(text: &[u8]) -> Cow<'_, [u8]> {
fn trim_all(text: &str) -> Cow<'_, str> {
    let text = text.trim_matches(' ');
    let requires_filter = text
        .chars()
        .zip(text.chars().skip(1))
        .any(|(a, b)| a == ' ' && b == ' ');
    if !requires_filter {
        Cow::Borrowed(text)
    } else {
        // The normal trim function will trim non-breaking spaces and other various whitespace chars.
        // S3 ONLY trims spaces so we use trim_matches to trim spaces only
    let text = trim_spaces_from_byte_string(text);
    MULTIPLE_SPACES.replace_all(text, " ".as_bytes())
        Cow::Owned(
            text.chars()
                // Filter out consecutive spaces
                .zip(text.chars().skip(1).chain(std::iter::once('!')))
                .filter(|(a, b)| *a != ' ' || *b != ' ')
                .map(|(a, _)| a)
                .collect(),
        )
    }

/// Removes excess spaces before and after a given byte string by returning a subset of those bytes.
/// Will return an empty slice if a string is composed entirely of whitespace.
fn trim_spaces_from_byte_string(bytes: &[u8]) -> &[u8] {
    let starting_index = bytes.iter().position(|b| *b != b' ').unwrap_or(0);
    let ending_offset = bytes.iter().rev().position(|b| *b != b' ').unwrap_or(0);
    let ending_index = bytes.len() - ending_offset;
    &bytes[starting_index..ending_index]
}

/// Works just like [trim_all] but acts on HeaderValues instead of bytes.
/// Will ensure that the underlying bytes are valid UTF-8.
fn normalize_header_value(header_value: &str) -> Result<HeaderValue, CanonicalRequestError> {
    let trimmed_value = trim_all(header_value.as_bytes());
    HeaderValue::from_str(
        std::str::from_utf8(&trimmed_value)
            .map_err(CanonicalRequestError::invalid_utf8_in_header_value)?,
    )
    .map_err(CanonicalRequestError::from)
    let trimmed_value = trim_all(header_value);
    HeaderValue::from_str(&trimmed_value).map_err(CanonicalRequestError::from)
}

#[derive(Debug, PartialEq, Default)]
@@ -631,6 +629,7 @@ mod tests {
    use http::{HeaderValue, Uri};
    use pretty_assertions::assert_eq;
    use proptest::{prelude::*, proptest};
    use std::borrow::Cow;
    use std::time::Duration;

    fn signing_params(identity: &Identity, settings: SigningSettings) -> SigningParams<'_> {
@@ -982,32 +981,34 @@ mod tests {

    #[test]
    fn test_trim_all_handles_spaces_correctly() {
        // Can't compare a byte array to a Cow so we convert both to slices before comparing
        let expected = &b"Some example text"[..];
        let actual = &trim_all(b"  Some  example   text  ")[..];

        assert_eq!(expected, actual);
        assert_eq!(Cow::Borrowed("don't touch me"), trim_all("don't touch me"));
        assert_eq!("trim left", trim_all("   trim left"));
        assert_eq!("trim right", trim_all("trim right "));
        assert_eq!("trim both", trim_all("   trim both  "));
        assert_eq!("", trim_all(" "));
        assert_eq!("", trim_all("  "));
        assert_eq!("a b", trim_all(" a   b "));
        assert_eq!("Some example text", trim_all("  Some  example   text  "));
    }

    #[test]
    fn test_trim_all_ignores_other_forms_of_whitespace() {
        // Can't compare a byte array to a Cow so we convert both to slices before comparing
        let expected = &b"\t\xA0Some\xA0 example \xA0text\xA0\n"[..];
        // \xA0 is a non-breaking space character
        let actual = &trim_all(b"\t\xA0Some\xA0     example   \xA0text\xA0\n")[..];

        assert_eq!(expected, actual);
        assert_eq!(
            "\t\u{A0}Some\u{A0} example \u{A0}text\u{A0}\n",
            trim_all("\t\u{A0}Some\u{A0}     example   \u{A0}text\u{A0}\n")
        );
    }

    #[test]
    fn trim_spaces_works_on_single_characters() {
        assert_eq!(trim_all(b"2").as_ref(), b"2");
        assert_eq!(trim_all("2").as_ref(), "2");
    }

    proptest! {
        #[test]
        fn test_trim_all_doesnt_elongate_strings(s in ".*") {
            assert!(trim_all(s.as_bytes()).len() <= s.len())
            assert!(trim_all(&s).len() <= s.len())
        }

        #[test]
@@ -1018,7 +1019,7 @@ mod tests {

        #[test]
        fn test_trim_all_does_nothing_when_there_are_no_spaces(s in "[^ ]*") {
            assert_eq!(trim_all(s.as_bytes()).as_ref(), s.as_bytes());
            assert_eq!(trim_all(&s).as_ref(), s);
        }
    }
}
+0 −10
Original line number Diff line number Diff line
@@ -7,7 +7,6 @@ use http::header::{InvalidHeaderName, InvalidHeaderValue};
use http::uri::InvalidUri;
use std::error::Error;
use std::fmt;
use std::str::Utf8Error;

#[derive(Debug)]
enum SigningErrorKind {
@@ -63,7 +62,6 @@ impl From<CanonicalRequestError> for SigningError {
enum CanonicalRequestErrorKind {
    InvalidHeaderName { source: InvalidHeaderName },
    InvalidHeaderValue { source: InvalidHeaderValue },
    InvalidUtf8InHeaderValue { source: Utf8Error },
    InvalidUri { source: InvalidUri },
    UnsupportedIdentityType,
}
@@ -79,7 +77,6 @@ impl fmt::Display for CanonicalRequestError {
        match self.kind {
            InvalidHeaderName { .. } => write!(f, "invalid header name"),
            InvalidHeaderValue { .. } => write!(f, "invalid header value"),
            InvalidUtf8InHeaderValue { .. } => write!(f, "invalid UTF-8 in header value"),
            InvalidUri { .. } => write!(f, "the uri was invalid"),
            UnsupportedIdentityType => {
                write!(f, "only AWS credentials are supported for signing")
@@ -94,7 +91,6 @@ impl Error for CanonicalRequestError {
        match &self.kind {
            InvalidHeaderName { source } => Some(source),
            InvalidHeaderValue { source } => Some(source),
            InvalidUtf8InHeaderValue { source } => Some(source),
            InvalidUri { source } => Some(source),
            UnsupportedIdentityType => None,
        }
@@ -102,12 +98,6 @@ impl Error for CanonicalRequestError {
}

impl CanonicalRequestError {
    pub(crate) fn invalid_utf8_in_header_value(source: Utf8Error) -> Self {
        Self {
            kind: CanonicalRequestErrorKind::InvalidUtf8InHeaderValue { source },
        }
    }

    pub(crate) fn unsupported_identity_type() -> Self {
        Self {
            kind: CanonicalRequestErrorKind::UnsupportedIdentityType,
+28 −9
Original line number Diff line number Diff line
@@ -20,10 +20,12 @@ import software.amazon.smithy.rulesengine.traits.ContextParamTrait
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.EndpointStdLib
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.FunctionRegistry
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
@@ -43,9 +45,25 @@ fun Identifier.rustName(): String {
}

/**
 * Endpoints standard library file
 * Endpoints standard library
 */
internal fun endpointsLib(name: String, vararg additionalDependency: RustDependency) = InlineDependency.forRustFile(
object EndpointsLib {
    val DiagnosticCollector = endpointsLib("diagnostic").toType().resolve("DiagnosticCollector")
    fun PartitionResolver(runtimeConfig: RuntimeConfig) =
        endpointsLib("partition", CargoDependency.smithyJson(runtimeConfig), CargoDependency.RegexLite).toType()
            .resolve("PartitionResolver")

    val substring = endpointsLib("substring").toType().resolve("substring")
    val isValidHostLabel = endpointsLib("host").toType().resolve("is_valid_host_label")
    val parseUrl = endpointsLib("parse_url", CargoDependency.Http, CargoDependency.Url).toType().resolve("parse_url")
    val uriEncode = endpointsLib("uri_encode", CargoDependency.PercentEncoding).toType().resolve("uri_encode")

    val awsParseArn = endpointsLib("arn").toType().resolve("parse_arn")
    val awsIsVirtualHostableS3Bucket =
        endpointsLib("s3", endpointsLib("host"), CargoDependency.OnceCell, CargoDependency.RegexLite).toType()
            .resolve("is_virtual_hostable_s3_bucket")

    private fun endpointsLib(name: String, vararg additionalDependency: RustDependency) = InlineDependency.forRustFile(
        RustModule.pubCrate(
            name,
            parent = EndpointStdLib,
@@ -53,6 +71,7 @@ internal fun endpointsLib(name: String, vararg additionalDependency: RustDepende
        "/inlineable/src/endpoint_lib/$name.rs",
        *additionalDependency,
    )
}

class Types(runtimeConfig: RuntimeConfig) {
    private val smithyTypesEndpointModule = RuntimeType.smithyTypes(runtimeConfig).resolve("endpoint")
Loading