Unverified Commit b4f57b7c authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Introduce HttpBindingResolver to abstract HttpBindingIndex (#448)

* Introduce HttpBindingResolver to abstract HttpBindingIndex

* Generalize HttpBindingResolver for protocols that don't have HTTP traits

* Sort members in HttpTraitHttpBindingResolver
parent 1be3fa3d
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -285,7 +285,7 @@ class RustWriter private constructor(
     * - If the field is an unboxed primitive, it will only be called if the field is non-zero
     *
     */
    fun ifSet(shape: Shape, member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) {
    fun ifSet(shape: Shape, member: Symbol, outerField: String, block: RustWriter.(field: String) -> Unit) {
        // TODO: this API should be refactored so that we don't need to strip `&` to get reference comparisons to work.
        when {
            member.isOptional() -> {
@@ -307,7 +307,7 @@ class RustWriter private constructor(
    fun listForEach(
        target: Shape,
        outerField: String,
        block: CodeWriter.(field: String, target: ShapeId) -> Unit
        block: RustWriter.(field: String, target: ShapeId) -> Unit
    ) {
        if (target is CollectionShape) {
            val derefName = safeName("inner")
+27 −15
Original line number Diff line number Diff line
@@ -26,8 +26,10 @@ import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError
import software.amazon.smithy.rust.codegen.smithy.generators.redactIfNecessary
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait
@@ -59,16 +61,30 @@ class RequestBindingGenerator(
    val model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val runtimeConfig: RuntimeConfig,
    private val writer: RustWriter,
    private val defaultTimestampFormat: TimestampFormatTrait.Format,
    private val shape: OperationShape,
    private val inputShape: StructureShape,
    private val httpTrait: HttpTrait
    private val httpTrait: HttpTrait,
) {
    // TODO: make defaultTimestampFormat configurable
    private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS
    private val index = HttpBindingIndex.of(model)
    private val buildError = runtimeConfig.operationBuildError()

    constructor(
        protocolConfig: ProtocolConfig,
        defaultTimestampFormat: TimestampFormatTrait.Format,
        httpBindingResolver: HttpBindingResolver,
        shape: OperationShape,
        inputShape: StructureShape,
    ) : this(
        protocolConfig.model,
        protocolConfig.symbolProvider,
        protocolConfig.runtimeConfig,
        defaultTimestampFormat,
        shape,
        inputShape,
        httpBindingResolver.httpTrait(shape),
    )

    /**
     * Generates `update_http_builder` and all necessary dependency functions into the impl block provided by
     * [implBlockWriter]. The specific behavior is configured by [httpTrait].
@@ -145,7 +161,7 @@ class RequestBindingGenerator(
                        #{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header name: {}", k, err)}
                    })?;
                    use std::convert::TryFrom;
                    let header_value = ${headerFmtFun(target, memberShape, "v")};
                    let header_value = ${headerFmtFun(this, target, memberShape, "v")};
                    let header_value = http::header::HeaderValue::try_from(header_value).map_err(|err| {
                        #{build_error}::InvalidField {
                            field: ${memberName.dq()},
@@ -174,7 +190,7 @@ class RequestBindingGenerator(
        ifSet(memberType, memberSymbol, "&self.$memberName") { field ->
            listForEach(memberType, field) { innerField, targetId ->
                val innerMemberType = model.expectShape(targetId)
                val formatted = headerFmtFun(innerMemberType, memberShape, innerField)
                val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField)
                val safeName = safeName("formatted")
                write("let $safeName = $formatted;")
                rustBlock("if !$safeName.is_empty()") {
@@ -203,7 +219,7 @@ class RequestBindingGenerator(
    /**
     * Format [member] in the when used as an HTTP header
     */
    private fun headerFmtFun(target: Shape, member: MemberShape, targetName: String): String {
    private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String {
        return when {
            target.isStringShape -> {
                if (target.hasTrait<MediaTypeTrait>()) {
@@ -238,7 +254,7 @@ class RequestBindingGenerator(
        val formatString = httpTrait.uriFormatString()
        val args = httpTrait.uri.labels.map { label ->
            val member = inputShape.expectMember(label.content)
            "${label.content} = ${labelFmtFun(model.expectShape(member.target), member, label)}"
            "${label.content} = ${labelFmtFun(writer, model.expectShape(member.target), member, label)}"
        }
        val combinedArgs = listOf(formatString, *args.toTypedArray())
        writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null)
@@ -316,11 +332,7 @@ class RequestBindingGenerator(
                        val target = model.expectShape(targetId)
                        rust(
                            "query.push_kv(${param.locationName.dq()}, &${
                            paramFmtFun(
                                target,
                                memberShape,
                                innerField
                            )
                            paramFmtFun(writer, target, memberShape, innerField)
                            });"
                        )
                    }
@@ -333,7 +345,7 @@ class RequestBindingGenerator(
    /**
     * Format [member] when used as a queryParam
     */
    private fun paramFmtFun(target: Shape, member: MemberShape, targetName: String): String {
    private fun paramFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String {
        return when {
            target.isStringShape -> {
                val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string"))
@@ -359,7 +371,7 @@ class RequestBindingGenerator(
    /**
     * Format [member] when used as an HTTP Label (`/bucket/{key}`)
     */
    private fun labelFmtFun(target: Shape, member: MemberShape, label: SmithyPattern.Segment): String {
    private fun labelFmtFun(writer: RustWriter, target: Shape, member: MemberShape, label: SmithyPattern.Segment): String {
        val memberName = symbolProvider.toMemberName(member)
        return when {
            target.isStringShape -> {
+10 −7
Original line number Diff line number Diff line
@@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.smithy.Default
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.defaultValue
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasTrait
@@ -62,8 +64,8 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
     * }
     * ```
     */
    fun generateDeserializeHeaderFn(binding: HttpBinding): RuntimeType {
        check(binding.location == HttpBinding.Location.HEADER)
    fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType {
        check(binding.location == HttpLocation.HEADER)
        val outputT = symbolProvider.toSymbol(binding.member)
        val fnName = "deser_header_${fnName(operationShape, binding)}"
        return RuntimeType.forInlineFun(fnName, "http_serde") { writer ->
@@ -79,7 +81,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
        }
    }

    fun generateDeserializePrefixHeaderFn(binding: HttpBinding): RuntimeType {
    fun generateDeserializePrefixHeaderFn(binding: HttpBindingDescriptor): RuntimeType {
        check(binding.location == HttpBinding.Location.PREFIX_HEADERS)
        val outputT = symbolProvider.toSymbol(binding.member)
        check(outputT.rustType().stripOuter<RustType.Option>() is RustType.HashMap) { outputT.rustType() }
@@ -122,7 +124,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
     * Generate a function to deserialize `[binding]` from the response payload
     */
    fun generateDeserializePayloadFn(
        binding: HttpBinding,
        binding: HttpBindingDescriptor,
        errorT: RuntimeType,
        // Deserialize a single structure or union member marked as a payload
        structuredHandler: RustWriter.(String) -> Unit,
@@ -156,7 +158,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
    }

    private fun RustWriter.deserializeStreamingBody(
        binding: HttpBinding,
        binding: HttpBindingDescriptor,
    ) {
        val member = binding.member
        val targetShape = model.expectShape(member.target)
@@ -172,7 +174,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
    }

    private fun RustWriter.deserializePayloadBody(
        binding: HttpBinding,
        binding: HttpBindingDescriptor,
        errorSymbol: RuntimeType,
        structuredHandler: RustWriter.(String) -> Unit,
        docShapeHandler: RustWriter.(String) -> Unit
@@ -312,5 +314,6 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
     * Generate a unique name for the deserializer function for a given operationShape -> member pair
     */
    // rename here technically not required, operations and members cannot be renamed
    private fun fnName(operationShape: OperationShape, binding: HttpBinding) = "${operationShape.id.getName(service).toSnakeCase()}_${binding.memberName.toSnakeCase()}"
    private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) =
        "${operationShape.id.getName(service).toSnakeCase()}_${binding.memberName.toSnakeCase()}"
}
+37 −1
Original line number Diff line number Diff line
@@ -8,10 +8,13 @@ package software.amazon.smithy.rust.codegen.smithy.protocols
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustType
@@ -46,6 +49,7 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStream
import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape

sealed class AwsJsonVersion {
@@ -148,6 +152,38 @@ class SyntheticBodySymbolProvider(private val model: Model, private val base: Ru
    }
}

class AwsJsonHttpBindingResolver(
    private val model: Model,
    private val awsJsonVersion: AwsJsonVersion,
) : HttpBindingResolver {
    private val httpTrait = HttpTrait.builder()
        .code(200)
        .method("POST")
        .uri(UriPattern.parse("/"))
        .build()

    private fun bindings(shape: ToShapeId?) =
        shape?.let { model.expectShape(it.toShapeId()) }?.members()
            ?.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") }
            ?.toList()
            ?: emptyList()

    override fun httpTrait(operationShape: OperationShape): HttpTrait = httpTrait

    override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
        bindings(operationShape.input.orNull())

    override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
        bindings(operationShape.output.orNull())

    override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> =
        bindings(errorShape)

    override fun requestContentType(operationShape: OperationShape): String =
        "application/x-amz-json-${awsJsonVersion.value}"
}

// TODO: Refactor to use HttpBoundProtocolGenerator
class BasicAwsJsonGenerator(
    private val protocolConfig: ProtocolConfig,
    private val awsJsonVersion: AwsJsonVersion
@@ -198,7 +234,7 @@ class BasicAwsJsonGenerator(
    }

    override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata {
        val generator = JsonSerializerGenerator(protocolConfig)
        val generator = JsonSerializerGenerator(protocolConfig, AwsJsonHttpBindingResolver(model, awsJsonVersion))
        val serializer = generator.operationSerializer(operationShape)
        serializer?.also { sym ->
            rustTemplate(
+86 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.AwsQueryParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.AwsQuerySerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations

class AwsQueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
    override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator =
        HttpBoundProtocolGenerator(protocolConfig, AwsQueryProtocol(protocolConfig))

    override fun transformModel(model: Model): Model {
        return OperationNormalizer(model).transformModel(
            inputBodyFactory = OperationNormalizer.NoBody,
            outputBodyFactory = OperationNormalizer.NoBody
        ).let(RemoveEventStreamOperations::transform)
    }

    override fun support(): ProtocolSupport {
        return ProtocolSupport(
            requestSerialization = true,
            requestBodySerialization = true,
            responseDeserialization = true,
            errorDeserialization = true,
        )
    }
}

class AwsQueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol {
    private val runtimeConfig = protocolConfig.runtimeConfig
    private val awsQueryErrors: RuntimeType = RuntimeType.wrappedXmlErrors(runtimeConfig)
    override val httpBindingResolver: HttpBindingResolver = StaticHttpBindingResolver(
        protocolConfig.model,
        HttpTrait.builder()
            .code(200)
            .method("POST")
            .uri(UriPattern.parse("/"))
            .build(),
        "application/x-www-form-urlencoded"
    )

    override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME

    override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
        AwsQueryParserGenerator(protocolConfig, awsQueryErrors)

    override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
        AwsQuerySerializerGenerator(protocolConfig)

    override fun parseGenericError(operationShape: OperationShape): RuntimeType {
        /**
         fn parse_generic(response: &Response<Bytes>) -> Result<smithy_types::error::Generic, T: Error>
         **/
        return RuntimeType.forInlineFun("parse_generic_error", "xml_deser") {
            it.rustBlockTemplate(
                "pub fn parse_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlError}>",
                "Response" to RuntimeType.http.member("Response"),
                "Bytes" to RuntimeType.Bytes,
                "Error" to RuntimeType.GenericError(runtimeConfig),
                "XmlError" to CargoDependency.smithyXml(runtimeConfig).asType().member("decode::XmlError")
            ) {
                rust("#T::parse_generic_error(response.body().as_ref())", awsQueryErrors)
            }
        }
    }
}
Loading