Unverified Commit 9c824df3 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Make `ServerProtocolLoader` delegate to `ProtocolLoader` (#1524)

While the set of supported protocols and their implementation is
different among servers and clients, the logic to load a specific
protocol given a Smithy model is identical. This commit makes
`ServerProtocolLoader` inherit from `ProtocolLoader` so that said logic
can be reused among clients and servers.
parent 7f4dad62
Loading
Loading
Loading
Loading
+3 −25
Original line number Diff line number Diff line
@@ -9,35 +9,13 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap

/*
 * Protocol dispatcher, responsible for protocol selection.
 */
class ServerProtocolLoader(private val supportedProtocols: ProtocolMap<ServerCodegenContext>) {
    fun protocolFor(
        model: Model,
        serviceShape: ServiceShape,
    ): Pair<ShapeId, ProtocolGeneratorFactory<ProtocolGenerator, ServerCodegenContext>> {
        val protocols: MutableMap<ShapeId, Trait> = ServiceIndex.of(model).getProtocols(serviceShape)
        val matchingProtocols =
            protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } }
        if (matchingProtocols.isEmpty()) {
            throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}")
        }
        val pair = matchingProtocols.first()
        return Pair(pair.first, pair.second)
    }
class ServerProtocolLoader(supportedProtocols: ProtocolMap<ServerCodegenContext>) :
    ProtocolLoader<ServerCodegenContext>(supportedProtocols) {

    companion object {
        val DefaultProtocols = mapOf(
+1 −1
Original line number Diff line number Diff line
@@ -100,7 +100,7 @@ interface ProtocolGeneratorFactory<out T : ProtocolGenerator, C : CoreCodegenCon
    fun support(): ProtocolSupport
}

class ProtocolLoader<C : CoreCodegenContext>(private val supportedProtocols: ProtocolMap<C>) {
open class ProtocolLoader<C : CoreCodegenContext>(private val supportedProtocols: ProtocolMap<C>) {
    fun protocolFor(
        model: Model,
        serviceShape: ServiceShape,