Unverified Commit e935fbc7 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Implement httpLabel with nom (#996)



Move to `nom` to implement httpLabel instead of using regexes.

Issue: #938

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.com>

Co-authored-by: default avatardavid-perez <d@vidp.dev>
parent daf4b8de
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ object ServerCargoDependency {
    val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1"))
    val AxumCore: CargoDependency = CargoDependency("axum-core", CratesIo("0.1"))
    val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3"))
    val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
    val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
    val SerdeUrlEncoded: CargoDependency = CargoDependency("serde_urlencoded", CratesIo("0.7"))
    val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
+64 −29
Original line number Diff line number Diff line
@@ -107,6 +107,7 @@ private class ServerHttpProtocolImplGenerator(
        "HttpBody" to CargoDependency.HttpBody.asType(),
        "Hyper" to CargoDependency.Hyper.asType(),
        "LazyStatic" to CargoDependency.LazyStatic.asType(),
        "Nom" to ServerCargoDependency.Nom.asType(),
        "PercentEncoding" to CargoDependency.PercentEncoding.asType(),
        "Regex" to CargoDependency.Regex.asType(),
        "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
@@ -572,42 +573,76 @@ private class ServerHttpProtocolImplGenerator(
        if (pathBindings.isEmpty()) {
            return
        }
        val pattern = StringBuilder()
        val httpTrait = httpBindingResolver.httpTrait(operationShape)
        httpTrait.uri.segments.forEach {
            pattern.append("/")
            if (it.isLabel) {
                pattern.append("(?P<${it.content}>")
                if (it.isGreedyLabel) {
                    pattern.append(".+")
        val greedyLabelIndex = httpTrait.uri.segments.indexOfFirst { it.isGreedyLabel }
        val segments =
            if (greedyLabelIndex >= 0)
                httpTrait.uri.segments.slice(0 until (greedyLabelIndex + 1))
            else
                httpTrait.uri.segments
        val restAfterGreedyLabel =
            if (greedyLabelIndex >= 0)
                httpTrait.uri.segments.slice((greedyLabelIndex + 1) until httpTrait.uri.segments.size).joinToString(prefix = "/", separator = "/")
            else
                ""
        val labeledNames = segments
            .mapIndexed { index, segment ->
                if (segment.isLabel) { "m$index" } else { "_" }
            }
            .joinToString(prefix = (if (segments.size > 1) "(" else ""), separator = ",", postfix = (if (segments.size > 1) ")" else ""))
        val nomParser = segments
            .map { segment ->
                if (segment.isGreedyLabel) {
                    "#{Nom}::combinator::rest::<_, #{Nom}::error::Error<&str>>"
                } else if (segment.isLabel) {
                    """#{Nom}::branch::alt::<_, _, #{Nom}::error::Error<&str>, _>((#{Nom}::bytes::complete::take_until("/"), #{Nom}::combinator::rest))"""
                } else {
                    pattern.append("[^/]+")
                    """#{Nom}::bytes::complete::tag::<_, _, #{Nom}::error::Error<&str>>("${segment.content}")"""
                }
                pattern.append(")")
            } else {
                pattern.append(it.content)
            }
            .joinToString(
                // TODO: tuple() is currently limited to 21 items
                prefix = if (segments.size > 1) "#{Nom}::sequence::tuple::<_, _, #{Nom}::error::Error<&str>, _>((" else "",
                postfix = if (segments.size > 1) "))" else "",
                transform = { parser ->
                    """
                    #{Nom}::sequence::preceded(#{Nom}::bytes::complete::tag("/"),  $parser)
                    """.trimIndent()
                }
            )
        with(writer) {
            rustTemplate("let input_string = request.uri().path();")
            if (greedyLabelIndex >= 0 && greedyLabelIndex + 1 < httpTrait.uri.segments.size) {
                rustTemplate(
                    """
                #{LazyStatic}::lazy_static! {
                    static ref RE: #{Regex}::Regex = #{Regex}::Regex::new("$pattern").unwrap();
                    if !input_string.ends_with(${restAfterGreedyLabel.dq()}) {
                        return std::result::Result::Err(#{SmithyRejection}::Deserialize(
                            aws_smithy_http_server::rejection::Deserialize::from_err(format!("Postfix not found: {}", ${restAfterGreedyLabel.dq()}))));
                    }
                    let input_string = &input_string[..(input_string.len() - ${restAfterGreedyLabel.dq()}.len())];
                    """.trimIndent(),
                *codegenScope,
                    *codegenScope
                )
            rustBlock("if let Some(captures) = RE.captures(request.uri().path())") {
                pathBindings.forEach {
                    val deserializer = generateParsePercentEncodedStrFn(it)
            }
            rustTemplate(
                """
                        if let Some(m) = captures.name("${it.locationName}") {
                            input = input.${it.member.setterName()}(
                                #{deserializer}(m.as_str())?
                let (input_string, $labeledNames) = $nomParser(input_string)?;
                debug_assert_eq!("", input_string);
                """.trimIndent(),
                *codegenScope
            )
            segments
                .forEachIndexed { index, segment ->
                    val binding = pathBindings.find { it.memberName == segment.content }
                    if (binding != null && segment.isLabel) {
                        val deserializer = generateParsePercentEncodedStrFn(binding)
                        rustTemplate(
                            """
                            input = input.${binding.member.setterName()}(
                                #{deserializer}(m$index)?
                            );
                        }
                            """.trimIndent(),
                            *codegenScope,
                            "deserializer" to deserializer,
                        )
                    }
+1 −0
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ http = "0.2"
http-body = "0.4"
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] }
mime = "0.3"
nom = "7"
pin-project-lite = "0.2"
regex = "1.0"
serde_urlencoded = "0.7"
+6 −0
Original line number Diff line number Diff line
@@ -227,3 +227,9 @@ impl From<serde_urlencoded::de::Error> for SmithyRejection {
        SmithyRejection::Deserialize(Deserialize::from_err(err))
    }
}

impl From<nom::Err<nom::error::Error<&str>>> for SmithyRejection {
    fn from(err: nom::Err<nom::error::Error<&str>>) -> Self {
        SmithyRejection::Deserialize(Deserialize::from_err(err.to_owned()))
    }
}