Unverified Commit 1d30ad06 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add initial protocol test generation (#16)

* Implement Instantiator

* Add support for dev-dependencies

* Pass the protocol in Protocol config

* Add initial protocol test generator

This commit adds an initial and very limited implementation of protocol test generation. Specifically,
it only asserts that the query string contains certain required parameters. We will update protocol test generation in concert with updates to the code generation features.

* Fixes from integration tests

* Add test of recursive maps, fix implementation of empty maps

* Add protocol test generator, pre-commit hooks

* Fix test

* Fix formatting
parent 44720f4d
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -93,6 +93,14 @@ tasks.register<Exec>("cargoCheck") {
    dependsOn("build")
}

tasks.register<Exec>("cargoTest") {
    workingDir("build/smithyprojections/codegen-test/")
    // disallow warnings
    environment("RUSTFLAGS", "-D warnings")
    commandLine("cargo", "test")
    dependsOn("build")
}

tasks.register<Exec>("cargoClippy") {
    workingDir("build/smithyprojections/codegen-test/")
    // disallow warnings
@@ -101,7 +109,7 @@ tasks.register<Exec>("cargoClippy") {
    dependsOn("build")
}

tasks["test"].finalizedBy("cargoCheck", "cargoClippy")
tasks["test"].finalizedBy("cargoCheck", "cargoClippy", "cargoTest")

tasks["clean"].doFirst {
    delete("smithy-build.json")
+10 −1
Original line number Diff line number Diff line
@@ -9,13 +9,18 @@ import software.amazon.smithy.codegen.core.SymbolDependency
import software.amazon.smithy.codegen.core.SymbolDependencyContainer
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig

sealed class DependencyScope
object Dev : DependencyScope()
object Compile : DependencyScope()

sealed class DependencyLocation
data class CratesIo(val version: String) : DependencyLocation()
data class Local(val path: String? = null) : DependencyLocation()

data class RustDependency(
    val name: String,
    val location: DependencyLocation
    val location: DependencyLocation,
    val scope: DependencyScope = Compile
) : SymbolDependencyContainer {
    override fun getDependencies(): List<SymbolDependency> {
        return listOf(
@@ -44,6 +49,10 @@ data class RustDependency(
            "${runtimeConfig.cratePrefix}-http", Local(runtimeConfig.relativePath)
        )

        fun ProtocolTestHelpers(runtimeConfig: RuntimeConfig) = RustDependency(
            "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = Dev
        )

        private val PropKey = "rustdep"

        fun fromSymbolDependency(symbolDependency: SymbolDependency) =
+6 −1
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand

private val PublicModules = listOf("error", "operation", "model")
@@ -71,7 +72,11 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
            LibRsGenerator(modules, it).render()
        }
        writers.flushWriters()
        try {
            "cargo fmt".runCommand(fileManifest.baseDir)
        } catch (_: CommandFailed) {
            logger.warning("Generated output did not parse")
        }
    }

    override fun getDefault(shape: Shape?) {
+10 −0
Original line number Diff line number Diff line
@@ -40,12 +40,14 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na

    // TODO: refactor to be RuntimeTypeProvider a la Symbol provider that packages the `RuntimeConfig` state.
    companion object {

        // val Blob = RuntimeType("Blob", RustDependency.IO_CORE, "blob")
        val From = RuntimeType("From", dependency = null, namespace = "std::convert")
        val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert")
        fun StdFmt(member: String) = RuntimeType("fmt::$member", dependency = null, namespace = "std")
        val StdError = RuntimeType("Error", dependency = null, namespace = "std::error")
        val HashSet = RuntimeType("HashSet", dependency = null, namespace = "std::collections")
        val HashMap = RuntimeType("HashMap", dependency = null, namespace = "std::collections")

        fun Instant(runtimeConfig: RuntimeConfig) =
            RuntimeType("Instant", RustDependency.SmithyTypes(runtimeConfig), "${runtimeConfig.cratePrefix}_types")
@@ -62,6 +64,9 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        fun Base64Encode(runtimeConfig: RuntimeConfig): RuntimeType =
            RuntimeType("encode", RustDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64")

        fun Base64Decode(runtimeConfig: RuntimeConfig): RuntimeType =
            RuntimeType("decode", RustDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64")

        fun TimestampFormat(runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format): RuntimeType {
            val timestampFormat = when (format) {
                TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds"
@@ -76,6 +81,11 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
            )
        }

        fun ProtocolTestHelper(runtimeConfig: RuntimeConfig, func: String): RuntimeType =
            RuntimeType(
                func, RustDependency.ProtocolTestHelpers(runtimeConfig), "protocol_test_helpers"
            )

        fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = RustDependency.Http, namespace = "http")
        val HttpRequestBuilder = Http("request::Builder")
    }
+13 −2
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@

package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.rust.codegen.lang.Compile
import software.amazon.smithy.rust.codegen.lang.Dev
import software.amazon.smithy.rust.codegen.lang.RustDependency
import software.amazon.smithy.rust.codegen.smithy.RustSettings
import software.amazon.smithy.utils.CodeWriter
@@ -20,9 +22,18 @@ class CargoTomlGenerator(private val settings: RustSettings, private val writer:

        writer.insertTrailingNewline()

        if (dependencies.isNotEmpty()) {
        val compileDependencies = dependencies.filter { it.scope == Compile }
        val devDependencies = dependencies.filter { it.scope == Dev }
        if (compileDependencies.isNotEmpty()) {
            writer.write("[dependencies]")
            dependencies.forEach {
            compileDependencies.forEach {
                writer.write(it.toString())
            }
        }

        if (devDependencies.isNotEmpty()) {
            writer.write("[dev-dependencies]")
            devDependencies.forEach {
                writer.write(it.toString())
            }
        }
Loading