Unverified Commit fb00603c authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Extract a function to set headers only when absent (#591)

parent cd31c9c2
Loading
Loading
Loading
Loading
+7 −5
Original line number Diff line number Diff line
@@ -17,8 +17,10 @@ import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.assignment
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
@@ -408,12 +410,12 @@ class HttpBoundProtocolGenerator(
            rust("let mut builder = self.update_http_builder(#T::new())?;", RuntimeType.HttpRequestBuilder)
            val additionalHeaders = listOf("content-type" to contentType) + protocol.additionalHeaders(operationShape)
            for (header in additionalHeaders) {
                rust(
                    """
                    if !builder.headers_ref().map(|h| h.contains_key(${header.first.dq()})).unwrap_or(false) {
                        builder = builder.header(${header.first.dq()}, ${header.second.dq()});
                    }
                rustTemplate(
                    """
                    builder = #{header_util}::set_header_if_absent(builder, ${header.first.dq()}, ${header.second.dq()});
                    """,
                    "header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header")

                )
            }
            rust("Ok(builder)")
+33 −1
Original line number Diff line number Diff line
@@ -90,6 +90,22 @@ pub fn one_or_none<T: FromStr>(
    }
}

pub fn set_header_if_absent(
    request: http::request::Builder,
    key: &'static str,
    value: &'static str,
) -> http::request::Builder {
    if !request
        .headers_ref()
        .map(|map| map.contains_key(key))
        .unwrap_or(false)
    {
        request.header(key, value)
    } else {
        request
    }
}

/// Read one comma delimited value for `FromStr` types
fn read_one<T>(s: &[u8]) -> Result<(T, &[u8]), ParseError>
where
@@ -118,9 +134,25 @@ fn then_delim(s: &[u8]) -> Result<&[u8], ParseError> {

#[cfg(test)]
mod test {
    use crate::header::{headers_for_prefix, read_many, ParseError};
    use crate::header::{headers_for_prefix, read_many, set_header_if_absent, ParseError};
    use std::collections::HashMap;

    #[test]
    fn put_if_absent() {
        let builder = http::Request::builder().header("foo", "bar");
        let builder = set_header_if_absent(builder, "foo", "baz");
        let builder = set_header_if_absent(builder, "other", "value");
        let req = builder.body(()).expect("valid request");
        assert_eq!(
            req.headers().get_all("foo").iter().collect::<Vec<_>>(),
            vec!["bar"]
        );
        assert_eq!(
            req.headers().get_all("other").iter().collect::<Vec<_>>(),
            vec!["value"]
        );
    }

    #[test]
    fn read_many_bools() {
        let test_request = http::Request::builder()