Unverified Commit d97defbd authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

[Python] Support more testing model (#2541)



* Remove parameter from `Protocol`s `structuredDataParser`, `structuredDataSerializer`

No implementation of the `Protocol` interface makes use of the
`OperationShape` parameter in the `structuredDataParser` and
`structuredDataSerializer` methods.

* Remove the TypeConversionGenerator class in favor of using
customizations for JsonParserGenerator and ServerHttpBoundProtocolGenerator.

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>

* Make the additionaParserCustomizations default to empty list

* Fix merge conflict

* Fix missing ;

* Use better defaults when checking for customizations

* Use better defaults when checking for customizations

* Add HttpBindingCustomization and relax the datetime symbol check

* Support recursive shapes and add a lot more models to the tests

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>

* Support naming obstacle course

* Add support for constrained blobs conversions

* Support constraint traits

* Try to generate the full diff

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>

* A better way of checking if we need to go into the Timestamp branch

* Remove wheels folder

---------

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: default avatardavid-perez <d@vidp.dev>
parent 3aa4cc24
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -56,3 +56,6 @@ target/

# tools
.tool-versions

# python
__pycache__
+0 −54
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.rustType

/*
 * Utility class used to force casting a non primitive type into one overriden by a new symbol provider,
 * by explicitly calling `from()` or into().
 *
 * For example we use this in the server Python implementation, where we override types like [Blob] and [DateTime]
 * with wrappers compatible with Python, without touching the original implementation coming from `aws-smithy-types`.
 */
class TypeConversionGenerator(private val model: Model, private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig) {
    private fun findOldSymbol(shape: Shape): Symbol {
        return when (shape) {
            is BlobShape -> RuntimeType.blob(runtimeConfig).toSymbol()
            is TimestampShape -> RuntimeType.dateTime(runtimeConfig).toSymbol()
            else -> symbolProvider.toSymbol(shape)
        }
    }

    fun convertViaFrom(shape: Shape): Writable =
        writable {
            val oldSymbol = findOldSymbol(shape)
            val newSymbol = symbolProvider.toSymbol(shape)
            if (oldSymbol.rustType() != newSymbol.rustType()) {
                rust(".map($newSymbol::from)")
            }
        }

    fun convertViaInto(shape: Shape): Writable =
        writable {
            val oldSymbol = findOldSymbol(shape)
            val newSymbol = symbolProvider.toSymbol(shape)
            if (oldSymbol.rustType() != newSymbol.rustType()) {
                rust(".into()")
            }
        }
}
+9 −2
Original line number Diff line number Diff line
@@ -88,6 +88,9 @@ sealed class HttpBindingSection(name: String) : Section(name) {

    data class AfterDeserializingIntoAHashMapOfHttpPrefixHeaders(val memberShape: MemberShape) :
        HttpBindingSection("AfterDeserializingIntoAHashMapOfHttpPrefixHeaders")

    data class AfterDeserializingIntoADateTimeOfHttpHeaders(val memberShape: MemberShape) :
        HttpBindingSection("AfterDeserializingIntoADateTimeOfHttpHeaders")
}

typealias HttpBindingCustomization = NamedCustomization<HttpBindingSection>
@@ -353,7 +356,7 @@ class HttpBindingGenerator(
            rustType to targetShape
        }
        val parsedValue = safeName()
        if (coreType == dateTime) {
        if (coreShape.isTimestampShape()) {
            val timestampFormat =
                index.determineTimestampFormat(
                    memberShape,
@@ -362,10 +365,14 @@ class HttpBindingGenerator(
                )
            val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat)
            rust(
                "let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?;",
                "let $parsedValue: Vec<${coreType.render()}> = #T::many_dates(headers, #T)?",
                headerUtil,
                timestampFormatType,
            )
            for (customization in customizations) {
                customization.section(HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders(memberShape))(this)
            }
            rust(";")
        } else if (coreShape.isPrimitive()) {
            rust(
                "let $parsedValue = #T::read_many_primitive::<${coreType.render()}>(headers)?;",
+24 −7
Original line number Diff line number Diff line
@@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
@@ -61,6 +60,12 @@ import software.amazon.smithy.utils.StringUtils
 */
sealed class JsonParserSection(name: String) : Section(name) {
    data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember")

    data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember")

    data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember")

    data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember")
}

/**
@@ -94,7 +99,6 @@ class JsonParserGenerator(
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenTarget = codegenContext.target
    private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType()
    private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig)
    private val protocolFunctions = ProtocolFunctions(codegenContext)
    private val codegenScope = arrayOf(
        "Error" to smithyJson.resolve("deserialize::error::DeserializeError"),
@@ -276,13 +280,13 @@ class JsonParserGenerator(
            is StringShape -> deserializeString(target)
            is BooleanShape -> rustTemplate("#{expect_bool_or_null}(tokens.next())?", *codegenScope)
            is NumberShape -> deserializeNumber(target)
            is BlobShape -> deserializeBlob()
            is BlobShape -> deserializeBlob(memberShape)
            is TimestampShape -> deserializeTimestamp(target, memberShape)
            is CollectionShape -> deserializeCollection(target)
            is MapShape -> deserializeMap(target)
            is StructureShape -> deserializeStruct(target)
            is UnionShape -> deserializeUnion(target)
            is DocumentShape -> rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope)
            is DocumentShape -> deserializeDocument(memberShape)
            else -> PANIC("unexpected shape: $target")
        }
        val symbol = symbolProvider.toSymbol(memberShape)
@@ -294,11 +298,21 @@ class JsonParserGenerator(
        }
    }

    private fun RustWriter.deserializeBlob() {
    private fun RustWriter.deserializeDocument(member: MemberShape) {
        rustTemplate("Some(#{expect_document}(tokens)?)", *codegenScope)
        for (customization in customizations) {
            customization.section(JsonParserSection.AfterDocumentDeserializedMember(member))(this)
        }
    }

    private fun RustWriter.deserializeBlob(member: MemberShape) {
        rustTemplate(
            "#{expect_blob_or_null}(tokens.next())?",
            *codegenScope,
        )
        for (customization in customizations) {
            customization.section(JsonParserSection.AfterBlobDeserializedMember(member))(this)
        }
    }

    private fun RustWriter.deserializeStringInner(target: StringShape, escapedStrName: String) {
@@ -349,9 +363,12 @@ class JsonParserGenerator(
            )
        val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat)
        rustTemplate(
            "#{expect_timestamp_or_null}(tokens.next(), #{T})?#{ConvertFrom:W}",
            "T" to timestampFormatType, "ConvertFrom" to typeConversionGenerator.convertViaFrom(shape), *codegenScope,
            "#{expect_timestamp_or_null}(tokens.next(), #{T})?",
            "T" to timestampFormatType, *codegenScope,
        )
        for (customization in customizations) {
            customization.section(JsonParserSection.AfterTimestampDeserializedMember(member))(this)
        }
    }

    private fun RustWriter.deserializeCollection(shape: CollectionShape) {
+40 −5
Original line number Diff line number Diff line
@@ -54,14 +54,49 @@ val allCodegenTests = "../../codegen-core/common-test-models".let { commonModels
            // TODO(https://github.com/awslabs/smithy-rs/issues/1401) `@uniqueItems` is used.
            extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
        ),
        // TODO(https://github.com/awslabs/smithy-rs/issues/2476)
        CodegenTest(
            "aws.protocoltests.json#JsonProtocol",
            "json_rpc11",
            extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
        ),
        CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
        CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"),
        CodegenTest(
            "aws.protocoltests.restjson#RestJsonExtras",
            "rest_json_extras",
            imports = listOf("$commonModels/rest-json-extras.smithy"),
        ),
        // TODO(https://github.com/awslabs/smithy-rs/issues/2551)
        // CodegenTest(
        //     "aws.protocoltests.json#JsonProtocol",
        //     "json_rpc11",
        //     "aws.protocoltests.restjson.validation#RestJsonValidation",
        //     "rest_json_validation",
        //     // `@range` trait is used on floating point shapes, which we deliberately don't want to support.
        //     // See https://github.com/awslabs/smithy-rs/issues/1401.
        //     extraConfig = """, "codegen": { "ignoreUnsupportedConstraints": true } """,
        // ),
        // TODO(https://github.com/awslabs/smithy-rs/issues/2479)
        // CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"),
        CodegenTest(
            "com.amazonaws.constraints#ConstraintsService",
            "constraints",
            imports = listOf("$commonModels/constraints.smithy"),
        ),
        CodegenTest(
            "com.amazonaws.constraints#ConstraintsService",
            "constraints_without_public_constrained_types",
            imports = listOf("$commonModels/constraints.smithy"),
            extraConfig = """, "codegen": { "publicConstrainedTypes": false } """,
        ),
        CodegenTest(
            "com.amazonaws.constraints#UniqueItemsService",
            "unique_items",
            imports = listOf("$commonModels/unique-items.smithy"),
        ),
        CodegenTest(
            "naming_obs_structs#NamingObstacleCourseStructs",
            "naming_test_structs",
            imports = listOf("$commonModels/naming-obstacle-course-structs.smithy"),
        ),
        CodegenTest("casing#ACRONYMInside_Service", "naming_test_casing", imports = listOf("$commonModels/naming-obstacle-course-casing.smithy")),
        CodegenTest("crate#Config", "naming_test_ops", imports = listOf("$commonModels/naming-obstacle-course-ops.smithy")),
    )
}

Loading