Unverified Commit 25c50a5f authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

[Server] Respect @required trait in model generation (#1148)



Since adding @required to a shape member is not a backwards-compatible change, the server implementation can generate non-Optional structure members for @required shape members.

On top of this generated server code shouldn't use non-exhaustive tag on structures for the same non-backwards compatibility and because it forces to constructs instances through builders.

This change introduces the new parameter handleRequired when constructing the SymbolVisitor. This parameter is set to false for client codegen and to true for server codegen.

With this new parameter we control whether to always treat symbols as Option or to follow the model directives more strictly remove optionality for @required members.

Co-authored-by: default avatardavid-perez <d@vidp.dev>
Co-authored-by: default avatarRussell Cohen <rcoh@amazon.com>
parent a1f06527
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -66,7 +66,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin {
                // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
                .let { StreamingShapeSymbolProvider(it, model) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it) }
                .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) }
                // Streaming shapes need different derives (e.g. they cannot derive Eq)
                .let { StreamingShapeMetadataProvider(it, model) }
                // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot
+2 −1
Original line number Diff line number Diff line
@@ -67,7 +67,8 @@ class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator:
        val symbolVisitorConfig =
            SymbolVisitorConfig(
                runtimeConfig = settings.runtimeConfig,
                codegenConfig = settings.codegenConfig
                codegenConfig = settings.codegenConfig,
                handleRequired = true
            )
        val baseModel = baselineTransform(context.model)
        val service = settings.getService(baseModel)
+9 −11
Original line number Diff line number Diff line
@@ -45,12 +45,13 @@ import software.amazon.smithy.rust.codegen.smithy.generators.protocol.MakeOperat
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTraitImplGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.smithy.makeOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.toOptional
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
@@ -450,7 +451,7 @@ private class ServerHttpProtocolImplGenerator(
        serverRenderResponseHeaders(operationShape)

        for (binding in bindings) {
            val serializedValue = serverRenderBindingSerializer(binding, operationShape)
            val serializedValue = serverRenderBindingSerializer(binding)
            if (serializedValue != null) {
                serializedValue(this)
            }
@@ -521,9 +522,7 @@ private class ServerHttpProtocolImplGenerator(

    private fun serverRenderBindingSerializer(
        binding: HttpBindingDescriptor,
        operationShape: OperationShape,
    ): Writable? {
        val operationName = symbolProvider.toSymbol(operationShape).name
        val member = binding.member
        return when (binding.location) {
            HttpLocation.HEADER,
@@ -728,7 +727,7 @@ private class ServerHttpProtocolImplGenerator(
                        rustTemplate(
                            """
                            input = input.${binding.member.setterName()}(
                                #{deserializer}(m$index)?
                                ${symbolProvider.toOptional(binding.member, "#{deserializer}(m$index)?")}
                            );
                            """.trimIndent(),
                            *codegenScope,
@@ -825,7 +824,7 @@ private class ServerHttpProtocolImplGenerator(
                        """
                        if !seen_$memberName && k == "${it.locationName}" {
                            input = input.${it.member.setterName()}(
                                #{deserializer}(&v)?
                                ${symbolProvider.toOptional(it.member, "#{deserializer}(&v)?")}
                            );
                            seen_$memberName = true;
                        }
@@ -979,7 +978,7 @@ private class ServerHttpProtocolImplGenerator(
                rustTemplate(
                    """
                    let value = <_>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
                    Ok(Some(value))
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """.trimIndent(),
                    *codegenScope,
                )
@@ -1008,7 +1007,7 @@ private class ServerHttpProtocolImplGenerator(
                    """
                    let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
                    let value = #{DateTime}::from_str(&value, #{format})?;
                    Ok(Some(value))
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """.trimIndent(),
                    *codegenScope,
                    "format" to timestampFormatType,
@@ -1019,7 +1018,7 @@ private class ServerHttpProtocolImplGenerator(

    // TODO These functions can be replaced with the ones in https://docs.rs/aws-smithy-types/latest/aws_smithy_types/primitive/trait.Parse.html
    private fun generateParseStrAsPrimitiveFn(binding: HttpBindingDescriptor): RuntimeType {
        val output = symbolProvider.toSymbol(binding.member).makeOptional()
        val output = symbolProvider.toSymbol(binding.member)
        val fnName = generateParseStrFnName(binding)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
            writer.rustBlockTemplate(
@@ -1030,7 +1029,7 @@ private class ServerHttpProtocolImplGenerator(
                rustTemplate(
                    """
                    let value = std::str::FromStr::from_str(value)?;
                    Ok(Some(value))
                    Ok(${symbolProvider.wrapOptional(binding.member, "value")})
                    """.trimIndent(),
                    *codegenScope,
                )
@@ -1074,7 +1073,6 @@ private class ServerHttpProtocolImplGenerator(
            }
        }
    }

    private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String {
        if (operationShape.inputShape(model).hasStreamingMember(model)) {
            return "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
+2 −1
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import software.amazon.smithy.build.SmithyBuildPlugin
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhaustive
import software.amazon.smithy.rust.codegen.rustlang.RustReservedWordSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.customizations.ClientCustomizations
import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator
@@ -52,7 +53,7 @@ class RustCodegenPlugin : SmithyBuildPlugin {
                // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
                .let { StreamingShapeSymbolProvider(it, model) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it) }
                .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(NonExhaustive)) }
                // Streaming shapes need different derives (e.g. they cannot derive Eq)
                .let { StreamingShapeMetadataProvider(it, model) }
                // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot
+10 −3
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhaustive
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.PartialEq
import software.amazon.smithy.rust.codegen.util.hasTrait
@@ -67,10 +66,18 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr
    abstract fun enumMeta(stringShape: StringShape): RustMetadata
}

class BaseSymbolMetadataProvider(base: RustSymbolProvider) : SymbolMetadataProvider(base) {
/**
 * The base metadata supports a list of attributes that are used by generators to decorate code.
 * By default we apply #[non_exhaustive] only to client structures since model changes should
 * be considered as breaking only when generating server code.
 */
class BaseSymbolMetadataProvider(
    base: RustSymbolProvider,
    additionalAttributes: List<Attribute>,
) : SymbolMetadataProvider(base) {
    private val containerDefault = RustMetadata(
        Attribute.Derives(defaultDerives.toSet()),
        additionalAttributes = listOf(NonExhaustive),
        additionalAttributes = additionalAttributes,
        public = true
    )

Loading