Unverified Commit b1ee45ab authored by Julian Antonielli's avatar Julian Antonielli Committed by GitHub
Browse files

Fix protocol test: `RestJsonMalformedPatternSensitiveString` (#2321)

* Make `RestJsonMalformedPatternSensitiveString` pass

* Use `AwsJson11` variable instead of hardcoded string

* Refactor `errorMessage` into private function

* Add comment about remaining `@range` on floats tests

* Use `AwsJson11` variable instead of hardcoded string (missed one)

* Use `hasTrait` over `getTrait() != null`
parent 12a66c2c
Loading
Loading
Loading
Loading
+0 −12
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.server.smithy

import software.amazon.smithy.model.traits.PatternTrait

@Suppress("UnusedReceiverParameter")
fun PatternTrait.validationErrorMessage(): String =
    "Value {} at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}"
+37 −9
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.model.traits.PatternTrait
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
@@ -22,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
@@ -29,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.smithy.testModuleForShape
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
@@ -63,7 +66,7 @@ class ConstrainedStringGenerator(
    private val constraintsInfo: List<TraitInfo> =
        supportedStringConstraintTraits
            .mapNotNull { shape.getTrait(it).orNull() }
            .map { StringTraitInfo.fromTrait(symbol, it) }
            .map { StringTraitInfo.fromTrait(symbol, it, isSensitive = shape.hasTrait<SensitiveTrait>()) }
            .map(StringTraitInfo::toTraitInfo)

    fun render() {
@@ -184,6 +187,7 @@ class ConstrainedStringGenerator(
        }
    }
}

private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() {
    override fun toTraitInfo(): TraitInfo = TraitInfo(
        tryFromCheck = { rust("Self::check_length(&value)?;") },
@@ -229,10 +233,9 @@ private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() {
    }
}

private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : StringTraitInfo() {
private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait, val isSensitive: Boolean) :
    StringTraitInfo() {
    override fun toTraitInfo(): TraitInfo {
        val pattern = patternTrait.pattern

        return TraitInfo(
            tryFromCheck = { rust("let value = Self::check_pattern(value)?;") },
            constraintViolationVariant = {
@@ -241,13 +244,14 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) :
                rust("Pattern(String)")
            },
            asValidationExceptionField = {
                rust(
                rustTemplate(
                    """
                    Self::Pattern(string) => crate::model::ValidationExceptionField {
                        message: format!("${patternTrait.validationErrorMessage()}", &string, &path, r##"$pattern"##),
                    Self::Pattern(_string) => crate::model::ValidationExceptionField {
                        message: #{ErrorMessage:W},
                        path
                    },
                    """,
                    "ErrorMessage" to errorMessage(),
                )
            },
            this::renderValidationFunction,
@@ -264,6 +268,28 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) :
        )
    }

    private fun errorMessage(): Writable {
        val pattern = patternTrait.pattern

        return if (isSensitive) {
            writable {
                rust(
                    """
                    format!("Value at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}", &path, r##"$pattern"##)
                    """,
                )
            }
        } else {
            writable {
                rust(
                    """
                    format!("Value {} at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}", &_string, &path, r##"$pattern"##)
                    """,
                )
            }
        }
    }

    /**
     * Renders a `check_pattern` function to validate the string matches the
     * supplied regex in the `@pattern` trait.
@@ -303,14 +329,16 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) :

private sealed class StringTraitInfo {
    companion object {
        fun fromTrait(symbol: Symbol, trait: Trait) =
        fun fromTrait(symbol: Symbol, trait: Trait, isSensitive: Boolean) =
            when (trait) {
                is PatternTrait -> {
                    Pattern(symbol, trait)
                    Pattern(symbol, trait, isSensitive)
                }

                is LengthTrait -> {
                    Length(trait)
                }

                else -> PANIC("StringTraitInfo.fromTrait called with unsupported trait $trait")
            }
    }
+7 −13
Original line number Diff line number Diff line
@@ -774,13 +774,12 @@ class ServerProtocolTestGenerator(

            FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest),

            // Tests involving constraint traits, which are not yet fully implemented.
            // See https://github.com/awslabs/smithy-rs/issues/1401.
            // Tests involving `@range` on floats.
            // Pending resolution from the Smithy team, see https://github.com/awslabs/smithy-rs/issues/2007.
            FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case0", TestType.MalformedRequest),
            FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest),
            FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest),
            FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest),
            FailingTest(RestJsonValidation, "RestJsonMalformedPatternSensitiveString", TestType.MalformedRequest),

            // See https://github.com/awslabs/smithy-rs/issues/1969
            FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case0", TestType.MalformedRequest),
@@ -872,16 +871,11 @@ class ServerProtocolTestGenerator(
            FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request),

            // AwsJson1.1 failing tests.
            FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTraitWithHostLabel", TestType.Request),
            FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTrait", TestType.Request),
            FailingTest("aws.protocoltests.json#JsonProtocol", "parses_httpdate_timestamps", TestType.Response),
            FailingTest("aws.protocoltests.json#JsonProtocol", "parses_iso8601_timestamps", TestType.Response),
            FailingTest(
                "aws.protocoltests.json#JsonProtocol",
                "parses_the_request_id_from_the_response",
                TestType.Response,
            ),

            FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", TestType.Request),
            FailingTest(AwsJson11, "AwsJson11EndpointTrait", TestType.Request),
            FailingTest(AwsJson11, "parses_httpdate_timestamps", TestType.Response),
            FailingTest(AwsJson11, "parses_iso8601_timestamps", TestType.Response),
            FailingTest(AwsJson11, "parses_the_request_id_from_the_response", TestType.Response),
        )
        private val RunOnly: Set<String>? = null