Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ private void generateDeserializer() {
var symbol = symbolProvider.toSymbol(shape);
var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER);
var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA);
var unknownSymbol = symbol.expectProperty(SymbolProperties.UNION_UNKNOWN);
writer.putContext("schema", schemaSymbol);
writer.write("""
class $1L:
Expand All @@ -168,6 +169,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
${4C|}
case _:
logger.debug("Unexpected member schema: %s", schema)
self._set_result($5L(tag=schema.member_name or ""))

def _set_result(self, value: $2T) -> None:
if self._result is not None:
Expand All @@ -177,7 +179,8 @@ raise SerializationError("Unions must have exactly one value, but found more tha
deserializerSymbol.getName(),
symbol,
schemaSymbol,
writer.consumer(w -> deserializeMembers()));
writer.consumer(w -> deserializeMembers()),
unknownSymbol.getName());
}

private void deserializeMembers() {
Expand Down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add unit tests in test_deserializers.py for the unknown event handling path.

Also, the test fixture EventStreamDeserializer._consumer still raises SmithyError in the default case. Should that be updated to match the new behavior?

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SpecificShapeDeserializer,
)
from smithy_core.schemas import Schema
from smithy_core.shapes import ShapeType
from smithy_core.shapes import ShapeID, ShapeType
from smithy_core.traits import EventHeaderTrait
from smithy_core.utils import expect_type

Expand Down Expand Up @@ -50,11 +50,34 @@ def read_struct(
message_deserializer = self._create_deserializer(schema, headers)
message_deserializer.read_struct(schema, consumer)
else:
member_schema = schema.members[member_name]
message_deserializer = self._create_deserializer(
member_schema, headers
)
consumer(member_schema, message_deserializer)
member_schema = schema.members.get(member_name)
if member_schema is None:
# Unknown event type. Call the consumer with a
# schema that carries the event type name as
# member_name and a member_index of -1 so the
# generated default branch constructs the unknown
# variant with the correct tag.
logger.debug("Unknown event type: %s", member_name)

_UNKNOWN_TARGET = Schema(
id=ShapeID("smithy.unknown#Unknown"),
shape_type=ShapeType.STRUCTURE,
)
unknown_schema = Schema(
id=ShapeID(f"smithy.unknown#Unknown${member_name}"),
shape_type=ShapeType.STRUCTURE,
member_target=_UNKNOWN_TARGET,
member_index=-1,
)
consumer(
unknown_schema,
self._payload_codec.create_deserializer(b"{}"),
)
else:
message_deserializer = self._create_deserializer(
member_schema, headers
)
consumer(member_schema, message_deserializer)
case "exception":
member_name = expect_type(str, headers[":exception-type"])
member_schema = schema.members[member_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ async def receive(self) -> E | None:
)
result = self._deserializer(deserializer)
logger.debug("Successfully deserialized event: %s", result)
if isinstance(getattr(result, "value"), Exception):
raise result.value # type: ignore
value = getattr(result, "value", None)
if isinstance(value, Exception):
raise value
return result

async def close(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def serialize_members(self, serializer: ShapeSerializer):


@dataclass
class EventStreamUnknownEvent:
class EventStreamUnknown:
tag: str

def serialize(self, serializer: ShapeSerializer):
Expand All @@ -396,7 +396,7 @@ def serialize_members(self, serializer: ShapeSerializer):
| EventStreamPayloadEvent
| EventStreamBlobPayloadEvent
| EventStreamErrorEvent
| EventStreamUnknownEvent
| EventStreamUnknown
)


Expand Down Expand Up @@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de)))

case _:
raise SmithyError(f"Unexpected member schema: {schema}")
self._set_result(EventStreamUnknown(tag=schema.member_name or ""))

def _set_result(self, value: EventStream) -> None:
if self._result is not None:
Expand Down Expand Up @@ -635,6 +635,19 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
]


UNKNOWN_EVENT_CASE = (
EventStreamUnknown(tag="intermediateGroupEvent"),
EventMessage(
headers={
":message-type": "event",
":event-type": "intermediateGroupEvent",
":content-type": "application/json",
},
payload=b"{}",
),
)


INITIAL_REQUEST_CASE = (
EventStreamOperationInputOutput(message="The initial request!"),
EventMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EventStreamDeserializer,
EventStreamErrorEvent,
EventStreamOperationInputOutput,
EventStreamUnknown,
)


Expand Down Expand Up @@ -126,3 +127,20 @@ async def test_read_closed_receiver_source() -> None:
with pytest.raises(IOError):
await receiver.receive()
assert receiver.closed


def test_deserialize_unknown_event_type():
message = EventMessage(
headers={
":message-type": "event",
":event-type": "intermediateGroupEvent",
":content-type": "application/json",
},
payload=b"{}",
)
source = Event.decode(BytesIO(message.encode()))
assert source is not None
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
result = EventStreamDeserializer().deserialize(deserializer)
assert isinstance(result, EventStreamUnknown)
assert result.tag == "intermediateGroupEvent"
Loading