Unverified Commit 33348f2a authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

ebs + unsignedPayloadTrait support (#567)

* ebs + unsignedPayloadTrait support

1. Add support for the "unsignedPayoad" Smithy trait. When an operation is targetted with this trait, rather than signing the body of the operation, we will sign the literal string `Unsigned-Payload`.

This is the same behavior that will occur if the body is streaming and not directly signable.

2. Add an EBS example which utilizes this new signing behavior.

* Back out debugging change

* remove unused mut

* remove unused import

* fix clippy error

* Add protocol test for EBS

* Fix gradle build
parent 8e9b91d6
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ impl SigV4SigningStage {
    }
}

use aws_sigv4_poc::SignableBody;
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use thiserror::Error;
@@ -76,12 +77,14 @@ fn signing_config(
    let signing_service = config
        .get::<SigningService>()
        .ok_or(SigningStageError::MissingSigningService)?;
    let payload_override = config.get::<SignableBody<'static>>();
    let request_config = RequestConfig {
        request_ts: config
            .get::<SystemTime>()
            .copied()
            .unwrap_or_else(SystemTime::now),
        region,
        payload_override,
        service: signing_service,
    };
    Ok((operation_config, request_config, creds))
+16 −9
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@
 */

use aws_auth::Credentials;
use aws_sigv4_poc::{PayloadChecksumKind, SignableBody, SigningSettings, UriEncoding};
use aws_sigv4_poc::{PayloadChecksumKind, SigningSettings, UriEncoding};
use aws_types::region::SigningRegion;
use aws_types::SigningService;
use http::header::HeaderName;
@@ -13,6 +13,8 @@ use std::error::Error;
use std::fmt;
use std::time::SystemTime;

pub use aws_sigv4_poc::SignableBody;

#[derive(Eq, PartialEq, Clone, Copy)]
pub enum SigningAlgorithm {
    SigV4,
@@ -77,6 +79,7 @@ pub struct RequestConfig<'a> {
    pub request_ts: SystemTime,
    pub region: &'a SigningRegion,
    pub service: &'a SigningService,
    pub payload_override: Option<&'a SignableBody<'static>>,
}

#[derive(Clone, Default)]
@@ -135,14 +138,18 @@ impl SigV4Signer {

        // A body that is already in memory can be signed directly. A  body that is not in memory
        // (any sort of streaming body) will be signed via UNSIGNED-PAYLOAD.
        // The final enhancement that will come a bit later is writing a `SignableBody::Precomputed`
        // into the property bag when we have a sha 256 middleware that can compute a streaming checksum
        // for replayable streams but currently even replayable streams will result in `UNSIGNED-PAYLOAD`
        let signable_body = request
        let signable_body = request_config
            .payload_override
            // the payload_override is a cheap clone because it contains either a
            // reference or a short checksum (we're not cloning the entire body)
            .cloned()
            .unwrap_or_else(|| {
                request
                    .body()
                    .bytes()
                    .map(SignableBody::Bytes)
            .unwrap_or(SignableBody::UnsignedPayload);
                    .unwrap_or(SignableBody::UnsignedPayload)
            });
        for (key, value) in aws_sigv4_poc::sign_core(request, signable_body, &sigv4_config)? {
            request
                .headers_mut()
+19 −4
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
package software.amazon.smithy.rustsdk

import software.amazon.smithy.aws.traits.auth.SigV4Trait
import software.amazon.smithy.aws.traits.auth.UnsignedPayloadTrait
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
@@ -54,7 +55,7 @@ class SigV4SigningDecorator : RustCodegenDecorator {
        baseCustomizations: List<OperationCustomization>
    ): List<OperationCustomization> {
        return baseCustomizations.letIf(applies(protocolConfig)) {
            it + SigV4SigningFeature(protocolConfig.runtimeConfig, protocolConfig.serviceShape)
            it + SigV4SigningFeature(operation, protocolConfig.runtimeConfig, protocolConfig.serviceShape)
        }
    }
}
@@ -90,8 +91,15 @@ fun disableDoubleEncode(service: ServiceShape) = when {
    else -> false
}

class SigV4SigningFeature(private val runtimeConfig: RuntimeConfig, private val service: ServiceShape) :
class SigV4SigningFeature(
    private val operation: OperationShape,
    runtimeConfig: RuntimeConfig,
    private val service: ServiceShape
) :
    OperationCustomization() {
    private val codegenScope =
        arrayOf("sig_auth" to runtimeConfig.sigAuth().asType(), "aws_types" to awsTypes(runtimeConfig).asType())

    override fun section(section: OperationSection): Writable {
        return when (section) {
            is OperationSection.MutateRequest -> writable {
@@ -100,7 +108,7 @@ class SigV4SigningFeature(private val runtimeConfig: RuntimeConfig, private val
                ##[allow(unused_mut)]
                let mut signing_config = #{sig_auth}::signer::OperationSigningConfig::default_config();
                """,
                    "sig_auth" to runtimeConfig.sigAuth().asType()
                    *codegenScope
                )
                if (needsAmzSha256(service)) {
                    rust("signing_config.signing_options.content_sha256_header = true;")
@@ -108,12 +116,19 @@ class SigV4SigningFeature(private val runtimeConfig: RuntimeConfig, private val
                if (disableDoubleEncode(service)) {
                    rust("signing_config.signing_options.double_uri_encode = false;")
                }
                if (operation.hasTrait<UnsignedPayloadTrait>()) {
                    rust("signing_config.signing_options.content_sha256_header = true;")
                    rustTemplate(
                        "${section.request}.config_mut().insert(#{sig_auth}::signer::SignableBody::UnsignedPayload);",
                        *codegenScope
                    )
                }
                rustTemplate(
                    """
                ${section.request}.config_mut().insert(signing_config);
                ${section.request}.config_mut().insert(#{aws_types}::SigningService::from_static(${section.config}.signing_service()));
                """,
                    "aws_types" to awsTypes(runtimeConfig).asType()
                    *codegenScope
                )
            }
            else -> emptySection
+51 −0
Original line number Diff line number Diff line
$version: "1.0"
namespace com.amazonaws.ebs

use smithy.test#httpResponseTests

apply ValidationException @httpResponseTests([
    {
        id: "lowercase message",
        documentation: "This test case validates case insensitive parsing of `message`",
        params: {
            Message: "1 validation error detected"
        },
        bodyMediaType: "application/json",
        body: """
        {
          "message": "1 validation error detected"
        }
        """,
        protocol: "aws.protocols#restJson1",
        code: 400,
        headers:  {
            "x-amzn-requestid": "2af8f013-250a-4f6e-88ae-6dd7f6e12807",
            "x-amzn-errortype": "ValidationException:http://internal.amazon.com/coral/com.amazon.coral.validate/",
            "content-type": "application/json",
            "content-length": "77",
            "date": "Wed, 30 Jun 2021 23:42:27 GMT"
        },
    },

    {
        id: "uppercase message",
        documentation: "This test case validates case insensitive parsing of `message`",
        params: {
            Message: "Invalid volume size: 99999999999",
            Reason: "INVALID_VOLUME_SIZE"
        },
        bodyMediaType: "application/json",
        body: """
        {"Message":"Invalid volume size: 99999999999","Reason":"INVALID_VOLUME_SIZE"}
        """,
        protocol: "aws.protocols#restJson1",
        code: 400,
        headers:  {
            "x-amzn-requestid": "2af8f013-250a-4f6e-88ae-6dd7f6e12807",
            "x-amzn-errortype": "ValidationException:http://internal.amazon.com/coral/com.amazon.zeppelindataservice/",
            "content-type": "application/json",
            "content-length": "77",
            "date": "Wed, 30 Jun 2021 23:42:27 GMT"
        },
    },
])
+32 −15
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.aws.traits.ServiceTrait
import kotlin.streams.toList
import org.jetbrains.kotlin.utils.ifEmpty

extra["displayName"] = "Smithy :: Rust :: AWS-SDK"
extra["moduleName"] = "software.amazon.smithy.rust.awssdk"
@@ -81,6 +82,7 @@ val tier1Services = setOf(
    "sts",
    "cloudwatch",
    "ecr",
    "ebs",
    "config",
    "eks"
)
@@ -97,19 +99,28 @@ data class AwsService(
    fun files(): List<File> = listOf(modelFile) + extraFiles
}

val generateAllServices = project.providers.environmentVariable("GENERATE_ALL_SERVICES").orElse("")
val awsServices: Provider<List<AwsService>> = generateAllServices.map { v ->
    discoverServices(v.toLowerCase() == "true")
val generateAllServices =
    project.providers.environmentVariable("GENERATE_ALL_SERVICES").forUseAtConfigurationTime().orElse("")

val generateOnly: Provider<Set<String>> =
    project.providers.environmentVariable("GENERATE_ONLY")
        .forUseAtConfigurationTime()
        .map { envVar ->
            envVar.split(",").filter { service -> service.trim().isNotBlank() }
        }
        .orElse(listOf())
        .map { it.toSet() }

val generateOnly: Set<String>? = null
val awsServices: Provider<List<AwsService>> = generateAllServices.zip(generateOnly) { v, only ->
    discoverServices(v.toLowerCase() == "true", only)
}

/**
 * Discovers services from the `models` directory
 *
 * Do not invoke this function directly. Use the `awsServices` provider.
 */
fun discoverServices(allServices: Boolean): List<AwsService> {
fun discoverServices(allServices: Boolean, generateOnly: Set<String>): List<AwsService> {
    val models = project.file("aws-models")
    val services = fileTree(models).mapNotNull { file ->
        val model = Model.assembler().addImport(file.absolutePath).assemble().result.get()
@@ -136,14 +147,13 @@ fun discoverServices(allServices: Boolean): List<AwsService> {
            }
            AwsService(service = service.id.toString(), module = sdkId, modelFile = file, extraFiles = extras)
        }
    }.filterNot {
        disableServices.contains(it.module)
    }.filter {
        allServices || (generateOnly != null && generateOnly.contains(it.module)) || (generateOnly == null && tier1Services.contains(
    }.filterNot { disableServices.contains(it.module) }
        .filter {
            allServices || (generateOnly.isNotEmpty() && generateOnly.contains(it.module)) || (generateOnly.isEmpty() && tier1Services.contains(
                it.module
            ))
        }
    if (generateOnly == null) {
    if (generateOnly.isNotEmpty()) {
        val modules = services.map { it.module }.toSet()
        tier1Services.forEach { service ->
            check(modules.contains(service)) { "Service $service was in list of tier 1 services but not generated!" }
@@ -195,6 +205,7 @@ fun generateSmithyBuild(tests: List<AwsService>): String {

task("generateSmithyBuild") {
    description = "generate smithy-build.json"
    dependsOn(awsServices)
    doFirst {
        projectDir.resolve("smithy-build.json").writeText(generateSmithyBuild(awsServices.get()))
    }
@@ -272,9 +283,10 @@ tasks.register<Copy>("relocateRuntime") {
}

fun generateCargoWorkspace(services: List<AwsService>): String {
    val generatedModules = services.map { it.module }.toSet()
    val examples = projectDir.resolve("examples")
        .listFiles { file -> !file.name.startsWith(".") }.orEmpty().toList()
        .filter { generateOnly == null || generateOnly.contains(it.name) }
        .filter { generatedModules.contains(it.name) }
        .map { "examples/${it.name}" }

    val modules = services.map(AwsService::module) + runtimeModules + awsModules + examples.toList()
@@ -291,7 +303,10 @@ task("generateCargoWorkspace") {
        sdkOutputDir.mkdirs()
        sdkOutputDir.resolve("Cargo.toml").writeText(generateCargoWorkspace(awsServices.get()))
    }
    dependsOn(awsServices)
    inputs.dir(projectDir.resolve("examples"))
    outputs.file(sdkOutputDir.resolve("Cargo.toml"))
    outputs.upToDateWhen { false }
}

task("finalizeSdk") {
@@ -308,7 +323,9 @@ task("finalizeSdk") {
tasks["smithyBuildJar"].inputs.file(projectDir.resolve("smithy-build.json"))
tasks["smithyBuildJar"].inputs.dir(projectDir.resolve("aws-models"))
tasks["smithyBuildJar"].dependsOn("generateSmithyBuild")
tasks["smithyBuildJar"].dependsOn(awsServices)
tasks["smithyBuildJar"].dependsOn("generateCargoWorkspace")
tasks["smithyBuildJar"].outputs.upToDateWhen { false }
tasks["assemble"].dependsOn("smithyBuildJar")
tasks["assemble"].finalizedBy("finalizeSdk")

Loading