diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 078439706cc0ff92f80fbcf4ab9b9ff9bc7fd0d3..2f07111d0c72b0f62b3955ce7640ca0f44b82fbd 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -26,7 +26,7 @@ author = "Velfi" [[smithy-rs]] message = """ -Previously, the config customizations that added functionality related to retry configs, timeout configs, and the +Previously, the config customizations that added functionality related to retry configs, timeout configs, and the async sleep impl were defined in the smithy codegen module but were being loaded in the AWS codegen module. They have now been updated to be loaded during smithy codegen. The affected classes are all defined in the `software.amazon.smithy.rust.codegen.smithy.customizations` module of smithy codegen.` This change does not affect @@ -96,3 +96,14 @@ implies that we're using `aws-smithy-http` 0.47.1 specifically. For the rest of references = ["smithy-rs#1635", "smithy-rs#1416"] meta = { "breaking" = true, "tada" = true, "bug" = false } author = "weihanglo" + +[[smithy-rs]] +message = """ +Servers now allow requests' ACCEPT header values to be: +- `*/*` +- `type/*` +- `type/subtype` +""" +references = ["smithy-rs#1544"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server" } +author = "82marbag" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 088ed67d9dcce96a1bc427343cc25bcb9bc02a96..57178a870cf5b281bea502bd606fe0508b1219b4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -168,15 +168,11 @@ private class ServerHttpBoundProtocolTraitImplGenerator( httpBindingResolver.responseContentType(operationShape)?.also { contentType -> rustTemplate( """ - if let Some(headers) = req.headers() { - if let Some(accept) = headers.get(#{http}::header::ACCEPT) { - if accept != "$contentType" { - return Err(#{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable, - }) - } - } + if ! #{SmithyHttpServer}::protocols::accept_header_classifier(req, ${contentType.dq()}) { + return Err(#{RuntimeError} { + protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, + kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable, + }) } """, *codegenScope, diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index 4a088d46df51f82d0cb563e934678c19af8767f1..5f099f8eb42c4fc089df00c5cf9f4a253f5abd80 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -39,12 +39,55 @@ pub fn check_content_type( } } +pub fn accept_header_classifier(req: &RequestParts, content_type: &'static str) -> bool { + // Allow no ACCEPT header + if req.headers().is_none() { + return true; + } + let headers = req.headers().unwrap(); + if !headers.contains_key(http::header::ACCEPT) { + return true; + } + // Must be of the form: type/subtype + let content_type = content_type + .parse::() + .expect("BUG: MIME parsing failed, content_type is not valid"); + headers + .get_all(http::header::ACCEPT) + .into_iter() + .flat_map(|header| { + header + .to_str() + .ok() + .into_iter() + /* + * turn a header value of: "type0/subtype0, type1/subtype1, ..." + * into: ["type0/subtype0", "type1/subtype1", ...] + * and remove the optional "; q=x" parameters + * NOTE: the unwrap() is safe, because it takes the first element (if there's nothing to split, returns the string) + */ + .flat_map(|s| s.split(',').map(|typ| typ.split(';').next().unwrap().trim())) + }) + .filter_map(|h| h.parse::().ok()) + .any(|mim| { + let typ = content_type.type_(); + let subtype = content_type.subtype(); + // Accept: */*, type/*, type/subtype + match (mim.type_(), mim.subtype()) { + (t, s) if t == typ && s == subtype => true, + (t, mime::STAR) if t == typ => true, + (mime::STAR, mime::STAR) => true, + _ => false, + } + }) +} + #[cfg(test)] mod tests { use super::*; use http::Request; - fn req(content_type: &str) -> RequestParts<&str> { + fn req_content_type(content_type: &str) -> RequestParts<&str> { let request = Request::builder() .header("content-type", content_type) .body("") @@ -52,12 +95,17 @@ mod tests { RequestParts::new(request) } + fn req_accept(content_type: &str) -> RequestParts<&str> { + let request = Request::builder().header("accept", content_type).body("").unwrap(); + RequestParts::new(request) + } + static EXPECTED_MIME_APPLICATION_JSON: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| "application/json".parse::().unwrap()); #[test] fn check_valid_content_type() { - let valid_request = req("application/json"); + let valid_request = req_content_type("application/json"); assert!(check_content_type(&valid_request, &EXPECTED_MIME_APPLICATION_JSON).is_ok()); } @@ -65,7 +113,7 @@ mod tests { fn check_invalid_content_type() { let invalid = vec!["application/ajson", "text/xml"]; for invalid_mime in invalid { - let request = req(invalid_mime); + let request = req_content_type(invalid_mime); let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); // Validates the rejection type since we cannot implement `PartialEq` @@ -98,7 +146,7 @@ mod tests { #[test] fn check_not_parsable_content_type() { - let request = req("123"); + let request = req_content_type("123"); let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); assert!(matches!( result.unwrap_err(), @@ -108,8 +156,51 @@ mod tests { #[test] fn check_non_ascii_visible_characters_content_type() { - let request = req("application/💩"); + let request = req_content_type("application/💩"); let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_))); } + + #[test] + fn valid_accept_header_classifier_multiple_values() { + let valid_request = req_accept("text/strings, application/json, invalid"); + assert!(accept_header_classifier(&valid_request, "application/json")); + } + + #[test] + fn invalid_accept_header_classifier() { + let invalid_request = req_accept("text/invalid, invalid, invalid/invalid"); + assert!(!accept_header_classifier(&invalid_request, "application/json")); + } + + #[test] + fn valid_accept_header_classifier_star() { + let valid_request = req_accept("application/*"); + assert!(accept_header_classifier(&valid_request, "application/json")); + } + + #[test] + fn valid_accept_header_classifier_star_star() { + let valid_request = req_accept("*/*"); + assert!(accept_header_classifier(&valid_request, "application/json")); + } + + #[test] + fn valid_empty_accept_header_classifier() { + let valid_request = Request::builder().body("").unwrap(); + let valid_request = RequestParts::new(valid_request); + assert!(accept_header_classifier(&valid_request, "application/json")); + } + + #[test] + fn valid_accept_header_classifier_with_params() { + let valid_request = req_accept("application/json; q=30, */*"); + assert!(accept_header_classifier(&valid_request, "application/json")); + } + + #[test] + fn valid_accept_header_classifier() { + let valid_request = req_accept("application/json"); + assert!(accept_header_classifier(&valid_request, "application/json")); + } }