Unverified Commit 9707c034 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Increase scope of Protocol test generation (#18)

* Setup pre-commit

* Run precommit hook across all files

* Increase scope of Protocol test generaton

Added support for the following fields in protocol tests:
- forbidQueryParams
- requireQueryParams
- headers
parent 86dc5db5
Loading
Loading
Loading
Loading
+61 −7
Original line number Diff line number Diff line
@@ -44,6 +44,10 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
            instantiator.render(httpRequestTestCase.params, protocolConfig.inputShape, this)
            write(";")
            write("let http_request = input.build_http_request().body(()).unwrap();")
            checkQueryParams(this, httpRequestTestCase.queryParams)
            checkForbidQueryParams(this, httpRequestTestCase.forbidQueryParams)
            checkRequiredQueryParams(this, httpRequestTestCase.requireQueryParams)
            checkHeaders(this, httpRequestTestCase.headers)
            with(httpRequestTestCase) {
                write(
                    """
@@ -51,16 +55,66 @@ class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) {
                    assert_eq!(http_request.uri().path(), ${uri.dq()});
                """
                )
                withBlock("let expected_query_params = vec![", "];") {
                    write(queryParams.joinToString(",") { it.dq() })
                // TODO: assert on the body contents
                write("/* BODY:\n ${body.orElse("[ No Body ]")} */")
            }
        }
    }

    private fun checkHeaders(rustWriter: RustWriter, headers: Map<String, String>) {
        if (headers.isEmpty()) {
            return
        }
        val variableName = "expected_headers"
        rustWriter.withBlock("let $variableName = &[", "];") {
            write(
                    "\$T(&http_request, expected_query_params.as_slice()).unwrap();",
                    RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_query_string")
                headers.entries.joinToString(",") {
                    "(${it.key.dq()}, ${it.value.dq()})"
                }
            )
                // TODO: assert on the body contents
                write("/* BODY:\n ${body.orElse("[ No Body ]")} */")
        }
        rustWriter.write(
            "assert_eq!(\$T(&http_request, $variableName), Ok(()));",
            RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_headers")
        )
    }

    private fun checkRequiredQueryParams(
        rustWriter: RustWriter,
        requiredParams: List<String>
    ) = basicCheck(requiredParams, rustWriter, "required_params", "require_query_params")

    private fun checkForbidQueryParams(
        rustWriter: RustWriter,
        forbidParams: List<String>
    ) = basicCheck(forbidParams, rustWriter, "forbid_params", "forbid_query_params")

    private fun checkQueryParams(
        rustWriter: RustWriter,
        queryParams: List<String>
    ) = basicCheck(queryParams, rustWriter, "expected_query_params", "validate_query_string")

    private fun basicCheck(
        params: List<String>,
        rustWriter: RustWriter,
        variableName: String,
        checkFunction: String
    ) {
        if (params.isEmpty()) {
            return
        }
        rustWriter.withBlock("let $variableName = ", ";") {
            strSlice(this, params)
        }
        rustWriter.write(
            "assert_eq!(\$T(&http_request, $variableName), Ok(()));",
            RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, checkFunction)
        )
    }

    private fun strSlice(writer: RustWriter, args: List<String>) {
        writer.withBlock("&[", "]") {
            write(args.joinToString(",") { it.dq() })
        }
    }
}
+97 −46
Original line number Diff line number Diff line
@@ -43,6 +43,10 @@ class HttpProtocolTestGeneratorTest {
                queryParams: [
                    "Hi=Hello%20there"
                ],
                forbidQueryParams: [
                    "goodbye"
                ],
                requireQueryParams: ["required"],
                headers: {
                    "X-Greeting": "Hi",
                },
@@ -68,30 +72,15 @@ class HttpProtocolTestGeneratorTest {
    private val symbolProvider = testSymbolProvider(model)
    private val runtimeConfig = TestRuntimeConfig

    private fun fakeInput(writer: RustWriter, body: String) {
        StructureGenerator(model, symbolProvider, writer, model.lookup("com.example#SayHelloInput")).render()
        writer.rustBlock("impl SayHelloInput") {
    private fun writeHttpImpl(writer: RustWriter, body: String) {
        writer.withModule("operation") {
            StructureGenerator(model, symbolProvider, this, model.lookup("com.example#SayHelloInput")).render()
            rustBlock("impl SayHelloInput") {
                rustBlock("pub fn build_http_request(&self) -> \$T", RuntimeType.HttpRequestBuilder) {
                    write("\$T::new()", RuntimeType.HttpRequestBuilder)
                    write(body)
                }
            }
    }

    @Test
    fun `passing e2e protocol request test`() {
        val writer = RustWriter.forModule("lib")

        // Hard coded implementation for this 1 test
        writer.withModule("operation") {
            fakeInput(
                this,
                """
                        .uri("/?Hi=Hello%20there")
                        .header("X-Greeting", "Hi")
                        .method("POST")
                    """
            )
            val protocolConfig = ProtocolConfig(
                model,
                symbolProvider,
@@ -106,6 +95,19 @@ class HttpProtocolTestGeneratorTest {
                protocolConfig
            ).render()
        }
    }

    @Test
    fun `passing e2e protocol request test`() {
        val writer = RustWriter.forModule("lib")
        writeHttpImpl(
            writer,
            """
                    .uri("/?Hi=Hello%20there&required")
                    .header("X-Greeting", "Hi")
                    .method("POST")
                """
        )

        val testOutput = writer.shouldCompile()
        // Verify the test actually ran
@@ -113,39 +115,88 @@ class HttpProtocolTestGeneratorTest {
    }

    @Test
    fun `failing e2e protocol test`() {
    fun `test invalid url parameter`() {
        val writer = RustWriter.forModule("lib")

        // Hard coded implementation for this 1 test
        writer.withModule("operation") {
            fakeInput(
                this,
        writeHttpImpl(
            writer,
            """
                        .uri("/?Hi=INCORRECT")
                    .uri("/?Hi=INCORRECT&required")
                    .header("X-Greeting", "Hi")
                    .method("POST")
                """
        )
            val protocolConfig = ProtocolConfig(
                model,
                symbolProvider,
                runtimeConfig,
                this,
                model.lookup("com.example#HelloService"),
                model.lookup("com.example#SayHello"),
                model.lookup("com.example#SayHelloInput"),
                RestJson1Trait.ID

        val err = assertThrows<CommandFailed> {
            writer.shouldCompile(expectFailure = true)
        }
        // Verify the test actually ran
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "MissingQueryParam"
    }

    @Test
    fun `test forbidden url parameter`() {
        val writer = RustWriter.forModule("lib")

        // Hard coded implementation for this 1 test
        writeHttpImpl(
            writer,
            """
                    .uri("/?goodbye&Hi=Hello%20there&required")
                    .header("X-Greeting", "Hi")
                    .method("POST")
                """
        )
            HttpProtocolTestGenerator(
                protocolConfig
            ).render()

        val err = assertThrows<CommandFailed> {
            writer.shouldCompile(expectFailure = true)
        }
        // Verify the test actually ran
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "ForbiddenQueryParam"
    }

    @Test
    fun `test required url parameter`() {
        val writer = RustWriter.forModule("lib")

        // Hard coded implementation for this 1 test
        writeHttpImpl(
            writer,
            """
                    .uri("/?Hi=Hello%20there")
                    .header("X-Greeting", "Hi")
                    .method("POST")
                """
        )

        val err = assertThrows<CommandFailed> {
            writer.shouldCompile(expectFailure = true)
        }
        // Verify the test actually ran
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "MissingQueryParam"
        err.message shouldContain "RequiredQueryParam"
    }

    @Test
    fun `invalid header`() {
        val writer = RustWriter.forModule("lib")
        writeHttpImpl(
            writer,
            """
                    .uri("/?Hi=Hello%20there&required")
                    // should be "Hi"
                    .header("X-Greeting", "Hey")
                    .method("POST")
                """
        )

        val err = assertThrows<CommandFailed> {
            writer.shouldCompile(expectFailure = true)
        }
        err.message shouldContain "test_say_hello ... FAILED"
        err.message shouldContain "InvalidHeader"
    }
}
+169 −33
Original line number Diff line number Diff line
use http::Request;
use http::{Request, Uri};
use std::collections::HashSet;

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub enum ProtocolTestFailure {
    MissingQueryParam {
        expected: String,
        found: Vec<String>,
    },
    ForbiddenQueryParam {
        expected: String,
    },
    RequiredQueryParam {
        expected: String,
    },
    InvalidHeader {
        expected: String,
        found: String,
    },
    MissingHeader {
        expected: String,
    },
}

#[derive(Eq, PartialEq, Hash)]
struct QueryParam<'a> {
    key: &'a str,
    value: Option<&'a str>,
}

impl<'a> QueryParam<'a> {
    fn parse(s: &'a str) -> Self {
        let mut parsed = s.split('=');
        QueryParam {
            key: parsed.next().unwrap(),
            value: parsed.next(),
        }
    }
}

fn extract_params(uri: &Uri) -> HashSet<&str> {
    uri.query().unwrap_or_default().split('&').collect()
}

pub fn validate_query_string<B>(
    request: &Request<B>,
    params: &[&str],
    expected_params: &[&str],
) -> Result<(), ProtocolTestFailure> {
    let query_str = request.uri().query().unwrap_or_default();
    let request_params: HashSet<&str> = query_str.split('&').collect();
    let expected: HashSet<&str> = params.iter().copied().collect();
    for param in expected {
        if !request_params.contains(param) {
    let actual_params = extract_params(request.uri());
    for param in expected_params {
        if !actual_params.contains(param) {
            return Err(ProtocolTestFailure::MissingQueryParam {
                expected: param.to_owned(),
                found: request_params
                    .clone()
                    .into_iter()
                    .map(|x| x.to_owned())
                    .collect(),
                expected: param.to_string(),
                found: actual_params.iter().map(|s| s.to_string()).collect(),
            });
        }
    }
    Ok(())
}

pub fn forbid_query_params<B>(
    request: &Request<B>,
    forbid_keys: &[&str],
) -> Result<(), ProtocolTestFailure> {
    let actual_keys: HashSet<&str> = extract_params(request.uri())
        .iter()
        .map(|param| QueryParam::parse(param).key)
        .collect();
    for key in forbid_keys {
        if actual_keys.contains(*key) {
            return Err(ProtocolTestFailure::ForbiddenQueryParam {
                expected: key.to_string(),
            });
        }
    }
    Ok(())
}

pub fn require_query_params<B>(
    request: &Request<B>,
    require_keys: &[&str],
) -> Result<(), ProtocolTestFailure> {
    let actual_keys: HashSet<&str> = extract_params(request.uri())
        .iter()
        .map(|param| QueryParam::parse(param).key)
        .collect();
    for key in require_keys {
        if !actual_keys.contains(*key) {
            return Err(ProtocolTestFailure::RequiredQueryParam {
                expected: key.to_string(),
            });
        }
    }
    Ok(())
}

pub fn validate_headers<B>(
    request: &Request<B>,
    expected_headers: &[(&str, &str)],
) -> Result<(), ProtocolTestFailure> {
    for (key, expected_value) in expected_headers {
        // Protocol tests store header lists as comma-delimited
        if !request.headers().contains_key(*key) {
            return Err(ProtocolTestFailure::MissingHeader {
                expected: key.to_string(),
            });
        }
        let actual_value: String = request
            .headers()
            .get_all(*key)
            .iter()
            .map(|hv| hv.to_str().unwrap())
            .collect::<Vec<_>>()
            .join(", ");
        if *expected_value != actual_value {
            return Err(ProtocolTestFailure::InvalidHeader {
                expected: expected_value.to_string(),
                found: actual_value,
            });
        }
    }
@@ -33,14 +124,17 @@ pub fn validate_query_string<B>(

#[cfg(test)]
mod tests {
    use crate::validate_query_string;
    use crate::{
        forbid_query_params, require_query_params, validate_headers, validate_query_string,
        ProtocolTestFailure,
    };
    use http::Request;

    #[test]
    fn test_validate_empty_query_string() {
        let request = Request::builder().uri("/foo").body(()).unwrap();
        validate_query_string(&request, &vec![]).expect("no required params should pass");
        validate_query_string(&request, &vec!["a"])
        validate_query_string(&request, &[]).expect("no required params should pass");
        validate_query_string(&request, &["a"])
            .err()
            .expect("no params provided");
    }
@@ -51,24 +145,66 @@ mod tests {
            .uri("/foo?a=b&c&d=efg&hello=a%20b")
            .body(())
            .unwrap();
        validate_query_string(&request, &vec!["a=b"]).expect("a=b is in the query string");
        validate_query_string(&request, &vec!["c", "a=b"])
        validate_query_string(&request, &["a=b"]).expect("a=b is in the query string");
        validate_query_string(&request, &["c", "a=b"])
            .expect("both params are in the query string");
        validate_query_string(&request, &vec!["a=b", "c", "d=efg", "hello=a%20b"])
        validate_query_string(&request, &["a=b", "c", "d=efg", "hello=a%20b"])
            .expect("all params are in the query string");
        validate_query_string(&request, &vec![]).expect("no required params should pass");
        validate_query_string(&request, &[]).expect("no required params should pass");

        validate_query_string(&request, &vec!["a"])
            .err()
            .expect("no parameter should match");
        validate_query_string(&request, &vec!["a=bc"])
            .err()
            .expect("no parameter should match");
        validate_query_string(&request, &vec!["a=bc"])
            .err()
            .expect("no parameter should match");
        validate_query_string(&request, &vec!["hell=a%20"])
            .err()
            .expect("no parameter should match");
        validate_query_string(&request, &["a"]).expect_err("no parameter should match");
        validate_query_string(&request, &["a=bc"]).expect_err("no parameter should match");
        validate_query_string(&request, &["a=bc"]).expect_err("no parameter should match");
        validate_query_string(&request, &["hell=a%20"]).expect_err("no parameter should match");
    }

    #[test]
    fn test_forbid_query_param() {
        let request = Request::builder()
            .uri("/foo?a=b&c&d=efg&hello=a%20b")
            .body(())
            .unwrap();
        forbid_query_params(&request, &["a"]).expect_err("a is a query param");
        forbid_query_params(&request, &["not_included"]).expect("query param not included");
        forbid_query_params(&request, &["a=b"]).expect("should be matching against keys");
        forbid_query_params(&request, &["c"]).expect_err("c is a query param");
    }

    #[test]
    fn test_require_query_param() {
        let request = Request::builder()
            .uri("/foo?a=b&c&d=efg&hello=a%20b")
            .body(())
            .unwrap();
        require_query_params(&request, &["a"]).expect("a is a query param");
        require_query_params(&request, &["not_included"]).expect_err("query param not included");
        require_query_params(&request, &["a=b"]).expect_err("should be matching against keys");
        require_query_params(&request, &["c"]).expect("c is a query param");
    }

    #[test]
    fn test_validate_headers() {
        let request = Request::builder()
            .uri("/")
            .header("X-Foo", "foo")
            .header("X-Foo-List", "foo")
            .header("X-Foo-List", "bar")
            .header("X-Inline", "inline, other")
            .body(())
            .unwrap();

        validate_headers(&request, &[("X-Foo", "foo")]).expect("header present");
        validate_headers(&request, &[("X-Foo", "Foo")]).expect_err("case sensitive");
        validate_headers(&request, &[("x-foo-list", "foo, bar")]).expect("list concat");
        validate_headers(&request, &[("X-Foo-List", "foo")])
            .expect_err("all list members must be specified");
        validate_headers(&request, &[("X-Inline", "inline, other")])
            .expect("inline header lists also work");
        assert_eq!(
            validate_headers(&request, &[("missing", "value")]),
            Err(ProtocolTestFailure::MissingHeader {
                expected: "missing".to_owned()
            })
        );
    }
}