Unverified Commit ddf1e01a authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Refactor server codegen for restJson1 (#786)



* Refactor server restJson1 and allow generation of HTTP request
/ response tests

* Remove useless logging and only generate server specific tests

* Fix simple.smithy model and add tests for deserialization of URI and
method

* Don't double quote strings and just use the dq() method

Co-authored-by: default avatarRussell Cohen <russell.r.cohen@gmail.com>
parent 99fe4a76
Loading
Loading
Loading
Loading
+27 −1
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@ $version: "1.0"
namespace com.amazonaws.simple

use aws.protocols#restJson1
use smithy.test#httpRequestTests
use smithy.test#httpResponseTests

@restJson1
@title("SimpleService")
@@ -20,6 +22,9 @@ service SimpleService {
@documentation("Id of the service that will be registered")
string ServiceId

@documentation("Name of the service that will be registered")
string ServiceName

@error("client")
@documentation(
    """
@@ -40,6 +45,25 @@ resource Service {
@idempotent
@http(method: "PUT", uri: "/service/{id}")
@documentation("Service register operation")
@httpRequestTests([
    {
        id: "RegisterServiceRequestTest",
        protocol: "aws.protocols#restJson1",
        uri: "/service/1",
        params: { id: "1", name: "TestService" },
        body: "{\"name\":\"TestService\"}",
        method: "PUT",
    }
])
@httpResponseTests([
    {
        id: "RegisterServiceResponseTest",
        protocol: "aws.protocols#restJson1",
        params: { id: "1", name: "TestService" },
        body: "{\"id\":\"1\",\"name\":\"TestService\"}",
        code: 200,
    }
])
operation RegisterService {
    input: RegisterServiceInputRequest,
    output: RegisterServiceOutputResponse,
@@ -51,12 +75,14 @@ structure RegisterServiceInputRequest {
    @required
    @httpLabel
    id: ServiceId,
    name: ServiceName,
}

@documentation("Service register output structure")
structure RegisterServiceOutputResponse {
    @required
    id: ServiceId
    id: ServiceId,
    name: ServiceName,
}

@readonly
+5 −3
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpDeserializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.DefaultPublicModules
import software.amazon.smithy.rust.codegen.smithy.RustCrate
@@ -74,8 +73,8 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
    private val protocolGeneratorFactory: ProtocolGeneratorFactory<ProtocolGenerator>
    private val protocolGenerator: ProtocolGenerator

    private val httpSerializerGenerator: ServerGenerator
    private val httpDeserializerGenerator: ServerGenerator
    private val httpSerializerGenerator: RestJson1HttpSerializerGenerator
    private val httpDeserializerGenerator: RestJson1HttpDeserializerGenerator
    private val httpBindingResolver: HttpBindingResolver

    init {
@@ -230,6 +229,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
     * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes.
     */
    override fun stringShape(shape: StringShape) {
        logger.info("[rust-server-codegen] Generating an enum $shape")
        shape.getTrait<EnumTrait>()?.also { enum ->
            rustCrate.useShapeWriter(shape) { writer ->
                EnumGenerator(model, symbolProvider, writer, shape, enum).render()
@@ -245,6 +245,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
     * This function _does not_ generate any serializers.
     */
    override fun unionShape(shape: UnionShape) {
        logger.info("[rust-server-codegen] Generating an union $shape")
        rustCrate.useShapeWriter(shape) {
            UnionGenerator(model, symbolProvider, it, shape).render()
        }
@@ -260,6 +261,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
     * - Operation structures
     */
    override fun serviceShape(shape: ServiceShape) {
        logger.info("[rust-server-codegen] Generating a service $shape")
        ServiceGenerator(
            rustCrate,
            protocolGenerator,
+52 −84
Original line number Diff line number Diff line
@@ -58,38 +58,21 @@ import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import java.util.logging.Logger

/**
* TODO: this abstract class is here temporarily as it promotes a bit of code reusal between the RestJson1
* serialzer and deserializer. This will be refactored and probably removed soon.
*/
abstract class ServerGenerator(
    codegenContext: CodegenContext,
    private val httpBindingResolver: HttpTraitHttpBindingResolver,
) {
    public val logger = Logger.getLogger(javaClass.name)
    public val error = RuntimeType("error", null, "crate")
    public val operation = RuntimeType("operation", null, "crate")
    public val runtimeConfig = codegenContext.runtimeConfig
    public val model = codegenContext.model
    public val symbolProvider = codegenContext.symbolProvider
    public val instantiator =
        with(codegenContext) { Instantiator(symbolProvider, model, runtimeConfig) }
    public val smithyHttp = CargoDependency.SmithyHttp(runtimeConfig).asType()
    public val index = HttpBindingIndex.of(model)
    public val service = codegenContext.serviceShape
    public val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS
    public val jsonSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver)
    public val jsonParserGenerator = JsonParserGenerator(codegenContext, httpBindingResolver)

    abstract fun render(writer: RustWriter, operationShape: OperationShape)
}

class RestJson1HttpSerializerGenerator(
    codegenContext: CodegenContext,
    private val httpBindingResolver: HttpTraitHttpBindingResolver,
) : ServerGenerator(codegenContext, httpBindingResolver) {
    private val serde = RuntimeType("json_ser", null, "crate")
) {
    private val logger = Logger.getLogger(javaClass.name)
    private val error = RuntimeType("error", null, "crate")
    private val operation = RuntimeType("operation", null, "crate")
    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
    private val index = HttpBindingIndex.of(model)
    private val runtimeConfig = codegenContext.runtimeConfig
    private val instantiator = with(codegenContext) { Instantiator(symbolProvider, model, runtimeConfig) }
    private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType()
    private val smithyHttp = CargoDependency.SmithyHttp(runtimeConfig).asType()
    private val jsonSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver)
    private val codegenScope =
        arrayOf(
            "Bytes" to RuntimeType.Bytes,
@@ -99,11 +82,11 @@ class RestJson1HttpSerializerGenerator(
            "build_error" to runtimeConfig.operationBuildError(),
            "JsonSerdeError" to error.member("Error"),
            "JsonObjectWriter" to smithyJson.member("serialize::JsonObjectWriter"),
            "parse_http_response" to smithyHttp.member("response::ParseHttpResponse"),
            "sdk_body" to RuntimeType.sdkBody(runtimeConfig = runtimeConfig),
            "ParseHttpResponse" to smithyHttp.member("response::ParseHttpResponse"),
            "SdkBody" to RuntimeType.sdkBody(runtimeConfig = runtimeConfig),
        )

    override fun render(writer: RustWriter, operationShape: OperationShape) {
    fun render(writer: RustWriter, operationShape: OperationShape) {
        renderResponseSerializer(writer, operationShape)
        renderErrorSerializer(writer, operationShape)
        renderTests(writer, operationShape)
@@ -120,9 +103,6 @@ class RestJson1HttpSerializerGenerator(
        }
        val serializerSymbol = jsonSerializerGenerator.serverOutputSerializer(operationShape)
        if (serializerSymbol == null) {
            logger.warning(
                "[rust-server-codegen] $outputShape: response output serialization does not contain any member"
            )
            return
        }
        val outputSymbol = symbolProvider.toSymbol(outputShape)
@@ -315,20 +295,14 @@ class RestJson1HttpSerializerGenerator(
            }
            rust(
                """let response = #T(&expected).expect("failed to serialize response");""",
                serde.member(fnName)
                operation.member(fnName)
            )
            rust("assert_eq!(response.status(), ${testCase.code});")

            rustTemplate("let mut response = response.map(#{sdk_body}::from);", *codegenScope)
            rustTemplate(
                """
                    use #{parse_http_response};
                    use #{ParseHttpResponse};
                    let parser = #{op}::new();
                    let actual = parser.parse_unloaded(&mut response);
                    let actual = actual.unwrap_or_else(|| {
                        let response = response.map(|body|#{Bytes}::copy_from_slice(body.bytes().unwrap()));
                        <#{op} as #{parse_http_response}<#{sdk_body}>>::parse_loaded(&parser, &response)
                    });
                    let actual = parser.parse_loaded(&response);
                """.trimIndent(),
                *codegenScope,
                "op" to symbolProvider.toSymbol(operationShape),
@@ -354,30 +328,16 @@ class RestJson1HttpSerializerGenerator(
class RestJson1HttpDeserializerGenerator(
    private val codegenContext: CodegenContext,
    private val httpBindingResolver: HttpTraitHttpBindingResolver,
) : ServerGenerator(codegenContext, httpBindingResolver) {
    override fun render(writer: RustWriter, operationShape: OperationShape) {
        RestJson1HttpRequestDeserializerGenerator(
            codegenContext,
            httpBindingResolver,
            operationShape
        )
            .render(writer, operationShape)
    }
}

class RestJson1HttpRequestDeserializerGenerator(
    codegenContext: CodegenContext,
    private val httpBindingResolver: HttpTraitHttpBindingResolver,
    private val operationShape: OperationShape,
) : ServerGenerator(codegenContext, httpBindingResolver) {
    private val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request"
    private val httpBindingGenerator =
        ResponseBindingGenerator(
            RestJson(codegenContext),
            codegenContext,
            operationShape,
        )
    private val httpTrait = httpBindingResolver.httpTrait(operationShape)
) {
    private val logger = Logger.getLogger(javaClass.name)
    private val error = RuntimeType("error", null, "crate")
    private val operation = RuntimeType("operation", null, "crate")
    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
    private val index = HttpBindingIndex.of(model)
    private val runtimeConfig = codegenContext.runtimeConfig
    private val instantiator = with(codegenContext) { Instantiator(symbolProvider, model, runtimeConfig) }
    private val jsonParserGenerator = JsonParserGenerator(codegenContext, httpBindingResolver)
    private val codegenScope =
        arrayOf(
            "Bytes" to RuntimeType.Bytes,
@@ -394,13 +354,14 @@ class RestJson1HttpRequestDeserializerGenerator(
        )
    private val operationDeserModule = RustModule.public("operation_deser")

    override fun render(writer: RustWriter, operationShape: OperationShape) {
        renderRequestDeserializer(writer)
        renderRequestDeserializerTests(writer)
    fun render(writer: RustWriter, operationShape: OperationShape) {
        renderRequestDeserializer(writer, operationShape)
        renderRequestDeserializerTests(writer, operationShape)
    }

    private fun renderRequestDeserializer(writer: RustWriter) {
    private fun renderRequestDeserializer(writer: RustWriter, operationShape: OperationShape) {
        val inputShape = operationShape.inputShape(model)
        val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request"
        if (inputShape.hasStreamingMember(model)) {
            logger.warning(
                "[rust-server-codegen] $operationShape: request deserialization does not currently support streaming shapes"
@@ -409,9 +370,6 @@ class RestJson1HttpRequestDeserializerGenerator(
        }
        val deserializerSymbol = jsonParserGenerator.serverInputParser(operationShape)
        if (deserializerSymbol == null) {
            logger.warning(
                "[rust-server-codegen] $inputShape: response output serialization does not contain any member"
            )
            return
        }
        val inputSymbol = symbolProvider.toSymbol(inputShape)
@@ -427,20 +385,20 @@ class RestJson1HttpRequestDeserializerGenerator(
                deserializerSymbol,
            )
            httpBindingResolver.requestBindings(operationShape).forEach { binding ->
                httpBindingDeserializer(binding)?.let { deserializer ->
                httpBindingDeserializer(binding, operationShape)?.let { deserializer ->
                    withBlock("input = input.${binding.member.setterName()}(", ");") {
                        deserializer(this)
                    }
                }
            }
            renderPathDeserializer(writer)
            renderPathDeserializer(writer, operationShape)
            rustTemplate("input.build().map_err(#{JsonSerdeError}::from)", *codegenScope)
        }
    }

    private fun httpBindingDeserializer(binding: HttpBindingDescriptor): Writable? {
    private fun httpBindingDeserializer(binding: HttpBindingDescriptor, operationShape: OperationShape): Writable? {
        return when (val location = binding.location) {
            HttpLocation.HEADER -> writable { renderHeaderDeserializer(this, binding) }
            HttpLocation.HEADER -> writable { renderHeaderDeserializer(this, binding, operationShape) }
            HttpLocation.LABEL -> {
                null
            }
@@ -456,7 +414,7 @@ class RestJson1HttpRequestDeserializerGenerator(
        }
    }

    private fun renderPathDeserializer(writer: RustWriter) {
    private fun renderPathDeserializer(writer: RustWriter, operationShape: OperationShape) {
        val pathBindings =
            httpBindingResolver.requestBindings(operationShape).filter {
                it.location == HttpLocation.LABEL
@@ -465,6 +423,7 @@ class RestJson1HttpRequestDeserializerGenerator(
            return
        }
        val pattern = StringBuilder()
        val httpTrait = httpBindingResolver.httpTrait(operationShape)
        httpTrait.uri.segments.forEach {
            pattern.append("/")
            if (it.isLabel) {
@@ -508,7 +467,13 @@ class RestJson1HttpRequestDeserializerGenerator(
        }
    }

    private fun renderHeaderDeserializer(writer: RustWriter, binding: HttpBindingDescriptor) {
    private fun renderHeaderDeserializer(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) {
        val httpBindingGenerator =
            ResponseBindingGenerator(
                RestJson(codegenContext),
                codegenContext,
                operationShape,
            )
        val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding)
        writer.rust(
            """
@@ -557,7 +522,7 @@ class RestJson1HttpRequestDeserializerGenerator(
            index.determineTimestampFormat(
                binding.member,
                binding.location,
                defaultTimestampFormat,
                TimestampFormatTrait.Format.EPOCH_SECONDS
            )
        val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
@@ -609,7 +574,7 @@ class RestJson1HttpRequestDeserializerGenerator(
        return "deser_label_${containerName}_$memberName"
    }

    private fun renderRequestDeserializerTests(writer: RustWriter) {
    private fun renderRequestDeserializerTests(writer: RustWriter, operationShape: OperationShape) {
        val testCases =
            operationShape.getTrait<HttpRequestTestsTrait>()?.getTestCasesFor(AppliesTo.SERVER)
                ?: return
@@ -627,14 +592,15 @@ class RestJson1HttpRequestDeserializerGenerator(
            )
        writer.write("")
        writer.withModule(testModuleName, moduleMeta) {
            testCases.forEach { renderRequestDeserializerTestCase(it) }
            testCases.forEach { renderRequestDeserializerTestCase(it, operationShape) }
        }
    }

    private fun RustWriter.renderRequestDeserializerTestCase(testCase: HttpRequestTestCase) {
    private fun RustWriter.renderRequestDeserializerTestCase(testCase: HttpRequestTestCase, operationShape: OperationShape) {
        Attribute.Custom("test").render(this)
        rustBlock("fn ${testCase.id.toSnakeCase()}()") {
            val inputShape = operationShape.inputShape(model)
            val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request"
            val customToken =
                inputShape.findMemberWithTrait<IdempotencyTokenTrait>(model)?.let {
                    """.make_token("00000000-0000-4000-8000-000000000000")"""
@@ -652,6 +618,8 @@ class RestJson1HttpRequestDeserializerGenerator(
                "let request = request.map(|body| #{Bytes}::from(body.bytes().unwrap().to_vec()));",
                *codegenScope
            )
            rust("assert_eq!(request.method().as_str(), ${testCase.method.dq()});")
            rust("assert_eq!(request.uri().to_owned(), ${testCase.uri.dq()});")
            rust(
                """let actual = #T(&request).expect("failed to parse request");""",
                operation.member(deserFnName)