Skip to content

Commit 43d8d35

Browse files
committed
ensure info is cached
1 parent 57bf7e6 commit 43d8d35

2 files changed

Lines changed: 117 additions & 4 deletions

File tree

src/msgspec/_core.c

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4958,11 +4958,36 @@ convert_types_generic_alias(TypeNodeCollectState *state, PyObject *obj, PyObject
49584958
// built-in container generic, such as 'collections.abc.Mapping'
49594959

49604960
if (MS_UNLIKELY(Py_TYPE(obj) == (PyTypeObject *)state->mod->types_generic_alias)) {
4961-
PyObject *genericAliasArgsList = Py_BuildValue("OO", origin, args);
4961+
// subscribed typing._GenericAlias instances are cached within the typing module
4962+
// we make use of this fact, by storing a __msgspec_cache__ attribute on the
4963+
// subscribed instance. only subscribed types are cache, so
4964+
// 'typing._GenericAlias(list, int) is typing._GenericAlias(list, int)' would be
4965+
// false.
4966+
// to achieve the same behaviour when re-creating a typing._GenericAlias from a
4967+
// types.GenericAlias, we first construct a temporary *unbound*
4968+
// typing._GenericAlias, on which we then call __getattr__. effectively doing
4969+
// typing._GenericAlias(list, T)[int], for which
4970+
// 'typing._GenericAlias(list, T)[int] is typing._GenericAlias(list, T)[int]'
4971+
// holds true
4972+
PyObject *params = PyObject_GetAttrString(origin, "__parameters__");
4973+
if (params == NULL) {
4974+
Py_DECREF(origin);
4975+
return NULL;
4976+
}
4977+
4978+
// create a new typing._GenericAlias with the unbound type params of the
4979+
// original types.GenericAlias.
4980+
// given a Mapping[str, int], this would produce a _GenericAlias(Mapping, (~K, ~V))
4981+
PyObject *new_alias = PyObject_CallFunctionObjArgs(state->mod->typing_generic_alias, origin, params, NULL);
4982+
if (new_alias == NULL) {
4983+
return NULL;
4984+
}
49624985

4963-
PyObject *newGenericAlias = PyObject_CallObject(state->mod->typing_generic_alias, genericAliasArgsList);
4964-
Py_DECREF(genericAliasArgsList);
4965-
return newGenericAlias;
4986+
// bind it to the concrete types.
4987+
// given a _GenericAlias(Mapping, (~K, ~V)), produce a Mapping[str, int] again
4988+
PyObject *result = PyObject_CallMethod(new_alias, "__getitem__", "O", args);
4989+
Py_DECREF(new_alias);
4990+
return result;
49664991
}
49674992
return obj;
49684993
}

tests/unit/test_common.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
T = TypeVar("T")
5757

5858

59+
def make_new_alias(alias):
60+
"""Replicate the logic to produce a types._GenericAlias from a typing.GenericAlias"""
61+
return typing._GenericAlias(
62+
alias.__origin__, alias.__origin__.__parameters__
63+
).__getitem__(*alias.__args__)
64+
65+
5966
def assert_eq(x, y):
6067
assert x == y
6168
assert type(x) is type(y)
@@ -1634,6 +1641,47 @@ def __iter__(self):
16341641
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
16351642
)
16361643

1644+
@pytest.mark.parametrize(
1645+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
1646+
)
1647+
@py312_plus
1648+
def test_inherited_builtin_generic_typevar_syntax_info_cached(self, mapping_type: str):
1649+
source = f"""
1650+
from msgspec import Struct, StructMeta
1651+
import collections
1652+
import abc
1653+
import typing
1654+
1655+
class CombinedMeta(StructMeta, abc.ABCMeta):
1656+
pass
1657+
1658+
class Foo[T]({mapping_type}[str, T], Struct, metaclass=CombinedMeta):
1659+
data: dict[str, T]
1660+
1661+
def __getitem__(self, x):
1662+
return self.data[x]
1663+
1664+
def __len__(self):
1665+
return len(self.data)
1666+
1667+
def __iter__(self):
1668+
return iter(self.data)
1669+
"""
1670+
1671+
with temp_module(source) as mod:
1672+
typ = mod.Foo[int]
1673+
dec = msgspec.json.Decoder(typ)
1674+
info = make_new_alias(typ).__msgspec_cache__
1675+
assert info is not None
1676+
assert sys.getrefcount(info) <= 4 # info + attr + decoder + func call
1677+
dec2 = msgspec.json.Decoder(typ)
1678+
assert make_new_alias(typ).__msgspec_cache__ is info
1679+
assert sys.getrefcount(info) <= 5
1680+
1681+
del dec
1682+
del dec2
1683+
assert sys.getrefcount(info) <= 3
1684+
16371685

16381686
class TestStructPostInit:
16391687
@pytest.mark.parametrize("array_like", [False, True])
@@ -1953,6 +2001,46 @@ def __iter__(self):
19532001
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
19542002
)
19552003

2004+
@pytest.mark.parametrize(
2005+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
2006+
)
2007+
@py312_plus
2008+
def test_inherited_builtin_generic_typevar_syntax_info_cached(
2009+
self, mapping_type: str
2010+
):
2011+
source = f"""
2012+
import dataclasses
2013+
import collections
2014+
import typing
2015+
2016+
@dataclasses.dataclass
2017+
class Foo[T]({mapping_type}[str, T]):
2018+
data: dict[str, T]
2019+
2020+
def __getitem__(self, x):
2021+
return self.data[x]
2022+
2023+
def __len__(self):
2024+
return len(self.data)
2025+
2026+
def __iter__(self):
2027+
return iter(self.data)
2028+
"""
2029+
2030+
with temp_module(source) as mod:
2031+
typ = mod.Foo[int]
2032+
dec = msgspec.json.Decoder(typ)
2033+
info = make_new_alias(typ).__msgspec_cache__
2034+
assert info is not None
2035+
assert sys.getrefcount(info) <= 4 # info + attr + decoder + func call
2036+
dec2 = msgspec.json.Decoder(typ)
2037+
assert make_new_alias(typ).__msgspec_cache__ is info
2038+
assert sys.getrefcount(info) <= 5
2039+
2040+
del dec
2041+
del dec2
2042+
assert sys.getrefcount(info) <= 3
2043+
19562044

19572045
class TestStructOmitDefaults:
19582046
def test_omit_defaults(self, proto):

0 commit comments

Comments
 (0)