Unverified Commit 56ee9c52 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

ProtocolConfig Refactoring (#25)

Mostly mechanical refactoring to remove operation specific fields from ProtocolConfig—this allows protocol config to be used for an entire service rather than being recreated for each operation.
parent 31cfbf57
Loading
Loading
Loading
Loading
+19 −4
Original line number Diff line number Diff line
@@ -21,10 +21,14 @@ import software.amazon.smithy.rust.codegen.lang.RustDependency
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.generators.CargoTomlGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand
@@ -40,17 +44,28 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
    private val symbolProvider: SymbolProvider
    private val writers: CodegenWriterDelegator<RustWriter>
    private val fileManifest = context.fileManifest
    val model: Model
    private val model: Model
    private val protocolConfig: ProtocolConfig
    private val protocolGenerator: ProtocolGeneratorFactory<HttpProtocolGenerator>
    private val httpGenerator: HttpProtocolGenerator
    init {
        val bootstrapProvider = SymbolVisitor(context.model, config = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        model = OperationNormalizer(bootstrapProvider).addOperationInputs(context.model)
        val service = settings.getService(context.model)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        protocolGenerator = generator

        val baseVisitor = SymbolVisitor(context.model, config = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        val normalizer = OperationNormalizer(baseVisitor)

        model = normalizer.addOperationInputs(context.model)
        symbolProvider = SymbolVisitor(model, config = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        protocolConfig = ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol)
        writers = CodegenWriterDelegator(
            context.fileManifest,
            // TODO: load symbol visitor from integrations; 2d
            symbolProvider,
            RustWriter.Factory
        )
        httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig)
    }

    fun execute() {
@@ -105,6 +120,6 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
    }

    override fun serviceShape(shape: ServiceShape) {
        ServiceGenerator(model, symbolProvider, settings.runtimeConfig, shape, writers).render()
        ServiceGenerator(writers, httpGenerator, protocolConfig).render()
    }
}
+11 −10
Original line number Diff line number Diff line
@@ -16,29 +16,30 @@ import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType

/**
 * Configuration needed to generate the client for a given Service<->Protocol pair
 */
data class ProtocolConfig(
    val model: Model,
    val symbolProvider: SymbolProvider,
    val runtimeConfig: RuntimeConfig,
    val writer: RustWriter,
    val serviceShape: ServiceShape,
    val operationShape: OperationShape,
    val inputShape: StructureShape,
    val protocol: ShapeId
)

interface ProtocolGeneratorFactory<out T : HttpProtocolGenerator> {
    fun build(protocolConfig: ProtocolConfig): T
    fun buildProtocolGenerator(protocolConfig: ProtocolConfig): T
}

abstract class HttpProtocolGenerator(
    private val symbolProvider: SymbolProvider,
    private val writer: RustWriter,
    private val inputShape: StructureShape
    protocolConfig: ProtocolConfig
) {
    fun render() {
    private val symbolProvider = protocolConfig.symbolProvider
    private val model = protocolConfig.model
    fun renderOperation(writer: RustWriter, operationShape: OperationShape) {
        val inputShape = model.expectShape(operationShape.input.get(), StructureShape::class.java)
        writer.rustBlock("impl ${symbolProvider.toSymbol(inputShape).name}") {
            toHttpRequestImpl(this)
            toHttpRequestImpl(this, operationShape, inputShape)
        }
    }

@@ -56,5 +57,5 @@ abstract class HttpProtocolGenerator(
     *
     * Your implementation MUST call `httpBuilderFun` to create the public method.
     */
    abstract fun toHttpRequestImpl(implBlockWriter: RustWriter)
    abstract fun toHttpRequestImpl(implBlockWriter: RustWriter, operationShape: OperationShape, inputShape: StructureShape)
}
+7 −6
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait
import software.amazon.smithy.rust.codegen.lang.RustWriter
@@ -7,18 +8,18 @@ import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.inputShape

/**
 * Generate protocol tests for an operation
 */
class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig, private val operationShape: OperationShape, private val writer: RustWriter) {
    private val inputShape = operationShape.inputShape(protocolConfig.model)
    fun render() {
        with(protocolConfig) {
        operationShape.getTrait(HttpRequestTestsTrait::class.java).map {
            renderHttpRequestTests(it)
        }
    }
    }

    private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) {
        with(protocolConfig) {
@@ -44,7 +45,7 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
        testModuleWriter.write("#[test]")
        testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") {
            writeInline("let input =")
            instantiator.render(httpRequestTestCase.params, protocolConfig.inputShape, this)
            instantiator.render(httpRequestTestCase.params, inputShape, this)
            write(";")
            write("let http_request = input.build_http_request().body(()).unwrap();")
            checkQueryParams(this, httpRequestTestCase.queryParams)
+0 −14
Original line number Diff line number Diff line
@@ -93,20 +93,6 @@ class HttpTraitBindingGenerator(
        }
    }

    /**
     * Default implementation of HttpTraitBindings. A `build_http_request()` method is added that
     * simply calls `update_http_builder()`
     */
    inner class Default : HttpProtocolGenerator(symbolProvider, writer, inputShape) {
        override fun toHttpRequestImpl(implBlockWriter: RustWriter) {
            renderUpdateHttpBuilder(implBlockWriter)
            httpBuilderFun(implBlockWriter) {
                write("let builder = \$T::new();", RuntimeType.HttpRequestBuilder)
                write("self.update_http_builder(builder)")
            }
        }
    }

    /** Header Generation **/

    /**
+7 −40
Original line number Diff line number Diff line
@@ -5,57 +5,24 @@

package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.codegen.core.writer.CodegenWriterDelegator
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJson10Factory
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsRestJsonFactory

class ServiceGenerator(
    private val model: Model,
    private val symbolProvider: SymbolProvider,
    private val runtimeConfig: RuntimeConfig,
    private val serviceShape: ServiceShape,
    private val writers: CodegenWriterDelegator<RustWriter>
    private val writers: CodegenWriterDelegator<RustWriter>,
    private val protocolGenerator: HttpProtocolGenerator,
    private val config: ProtocolConfig
) {
    // TODO: refactor to be runtime pluggable; 2d
    private val index = TopDownIndex(model)
    private val supportedProtocols = mapOf(
        AwsJson1_0Trait.ID to AwsJson10Factory(),
        RestJson1Trait.ID to AwsRestJsonFactory()

    )
    private val protocols: MutableMap<ShapeId, Trait> = ServiceIndex(model).getProtocols(serviceShape)
    private val matchingProtocols = protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } }

    init {
        if (matchingProtocols.isEmpty()) {
            throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}")
        }
    }
    private val index = TopDownIndex(config.model)

    fun render() {
        val operations = index.getContainedOperations(serviceShape)
        val (protocol, generator) = matchingProtocols.first()
        val operations = index.getContainedOperations(config.serviceShape)
        // TODO: refactor so that we don't need to re-instantiate the protocol for every operation
        operations.map { operation ->
            writers.useShapeWriter(operation) { writer ->
                // transform ensures that all models have input shapes
                val input = operation.input.get().let { model.expectShape(it, StructureShape::class.java) }
                val config = ProtocolConfig(model, symbolProvider, runtimeConfig, writer, serviceShape, operation, input, protocol)
                generator.build(config).render()
                HttpProtocolTestGenerator(config).render()
                protocolGenerator.renderOperation(writer, operation)
                HttpProtocolTestGenerator(config, operation, writer).render()
            }
        }
    }
Loading