diff --git a/README.md b/README.md index 29f346a70..11b0ea5a4 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,17 @@ For compatibility the default is to convert field names to `camelCase`. You can MyMessage().to_dict(casing=betterproto.Casing.SNAKE) ``` +#### Proto3 canonical JSON + +By default, enum values are serialized using their stripped Python names (e.g. `"HEARTS"` for a proto enum value `SUIT_HEARTS`). To use the original `.proto` enum names as required by the [proto3 JSON spec](https://protobuf.dev/programming-guides/json/), pass `proto3_json=True`: + +```python +Card(suit=Suit.HEARTS).to_dict(proto3_json=True) +# {"suit": "SUIT_HEARTS"} instead of {"suit": "HEARTS"} +``` + +When deserializing, `from_dict()` accepts both formats (stripped names, full proto names, and integer values). + ### Determining if a message was sent Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example. diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index ce8a26a4d..af3ddd57e 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1405,7 +1405,10 @@ def FromString(cls: Type[T], data: bytes) -> T: return cls().parse(data) def to_dict( - self, casing: Casing = Casing.CAMEL, include_default_values: bool = False + self, + casing: Casing = Casing.CAMEL, + include_default_values: bool = False, + proto3_json: bool = False, ) -> Dict[str, Any]: """ Returns a JSON serializable dict representation of this object. @@ -1419,6 +1422,10 @@ def to_dict( If ``True`` will include the default values of fields. Default is ``False``. E.g. an ``int32`` field will be included with a value of ``0`` if this is set to ``True``, otherwise this would be ignored. + proto3_json: :class:`bool` + If ``True`` will use proto3 canonical JSON format for enum values, + serializing them with their original .proto names (e.g. "MARK_TYPE_BOLD") + instead of the stripped Python names (e.g. "BOLD"). Default is ``False``. Returns -------- @@ -1466,7 +1473,8 @@ def to_dict( value = [_Duration.delta_to_json(i) for i in value] else: value = [ - i.to_dict(casing, include_default_values) for i in value + i.to_dict(casing, include_default_values, proto3_json) + for i in value ] if value or include_default_values: output[cased_name] = value @@ -1480,12 +1488,16 @@ def to_dict( field_name=field_name, meta=meta ) ): - output[cased_name] = value.to_dict(casing, include_default_values) + output[cased_name] = value.to_dict( + casing, include_default_values, proto3_json + ) elif meta.proto_type == TYPE_MAP: output_map = {**value} for k in value: if hasattr(value[k], "to_dict"): - output_map[k] = value[k].to_dict(casing, include_default_values) + output_map[k] = value[k].to_dict( + casing, include_default_values, proto3_json + ) if value or include_default_values: output[cased_name] = output_map @@ -1514,24 +1526,32 @@ def to_dict( else: output[cased_name] = b64encode(value).decode("utf8") elif meta.proto_type == TYPE_ENUM: + + def _enum_name(member): + if proto3_json: + return getattr(member, "proto_name", None) or member.name + return member.name + if field_is_repeated: enum_class = field_types[field_name].__args__[0] if isinstance(value, typing.Iterable) and not isinstance( value, str ): - output[cased_name] = [enum_class(el).name for el in value] + output[cased_name] = [ + _enum_name(enum_class(el)) for el in value + ] else: # transparently upgrade single value to repeated - output[cased_name] = [enum_class(value).name] + output[cased_name] = [_enum_name(enum_class(value))] elif value is None: if include_default_values: output[cased_name] = value elif meta.optional: enum_class = field_types[field_name].__args__[0] - output[cased_name] = enum_class(value).name + output[cased_name] = _enum_name(enum_class(value)) else: enum_class = field_types[field_name] # noqa - output[cased_name] = enum_class(value).name + output[cased_name] = _enum_name(enum_class(value)) elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): if field_is_repeated: output[cased_name] = [_dump_float(n) for n in value] @@ -1591,10 +1611,18 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: ) elif meta.proto_type == TYPE_ENUM: enum_cls = cls._betterproto.cls_by_field[field_name] + + def _parse_enum(e, ec=enum_cls): + if isinstance(e, int): + return ec.try_value(e) + return ec.from_string(e) + if isinstance(value, list): - value = [enum_cls.from_string(e) for e in value] + value = [_parse_enum(e) for e in value] + elif isinstance(value, int): + value = _parse_enum(value) elif isinstance(value, str): - value = enum_cls.from_string(value) + value = _parse_enum(value) elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): value = ( [_parse_float(n) for n in value] diff --git a/src/betterproto/enum.py b/src/betterproto/enum.py index 6b1b7e0a4..3d8bc45d3 100644 --- a/src/betterproto/enum.py +++ b/src/betterproto/enum.py @@ -35,12 +35,14 @@ def _is_descriptor(obj: object) -> bool: class EnumType(EnumMeta if TYPE_CHECKING else type): _value_map_: Mapping[int, Enum] _member_map_: Mapping[str, Enum] + _proto_names_: Dict[str, str] # Maps Python name -> original proto name def __new__( mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any] ) -> Self: value_map = {} member_map = {} + proto_names = namespace.pop("_proto_names_", {}) new_mcs = type( f"{name}Type", @@ -50,7 +52,11 @@ def __new__( + [EnumType, type] ) ), # reorder the bases so EnumType and type are last to avoid conflicts - {"_value_map_": value_map, "_member_map_": member_map}, + { + "_value_map_": value_map, + "_member_map_": member_map, + "_proto_names_": proto_names, + }, ) members = { @@ -71,7 +77,8 @@ def __new__( for name, value in members.items(): member = value_map.get(value) if member is None: - member = cls.__new__(cls, name=name, value=value) # type: ignore + proto_name = proto_names.get(name, name) + member = cls.__new__(cls, name=name, value=value, proto_name=proto_name) # type: ignore value_map[value] = member member_map[name] = member type.__setattr__(new_mcs, name, member) @@ -123,17 +130,27 @@ class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType): name: Optional[str] value: int + proto_name: Optional[str] if not TYPE_CHECKING: - def __new__(cls, *, name: Optional[str], value: int) -> Self: + def __new__( + cls, *, name: Optional[str], value: int, proto_name: Optional[str] = None + ) -> Self: self = super().__new__(cls, value) super().__setattr__(self, "name", name) super().__setattr__(self, "value", value) + # proto_name is the original name from the .proto file (e.g. "MARK_TYPE_BOLD") + # used for proto3 canonical JSON serialization + super().__setattr__(self, "proto_name", proto_name or name) return self def __getnewargs_ex__(self) -> Tuple[Tuple[()], Dict[str, Any]]: - return (), {"name": self.name, "value": self.value} + return (), { + "name": self.name, + "value": self.value, + "proto_name": self.proto_name, + } def __str__(self) -> str: return self.name or "None" @@ -181,6 +198,9 @@ def try_value(cls, value: int = 0) -> Self: def from_string(cls, name: str) -> Self: """Return the value which corresponds to the string name. + Accepts both the Python member name (e.g. "BOLD") and the original + proto name (e.g. "MARK_TYPE_BOLD") per the proto3 JSON spec. + Parameters ----------- name: :class:`str` @@ -191,7 +211,12 @@ def from_string(cls, name: str) -> Self: :exc:`ValueError` The member was not found in the Enum. """ - try: - return cls._member_map_[name] - except KeyError as e: - raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e + # Try stripped Python name first + member = cls._member_map_.get(name) + if member is not None: + return member + # Try original proto name + for m in cls._member_map_.values(): + if getattr(m, "proto_name", None) == name: + return m + raise ValueError(f"Unknown value {name} for enum {cls.__name__}") diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index e330e6884..e6ddb2885 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -650,6 +650,7 @@ class EnumEntry: name: str value: int comment: str + proto_name: str # Original name from .proto file def __post_init__(self) -> None: # Get entries/allowed values for this Enum @@ -662,6 +663,7 @@ def __post_init__(self) -> None: comment=get_comment( proto_file=self.source_file, path=self.path + [2, entry_number] ), + proto_name=entry_proto_value.name, ) for entry_number, entry_proto_value in enumerate(self.proto_obj.value) ] diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 4a252aec6..871bbf2bf 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -4,6 +4,9 @@ class {{ enum.py_name }}(betterproto.Enum): {{ enum.comment }} {% endif %} + # Mapping from Python member names to original proto names (for canonical JSON) + _proto_names_ = { {% for entry in enum.entries %}"{{ entry.name }}": "{{ entry.proto_name }}", {% endfor %}} + {% for entry in enum.entries %} {{ entry.name }} = {{ entry.value }} {% if entry.comment %} diff --git a/tests/test_proto3_json.py b/tests/test_proto3_json.py new file mode 100644 index 000000000..2a68337a9 --- /dev/null +++ b/tests/test_proto3_json.py @@ -0,0 +1,101 @@ +"""Test proto3 canonical JSON enum serialization. + +Verifies that to_dict(proto3_json=True) uses the original .proto enum +names per the proto3 JSON spec, while to_dict() (default) preserves +the existing stripped Python names for backwards compatibility. +""" + +from typing import List + +import pytest + +import betterproto + + +class Suit(betterproto.Enum): + _proto_names_ = { + "UNSPECIFIED": "SUIT_UNSPECIFIED", + "HEARTS": "SUIT_HEARTS", + "DIAMONDS": "SUIT_DIAMONDS", + "CLUBS": "SUIT_CLUBS", + "SPADES": "SUIT_SPADES", + } + + UNSPECIFIED = 0 + HEARTS = 1 + DIAMONDS = 2 + CLUBS = 3 + SPADES = 4 + + +from dataclasses import dataclass + + +@dataclass(eq=False, repr=False) +class Card(betterproto.Message): + suit: "Suit" = betterproto.enum_field(1) + value: int = betterproto.int32_field(2) + + +@dataclass(eq=False, repr=False) +class Hand(betterproto.Message): + cards: List["Card"] = betterproto.message_field(1) + + +class TestProto3JsonSerialization: + """to_dict(proto3_json=True) uses full proto names.""" + + def test_single_enum(self): + card = Card(suit=Suit.HEARTS, value=10) + d = card.to_dict(proto3_json=True) + assert d["suit"] == "SUIT_HEARTS" + + def test_default_uses_stripped_name(self): + card = Card(suit=Suit.HEARTS, value=10) + d = card.to_dict() + assert d["suit"] == "HEARTS" + + def test_nested_propagates(self): + hand = Hand(cards=[Card(suit=Suit.SPADES, value=1)]) + d = hand.to_dict(proto3_json=True) + assert d["cards"][0]["suit"] == "SUIT_SPADES" + + +class TestProto3JsonDeserialization: + """from_dict accepts both proto names and stripped names.""" + + def test_accept_proto_name(self): + card = Card().from_dict({"suit": "SUIT_HEARTS", "value": 10}) + assert card.suit == Suit.HEARTS + + def test_accept_stripped_name(self): + card = Card().from_dict({"suit": "HEARTS", "value": 10}) + assert card.suit == Suit.HEARTS + + def test_accept_integer(self): + card = Card().from_dict({"suit": 1, "value": 10}) + assert card.suit == Suit.HEARTS + + def test_round_trip_proto3(self): + original = Card(suit=Suit.DIAMONDS, value=7) + d = original.to_dict(proto3_json=True) + restored = Card().from_dict(d) + assert restored.suit == Suit.DIAMONDS + assert restored.value == 7 + + def test_round_trip_default(self): + original = Card(suit=Suit.CLUBS, value=3) + d = original.to_dict() + restored = Card().from_dict(d) + assert restored.suit == Suit.CLUBS + + +class TestEnumWithoutProtoNames: + """Enums without _proto_names_ (backwards compat).""" + + def test_proto3_json_falls_back_to_name(self): + """Without _proto_names_, proto3_json=True uses the Python name.""" + from tests.test_enum import Colour + + # Colour doesn't have _proto_names_, so proto_name == name + assert Colour.RED.proto_name == "RED"