Skip to content

Commit b3d4bb7

Browse files
soceanainnNusnus
authored andcommitted
Give user registered types priority when encoding / decoding JSON
1 parent 1a5eb30 commit b3d4bb7

2 files changed

Lines changed: 50 additions & 17 deletions

File tree

kombu/utils/json.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,20 @@ class JSONEncoder(json.JSONEncoder):
2323
"""Kombu custom json encoder."""
2424

2525
def default(self, o):
26+
for t, (marker, encoder) in _encoders.items():
27+
if isinstance(o, t):
28+
return (
29+
encoder(o) if marker is None else _as(marker, encoder(o))
30+
)
31+
2632
reducer = getattr(o, "__json__", None)
2733
if reducer is not None:
2834
return reducer()
2935

3036
if isinstance(o, textual_types):
3137
return str(o)
3238

33-
for t, (marker, encoder) in _encoders.items():
39+
for t, (marker, encoder) in _default_encoders.items():
3440
if isinstance(o, t):
3541
return (
3642
encoder(o) if marker is None else _as(marker, encoder(o))
@@ -66,7 +72,7 @@ def dumps(
6672
def object_hook(o: dict):
6773
"""Hook function to perform custom deserialization."""
6874
if o.keys() == {"__type__", "__value__"}:
69-
decoder = _decoders.get(o["__type__"])
75+
decoder = _decoders.get(o["__type__"]) or _default_decoders.get(o["__type__"])
7076
if decoder:
7177
return decoder(o["__value__"])
7278
else:
@@ -97,6 +103,16 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
97103
T = TypeVar("T")
98104
EncodedT = TypeVar("EncodedT")
99105

106+
# Separate user registered types from Kombu registered types to allow us to give preference to user types
107+
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
108+
_decoders: dict[str, DecoderT] = {}
109+
110+
_default_encoders: dict[type, tuple[str | None, EncoderT]] = {}
111+
_default_decoders: dict[str, DecoderT] = {
112+
"bytes": lambda o: o.encode("utf-8"),
113+
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
114+
}
115+
100116

101117
def register_type(
102118
t: type[T],
@@ -110,32 +126,40 @@ def register_type(
110126
is not placed in an envelope, so `decoder` is unnecessary. Decoding must
111127
instead be handled outside this library.
112128
"""
113-
_encoders[t] = (marker, encoder)
114-
if marker is not None:
115-
_decoders[marker] = decoder
129+
_register_type(t, marker, encoder, decoder, is_default_encoder=False)
116130

117131

118-
_encoders: dict[type, tuple[str | None, EncoderT]] = {}
119-
_decoders: dict[str, DecoderT] = {
120-
"bytes": lambda o: o.encode("utf-8"),
121-
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
122-
}
132+
def _register_type(
133+
t: type[T],
134+
marker: str | None,
135+
encoder: Callable[[T], EncodedT],
136+
decoder: Callable[[EncodedT], T] = lambda d: d,
137+
is_default_encoder: bool = True,
138+
):
139+
if is_default_encoder:
140+
_default_encoders[t] = (marker, encoder)
141+
if marker is not None:
142+
_default_decoders[marker] = decoder
143+
else:
144+
_encoders[t] = (marker, encoder)
145+
if marker is not None:
146+
_decoders[marker] = decoder
123147

124148

125149
def _register_default_types():
126150
# NOTE: datetime should be registered before date,
127151
# because datetime is also instance of date.
128-
register_type(datetime, "datetime", datetime.isoformat,
129-
datetime.fromisoformat)
130-
register_type(
152+
_register_type(datetime, "datetime", datetime.isoformat,
153+
datetime.fromisoformat)
154+
_register_type(
131155
date,
132156
"date",
133157
lambda o: o.isoformat(),
134-
lambda o: datetime.fromisoformat(o).date(),
158+
lambda o: datetime.fromisoformat(o).date()
135159
)
136-
register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
137-
register_type(Decimal, "decimal", str, Decimal)
138-
register_type(
160+
_register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
161+
_register_type(Decimal, "decimal", str, Decimal)
162+
_register_type(
139163
uuid.UUID,
140164
"uuid",
141165
lambda o: {"hex": o.hex},

t/unit/utils/test_json.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ def test_register_type_overrides_defaults(self):
9595
loaded_value = loads(dumps({'u': value}))
9696
assert loaded_value == {'u': "custom"}
9797

98+
def test_register_type_takes_priority(self):
99+
class MyDecimal(Decimal):
100+
pass
101+
102+
register_type(MyDecimal, "mydecimal", str, MyDecimal)
103+
original = {'md': MyDecimal('3314132.13363235235324234123213213214134')}
104+
loaded_value = loads(dumps(original))
105+
assert original == loaded_value
106+
98107
def test_register_type_with_new_type(self):
99108
# Guaranteed never before seen type
100109
@dataclass()

0 commit comments

Comments
 (0)