Unverified Commit c503debb authored by Harry Barber's avatar Harry Barber Committed by GitHub
Browse files

Remove usage of percent_decode_str for decoding query string key-value pairs (#1417)

* Use `form_urlencoding::parse` over `serde_urlencoded::from_str`, this removes our dependency on `serde_urlencoded`.

* Remove `percent_encoding::percent_decode_str` where possible.

* Contract `generateParsePercentEncodedStrFn` and its children into one function and make it optionally apply percent encoding.

* Add `impl From<std::convert::Infallible> for RequestRejection` to remove codegen branch.
parent fe92efbf
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -16,10 +16,10 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
 */
object ServerCargoDependency {
    val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1"))
    val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1"))
    val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3"))
    val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
    val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
    val SerdeUrlEncoded: CargoDependency = CargoDependency("serde_urlencoded", CratesIo("0.7"))
    val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))

    fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
+73 −110
Original line number Diff line number Diff line
@@ -12,12 +12,13 @@ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.model.traits.HttpTrait
@@ -57,7 +58,6 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPay
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.toOptional
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.util.dq
@@ -65,7 +65,6 @@ import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findStreamingMember
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.outputShape
@@ -120,6 +119,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
        "Cow" to ServerRuntimeType.Cow,
        "DateTime" to RuntimeType.DateTime(runtimeConfig),
        "FormUrlEncoded" to ServerCargoDependency.FormUrlEncoded.asType(),
        "HttpBody" to CargoDependency.HttpBody.asType(),
        "header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header"),
        "Hyper" to CargoDependency.Hyper.asType(),
@@ -127,7 +127,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        "Nom" to ServerCargoDependency.Nom.asType(),
        "PercentEncoding" to CargoDependency.PercentEncoding.asType(),
        "Regex" to CargoDependency.Regex.asType(),
        "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
        "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
        "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "RuntimeError" to ServerRuntimeType.RuntimeError(runtimeConfig),
@@ -775,7 +774,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                .forEachIndexed { index, segment ->
                    val binding = pathBindings.find { it.memberName == segment.content }
                    if (binding != null && segment.isLabel) {
                        val deserializer = generateParsePercentEncodedStrFn(binding)
                        val deserializer = generateParseFn(binding, true)
                        rustTemplate(
                            """
                            input = input.${binding.member.setterName()}(
@@ -847,7 +846,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
            rustTemplate(
                """
                let query_string = request.uri().query().unwrap_or("");
                let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(#{Cow}<'_, str>, #{Cow}<'_, str>)>>(query_string)?;
                let pairs = #{FormUrlEncoded}::parse(query_string.as_bytes());
                """.trimIndent(),
                *codegenScope
            )
@@ -870,7 +869,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(

            rustBlock("for (k, v) in pairs") {
                queryBindingsTargettingSimple.forEach {
                    val deserializer = generateParsePercentEncodedStrFn(it)
                    val deserializer = generateParseFn(it, false)
                    val memberName = symbolProvider.toMemberName(it.member)
                    rustTemplate(
                        """
@@ -891,25 +890,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator(

                        when {
                            memberShape.isStringShape -> {
                                // `<_>::from()/try_from()` is necessary to convert the `&str` into:
                                //     * the Rust enum in case the `string` shape has the `enum` trait; or
                                //     * `String` in case it doesn't.
                                if (memberShape.hasTrait<EnumTrait>()) {
                                // NOTE: This path is traversed with or without @enum applied. The `try_from` is used
                                // as a common conversion.
                                rustTemplate(
                                    """
                                        let v = <#{memberShape}>::try_from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref())?;
                                    let v = <#{memberShape}>::try_from(v.as_ref())?;
                                    """,
                                    *codegenScope,
                                    "memberShape" to symbolProvider.toSymbol(memberShape),
                                )
                                } else {
                                    rustTemplate(
                                        """
                                        let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
                                        """.trimIndent(),
                                        *codegenScope
                                    )
                                }
                            }
                            memberShape.isTimestampShape -> {
                                val index = HttpBindingIndex.of(model)
@@ -922,7 +911,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                                val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
                                rustTemplate(
                                    """
                                    let v = #{PercentEncoding}::percent_decode_str(&v).decode_utf8()?;
                                    let v = #{DateTime}::from_str(&v, #{format})?;
                                    """.trimIndent(),
                                    *codegenScope,
@@ -1013,21 +1001,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        )
    }

    // TODO(https://github.com/awslabs/smithy-rs/issues/1231): If this function was called to parse a query string
    // key value pair, we don't need to percent-decode it _again_.
    private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType {
        // HTTP bindings we support that contain percent-encoded data.
        check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY)

        val target = model.expectShape(binding.member.target)
        return when {
            target.isStringShape -> generateParsePercentEncodedStrAsStringFn(binding)
            target.isTimestampShape -> generateParsePercentEncodedStrAsTimestampFn(binding)
            else -> generateParseStrAsPrimitiveFn(binding)
        }
    }

    private fun generateParsePercentEncodedStrAsStringFn(binding: HttpBindingDescriptor): RuntimeType {
    private fun generateParseFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType {
        val output = symbolProvider.toSymbol(binding.member)
        val fnName = generateParseStrFnName(binding)
        val symbol = output.extractSymbolFromOption()
@@ -1037,41 +1011,31 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                *codegenScope,
                "O" to output,
            ) {
                // `<_>::from()` is necessary to convert the `&str` into:
                //     * the Rust enum in case the `string` shape has the `enum` trait; or
                //     * `String` in case it doesn't.
                when (symbol.rustType()) {
                    RustType.String ->
                val target = model.expectShape(binding.member.target)

                when {
                    target.isStringShape -> {
                        // NOTE: This path is traversed with or without @enum applied. The `try_from` is used as a
                        // common conversion.
                        if (percentDecoding) {
                            rustTemplate(
                                """
                            let value = <#{T}>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
                                let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
                                let value = #{T}::try_from(value.as_ref())?;
                                """,
                                *codegenScope,
                                "T" to symbol,
                            )
                    else -> { // RustType.Opaque, the Enum
                        check(symbol.rustType() is RustType.Opaque)
                        } else {
                            rustTemplate(
                                """
                            let value = <#{T}>::try_from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref())?;
                                let value = #{T}::try_from(value)?;
                                """,
                            *codegenScope,
                                "T" to symbol,
                            )
                        }
                    }
                writer.write(
                    """
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """
                )
            }
        }
    }

    private fun generateParsePercentEncodedStrAsTimestampFn(binding: HttpBindingDescriptor): RuntimeType {
        val output = symbolProvider.toSymbol(binding.member)
        val fnName = generateParseStrFnName(binding)
                    target.isTimestampShape -> {
                        val index = HttpBindingIndex.of(model)
                        val timestampFormat =
                            index.determineTimestampFormat(
@@ -1080,45 +1044,44 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                                protocol.defaultTimestampFormat,
                            )
                        val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
            writer.rustBlockTemplate(
                "pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
                *codegenScope,
                "O" to output,
            ) {

                        if (percentDecoding) {
                            rustTemplate(
                                """
                                let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
                    let value = #{DateTime}::from_str(&value, #{format})?;
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """.trimIndent(),
                                let value = #{DateTime}::from_str(value.as_ref(), #{format})?;
                                """,
                                *codegenScope,
                                "format" to timestampFormatType,
                            )
                        } else {
                            rustTemplate(
                                """
                                let value = #{DateTime}::from_str(value, #{format})?;
                                """,
                                *codegenScope,
                                "format" to timestampFormatType,
                            )
                        }
                    }
    }

    // Function to parse a string as the data type generated for boolean, byte, short, integer, long, float, or double shapes.
    // TODO(https://github.com/awslabs/smithy-rs/issues/1232): This function can be replaced by https://docs.rs/aws-smithy-types/latest/aws_smithy_types/primitive/trait.Parse.html
    private fun generateParseStrAsPrimitiveFn(binding: HttpBindingDescriptor): RuntimeType {
        val output = symbolProvider.toSymbol(binding.member)
        val fnName = generateParseStrFnName(binding)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
            writer.rustBlockTemplate(
                "pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
                *codegenScope,
                "O" to output,
            ) {
                    else -> {
                        check(target is NumberShape || target is BooleanShape)
                        rustTemplate(
                            """
                            let value = std::str::FromStr::from_str(value)?;
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """.trimIndent(),
                            """,
                            *codegenScope,
                        )
                    }
                }

                writer.write(
                    """
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """
                )
            }
        }
    }

    private fun generateParseStrFnName(binding: HttpBindingDescriptor): String {
+19 −0
Original line number Diff line number Diff line
@@ -194,6 +194,25 @@ pub enum RequestRejection {

impl std::error::Error for RequestRejection {}

// Consider a conversion between `T` and `U` followed by a bubbling up of the conversion error
// through `Result<_, RequestRejection>`. This [`From`] implementation accomodates the special case
// where `T` and `U` are equal, in such cases `T`/`U` a enjoy `TryFrom<T>` with
// `Err = std::convert::Infallible`.
//
// Note that when `!` stabilizes `std::convert::Infallible` will become an alias for `!` and there
// will be a blanket `impl From<!> for T`. This will remove the need for this implementation.
//
// More details on this can be found in the following links:
// - https://doc.rust-lang.org/std/primitive.never.html
// - https://doc.rust-lang.org/std/convert/enum.Infallible.html#future-compatibility
impl From<std::convert::Infallible> for RequestRejection {
    fn from(_err: std::convert::Infallible) -> Self {
        // We opt for this `match` here rather than [`unreachable`] to assure the reader that this
        // code path is dead.
        match _err {}
    }
}

// These converters are solely to make code-generation simpler. They convert from a specific error
// type (from a runtime/third-party crate or the standard library) into a variant of the
// [`crate::rejection::RequestRejection`] enum holding the type-erased boxed [`crate::Error`]