Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +25 −8 Original line number Diff line number Diff line Loading @@ -8,6 +8,9 @@ import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq /** * Generate protocol tests for an operation */ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { fun render() { with(protocolConfig) { Loading Loading @@ -73,11 +76,13 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { } ) } rustWriter.write( "assert_eq!(\$T(&http_request, $variableName), Ok(()));", assertOk(rustWriter) { write( "\$T(&http_request, $variableName)", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_headers") ) } } private fun checkRequiredQueryParams( rustWriter: RustWriter, Loading Loading @@ -106,11 +111,23 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { rustWriter.withBlock("let $variableName = ", ";") { strSlice(this, params) } rustWriter.write( "assert_eq!(\$T(&http_request, $variableName), Ok(()));", assertOk(rustWriter) { write( "\$T(&http_request, $variableName)", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, checkFunction) ) } } /** * wraps `inner` in a call to `protocol_test_helpers::assert_ok`, a convenience wrapper * for pretty prettying protocol test helper results */ private fun assertOk(rustWriter: RustWriter, inner: RustWriter.() -> Unit) { rustWriter.write("\$T(", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "assert_ok")) inner(rustWriter) rustWriter.write(");") } private fun strSlice(writer: RustWriter, args: List<String>) { writer.withBlock("&[", "]") { Loading codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt +5 −5 Original line number Diff line number Diff line Loading @@ -16,7 +16,7 @@ import software.amazon.smithy.rust.testutil.shouldCompile import software.amazon.smithy.rust.testutil.testSymbolProvider class HttpProtocolTestGeneratorTest { val baseModel = """ private val baseModel = """ namespace com.example use aws.protocols#restJson1 Loading Loading @@ -133,7 +133,7 @@ class HttpProtocolTestGeneratorTest { } // Verify the test actually ran err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "MissingQueryParam" err.message shouldContain "missing query param" } @Test Loading @@ -155,7 +155,7 @@ class HttpProtocolTestGeneratorTest { } // Verify the test actually ran err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "ForbiddenQueryParam" err.message shouldContain "forbidden query param" } @Test Loading @@ -177,7 +177,7 @@ class HttpProtocolTestGeneratorTest { } // Verify the test actually ran err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "RequiredQueryParam" err.message shouldContain "required query param missing" } @Test Loading @@ -197,6 +197,6 @@ class HttpProtocolTestGeneratorTest { writer.shouldCompile(expectFailure = true) } err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "InvalidHeader" err.message shouldContain "invalid header value" } } rust-runtime/protocol-test-helpers/Cargo.toml +4 −0 Original line number Diff line number Diff line Loading @@ -7,3 +7,7 @@ description = "A collection of library functions to validate HTTP requests again [dependencies] http = "0.2.1" thiserror = "1" serde_json = "1" # Not perfect for our needs, but good for now assert-json-diff = "1" rust-runtime/protocol-test-helpers/src/lib.rs +136 −10 Original line number Diff line number Diff line use assert_json_diff::assert_json_eq_no_panic; use http::{Request, Uri}; use std::collections::HashSet; use thiserror::Error; #[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq, Error)] pub enum ProtocolTestFailure { #[error("missing query param: expected `{expected}`, found {found:?}")] MissingQueryParam { expected: String, found: Vec<String>, }, ForbiddenQueryParam { expected: String, }, RequiredQueryParam { expected: String, }, #[error("forbidden query param present: `{expected}`")] ForbiddenQueryParam { expected: String }, #[error("required query param missing: `{expected}`")] RequiredQueryParam { expected: String }, #[error("invalid header value for key `{key}`: expected `{expected}`, found `{found}`")] InvalidHeader { key: String, expected: String, found: String, }, MissingHeader { #[error("missing required header: `{expected}`")] MissingHeader { expected: String }, #[error("body did not match. Hint:\n{hint}")] BodyDidNotMatch { expected: String, found: String, hint: String, }, #[error("Expected body to be valid {expected} but instead: {found}")] InvalidBodyFormat { expected: String, found: String }, } /// Check that the protocol test succeeded & print the pretty error /// if it did not /// /// The primary motivation is making multiline debug output /// readable & using the cleaner Display implementation #[track_caller] pub fn assert_ok(inp: Result<(), ProtocolTestFailure>) { match inp { Ok(_) => (), Err(e) => { eprintln!("{}", e); panic!("Protocol test failed"); } } } #[derive(Eq, PartialEq, Hash)] Loading Loading @@ -114,6 +141,7 @@ pub fn validate_headers<B>( .join(", "); if *expected_value != actual_value { return Err(ProtocolTestFailure::InvalidHeader { key: key.to_string(), expected: expected_value.to_string(), found: actual_value, }); Loading @@ -122,11 +150,77 @@ pub fn validate_headers<B>( Ok(()) } pub enum MediaType { /// Json media types are deserialized and compared Json, /// Other media types are compared literally // TODO: XML, etc. Other(String), } impl<T: AsRef<str>> From<T> for MediaType { fn from(inp: T) -> Self { match inp.as_ref() { "application/json" => MediaType::Json, other => MediaType::Other(other.to_string()), } } } pub fn validate_body<T: AsRef<[u8]>>( actual_body: T, expected_body: &str, media_type: MediaType, ) -> Result<(), ProtocolTestFailure> { let body_str = std::str::from_utf8(actual_body.as_ref()); match (media_type, body_str) { (MediaType::Json, Ok(actual_body)) => validate_json_body(actual_body, expected_body), (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: "input was not valid UTF-8".to_owned(), }), (MediaType::Other(media_type), Ok(actual_body)) => { if actual_body != expected_body { Err(ProtocolTestFailure::BodyDidNotMatch { expected: expected_body.to_string(), found: actual_body.to_string(), hint: format!("media type: {}", media_type), }) } else { Ok(()) } } // It's not clear from the Smithy spec exactly how a binary / base64 encoded body is supposed // to work. Defer implementation for now until an actual test exists. (MediaType::Other(_), Err(_)) => { unimplemented!("binary/non-utf8 formats not yet supported") } } } fn validate_json_body(actual: &str, expected: &str) -> Result<(), ProtocolTestFailure> { let actual_json: serde_json::Value = serde_json::from_str(actual).map_err(|e| ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: e.to_string(), })?; let expected_json: serde_json::Value = serde_json::from_str(expected).expect("expected value must be valid JSON"); match assert_json_eq_no_panic(&actual_json, &expected_json) { Ok(()) => Ok(()), Err(message) => Err(ProtocolTestFailure::BodyDidNotMatch { expected: expected.to_string(), found: actual.to_string(), hint: message, }), } } #[cfg(test)] mod tests { use crate::{ forbid_query_params, require_query_params, validate_headers, validate_query_string, ProtocolTestFailure, forbid_query_params, require_query_params, validate_body, validate_headers, validate_query_string, MediaType, ProtocolTestFailure, }; use http::Request; Loading Loading @@ -207,4 +301,36 @@ mod tests { }) ); } #[test] fn test_validate_json_body() { let expected = r#"{"abc": 5 }"#; let actual = r#" {"abc": 5 }"#; validate_body(actual.as_bytes(), expected, MediaType::Json) .expect("inputs matched as JSON"); let expected = r#"{"abc": 5 }"#; let actual = r#" {"abc": 6 }"#; validate_body(actual.as_bytes(), expected, MediaType::Json) .expect_err("bodies do not match"); } #[test] fn test_validate_non_json_body() { let expected = r#"asdf"#; let actual = r#"asdf "#; validate_body( actual.as_bytes(), expected, MediaType::from("something/else"), ) .expect_err("bodies do not match"); validate_body( expected.as_bytes(), expected, MediaType::from("something/else"), ) .expect("inputs matched exactly") } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +25 −8 Original line number Diff line number Diff line Loading @@ -8,6 +8,9 @@ import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq /** * Generate protocol tests for an operation */ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { fun render() { with(protocolConfig) { Loading Loading @@ -73,11 +76,13 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { } ) } rustWriter.write( "assert_eq!(\$T(&http_request, $variableName), Ok(()));", assertOk(rustWriter) { write( "\$T(&http_request, $variableName)", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_headers") ) } } private fun checkRequiredQueryParams( rustWriter: RustWriter, Loading Loading @@ -106,11 +111,23 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { rustWriter.withBlock("let $variableName = ", ";") { strSlice(this, params) } rustWriter.write( "assert_eq!(\$T(&http_request, $variableName), Ok(()));", assertOk(rustWriter) { write( "\$T(&http_request, $variableName)", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, checkFunction) ) } } /** * wraps `inner` in a call to `protocol_test_helpers::assert_ok`, a convenience wrapper * for pretty prettying protocol test helper results */ private fun assertOk(rustWriter: RustWriter, inner: RustWriter.() -> Unit) { rustWriter.write("\$T(", RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "assert_ok")) inner(rustWriter) rustWriter.write(");") } private fun strSlice(writer: RustWriter, args: List<String>) { writer.withBlock("&[", "]") { Loading
codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt +5 −5 Original line number Diff line number Diff line Loading @@ -16,7 +16,7 @@ import software.amazon.smithy.rust.testutil.shouldCompile import software.amazon.smithy.rust.testutil.testSymbolProvider class HttpProtocolTestGeneratorTest { val baseModel = """ private val baseModel = """ namespace com.example use aws.protocols#restJson1 Loading Loading @@ -133,7 +133,7 @@ class HttpProtocolTestGeneratorTest { } // Verify the test actually ran err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "MissingQueryParam" err.message shouldContain "missing query param" } @Test Loading @@ -155,7 +155,7 @@ class HttpProtocolTestGeneratorTest { } // Verify the test actually ran err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "ForbiddenQueryParam" err.message shouldContain "forbidden query param" } @Test Loading @@ -177,7 +177,7 @@ class HttpProtocolTestGeneratorTest { } // Verify the test actually ran err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "RequiredQueryParam" err.message shouldContain "required query param missing" } @Test Loading @@ -197,6 +197,6 @@ class HttpProtocolTestGeneratorTest { writer.shouldCompile(expectFailure = true) } err.message shouldContain "test_say_hello ... FAILED" err.message shouldContain "InvalidHeader" err.message shouldContain "invalid header value" } }
rust-runtime/protocol-test-helpers/Cargo.toml +4 −0 Original line number Diff line number Diff line Loading @@ -7,3 +7,7 @@ description = "A collection of library functions to validate HTTP requests again [dependencies] http = "0.2.1" thiserror = "1" serde_json = "1" # Not perfect for our needs, but good for now assert-json-diff = "1"
rust-runtime/protocol-test-helpers/src/lib.rs +136 −10 Original line number Diff line number Diff line use assert_json_diff::assert_json_eq_no_panic; use http::{Request, Uri}; use std::collections::HashSet; use thiserror::Error; #[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq, Error)] pub enum ProtocolTestFailure { #[error("missing query param: expected `{expected}`, found {found:?}")] MissingQueryParam { expected: String, found: Vec<String>, }, ForbiddenQueryParam { expected: String, }, RequiredQueryParam { expected: String, }, #[error("forbidden query param present: `{expected}`")] ForbiddenQueryParam { expected: String }, #[error("required query param missing: `{expected}`")] RequiredQueryParam { expected: String }, #[error("invalid header value for key `{key}`: expected `{expected}`, found `{found}`")] InvalidHeader { key: String, expected: String, found: String, }, MissingHeader { #[error("missing required header: `{expected}`")] MissingHeader { expected: String }, #[error("body did not match. Hint:\n{hint}")] BodyDidNotMatch { expected: String, found: String, hint: String, }, #[error("Expected body to be valid {expected} but instead: {found}")] InvalidBodyFormat { expected: String, found: String }, } /// Check that the protocol test succeeded & print the pretty error /// if it did not /// /// The primary motivation is making multiline debug output /// readable & using the cleaner Display implementation #[track_caller] pub fn assert_ok(inp: Result<(), ProtocolTestFailure>) { match inp { Ok(_) => (), Err(e) => { eprintln!("{}", e); panic!("Protocol test failed"); } } } #[derive(Eq, PartialEq, Hash)] Loading Loading @@ -114,6 +141,7 @@ pub fn validate_headers<B>( .join(", "); if *expected_value != actual_value { return Err(ProtocolTestFailure::InvalidHeader { key: key.to_string(), expected: expected_value.to_string(), found: actual_value, }); Loading @@ -122,11 +150,77 @@ pub fn validate_headers<B>( Ok(()) } pub enum MediaType { /// Json media types are deserialized and compared Json, /// Other media types are compared literally // TODO: XML, etc. Other(String), } impl<T: AsRef<str>> From<T> for MediaType { fn from(inp: T) -> Self { match inp.as_ref() { "application/json" => MediaType::Json, other => MediaType::Other(other.to_string()), } } } pub fn validate_body<T: AsRef<[u8]>>( actual_body: T, expected_body: &str, media_type: MediaType, ) -> Result<(), ProtocolTestFailure> { let body_str = std::str::from_utf8(actual_body.as_ref()); match (media_type, body_str) { (MediaType::Json, Ok(actual_body)) => validate_json_body(actual_body, expected_body), (MediaType::Json, Err(_)) => Err(ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: "input was not valid UTF-8".to_owned(), }), (MediaType::Other(media_type), Ok(actual_body)) => { if actual_body != expected_body { Err(ProtocolTestFailure::BodyDidNotMatch { expected: expected_body.to_string(), found: actual_body.to_string(), hint: format!("media type: {}", media_type), }) } else { Ok(()) } } // It's not clear from the Smithy spec exactly how a binary / base64 encoded body is supposed // to work. Defer implementation for now until an actual test exists. (MediaType::Other(_), Err(_)) => { unimplemented!("binary/non-utf8 formats not yet supported") } } } fn validate_json_body(actual: &str, expected: &str) -> Result<(), ProtocolTestFailure> { let actual_json: serde_json::Value = serde_json::from_str(actual).map_err(|e| ProtocolTestFailure::InvalidBodyFormat { expected: "json".to_owned(), found: e.to_string(), })?; let expected_json: serde_json::Value = serde_json::from_str(expected).expect("expected value must be valid JSON"); match assert_json_eq_no_panic(&actual_json, &expected_json) { Ok(()) => Ok(()), Err(message) => Err(ProtocolTestFailure::BodyDidNotMatch { expected: expected.to_string(), found: actual.to_string(), hint: message, }), } } #[cfg(test)] mod tests { use crate::{ forbid_query_params, require_query_params, validate_headers, validate_query_string, ProtocolTestFailure, forbid_query_params, require_query_params, validate_body, validate_headers, validate_query_string, MediaType, ProtocolTestFailure, }; use http::Request; Loading Loading @@ -207,4 +301,36 @@ mod tests { }) ); } #[test] fn test_validate_json_body() { let expected = r#"{"abc": 5 }"#; let actual = r#" {"abc": 5 }"#; validate_body(actual.as_bytes(), expected, MediaType::Json) .expect("inputs matched as JSON"); let expected = r#"{"abc": 5 }"#; let actual = r#" {"abc": 6 }"#; validate_body(actual.as_bytes(), expected, MediaType::Json) .expect_err("bodies do not match"); } #[test] fn test_validate_non_json_body() { let expected = r#"asdf"#; let actual = r#"asdf "#; validate_body( actual.as_bytes(), expected, MediaType::from("something/else"), ) .expect_err("bodies do not match"); validate_body( expected.as_bytes(), expected, MediaType::from("something/else"), ) .expect("inputs matched exactly") } }