diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt index 70c02c32fe5e50b63361b9f5e8d8bb462c373f42..08cb80ed4aaacd5cb1ce060d211fadf47985837c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt @@ -182,6 +182,9 @@ data class CargoDependency( "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = DependencyScope.Dev ) + fun smithyXml(runtimeConfig: RuntimeConfig): CargoDependency = + CargoDependency("${runtimeConfig.cratePrefix}-xml", Local(runtimeConfig.relativePath)) + val SerdeJson: CargoDependency = CargoDependency("serde_json", CratesIo("1"), features = listOf("float_roundtrip")) val Serde = CargoDependency("serde", CratesIo("1"), features = listOf("derive")) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt index fd747667f99886f5b888fc9e23d506e202684a3f..3ae8fe908706b6c1247a7ed7513314a62243fba8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt @@ -73,6 +73,24 @@ fun T.rust( this.write(contents, *args) } +/** + * Sibling method to [rustBlock] that enables `#{variablename}` style templating + */ +fun T.rustBlockTemplate( + @Language("Rust", prefix = "macro_rules! foo { () => {{ ", suffix = "}}}") contents: String, + vararg ctx: Pair, + block: T.() -> Unit +) { + check(ctx.distinctBy { it.first.toLowerCase() }.size == ctx.size) { "Duplicate cased keys not supported" } + this.pushState() + this.putContext(ctx.toMap().mapKeys { (k, _) -> k.toLowerCase() }) + val header = contents.replace(Regex("""#\{([a-zA-Z_0-9]+)\}""")) { matchResult -> "#{${matchResult.groupValues[1].toLowerCase()}:T}" } + this.openBlock("$header {") + block(this) + closeBlock("}") + this.popState() +} + /** * API for templating long blocks of Rust * diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt index 7bad1b86964cf89b53ed560564c7b09af2c96aed..d640218a87c383e0086b77a7a1986186cb661112 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt @@ -40,6 +40,7 @@ import software.amazon.smithy.model.traits.HttpLabelTrait import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable +import software.amazon.smithy.rust.codegen.rustlang.stripOuter import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait @@ -338,6 +339,8 @@ fun Symbol.isOptional(): Boolean = when (this.rustType()) { else -> false } +fun Symbol.isBoxed(): Boolean = rustType().stripOuter() is RustType.Box + // Symbols should _always_ be created with a Rust type & shape attached fun Symbol.rustType(): RustType = this.getProperty(RUST_TYPE_KEY, RustType::class.java).get() fun Symbol.shape(): Shape = this.expectProperty(SHAPE_KEY, Shape::class.java) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt index 94cd3f416a70f09e6c476c34ee75d6725ea24d17..afbed98bbf9c74db178ca75e0c02463aa7a2ce50 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt @@ -54,7 +54,6 @@ class BuilderGenerator( val symbol = symbolProvider.toSymbol(shape) // TODO: figure out exactly what docs we want on a the builder module writer.docs("See #D", symbol) - // check(writer.namespace == shape.builderSymbol(symbolProvider).namespace) val segments = shape.builderSymbol(symbolProvider).namespace.split("::") writer.withModule(segments.last()) { renderBuilder(this) @@ -112,8 +111,9 @@ class BuilderGenerator( val memberName = symbolProvider.toMemberName(member) // All fields in the builder are optional val memberSymbol = symbolProvider.toSymbol(member).makeOptional() - // TODO: should the builder members be public? - write("$memberName: #T,", memberSymbol) + // builder members are crate-public to enable using them + // directly in serializers/deserializers + write("pub(crate) $memberName: #T,", memberSymbol) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlBindingTraitParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlBindingTraitParserGenerator.kt new file mode 100644 index 0000000000000000000000000000000000000000..5c9ba25e66b3eedfd87e96e40cc54838a914dd6d --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlBindingTraitParserGenerator.kt @@ -0,0 +1,610 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.protocols + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.knowledge.HttpBinding +import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.model.traits.XmlAttributeTrait +import software.amazon.smithy.model.traits.XmlFlattenedTrait +import software.amazon.smithy.model.traits.XmlNameTrait +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.rustlang.escape +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.smithy.isBoxed +import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.expectMember +import software.amazon.smithy.rust.codegen.util.orNull +import software.amazon.smithy.rust.codegen.util.outputShape +import software.amazon.smithy.rust.codegen.util.toPascalCase +import software.amazon.smithy.rust.codegen.util.toSnakeCase + +class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig) { + + /** + * Abstraction to represent an XML element name: + * `[prefix]:[local]` + */ + data class XmlName(val local: String, val prefix: String? = null) { + override fun toString(): String { + return prefix?.let { "$it:" }.orEmpty() + local + } + + companion object { + fun parse(v: String): XmlName { + val split = v.indexOf(':') + return if (split == -1) { + XmlName(local = v, prefix = null) + } else { + XmlName(v.substring(split + 1), prefix = v.substring(0, split)) + } + } + } + } + + /** + * Codegeneration Context + * + * [tag]: The symbol name of the current tag + * [accum]: Flattened lists and maps need to be written into an accumulator. When a flattened list / map + * is possible, `[accum]` contains an expression to mutably access the accumulator. Specifically, this is an + * option to the collection st. the caller can evaluate `accum.unwrap_or_default()` to get a collection to write + * data into. + */ + data class Ctx(val tag: String, val accum: String?) + + private val symbolProvider = protocolConfig.symbolProvider + private val smithyXml = CargoDependency.smithyXml(protocolConfig.runtimeConfig).asType() + private val xmlError = smithyXml.member("decode::XmlError") + + private val scopedDecoder = smithyXml.member("decode::ScopedDecoder") + private val runtimeConfig = protocolConfig.runtimeConfig + + // The symbols we want all the time + private val codegenScope = arrayOf( + "Blob" to RuntimeType.Blob(runtimeConfig), + "Document" to smithyXml.member("decode::Document"), + "XmlError" to xmlError, + "next_start_element" to smithyXml.member("decode::next_start_element"), + "try_data" to smithyXml.member("decode::try_data"), + "ScopedDecoder" to scopedDecoder + ) + private val model = protocolConfig.model + private val index = HttpBindingIndex.of(model) + + /** + * Generate a parse function for a given targeted as a payload. + * Entry point for payload-based parsing. + * Roughly: + * ```rust + * fn parse_my_struct(input: &[u8]) -> Result { + * ... + * } + * ``` + */ + fun payloadParser(member: MemberShape): RuntimeType { + val shape = model.expectShape(member.target) + check(shape is UnionShape || shape is StructureShape) { "structure parser should only be used on structures & unions" } + val fnName = shape.id.name.toString().toSnakeCase() + return RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlock( + "pub fn $fnName(inp: &[u8]) -> Result<#1T, #2T>", + symbolProvider.toSymbol(shape), + xmlError + ) { + val shapeName = + member.getMemberTrait(model, XmlNameTrait::class.java).orNull()?.let { XmlName.parse(it.value) } + ?: XmlName(local = shape.id.name) + rustTemplate( + """ + use std::convert::TryFrom; + let mut doc = #{Document}::try_from(inp)?; + let mut decoder = doc.root_element()?; + let start_el = decoder.start_el(); + if !(${shapeName.compareTo("start_el")}) { + return Err(#{XmlError}::custom(format!("invalid root, expected $shapeName got {:?}", start_el))) + } + """, + *codegenScope + ) + val ctx = Ctx("decoder", accum = null) + withBlock("Ok(", ")") { + when (shape) { + is StructureShape -> { + parseStructure(shape, ctx) + } + is UnionShape -> parseUnion(shape, ctx) + } + rust("?") + } + } + } + } + + /** Generate a parser for operation input + * Because only a subset of fields of the operation may be impacted by the document, a builder is passed + * through: + * + * ```rust + * fn parse_some_operation(inp: &[u8], builder: my_operation::Builder) -> Result { + * ... + * } + * ``` + */ + fun operationParser(operationShape: OperationShape): RuntimeType { + val outputShape = operationShape.outputShape(model) + val fnName = outputShape.id.name.toString().toSnakeCase() + return RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlock( + "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", + outputShape.builderSymbol(symbolProvider), + xmlError + ) { + val shapeName = XmlName( + local = outputShape.expectTrait(SyntheticOutputTrait::class.java).originalId!!.name, + prefix = null + ) + rustTemplate( + """ + use std::convert::TryFrom; + let mut doc = #{Document}::try_from(inp)?; + let mut decoder = doc.root_element()?; + let start_el = decoder.start_el(); + if !(${shapeName.compareTo("start_el")}) { + return Err(#{XmlError}::custom(format!("invalid root, expected $shapeName got {:?}", start_el))) + } + """, + *codegenScope + ) + val members = operationShape.operationXmlMembers() + parseStructureInner(members, builder = "builder", Ctx(tag = "decoder", accum = null)) + rust("Ok(builder)") + } + } + } + + fun errorParser(errorShape: StructureShape, xmlErrors: RuntimeType): RuntimeType { + val fnName = errorShape.id.name.toString().toSnakeCase() + return RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlock( + "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", + errorShape.builderSymbol(symbolProvider), + xmlError + ) { + rustTemplate( + """ + use std::convert::TryFrom; + let mut document = #{Document}::try_from(inp)?; + let mut error_decoder = #{xml_errors}::error_scope(&mut document)?; + """, + *codegenScope, + "xml_errors" to xmlErrors + ) + val members = errorShape.errorXmlMembers() + parseStructureInner(members, builder = "builder", Ctx(tag = "error_decoder", accum = null)) + rust("Ok(builder)") + } + } + } + + /** + * Update a structure builder based on the [members], specifying where to find each member (document vs. attributes) + */ + private fun RustWriter.parseStructureInner(members: XmlMemberIndex, builder: String, outerCtx: Ctx) { + members.attributeMembers.forEach { member -> + val temp = safeName("attrib") + withBlock("let $temp = ", ";") { + parseAttributeMember(member, outerCtx) + } + rust("$builder.${symbolProvider.toMemberName(member)} = $temp;") + } + parseLoop(outerCtx) { ctx -> + members.dataMembers.forEach { member -> + case(member) { + val temp = safeName() + withBlock("let $temp = ", ";") { + parseMember( + member, + ctx.copy(accum = "$builder.${symbolProvider.toMemberName(member)}.take()") + ) + } + rust("$builder = $builder.${member.setterName()}($temp);") + } + } + } + } + + /** + * The core XML parsing abstraction: A loop that reads through the top level tags at the current scope & + * generates a match expression + * When [ignoreUnexpected] is true, unexpected tags are ignored + */ + private fun RustWriter.parseLoop(ctx: Ctx, ignoreUnexpected: Boolean = true, inner: RustWriter.(Ctx) -> Unit) { + rustBlock("while let Some(mut tag) = ${ctx.tag}.next_tag()") { + rustBlock("match tag.start_el()") { + inner(ctx.copy(tag = "tag")) + if (ignoreUnexpected) { + rust("_ => {}") + } + } + } + } + + /** + * Generate an XML parser for a given member + */ + private fun RustWriter.parseMember(memberShape: MemberShape, ctx: Ctx) { + val target = model.expectShape(memberShape.target) + val symbol = symbolProvider.toSymbol(memberShape) + conditionalBlock("Some(", ")", symbol.isOptional()) { + conditionalBlock("Box::new(", ")", symbol.isBoxed()) { + when (target) { + is StringShape, is BooleanShape, is NumberShape, is TimestampShape, is BlobShape -> parsePrimitiveInner( + memberShape + ) { + rustTemplate("#{try_data}(&mut ${ctx.tag})?.as_ref()", *codegenScope) + } + is MapShape -> if (memberShape.isFlattened()) { + parseFlatMap(target, ctx) + } else { + parseMap(target, ctx) + } + is CollectionShape -> if (memberShape.isFlattened()) { + parseFlatList(target, ctx) + } else { + parseList(target, ctx) + } + is StructureShape -> { + parseStructure(target, ctx) + } + is UnionShape -> parseUnion(target, ctx) + else -> TODO("Unhandled: $target") + } + // each internal `parseT` function writes an `Result` expression, unwrap those: + rust("?") + } + } + } + + private fun RustWriter.parseAttributeMember(memberShape: MemberShape, ctx: Ctx) { + rustBlock("") { + rustTemplate( + """let s = ${ctx.tag} + .start_el() + .attr(${memberShape.xmlName().toString().dq()});""", + *codegenScope + ) + rustBlock("match s") { + rust("None => None,") + withBlock("Some(s) => Some(", ")") { + parsePrimitiveInner(memberShape) { + rust("s") + } + rust("?") + } + } + } + } + + private fun RustWriter.parseUnion(shape: UnionShape, ctx: Ctx) { + val fnName = shape.id.name.toString().toSnakeCase() + "_inner" + val symbol = symbolProvider.toSymbol(shape) + val nestedParser = RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlError}>", + *codegenScope, "Shape" to symbol + ) { + val members = shape.members() + rustTemplate("let mut base: Option<#{Shape}> = None;", *codegenScope) + parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> + members.forEach { member -> + val variantName = member.memberName.toPascalCase() + case(member) { + val current = + """ + (match base.take() { + None => None, + Some(${format(symbol)}::$variantName(inner)) => Some(inner), + Some(_) => return Err(#{XmlError}::custom("mixed variants")) + }) + """ + withBlock("let tmp = ", ";") { + parseMember(member, ctx.copy(accum = current)) + } + rust("base = Some(#T::$variantName(tmp));", symbol) + } + } + } + rustTemplate("""base.ok_or_else(||#{XmlError}::custom("expected union, got nothing"))""", *codegenScope) + } + } + rust("#T(&mut ${ctx.tag})", nestedParser) + } + + /** + * The match clause to check if the tag matches a given member + */ + private fun RustWriter.case(member: MemberShape, inner: RustWriter.() -> Unit) { + rustBlock("s if ${member.xmlName().compareTo("s")} => ") { + inner() + } + rust(",") + } + + private fun RustWriter.parseStructure(shape: StructureShape, ctx: Ctx) { + val fnName = shape.id.name.toString().toSnakeCase() + "_inner" + val symbol = symbolProvider.toSymbol(shape) + val nestedParser = RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlError}>", + *codegenScope, "Shape" to symbol + ) { + rustTemplate( + """ + let mut builder = #{Shape}::builder(); + """, + *codegenScope, "Shape" to symbol + ) + val members = shape.xmlMembers() + parseStructureInner(members, "builder", Ctx(tag = "decoder", accum = null)) + withBlock("Ok(builder.build()", ")") { + if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { + rust(""".map_err(|_|{XmlError}::custom("missing field"))?""") + } + } + } + } + rust("#T(&mut ${ctx.tag})", nestedParser) + } + + private fun RustWriter.parseList(target: CollectionShape, ctx: Ctx) { + val fnName = "deserialize_${target.member.id.name.toSnakeCase()}" + val member = target.member + val listParser = RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{List}, #{XmlError}>", + *codegenScope, + "List" to symbolProvider.toSymbol(target) + ) { + rust("let mut out = std::vec::Vec::new();") + parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> + case(member) { + withBlock("out.push(", ");") { + parseMember(member, ctx) + } + } + } + rust("Ok(out)") + } + } + rust("#T(&mut ${ctx.tag})", listParser) + } + + private fun RustWriter.parseFlatList(target: CollectionShape, ctx: Ctx) { + val list = safeName("list") + withBlock("Result::<#T, #T>::Ok({", "})", symbolProvider.toSymbol(target), xmlError) { + val accum = ctx.accum ?: throw CodegenException("Need accum to parse flat list") + rustTemplate("""let mut $list = $accum.unwrap_or_default();""", *codegenScope) + withBlock("$list.push(", ");") { + parseMember(target.member, ctx) + } + rust(list) + } + } + + private fun RustWriter.parseMap(target: MapShape, ctx: Ctx) { + val fnName = "deserialize_${target.value.id.name.toSnakeCase()}" + val mapParser = RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Map}, #{XmlError}>", + *codegenScope, + "Map" to symbolProvider.toSymbol(target) + ) { + rust("let mut out = #T::new();", RustType.HashMap.RuntimeType) + parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> + rustBlock("s if ${XmlName(local = "entry").compareTo("s")} => ") { + rust("#T(&mut ${ctx.tag}, &mut out)?;", mapEntryParser(target, ctx)) + } + } + rust("Ok(out)") + } + } + rust("#T(&mut ${ctx.tag})", mapParser) + } + + private fun RustWriter.parseFlatMap(target: MapShape, ctx: Ctx) { + val map = safeName("map") + val entryDecoder = mapEntryParser(target, ctx) + withBlock("Result::<#T, #T>::Ok({", "})", symbolProvider.toSymbol(target), xmlError) { + val accum = ctx.accum ?: throw CodegenException("need accum to parse flat map") + rustTemplate( + """ + let mut $map = $accum.unwrap_or_default(); + #{decoder}(&mut tag, &mut $map)?; + $map + """, + *codegenScope, + "decoder" to entryDecoder + ) + } + } + + private fun mapEntryParser( + target: MapShape, + ctx: Ctx + ): RuntimeType { + + val fnName = target.value.id.name.toSnakeCase() + "_entry" + return RuntimeType.forInlineFun(fnName, "xml_deser") { + it.rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}, out: &mut #{Map}) -> Result<(), #{XmlError}>", + *codegenScope, + "Map" to symbolProvider.toSymbol(target) + ) { + rust("let mut k: Option = None;") + rust( + "let mut v: Option<#T> = None;", + symbolProvider.toSymbol(model.expectShape(target.value.target)) + ) + parseLoop(Ctx("decoder", accum = null)) { + case(target.key) { + withBlock("k = Some(", ")") { + parseMember(target.key, ctx = ctx.copy(accum = null)) + } + } + case(target.value) { + withBlock("v = Some(", ")") { + parseMember(target.value, ctx = ctx.copy(accum = "v")) + } + } + } + + rustTemplate( + """ + let k = k.ok_or_else(||#{XmlError}::custom("missing key map entry"))?; + let v = v.ok_or_else(||#{XmlError}::custom("missing value map entry"))?; + out.insert(k, v); + Ok(()) + """, + *codegenScope + ) + } + } + } + + /** + * Parse a simple member from a data field + * [provider] generates code for the inner data field + */ + private fun RustWriter.parsePrimitiveInner(member: MemberShape, provider: RustWriter.() -> Unit) { + when (val shape = model.expectShape(member.target)) { + is StringShape -> parseStringInner(shape, provider) + is NumberShape, is BooleanShape -> { + rustBlock("") { + rust("use std::str::FromStr;") + withBlock("#T::from_str(", ")", symbolProvider.toSymbol(shape)) { + provider() + } + rustTemplate( + """.map_err(|_|#{XmlError}::custom("expected ${escape(shape.toString())}"))""", + *codegenScope + ) + } + } + is TimestampShape -> { + val timestampFormat = + index.determineTimestampFormat( + member, + HttpBinding.Location.DOCUMENT, + TimestampFormatTrait.Format.DATE_TIME + ) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + withBlock("#T::from_str(", ")", RuntimeType.Instant(runtimeConfig)) { + provider() + rust(", #T", timestampFormatType) + } + rustTemplate( + """.map_err(|_|#{XmlError}::custom("expected ${escape(shape.toString())}"))""", + *codegenScope + ) + } + is BlobShape -> { + withBlock("#T(", ")", RuntimeType.Base64Decode(runtimeConfig)) { + provider() + } + rustTemplate( + """.map_err(|err|#{XmlError}::custom(format!("invalid base64: {:?}", err))).map(#{Blob}::new)""", + *codegenScope + ) + } + else -> TODO(shape.toString()) + } + } + + private fun RustWriter.parseStringInner(shape: StringShape, provider: RustWriter.() -> Unit) { + withBlock("Result::<#T, #T>::Ok(", ")", symbolProvider.toSymbol(shape), xmlError) { + val enumTrait = shape.getTrait(EnumTrait::class.java).orElse(null) + if (enumTrait == null) { + provider() + // if it's already `Cow::Owned` then `.into()` is free (vs. to_string()) + rust(".into()") + } else { + val enumSymbol = symbolProvider.toSymbol(shape) + withBlock("#T::from(", ")", enumSymbol) { + provider() + } + } + } + } + + private fun MemberShape.xmlName(): XmlName { + val override = this.getMemberTrait(model, XmlNameTrait::class.java).orNull() + return override?.let { XmlName.parse(it.value) } ?: XmlName(local = this.memberName) + } + + private fun MemberShape.isFlattened(): Boolean { + return getMemberTrait(model, XmlFlattenedTrait::class.java).isPresent + } + + fun XmlName.compareTo(start_el: String) = + "$start_el.matches(${this.toString().dq()})" + + data class XmlMemberIndex(val dataMembers: List, val attributeMembers: List) { + companion object { + fun fromMembers(members: List): XmlMemberIndex { + val (attribute, data) = members.partition { it.hasTrait(XmlAttributeTrait::class.java) } + return XmlMemberIndex(data, attribute) + } + } + } + + private fun OperationShape.operationXmlMembers(): XmlMemberIndex { + val outputShape = this.outputShape(model) + val documentMembers = + index.getResponseBindings(this).filter { it.value.location == HttpBinding.Location.DOCUMENT } + .keys.map { outputShape.expectMember(it) } + return XmlMemberIndex.fromMembers(documentMembers) + } + + private fun StructureShape.errorXmlMembers(): XmlMemberIndex { + val documentMembers = + index.getResponseBindings(this).filter { it.value.location == HttpBinding.Location.DOCUMENT } + .keys.map { this.expectMember(it) } + return XmlMemberIndex.fromMembers(documentMembers) + } + + private fun StructureShape.xmlMembers(): XmlMemberIndex { + return XmlMemberIndex.fromMembers(this.members().toList()) + } +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlBindingTraitParserGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlBindingTraitParserGeneratorTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..599c25a18baa155ab20e0a27e658f346b3d435b9 --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlBindingTraitParserGeneratorTest.kt @@ -0,0 +1,183 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.protocols + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer +import software.amazon.smithy.rust.codegen.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.testutil.renderWithModelBuilder +import software.amazon.smithy.rust.codegen.testutil.testProtocolConfig +import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.testutil.unitTest +import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.codegen.util.outputShape + +internal class XmlBindingTraitParserGeneratorTest { + private val baseModel = """ + namespace test + use aws.protocols#restXml + union Choice { + @xmlFlattened + @xmlName("Hi") + flatMap: MyMap, + + deepMap: MyMap, + + @xmlFlattened + flatList: SomeList, + + deepList: SomeList, + + s: String, + + enum: FooEnum, + + date: Timestamp, + + number: Double, + + top: Top, + + blob: Blob + } + + @enum([{name: "FOO", value: "FOO"}]) + string FooEnum + + map MyMap { + @xmlName("Name") + key: String, + + @xmlName("Setting") + value: Choice, + } + + list SomeList { + member: Choice + } + + structure Top { + choice: Choice, + + @xmlAttribute + extra: Long, + + @xmlName("prefix:local") + renamedWithPrefix: String + } + + @http(uri: "/top", method: "POST") + operation Op { + input: Top, + output: Top + } + """.asSmithyModel() + + @Test + fun `generates valid parsers`() { + val model = RecursiveShapeBoxer.transform(OperationNormalizer(baseModel).transformModel(OperationNormalizer.NoBody, OperationNormalizer.NoBody)) + val symbolProvider = testSymbolProvider(model) + val parserGenerator = XmlBindingTraitParserGenerator(testProtocolConfig(model)) + val operationParser = parserGenerator.operationParser(model.lookup("test#Op")) + val project = TestWorkspace.testProject(testSymbolProvider(model)) + project.lib { + it.unitTest( + name = "valid_input", + test = """ + let xml = br#" + + + some key + + hello + + + + hey + + "#; + let output = ${it.format(operationParser)}(xml, output::op_output::Builder::default()).unwrap().build(); + let mut map = std::collections::HashMap::new(); + map.insert("some key".to_string(), model::Choice::S("hello".to_string())); + assert_eq!(output.choice, Some(model::Choice::FlatMap(map))); + assert_eq!(output.renamed_with_prefix.as_deref(), Some("hey")); + """ + ) + + it.unitTest( + name = "ignore_extras", + test = """ + let xml = br#" + + + + + 5 + + + + some key + + hello + + + + + "#; + let output = ${it.format(operationParser)}(xml, output::op_output::Builder::default()).unwrap().build(); + let mut map = std::collections::HashMap::new(); + map.insert("some key".to_string(), model::Choice::S("hello".to_string())); + assert_eq!(output.choice, Some(model::Choice::FlatMap(map))); + """ + ) + + it.unitTest( + name = "nopanics_on_invalid", + test = """ + let xml = br#" + + + + + 5 + + + + some key + + hello + + + + + "#; + ${it.format(operationParser)}(xml, output::op_output::Builder::default()).expect_err("invalid input"); + """ + ) + } + project.withModule(RustModule.default("model", public = true)) { + model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, it) + UnionGenerator(model, symbolProvider, it, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(symbolProvider, it, enum, enum.expectTrait(EnumTrait::class.java)).render() + } + + project.withModule(RustModule.default("output", public = true)) { + model.lookup("test#Op").outputShape(model).renderWithModelBuilder(model, symbolProvider, it) + } + project.compileAndTest() + } +} diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml index a3b03707d0caeaedc0dd51cdbf875a7794a9f4da..98db25d0242c5d8f3fcd3cbeaadac792070d2e7b 100644 --- a/rust-runtime/inlineable/Cargo.toml +++ b/rust-runtime/inlineable/Cargo.toml @@ -9,11 +9,12 @@ are to allow this crate to be compilable and testable in isolation, no client co """ [dependencies] -"serde" = { version ="1", features = ["derive"] } +"serde" = { version = "1", features = ["derive"] } "serde_json" = "1" "http" = "0.2.1" "smithy-types" = { version = "0.0.1", path = "../smithy-types" } -"smithy-http" = { version = "0.0.1", path = "../smithy-http"} +"smithy-http" = { version = "0.0.1", path = "../smithy-http" } +"smithy-xml" = { path = "../smithy-xml" } "fastrand" = "1" [dev-dependencies] diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index c51ab23a3a24b1021fb3f51efa6335848aa206c0..484d2c837ed662cceab7314b0eb164b622de7b57 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -13,6 +13,10 @@ mod idempotency_token; mod instant_epoch; mod instant_httpdate; mod instant_iso8601; +#[allow(unused)] +mod rest_xml_unwrapped_errors; +#[allow(unused)] +mod rest_xml_wrapped_errors; // This test is outside of uuid.rs to enable copying the entirety of uuid.rs into the SDK without // requiring a proptest dependency diff --git a/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs b/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs new file mode 100644 index 0000000000000000000000000000000000000000..9d9f0ac12eb258bd3804de0a23ebd485baa346dc --- /dev/null +++ b/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +//! Error abstractions for `noErrorWrapping`. Code generators should either inline this file +//! or its companion `rest_xml_wrapped_errors.rs` for code generation + +use smithy_xml::decode::{try_data, Document, ScopedDecoder, XmlError}; +use std::convert::TryFrom; + +pub fn is_error(response: &http::Response) -> bool { + !response.status().is_success() +} + +pub fn body_is_error(body: &[u8]) -> Result { + let mut doc = Document::try_from(body)?; + let scoped = doc.root_element()?; + Ok(scoped.start_el().matches("Error")) +} + +pub fn error_scope<'a, 'b>(doc: &'a mut Document<'b>) -> Result, XmlError> { + let scoped = doc.root_element()?; + if !scoped.start_el().matches("Error") { + return Err(XmlError::custom("expected error as root")); + } + Ok(scoped) +} + +pub fn parse_generic_error(body: &[u8]) -> Result { + let mut doc = Document::try_from(body)?; + let mut root = doc.root_element()?; + let mut err = smithy_types::Error::default(); + while let Some(mut tag) = root.next_tag() { + match tag.start_el().local() { + "Code" => err.code = Some(String::from(try_data(&mut tag)?)), + "Message" => err.message = Some(String::from(try_data(&mut tag)?)), + "RequestId" => err.request_id = Some(String::from(try_data(&mut tag)?)), + _ => {} + } + } + Ok(err) +} + +#[cfg(test)] +mod test { + use super::{body_is_error, parse_generic_error}; + + #[test] + fn parse_unwrapped_error() { + let xml = br#" + Sender + InvalidGreeting + Hi + setting + foo-id +"#; + assert!(body_is_error(xml).unwrap()); + let parsed = parse_generic_error(xml).expect("valid xml"); + assert_eq!(parsed.request_id(), Some("foo-id")); + assert_eq!(parsed.message(), Some("Hi")); + assert_eq!(parsed.code(), Some("InvalidGreeting")); + } +} diff --git a/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs b/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs new file mode 100644 index 0000000000000000000000000000000000000000..c6bc6a0f8e729e42ecf15cae30a8ab8f7e4f1cc0 --- /dev/null +++ b/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs @@ -0,0 +1,112 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use smithy_xml::decode::{try_data, Document, ScopedDecoder, XmlError}; +use std::convert::TryFrom; + +pub fn is_error(response: &http::Response) -> bool { + !response.status().is_success() +} + +pub fn body_is_error(body: &[u8]) -> Result { + let mut doc = Document::try_from(body)?; + let scoped = doc.root_element()?; + Ok(scoped.start_el().matches("ErrorResponse")) +} + +pub fn parse_generic_error(body: &[u8]) -> Result { + let mut doc = Document::try_from(body)?; + let mut root = doc.root_element()?; + let mut err = smithy_types::Error::default(); + while let Some(mut tag) = root.next_tag() { + match tag.start_el().local() { + "Error" => { + while let Some(mut error_field) = tag.next_tag() { + match error_field.start_el().local() { + "Code" => err.code = Some(String::from(try_data(&mut error_field)?)), + "Message" => err.message = Some(String::from(try_data(&mut error_field)?)), + _ => {} + } + } + } + "RequestId" => err.request_id = Some(String::from(try_data(&mut tag)?)), + _ => {} + } + } + Ok(err) +} + +#[allow(unused)] +pub fn error_scope<'a, 'b>(doc: &'a mut Document<'b>) -> Result, XmlError> { + let root = doc + .next_start_element() + .ok_or_else(|| XmlError::custom("no root found searching for an Error"))?; + if !root.matches("ErrorResponse") { + return Err(XmlError::custom("expected ErrorResponse as root")); + } + + while let Some(el) = doc.next_start_element() { + if el.matches("Error") && el.depth() == 1 { + return Ok(doc.scoped_to(el)); + } + // otherwise, ignore it + } + Err(XmlError::custom("No Error found inside of ErrorResponse")) +} + +#[cfg(test)] +mod test { + use super::{body_is_error, parse_generic_error}; + use crate::rest_xml_wrapped_errors::error_scope; + use smithy_types::Document; + use std::convert::TryFrom; + + #[test] + fn parse_wrapped_error() { + let xml = br#" + + Sender + InvalidGreeting + Hi + setting + + + foo-id +"#; + assert!(body_is_error(xml).unwrap()); + let parsed = parse_generic_error(xml).expect("valid xml"); + assert_eq!(parsed.request_id(), Some("foo-id")); + assert_eq!(parsed.message(), Some("Hi")); + assert_eq!(parsed.code(), Some("InvalidGreeting")); + } + + #[test] + fn test_error_scope() { + let xml: &[u8] = br#" + foo-id + foo-id + These are not the errors you are looking for + + Sender + InvalidGreeting + Hi + setting + + + foo-id +"#; + let mut doc = smithy_xml::decode::Document::try_from(xml).expect("valid"); + let mut error = error_scope(&mut doc).expect("contains error"); + let mut keys = vec![]; + while let Some(tag) = error.next_tag() { + keys.push(tag.start_el().local().to_owned()); + // read this the full contents of this element + } + assert_eq!( + keys, + vec!["Type", "Code", "Message", "AnotherSetting", "Ignore",] + ) + } +} diff --git a/rust-runtime/smithy-xml/Cargo.toml b/rust-runtime/smithy-xml/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..efae22a6c9479dcd0cb1509548e6b6aaca0ae46f --- /dev/null +++ b/rust-runtime/smithy-xml/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "smithy-xml" +version = "0.1.0" +authors = ["AWS Rust SDK Team ", "Russell Cohen "] +edition = "2018" + +[dependencies] +xmlparser = "0.13.3" +thiserror = "1" + +[dev-dependencies] +proptest = "1" diff --git a/rust-runtime/smithy-xml/src/decode.rs b/rust-runtime/smithy-xml/src/decode.rs new file mode 100644 index 0000000000000000000000000000000000000000..b7f20da87d5621141e2c2760066ebc39c7dacc72 --- /dev/null +++ b/rust-runtime/smithy-xml/src/decode.rs @@ -0,0 +1,558 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use crate::unescape::unescape; +use std::borrow::Cow; +use std::convert::TryFrom; +use thiserror::Error; +use xmlparser::{ElementEnd, Token, Tokenizer}; + +pub type Depth = usize; + +// in general, these errors are just for reporting what happened, there isn't +// much value in lots of different match variants + +#[derive(Debug, Error)] +pub enum XmlError { + #[error("XML Parse Error")] + InvalidXml(#[from] xmlparser::Error), + + #[error("Invalid XML Escape: {esc}")] + InvalidEscape { esc: String }, + + #[error("Error parsing XML: {0}")] + Custom(Cow<'static, str>), + #[error("Encountered another error parsing XML: {0}")] + Unhandled(#[from] Box), +} + +impl XmlError { + pub fn custom(msg: impl Into>) -> Self { + XmlError::Custom(msg.into()) + } +} + +#[derive(PartialEq, Debug)] +pub struct Name<'a> { + pub prefix: &'a str, + pub local: &'a str, +} + +impl Name<'_> { + /// Check if a given name matches a tag name composed of `prefix:local` or just `local` + pub fn matches(&self, tag_name: &str) -> bool { + let split = tag_name.find(':'); + match split { + None => tag_name == self.local, + Some(idx) => { + let (prefix, local) = tag_name.split_at(idx); + let local = &local[1..]; + self.local == local && self.prefix == prefix + } + } + } +} + +#[derive(Debug, PartialEq)] +pub struct Attr<'a> { + name: Name<'a>, + // attribute values can be escaped (eg. with double quotes, so we need a Cow) + value: Cow<'a, str>, +} + +#[derive(Debug, PartialEq)] +pub struct StartEl<'a> { + name: Name<'a>, + attributes: Vec>, + closed: bool, + depth: Depth, +} + +/// Xml Start Element +/// +/// ```xml +/// +/// ^^^ ^^^^^ +/// name attributes +/// ``` +impl<'a> StartEl<'a> { + pub fn depth(&self) -> Depth { + self.depth + } + + fn new(local: &'a str, prefix: &'a str, depth: Depth) -> Self { + Self { + name: Name { prefix, local }, + attributes: vec![], + closed: false, + depth, + } + } + + /// Retrieve an attribute with a given key + /// + /// key `prefix:local` combined as a str, joined by a `:` + pub fn attr<'b>(&'b self, key: &'b str) -> Option<&'b str> { + self.attributes + .iter() + .find(|attr| attr.name.matches(key)) + .map(|attr| attr.value.as_ref()) + } + + /// Returns whether this `StartEl` matches a given name + /// in `prefix:local` form. + pub fn matches(&self, pat: &str) -> bool { + self.name.matches(pat) + } + + /// Local component of this element's name + /// + /// ```xml + /// + /// ^^^ + /// ``` + pub fn local(&self) -> &str { + self.name.local + } + + /// Prefix component of this elements name (or empty string) + /// ```xml + /// + /// ^^^ + /// ``` + pub fn prefix(&self) -> &str { + self.name.prefix + } + + /// Returns true of `el` at `depth` is a match for this `start_el` + fn end_el(&self, el: ElementEnd, depth: Depth) -> bool { + if depth != self.depth { + return false; + } + match el { + ElementEnd::Open => false, + ElementEnd::Close(prefix, local) => { + prefix.as_str() == self.name.prefix && local.as_str() == self.name.local + } + ElementEnd::Empty => false, + } + } +} + +/// Xml Document abstraction +/// +/// This document wraps a lazy tokenizer with depth tracking. +/// Constructing a document is essentially free. +pub struct Document<'a> { + tokenizer: Tokenizer<'a>, + depth: Depth, +} + +impl<'a> TryFrom<&'a [u8]> for Document<'a> { + type Error = XmlError; + + fn try_from(value: &'a [u8]) -> Result { + Ok(Document::new( + std::str::from_utf8(value).map_err(|err| XmlError::Unhandled(Box::new(err)))?, + )) + } +} + +impl<'inp> Document<'inp> { + pub fn new(doc: &'inp str) -> Self { + Document { + tokenizer: Tokenizer::from(doc), + depth: 0, + } + } + + /// "Depth first" iterator + /// + /// Unlike [`next_tag()`](ScopedDecoder::next_tag), this method returns the next + /// start element regardless of depth. This is useful to give a pointer into the middle + /// of a document to start reading. + /// + /// ```xml + /// <-- first call returns this: + /// <-- next call + /// <-- next call returns this + /// hello <-- then this: + /// + /// <-- second call to next_tag returns this + /// + /// ``` + pub fn next_start_element<'a>(&'a mut self) -> Option> { + next_start_element(self) + } + + /// A scoped reader for the entire document + pub fn root_element<'a>(&'a mut self) -> Result, XmlError> { + let start_el = self + .next_start_element() + .ok_or_else(|| XmlError::custom("no root element"))?; + Ok(ScopedDecoder { + doc: self, + start_el, + terminated: false, + }) + } + + /// A scoped reader for a specific tag + /// + /// This method is necessary for when you need to return a ScopedDecoder from a function + /// since normally the stacked-ownership that `next_tag()` uses would prevent returning a reference + /// to a field owned by the current function + pub fn scoped_to<'a>(&'a mut self, start_el: StartEl<'inp>) -> ScopedDecoder<'inp, 'a> { + ScopedDecoder { + doc: self, + start_el, + terminated: false, + } + } +} + +/// Depth tracking iterator +/// +/// ```xml +/// <- startel depth 0 +/// <- startel depth 1 +/// <- startel depth 2 +/// <- endel depth 2 +/// <- endel depth 1 +/// <- endel depth 0 +/// ``` +impl<'inp> Iterator for Document<'inp> { + type Item = Result<(Token<'inp>, Depth), XmlError>; + fn next<'a>(&'a mut self) -> Option, Depth), XmlError>> { + let tok = self.tokenizer.next()?; + let tok = match tok { + Err(e) => return Some(Err(e.into())), + Ok(tok) => tok, + }; + // depth bookkeeping + match tok { + Token::ElementEnd { + end: ElementEnd::Close(_, _), + .. + } => { + self.depth -= 1; + } + Token::ElementEnd { + end: ElementEnd::Empty, + .. + } => self.depth -= 1, + t @ Token::ElementStart { .. } => { + self.depth += 1; + // We want the startel and endel to have the same depth, but after the opener, + // the parser will be at depth 1. Return the previous depth: + return Some(Ok((t, self.depth - 1))); + } + _ => {} + } + Some(Ok((tok, self.depth))) + } +} + +/// XmlTag Abstraction +/// +/// ScopedDecoder represents a tag-scoped view into an XML document. Methods +/// on `ScopedDecoder` return `None` when the current tag has been exhausted. +pub struct ScopedDecoder<'inp, 'a> { + doc: &'a mut Document<'inp>, + start_el: StartEl<'inp>, + terminated: bool, +} + +/// When a scoped decoder is dropped, its entire scope is consumed so that the +/// next read begins at the next tag at the same depth. +impl Drop for ScopedDecoder<'_, '_> { + fn drop(&mut self) { + for _ in self {} + } +} + +impl<'inp> ScopedDecoder<'inp, '_> { + /// The start element for this scope + pub fn start_el<'a>(&'a self) -> &'a StartEl<'inp> { + &self.start_el + } + + /// Returns the next top-level tag in this scope + /// The returned reader will fully read the tag during its lifetime. If it is dropped without + /// the data being read, the reader will be advanced until the matching close tag. If you read + /// an element with `next_tag()` and you want to ignore it, simply drop the resulting `ScopeDecoder`. + /// + /// ```xml + /// <-- scoped reader on this tag + /// <-- first call to next_tag returns this + /// <-- to get inner data, call `next_tag` on the returned decoder for `A` + /// hello + /// + /// <-- second call to next_tag returns this + /// + /// ``` + pub fn next_tag<'a>(&'a mut self) -> Option> { + let next_tag = next_start_element(self)?; + Some(self.nested_decoder(next_tag)) + } + + fn nested_decoder<'a>(&'a mut self, start_el: StartEl<'inp>) -> ScopedDecoder<'inp, 'a> { + ScopedDecoder { + doc: &mut self.doc, + start_el, + terminated: false, + } + } +} + +impl<'inp, 'a> Iterator for ScopedDecoder<'inp, 'a> { + type Item = Result<(Token<'inp>, Depth), XmlError>; + + fn next(&mut self) -> Option { + if self.start_el.closed { + self.terminated = true; + } + if self.terminated { + return None; + } + let (tok, depth) = match self.doc.next() { + Some(Ok((tok, depth))) => (tok, depth), + other => return other, + }; + + match tok { + Token::ElementEnd { end, .. } if self.start_el.end_el(end, depth) => { + self.terminated = true; + return None; + } + _ => {} + } + Some(Ok((tok, depth))) + } +} + +/// Load the next start element out of a depth-tagged token iterator +fn next_start_element<'a, 'inp>( + tokens: &'a mut impl Iterator, Depth), XmlError>>, +) -> Option> { + let mut out = StartEl::new("", "", 0); + loop { + match tokens.next()? { + Ok((Token::ElementStart { local, prefix, .. }, depth)) => { + out.name.local = local.as_str(); + out.name.prefix = prefix.as_str(); + out.depth = depth; + } + Ok(( + Token::Attribute { + prefix, + local, + value, + .. + }, + _, + )) => out.attributes.push(Attr { + name: Name { + local: local.as_str(), + prefix: prefix.as_str(), + }, + value: unescape(value.as_str()).ok()?, + }), + Ok(( + Token::ElementEnd { + end: ElementEnd::Open, + .. + }, + _, + )) => break, + Ok(( + Token::ElementEnd { + end: ElementEnd::Empty, + .. + }, + _, + )) => { + out.closed = true; + break; + } + _ => {} + } + } + Some(out) +} + +/// Returns the data element at the current position +/// +/// If the current position is not a data element (and is instead a ) an error +/// will be returned +pub fn try_data<'a, 'inp>( + tokens: &'a mut impl Iterator, Depth), XmlError>>, +) -> Result, XmlError> { + loop { + match tokens.next().map(|opt| opt.map(|opt| opt.0)) { + None => return Ok(Cow::Borrowed("")), + Some(Ok(Token::Text { text })) if !text.as_str().trim().is_empty() => { + return unescape(text.as_str().trim()) + } + Some(Ok(e @ Token::ElementStart { .. })) => { + return Err(XmlError::custom(format!( + "Looking for a data element, found: {:?}", + e + ))) + } + Some(Err(e)) => return Err(e), + _ => {} + } + } +} + +#[cfg(test)] +mod test { + use crate::decode::{try_data, Attr, Depth, Document, Name, StartEl}; + + // test helper to create a closed startel + fn closed<'a>(local: &'a str, prefix: &'a str, depth: Depth) -> StartEl<'a> { + let mut s = StartEl::new(local, prefix, depth); + s.closed = true; + s + } + + #[test] + fn scoped_tokens() { + let xml = r#""#; + let mut doc = Document::new(xml); + let mut root = doc.root_element().expect("valid document"); + assert_eq!(root.start_el().local(), "Response"); + assert_eq!(root.next_tag().expect("tag exists").start_el().local(), "A"); + assert!(root.next_tag().is_none()); + } + + #[test] + fn handle_depth_properly() { + let xml = r#""#; + let mut doc = Document::new(xml); + let mut scoped = doc.root_element().expect("valid document"); + assert_eq!( + scoped.next_tag().unwrap().start_el(), + &StartEl::new("Response", "", 1) + ); + let closed_a = closed("A", "", 1); + assert_eq!(scoped.next_tag().unwrap().start_el(), &closed_a); + assert!(scoped.next_tag().is_none()) + } + + #[test] + fn self_closing() { + let xml = r#""#; + let mut doc = Document::new(xml); + let mut scoped = doc.root_element().expect("valid doc"); + assert_eq!(scoped.start_el.closed, true); + assert!(scoped.next_tag().is_none()) + } + + #[test] + fn terminate_scope() { + let xml = r#""#; + let mut doc = Document::new(xml); + let mut response_iter = doc.root_element().expect("valid doc"); + let mut struct_iter = response_iter.next_tag().unwrap(); + assert_eq!( + struct_iter.next_tag().as_ref().map(|t| t.start_el()), + Some(&StartEl::new("A", "", 2)) + ); + // When the inner iter is dropped, it will read to the end of its scope + // prevent accidental behavior where we didn't read a full node + drop(struct_iter); + assert_eq!( + response_iter.next_tag().unwrap().start_el(), + &closed("More", "", 1) + ); + } + + #[test] + fn read_data_invalid() { + let xml = r#""#; + let mut doc = Document::new(xml); + let mut resp = doc.root_element().unwrap(); + try_data(&mut resp).expect_err("no data"); + } + + #[test] + fn read_data() { + let xml = r#"hello"#; + let mut doc = Document::new(xml); + let mut scoped = doc.root_element().unwrap(); + assert_eq!(try_data(&mut scoped).unwrap(), "hello"); + } + + #[test] + fn read_attributes() { + let xml = r#"hello"#; + let mut tokenizer = Document::new(xml); + let root = tokenizer.root_element().unwrap(); + + assert_eq!( + root.start_el().attributes, + vec![Attr { + name: Name { + prefix: "xsi".into(), + local: "type".into() + }, + value: "CanonicalUser".into() + }] + ) + } + + #[test] + fn escape_data() { + let xml = r#">"#; + let mut doc = Document::new(xml); + let mut root = doc.root_element().unwrap(); + assert_eq!(try_data(&mut root).unwrap(), ">"); + assert_eq!(root.start_el().attr("key"), Some("\"hey\">")); + } + + #[test] + fn nested_self_closer() { + let xml = r#" + + + "#; + let mut doc = Document::new(xml); + let mut root = doc.root_element().unwrap(); + let mut string_list = root.next_tag().unwrap(); + assert_eq!(string_list.start_el(), &closed("stringList", "", 1)); + assert!(string_list.next_tag().is_none()); + drop(string_list); + assert_eq!( + root.next_tag().unwrap().start_el(), + &StartEl::new("stringSet", "", 1) + ); + } + + #[test] + fn confusing_nested_same_name_tag() { + // an inner b which could be confused as closing the outer b if depth + // is not properly tracked: + let root_tags = &["a", "b", "c", "d"]; + let xml = r#" + + + + + + + + more + "#; + let mut doc = Document::new(xml); + let mut root = doc.root_element().unwrap(); + let mut cmp = vec![]; + while let Some(tag) = root.next_tag() { + cmp.push(tag.start_el().local().to_owned()); + } + assert_eq!(root_tags, cmp.as_slice()); + } +} diff --git a/rust-runtime/smithy-xml/src/lib.rs b/rust-runtime/smithy-xml/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..a921c80d3f94eedbedc91c4151c33b086d5f5a18 --- /dev/null +++ b/rust-runtime/smithy-xml/src/lib.rs @@ -0,0 +1,4 @@ +//! Abstractions for Smithy +//! [XML Binding Traits](https://awslabs.github.io/smithy/1.0/spec/core/xml-traits.html) +pub mod decode; +mod unescape; diff --git a/rust-runtime/smithy-xml/src/unescape.rs b/rust-runtime/smithy-xml/src/unescape.rs new file mode 100644 index 0000000000000000000000000000000000000000..25e1a7f0ee03f15f2d3bb4fc991109654f1fdcba --- /dev/null +++ b/rust-runtime/smithy-xml/src/unescape.rs @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use crate::decode::XmlError; +use std::borrow::Cow; + +/// Unescape XML encoded characters +/// +/// This function will unescape the 4 literal escapes: +/// - `<`, `>`, `&`, `"`, and `'` +/// - Decimal escapes: `{` +/// - Hex escapes: ` ` +/// +/// If no escape sequences are present, Cow<&'str> will be returned, avoiding the need +/// to copy the String. +pub fn unescape(s: &str) -> Result, XmlError> { + // no &, no need to escape anything + if !s.contains('&') { + return Ok(Cow::Borrowed(s)); + } + // this will be strictly larger than required avoiding the need for another allocation + let mut res = String::with_capacity(s.len()); + // could consider memchr as performance optimization + let mut sections = s.split('&'); + // push content before the first & + if let Some(prefix) = sections.next() { + res.push_str(prefix); + } + for section in sections { + // entites look like &; + match section.find(';') { + Some(idx) => { + let entity = §ion[..idx]; + match entity { + "lt" => res.push('<'), + "gt" => res.push('>'), + "amp" => res.push('&'), + "quot" => res.push('"'), + "apos" => res.push('\''), + entity => { + // eg. + let (entity, radix) = if let Some(entity) = entity.strip_prefix("#x") { + (entity, 16) + } else if let Some(entity) = entity.strip_prefix("#") { + // eg. { + (entity, 10) + } else { + return Err(XmlError::InvalidEscape { + esc: entity.to_string(), + }); + }; + let char_code = u32::from_str_radix(&entity, radix).map_err(|_| { + XmlError::InvalidEscape { + esc: format!( + "Expected numeric escape in base {}; got: {}", + radix, &entity + ), + } + })?; + let chr = + std::char::from_u32(char_code).ok_or(XmlError::InvalidEscape { + esc: format!("invalid char code: {}", char_code), + })?; + res.push(chr); + } + } + // push everything from the `;` to the next `&` + res.push_str(§ion[idx + 1..]) + } + None => { + return Err(XmlError::InvalidEscape { + esc: "Unterminated pattern".to_string(), + }) + } + } + } + Ok(Cow::Owned(res)) +} + +#[cfg(test)] +mod test { + use crate::unescape::unescape; + use std::borrow::Cow; + + #[test] + fn basic_unescape() { + assert_eq!( + unescape("< > ' " &").unwrap(), + "< > ' \" &" + ); + assert_eq!( + unescape("Since a > b, b is less than a").unwrap(), + "Since a > b, b is less than a" + ); + } + + #[test] + fn no_need_to_escape() { + assert_eq!(unescape("hello 🍕!").unwrap(), Cow::Borrowed("hello 🍕!")); + } + + #[test] + fn complex_unescape() { + // Test cases adapted from Apache Commons StringEscapeUtilsTest.java + assert_eq!( + unescape("a<b>c"d'e&f;;").unwrap(), + "ac\"d'e&f;;" + ); + assert_eq!(unescape("&lt;").unwrap(), "<") + } + + #[test] + fn newline_encoding() { + assert_eq!(unescape(" ").unwrap(), "\n"); + assert_eq!(unescape(" ").unwrap(), "\r"); + } + + #[test] + fn xml_eol_encoding() { + assert_eq!(unescape(" ").unwrap(), "\n \n"); + assert_eq!( + unescape("a b c ").unwrap(), + "a\r\n b\n c\r" + ); + assert_eq!( + unescape("a … b…").unwrap(), + "a\r\u{0085} b\u{0085}" + ); + assert_eq!( + unescape("a 
 b… c
").unwrap(), + "a\r\u{2028} b\u{0085} c\u{2028}" + ); + } + + #[test] + fn invalid_escapes() { + unescape("<e;").expect_err("lte does not make a ≤"); + unescape("<").expect_err("unterminated escape sequence"); + unescape("&#Q1234;").expect_err("Q does not began a numeric sequence"); + unescape(".14;").expect_err("decimal escape"); + unescape("&#xZZ").expect_err("Z is not hex"); + unescape("here is a & but without an escape sequence...").expect_err("naked &"); + } + + use proptest::prelude::*; + proptest! { + #[test] + fn no_panics(s: String) { + let unescaped = unescape(&s); + // if the string needed to be escaped, we + if s.contains("&") { + assert!( + matches!(unescaped, Ok(Cow::Owned(_)) | Err(_)) + ); + } + } + } +} diff --git a/rust-runtime/smithy-xml/tests/handwritten_parsers.rs b/rust-runtime/smithy-xml/tests/handwritten_parsers.rs new file mode 100644 index 0000000000000000000000000000000000000000..9f1a2d60a7684af1dea9a07bfa2433f5d16a13a9 --- /dev/null +++ b/rust-runtime/smithy-xml/tests/handwritten_parsers.rs @@ -0,0 +1,234 @@ +//! A collection of handwritten parsers similar to the +//! parsers that are code generated by XmlParserGenerator.kt +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use smithy_xml::decode::{try_data, Document, ScopedDecoder, XmlError}; +use std::collections::HashMap; + +#[derive(Eq, PartialEq, Debug)] +enum FooEnum { + Unknown(String), +} + +impl<'a> From<&'a str> for FooEnum { + fn from(s: &'a str) -> Self { + Self::Unknown(s.to_string()) + } +} + +#[derive(Eq, PartialEq, Debug)] +struct FlatXmlMap { + my_map: HashMap, +} + +#[derive(Eq, PartialEq, Debug)] +struct XmlMap { + values: HashMap, +} + +#[derive(Eq, PartialEq, Debug)] +struct XmlAttribute { + foo: String, + bar: String, +} + +fn deserialize_xml_attribute(inp: &str) -> Result { + let mut doc = Document::new(inp); + let mut root = doc.root_element()?; + #[allow(unused_assignments)] + let mut foo: Option = None; + let mut bar: Option = None; + foo = root.start_el().attr("foo").map(|attr| attr.to_string()); + while let Some(mut tag) = root.next_tag() { + if tag.start_el().matches("bar") { + bar = Some(try_data(&mut tag)?.to_string()); + } + } + Ok(XmlAttribute { + foo: foo.ok_or_else(|| XmlError::custom("missing foo"))?, + bar: bar.ok_or_else(|| XmlError::custom("missing bar"))?, + }) +} + +fn deserialize_flat_xml_map(inp: &str) -> Result { + let mut doc = Document::new(inp); + let mut root = doc.root_element()?; + let mut my_map: Option> = None; + while let Some(mut tag) = root.next_tag() { + if tag.start_el().matches("myMap") { + let mut _my_map = my_map.unwrap_or_default(); + deserialize_foo_enum_map_entry(&mut tag, &mut _my_map)?; + my_map = Some(_my_map); + } + } + Ok(FlatXmlMap { + my_map: my_map.unwrap(), + }) +} + +fn deserialize_xml_map(inp: &str) -> Result { + let mut doc = Document::new(inp); + let mut root = doc.root_element()?; + let mut my_map: Option> = None; + while let Some(mut tag) = root.next_tag() { + if tag.start_el().matches("values") { + my_map = Some(deserialize_foo_enum_map(&mut tag)?); + } + } + Ok(XmlMap { + values: my_map.ok_or_else(|| XmlError::custom("missing map"))?, + }) +} + +fn deserialize_foo_enum_map( + decoder: &mut ScopedDecoder, +) -> Result, XmlError> { + let mut out: HashMap = HashMap::new(); + while let Some(mut tag) = decoder.next_tag() { + if tag.start_el().matches("entry") { + deserialize_foo_enum_map_entry(&mut tag, &mut out)?; + } + } + Ok(out) +} + +fn deserialize_foo_enum_map_entry( + decoder: &mut ScopedDecoder, + out: &mut HashMap, +) -> Result<(), XmlError> { + let mut k: Option = None; + let mut v: Option = None; + while let Some(mut tag) = decoder.next_tag() { + match tag.start_el() { + s if s.matches("key") => k = Some(try_data(&mut tag)?.to_string()), + s if s.matches("value") => v = Some(FooEnum::from(try_data(&mut tag)?.as_ref())), + _ => {} + } + } + match (k, v) { + (Some(k), Some(v)) => { + out.insert(k, v); + } + _ => return Err(XmlError::custom("missing key value in map")), + } + Ok(()) +} + +#[test] +fn deserialize_map_test() { + let xml = r#" + + + example-key1 + hello + example1 + + + example-key2 + example2 + + +"#; + + let mut out = HashMap::new(); + out.insert("example-key1".to_string(), FooEnum::from("example1")); + out.insert("example-key2".to_string(), FooEnum::from("example2")); + assert_eq!( + deserialize_xml_map(xml).expect("valid"), + XmlMap { values: out } + ) +} + +pub fn deserialize_nested_string_list( + decoder: &mut ScopedDecoder, +) -> Result>, XmlError> { + let mut out = std::vec::Vec::new(); + while let Some(mut tag) = decoder.next_tag() { + match tag.start_el() { + s if s.matches("member") => { + out.push(deserialize_string_list(&mut tag)?); + } + _ => {} + } + } + Ok(out) +} + +pub fn deserialize_string_list( + decoder: &mut ScopedDecoder, +) -> Result, XmlError> { + let mut out = std::vec::Vec::new(); + while let Some(mut tag) = decoder.next_tag() { + match dbg!(tag.start_el()) { + s if s.matches("member") => { + out.push(dbg!({ + smithy_xml::decode::try_data(&mut tag)?.to_string() + })); + } + _ => {} + }; + } + println!("done"); + Ok(out) +} + +#[test] +fn test_nested_string_list() { + let xml = r#" + + + foo + bar + + + baz + qux + + + "#; + let mut doc = Document::new(xml); + let mut root = doc.root_element().unwrap(); + assert_eq!( + deserialize_nested_string_list(&mut root).unwrap(), + vec![vec!["foo", "bar"], vec!["baz", "qux"]] + ); +} + +#[test] +fn deserialize_flat_map_test() { + let xml = r#" + + foo + Foo + + + baz + Baz + + "#; + + let mut out = HashMap::new(); + out.insert("foo".to_string(), FooEnum::from("Foo")); + out.insert("baz".to_string(), FooEnum::from("Baz")); + assert_eq!( + deserialize_flat_xml_map(xml).unwrap(), + FlatXmlMap { my_map: out } + ) +} + +#[test] +fn test_deserialize_xml_attribute() { + let xml = r#" + examplebar +"#; + assert_eq!( + deserialize_xml_attribute(xml).expect("valid"), + XmlAttribute { + foo: "example".to_string(), + bar: "examplebar".to_string() + } + ); +}