Unverified Commit 76f0fa96 authored by Julian Antonielli's avatar Julian Antonielli Committed by GitHub
Browse files

Compile regexes for `@pattern` strings early (#2058)



* Add docs for `compile_regex()`

* Try to pass `@pattern` string shapes to service visitor

* Add `compile_rege` calls

* Remove unwanted changes

* Fix python code generator

* Add hack for `@pattern` + `@enum` trait

* Generate tests for regexes

* Move some functionality to `TraitInfo`

* Add docs for `Symbol.testModule`

* Improve docs for `Symbol.testModule`

* Update codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt

Co-authored-by: default avatarLuca Palmieri <20745048+LukeMathWalker@users.noreply.github.com>

* Refactor `regex_compiles` test rendering

* Refactor `Symbol.testModule` -> `SymbolProvider.testModuleForShape`

* Use `RustWriter.unitTest` instead of manually annotating `#[test]`s

* Refactor `mapNotNull` -> `filter` + `map`

* Only render doc comment when there's `@pattern` strings

* Rename `String.toRustName` -> `´String.unsafeToRustName`

* Add kotlin test for pattern with lookahead

* Update `Attribute.Cfg` to new `Attribute.CfgTest`

* Prepend `@pattern` docs

* Fix test name

Co-authored-by: default avatarLuca Palmieri <20745048+LukeMathWalker@users.noreply.github.com>
parent aaac83f9
Loading
Loading
Loading
Loading
+3 −6
Original line number Diff line number Diff line
@@ -15,14 +15,13 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.Fun
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.unsafeToRustName
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

data class Context(val functionRegistry: FunctionRegistry, val runtimeConfig: RuntimeConfig)

@@ -30,7 +29,7 @@ data class Context(val functionRegistry: FunctionRegistry, val runtimeConfig: Ru
 * Utility function to convert an [Identifier] into a valid Rust identifier (snake case)
 */
fun Identifier.rustName(): String {
    return this.toString().stringToRustName()
    return this.toString().unsafeToRustName()
}

/**
@@ -53,8 +52,6 @@ class Types(runtimeConfig: RuntimeConfig) {
    val resolveEndpointError = smithyHttpEndpointModule.resolve("ResolveEndpointError")
}

private fun String.stringToRustName(): String = RustReservedWords.escapeIfNeeded(this.toSnakeCase())

/**
 * Returns the memberName() for a given [Parameter]
 */
@@ -62,7 +59,7 @@ fun Parameter.memberName(): String {
    return name.rustName()
}

fun ContextParamTrait.memberName(): String = this.name.stringToRustName()
fun ContextParamTrait.memberName(): String = this.name.unsafeToRustName()

/**
 * Returns the symbol for a given parameter. This enables [RustWriter] to generate the correct [RustType].
+2 −1
Original line number Diff line number Diff line
@@ -56,10 +56,11 @@ sealed class RustModule {
            documentation: String? = null,
            inline: Boolean = false,
            parent: RustModule = LibRs,
            additionalAttributes: List<Attribute> = listOf(),
        ): LeafModule {
            return LeafModule(
                RustReservedWords.escapeIfNeeded(name),
                RustMetadata(visibility = visibility),
                RustMetadata(visibility = visibility, additionalAttributes = additionalAttributes),
                documentation,
                inline = inline,
                parent = parent,
+31 −0
Original line number Diff line number Diff line
@@ -38,8 +38,11 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
@@ -393,6 +396,28 @@ fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder = this.putProper
fun Symbol.Builder.module(module: RustModule.LeafModule): Symbol.Builder = this.putProperty(RUST_MODULE_KEY, module)
fun Symbol.module(): RustModule.LeafModule = this.expectProperty(RUST_MODULE_KEY, RustModule.LeafModule::class.java)

/**
 * Creates a test module for this symbol.
 * For example if the symbol represents the name for the struct `struct MyStruct { ... }`,
 * this function will create the following inline module:
 * ```rust
 *  #[cfg(test)]
 *  mod test_my_struct { ... }
 * ```
 */
fun SymbolProvider.testModuleForShape(shape: Shape): RustModule.LeafModule {
    val symbol = toSymbol(shape)
    val rustName = symbol.name.unsafeToRustName()

    return RustModule.new(
        name = "test_$rustName",
        visibility = Visibility.PRIVATE,
        inline = true,
        parent = symbol.module(),
        additionalAttributes = listOf(Attribute.CfgTest),
    )
}

fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder {
    return this.putProperty(RENAMED_FROM_KEY, name)
}
@@ -435,3 +460,9 @@ fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter<RustType.Option>() is
// Symbols should _always_ be created with a Rust type & shape attached
fun Symbol.rustType(): RustType = this.expectProperty(RUST_TYPE_KEY, RustType::class.java)
fun Symbol.shape(): Shape = this.expectProperty(SHAPE_KEY, Shape::class.java)

/**
 *  You should rarely need this function, rust names in general should be symbol-aware,
 *  this is "automatic" if you use things like [software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate].
 */
fun String.unsafeToRustName(): String = RustReservedWords.escapeIfNeeded(this.toSnakeCase())
+2 −2
Original line number Diff line number Diff line
@@ -8,9 +8,9 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
@@ -26,7 +26,7 @@ class PythonServerServiceGenerator(
    protocolGenerator: ServerProtocolGenerator,
    protocolSupport: ProtocolSupport,
    protocol: ServerProtocol,
    private val context: CodegenContext,
    private val context: ServerCodegenContext,
) : ServerServiceGenerator(rustCrate, protocolGenerator, protocolSupport, protocol, context) {

    override fun renderCombinedErrors(writer: RustWriter, operation: OperationShape) {
+38 −5
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.model.traits.PatternTrait
@@ -25,6 +26,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.smithy.testModuleForShape
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
@@ -56,14 +59,14 @@ class ConstrainedStringGenerator(
                PubCrateConstraintViolationSymbolProvider(this)
            }
        }
    private val symbol = constrainedShapeSymbolProvider.toSymbol(shape)
    private val constraintsInfo: List<TraitInfo> =
        supportedStringConstraintTraits
            .mapNotNull { shape.getTrait(it).orNull() }
            .map(StringTraitInfo::fromTrait)
            .map { StringTraitInfo.fromTrait(symbol, it) }
            .map(StringTraitInfo::toTraitInfo)

    fun render() {
        val symbol = constrainedShapeSymbolProvider.toSymbol(shape)
        val name = symbol.name
        val inner = RustType.String.render()
        val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape)
@@ -145,6 +148,8 @@ class ConstrainedStringGenerator(
        writer.withInlineModule(constraintViolation.module()) {
            renderConstraintViolationEnum(this, shape, constraintViolation)
        }

        renderTests(shape)
    }

    private fun renderConstraintViolationEnum(writer: RustWriter, shape: StringShape, constraintViolation: Symbol) {
@@ -174,6 +179,22 @@ class ConstrainedStringGenerator(
            )
        }
    }

    private fun renderTests(shape: Shape) {
        val testCases = TraitInfo.testCases(constraintsInfo)

        if (testCases.isNotEmpty()) {
            val testModule = constrainedShapeSymbolProvider.testModuleForShape(shape)
            writer.withInlineModule(testModule) {
                rustTemplate(
                    """
                    #{TestCases:W}
                    """,
                    "TestCases" to testCases.join("\n"),
                )
            }
        }
    }
}
private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() {
    override fun toTraitInfo(): TraitInfo = TraitInfo(
@@ -217,7 +238,7 @@ private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() {
    }
}

private data class Pattern(val patternTrait: PatternTrait) : StringTraitInfo() {
private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : StringTraitInfo() {
    override fun toTraitInfo(): TraitInfo {
        val pattern = patternTrait.pattern

@@ -239,6 +260,16 @@ private data class Pattern(val patternTrait: PatternTrait) : StringTraitInfo() {
                )
            },
            this::renderValidationFunction,
            testCases = listOf {
                unitTest("regex_compiles") {
                    rustTemplate(
                        """
                        #{T}::compile_regex();
                        """,
                        "T" to symbol,
                    )
                }
            },
        )
    }

@@ -264,6 +295,8 @@ private data class Pattern(val patternTrait: PatternTrait) : StringTraitInfo() {
                    }
                }

                /// Attempts to compile the regex for this constrained type's `@pattern`.
                /// This can fail if the specified regex is not supported by the `#{Regex}` crate.
                pub fn compile_regex() -> &'static #{Regex}::Regex {
                    static REGEX: #{OnceCell}::sync::Lazy<#{Regex}::Regex> = #{OnceCell}::sync::Lazy::new(|| #{Regex}::Regex::new(r##"$pattern"##).expect(r##"$errorMessageForUnsupportedRegex"##));

@@ -279,10 +312,10 @@ private data class Pattern(val patternTrait: PatternTrait) : StringTraitInfo() {

private sealed class StringTraitInfo {
    companion object {
        fun fromTrait(trait: Trait): StringTraitInfo =
        fun fromTrait(symbol: Symbol, trait: Trait) =
            when (trait) {
                is PatternTrait -> {
                    Pattern(trait)
                    Pattern(symbol, trait)
                }
                is LengthTrait -> {
                    Length(trait)
Loading