Unverified Commit c7d4c031 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add Body Validation Support to Protocol Test Helpers (#21)

* Use thiserror to make prettier errors

* Add body validation support

* Fixes for clippy
parent b27ce8b7
Loading
Loading
Loading
Loading
+25 −8
Original line number Diff line number Diff line
@@ -8,6 +8,9 @@ import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq

/**
 * Generate protocol tests for an operation
 */
class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
    fun render() {
        with(protocolConfig) {
@@ -73,11 +76,13 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
                }
            )
        }
        rustWriter.write(
            "assert_eq!(\$T(&http_request, $variableName), Ok(()));",
        assertOk(rustWriter) {
            write(
                "\$T(&http_request, $variableName)",
                RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_headers")
            )
        }
    }

    private fun checkRequiredQueryParams(
        rustWriter: RustWriter,
@@ -106,11 +111,23 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
        rustWriter.withBlock("let $variableName = ", ";") {
            strSlice(this, params)
        }
        rustWriter.write(
            "assert_eq!(\$T(&http_request, $variableName), Ok(()));",
        assertOk(rustWriter) {
            write(
                "\$T(&http_request, $variableName)",
                RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, checkFunction)
            )
        }
    }

    /**
     * wraps `inner` in a call to `protocol_test_helpers::assert_ok`, a convenience wrapper
     * for pretty prettying protocol test helper results
     */
    private fun assertOk(rustWriter: RustWriter, inner: RustWriter.() -> Unit) {
        rustWriter.write("\$T(", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "assert_ok"))
        inner(rustWriter)
        rustWriter.write(");")
    }

    private fun strSlice(writer: RustWriter, args: List<String>) {
        writer.withBlock("&[", "]") {
+5 −5
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ import software.amazon.smithy.rust.testutil.shouldCompile
import software.amazon.smithy.rust.testutil.testSymbolProvider

class HttpProtocolTestGeneratorTest {
    val baseModel = """
    private val baseModel = """
        namespace com.example

        use aws.protocols#restJson1
@@ -133,7 +133,7 @@ class HttpProtocolTestGeneratorTest {
        }
        // Verify the test actually ran
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "MissingQueryParam"
        err.message shouldContain "missing query param"
    }

    @Test
@@ -155,7 +155,7 @@ class HttpProtocolTestGeneratorTest {
        }
        // Verify the test actually ran
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "ForbiddenQueryParam"
        err.message shouldContain "forbidden query param"
    }

    @Test
@@ -177,7 +177,7 @@ class HttpProtocolTestGeneratorTest {
        }
        // Verify the test actually ran
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "RequiredQueryParam"
        err.message shouldContain "required query param missing"
    }

    @Test
@@ -197,6 +197,6 @@ class HttpProtocolTestGeneratorTest {
            writer.shouldCompile(expectFailure = true)
        }
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "InvalidHeader"
        err.message shouldContain "invalid header value"
    }
}
+4 −0
Original line number Diff line number Diff line
@@ -7,3 +7,7 @@ description = "A collection of library functions to validate HTTP requests again

[dependencies]
http = "0.2.1"
thiserror = "1"
serde_json = "1"
# Not perfect for our needs, but good for now
assert-json-diff = "1"
+136 −10
Original line number Diff line number Diff line
use assert_json_diff::assert_json_eq_no_panic;
use http::{Request, Uri};
use std::collections::HashSet;
use thiserror::Error;

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Error)]
pub enum ProtocolTestFailure {
    #[error("missing query param: expected `{expected}`, found {found:?}")]
    MissingQueryParam {
        expected: String,
        found: Vec<String>,
    },
    ForbiddenQueryParam {
        expected: String,
    },
    RequiredQueryParam {
        expected: String,
    },
    #[error("forbidden query param present: `{expected}`")]
    ForbiddenQueryParam { expected: String },
    #[error("required query param missing: `{expected}`")]
    RequiredQueryParam { expected: String },

    #[error("invalid header value for key `{key}`: expected `{expected}`, found `{found}`")]
    InvalidHeader {
        key: String,
        expected: String,
        found: String,
    },
    MissingHeader {
    #[error("missing required header: `{expected}`")]
    MissingHeader { expected: String },
    #[error("body did not match. Hint:\n{hint}")]
    BodyDidNotMatch {
        expected: String,
        found: String,
        hint: String,
    },
    #[error("Expected body to be valid {expected} but instead: {found}")]
    InvalidBodyFormat { expected: String, found: String },
}

/// Check that the protocol test succeeded & print the pretty error
/// if it did not
///
/// The primary motivation is making multiline debug output
/// readable & using the cleaner Display implementation
#[track_caller]
pub fn assert_ok(inp: Result<(), ProtocolTestFailure>) {
    match inp {
        Ok(_) => (),
        Err(e) => {
            eprintln!("{}", e);
            panic!("Protocol test failed");
        }
    }
}

#[derive(Eq, PartialEq, Hash)]
@@ -114,6 +141,7 @@ pub fn validate_headers<B>(
            .join(", ");
        if *expected_value != actual_value {
            return Err(ProtocolTestFailure::InvalidHeader {
                key: key.to_string(),
                expected: expected_value.to_string(),
                found: actual_value,
            });
@@ -122,11 +150,77 @@ pub fn validate_headers<B>(
    Ok(())
}

pub enum MediaType {
    /// Json media types are deserialized and compared
    Json,
    /// Other media types are compared literally
    // TODO: XML, etc.
    Other(String),
}

impl<T: AsRef<str>> From<T> for MediaType {
    fn from(inp: T) -> Self {
        match inp.as_ref() {
            "application/json" => MediaType::Json,
            other => MediaType::Other(other.to_string()),
        }
    }
}

pub fn validate_body<T: AsRef<[u8]>>(
    actual_body: T,
    expected_body: &str,
    media_type: MediaType,
) -> Result<(), ProtocolTestFailure> {
    let body_str = std::str::from_utf8(actual_body.as_ref());
    match (media_type, body_str) {
        (MediaType::Json, Ok(actual_body)) => validate_json_body(actual_body, expected_body),
        (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat {
            expected: "json".to_owned(),
            found: "input was not valid UTF-8".to_owned(),
        }),
        (MediaType::Other(media_type), Ok(actual_body)) => {
            if actual_body != expected_body {
                Err(ProtocolTestFailure::BodyDidNotMatch {
                    expected: expected_body.to_string(),
                    found: actual_body.to_string(),
                    hint: format!("media type: {}", media_type),
                })
            } else {
                Ok(())
            }
        }
        // It's not clear from the Smithy spec exactly how a binary / base64 encoded body is supposed
        // to work. Defer implementation for now until an actual test exists.
        (MediaType::Other(_), Err(_)) => {
            unimplemented!("binary/non-utf8 formats not yet supported")
        }
    }
}

fn validate_json_body(actual: &str, expected: &str) -> Result<(), ProtocolTestFailure> {
    let actual_json: serde_json::Value =
        serde_json::from_str(actual).map_err(|e| ProtocolTestFailure::InvalidBodyFormat {
            expected: "json".to_owned(),
            found: e.to_string(),
        })?;
    let expected_json: serde_json::Value =
        serde_json::from_str(expected).expect("expected value must be valid JSON");
    match assert_json_eq_no_panic(&actual_json, &expected_json) {
        Ok(()) => Ok(()),
        Err(message) => Err(ProtocolTestFailure::BodyDidNotMatch {
            expected: expected.to_string(),
            found: actual.to_string(),
            hint: message,
        }),
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        forbid_query_params, require_query_params, validate_headers, validate_query_string,
        ProtocolTestFailure,
        forbid_query_params, require_query_params, validate_body, validate_headers,
        validate_query_string, MediaType, ProtocolTestFailure,
    };
    use http::Request;

@@ -207,4 +301,36 @@ mod tests {
            })
        );
    }

    #[test]
    fn test_validate_json_body() {
        let expected = r#"{"abc": 5 }"#;
        let actual = r#"   {"abc":   5 }"#;
        validate_body(actual.as_bytes(), expected, MediaType::Json)
            .expect("inputs matched as JSON");

        let expected = r#"{"abc": 5 }"#;
        let actual = r#"   {"abc":   6 }"#;
        validate_body(actual.as_bytes(), expected, MediaType::Json)
            .expect_err("bodies do not match");
    }

    #[test]
    fn test_validate_non_json_body() {
        let expected = r#"asdf"#;
        let actual = r#"asdf "#;
        validate_body(
            actual.as_bytes(),
            expected,
            MediaType::from("something/else"),
        )
        .expect_err("bodies do not match");

        validate_body(
            expected.as_bytes(),
            expected,
            MediaType::from("something/else"),
        )
        .expect("inputs matched exactly")
    }
}