Unverified Commit c7b97dcc authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Codegenerate Python application example and add explicit cast during JSON deserialization (#1520)




Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>
parent 2ad87c0e
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -205,3 +205,11 @@ message = "SDK crate READMEs now include an example of creating a client"
references = ["smithy-rs#1571", "smithy-rs#1385"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = """
Add explicit cast during JSON deserialization in case of custom Symbol providers.
"""
references = ["smithy-rs#1520"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "crisidev"
+77 −3
Original line number Diff line number Diff line
@@ -6,13 +6,21 @@
package software.amazon.smithy.rust.codegen.server.python.smithy.generators

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.Errors
import software.amazon.smithy.rust.codegen.smithy.Inputs
import software.amazon.smithy.rust.codegen.smithy.Outputs
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
@@ -56,8 +64,10 @@ class PythonApplicationGenerator(
    coreCodegenContext: CoreCodegenContext,
    private val operations: List<OperationShape>,
) {
    private val crateName = coreCodegenContext.settings.moduleName
    private val symbolProvider = coreCodegenContext.symbolProvider
    private val runtimeConfig = coreCodegenContext.runtimeConfig
    private val model = coreCodegenContext.model
    private val codegenScope =
        arrayOf(
            "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(),
@@ -73,10 +83,9 @@ class PythonApplicationGenerator(
        )

    fun render(writer: RustWriter) {
        renderPyApplicationRustDocs(writer)
        writer.rustTemplate(
            """
            /// Main Python application, used to register operations and context and start multiple
            /// workers on the same shared socket.
            ##[#{pyo3}::pyclass]
            ##[derive(Debug, Clone)]
            pub struct App {
@@ -127,7 +136,7 @@ class PythonApplicationGenerator(
            )
            rustBlockTemplate(
                """
                /// Dynamically codegenerate the routes, allowing to build the Smithy [Router].
                /// Dynamically codegenerate the routes, allowing to build the Smithy [#{SmithyServer}::Router].
                pub fn build_router(&mut self, py: #{pyo3}::Python) -> #{pyo3}::PyResult<()>
                """,
                *codegenScope
@@ -179,4 +188,69 @@ class PythonApplicationGenerator(
            }
        }
    }

    private fun renderPyApplicationRustDocs(writer: RustWriter) {
        writer.rust(
"""
##[allow(clippy::tabs_in_doc_comments)]
/// Main Python application, used to register operations and context and start multiple
/// workers on the same shared socket.
///
/// Operations can be registrered using the application object as a decorator (`@app.operation_name`).
///
/// Here's a full example to get you started:
///
/// ```python
${ if (operations.any { it.errors.isNotEmpty() }) {
"""/// from $crateName import ${Inputs.namespace}
/// from $crateName import ${Outputs.namespace}
/// from $crateName import ${Errors.namespace}"""
            } else {
"""/// from $crateName import ${Inputs.namespace}
/// from $crateName import ${Outputs.namespace}"""
            } }
/// from $crateName import App
///
/// @dataclass
/// class Context:
///     counter: int = 0
///
/// app = App()
/// app.context(Context())
///
${operationImplementationStubs(operations)}
///
/// app.run()
/// ```
///
/// Any of operations above can be written as well prepending the `async` keyword and
/// the Python application will automatically handle it and schedule it on the event loop for you.
            """
        )
    }

    private fun operationImplementationStubs(operations: List<OperationShape>): String =
        operations.joinToString("\n///\n") {
            val operationDocumentation = it.getTrait<DocumentationTrait>()?.value
            val ret = if (!operationDocumentation.isNullOrBlank()) {
                operationDocumentation.replace("#", "##").prependIndent("/// ## ") + "\n"
            } else ""
            ret +
                """
                    /// ${it.signature()}:
                    ///     raise NotImplementedError
                """.trimIndent()
        }

    /**
     * Returns the function signature for an operation handler implementation. Used in the documentation.
     */
    private fun OperationShape.signature(): String {
        val inputSymbol = symbolProvider.toSymbol(inputShape(model))
        val outputSymbol = symbolProvider.toSymbol(outputShape(model))
        val inputT = "${Inputs.namespace}::${inputSymbol.name}"
        val outputT = "${Outputs.namespace}::${outputSymbol.name}"
        val operationName = symbolProvider.toSymbol(this).name.toSnakeCase()
        return "@app.$operationName\n/// def $operationName(input: $inputT, ctx: Context) -> $outputT"
    }
}
+16 −4
Original line number Diff line number Diff line
@@ -65,6 +65,7 @@ class JsonParserGenerator(
    private val target = coreCodegenContext.target
    private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType()
    private val jsonDeserModule = RustModule.private("json_deser")
    private val typeConversionGenerator = TypeConversionGenerator(symbolProvider, runtimeConfig)
    private val codegenScope = arrayOf(
        "Error" to smithyJson.member("deserialize::Error"),
        "ErrorReason" to smithyJson.member("deserialize::ErrorReason"),
@@ -221,8 +222,8 @@ class JsonParserGenerator(
            is StringShape -> deserializeString(target)
            is BooleanShape -> rustTemplate("#{expect_bool_or_null}(tokens.next())?", *codegenScope)
            is NumberShape -> deserializeNumber(target)
            is BlobShape -> rustTemplate("#{expect_blob_or_null}(tokens.next())?", *codegenScope)
            is TimestampShape -> deserializeTimestamp(memberShape)
            is BlobShape -> deserializeBlob(target)
            is TimestampShape -> deserializeTimestamp(target, memberShape)
            is CollectionShape -> deserializeCollection(target)
            is MapShape -> deserializeMap(target)
            is StructureShape -> deserializeStruct(target)
@@ -236,6 +237,14 @@ class JsonParserGenerator(
        }
    }

    private fun RustWriter.deserializeBlob(target: BlobShape) {
        rustTemplate(
            "#{expect_blob_or_null}(tokens.next())?#{ConvertFrom:W}",
            "ConvertFrom" to typeConversionGenerator.convertViaFrom(target),
            *codegenScope
        )
    }

    private fun RustWriter.deserializeStringInner(target: StringShape, escapedStrName: String) {
        withBlock("$escapedStrName.to_unescaped().map(|u|", ")") {
            when (target.hasTrait<EnumTrait>()) {
@@ -266,14 +275,17 @@ class JsonParserGenerator(
        rustTemplate("#{expect_number_or_null}(tokens.next())?.map(|v| v.to_#{T}())", "T" to symbol, *codegenScope)
    }

    private fun RustWriter.deserializeTimestamp(member: MemberShape) {
    private fun RustWriter.deserializeTimestamp(shape: TimestampShape, member: MemberShape) {
        val timestampFormat =
            httpBindingResolver.timestampFormat(
                member, HttpLocation.DOCUMENT,
                TimestampFormatTrait.Format.EPOCH_SECONDS
            )
        val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
        rustTemplate("#{expect_timestamp_or_null}(tokens.next(), #{T})?", "T" to timestampFormatType, *codegenScope)
        rustTemplate(
            "#{expect_timestamp_or_null}(tokens.next(), #{T})?#{ConvertFrom:W}",
            "T" to timestampFormatType, "ConvertFrom" to typeConversionGenerator.convertViaFrom(shape), *codegenScope
        )
    }

    private fun RustWriter.deserializeCollection(shape: CollectionShape) {
+40 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.smithy.protocols.parse

import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.rustType

/*
 * Utility class used to force casting a non primitive type into one overriden by a new symbol provider,
 * by explicitly calling `from()`.
 *
 * For example we use this in the server Python implementation, where we override types like [Blob] and [DateTime]
 * with wrappers compatible with Python, without touching the original implementation coming from `aws-smithy-types`.
 */
class TypeConversionGenerator(private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig) {
    fun convertViaFrom(shape: Shape): Writable =
        writable {
            val oldSymbol = when (shape) {
                // TODO(understand what needs to be done for ByteStream)
                is BlobShape -> RuntimeType.Blob(runtimeConfig).toSymbol()
                is TimestampShape -> RuntimeType.DateTime(runtimeConfig).toSymbol()
                else -> symbolProvider.toSymbol(shape)
            }
            val newSymbol = symbolProvider.toSymbol(shape)
            if (oldSymbol.rustType() != newSymbol.rustType()) {
                rust(".map($newSymbol::from)")
            }
        }
}