Skip to content

Commit 4285b96

Browse files
committed
ensure info is cached
1 parent 57bf7e6 commit 4285b96

2 files changed

Lines changed: 119 additions & 4 deletions

File tree

src/msgspec/_core.c

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4958,11 +4958,36 @@ convert_types_generic_alias(TypeNodeCollectState *state, PyObject *obj, PyObject
49584958
// built-in container generic, such as 'collections.abc.Mapping'
49594959

49604960
if (MS_UNLIKELY(Py_TYPE(obj) == (PyTypeObject *)state->mod->types_generic_alias)) {
4961-
PyObject *genericAliasArgsList = Py_BuildValue("OO", origin, args);
4961+
// subscribed typing._GenericAlias instances are cached within the typing module
4962+
// we make use of this fact, by storing a __msgspec_cache__ attribute on the
4963+
// subscribed instance. only subscribed types are cache, so
4964+
// 'typing._GenericAlias(list, int) is typing._GenericAlias(list, int)' would be
4965+
// false.
4966+
// to achieve the same behaviour when re-creating a typing._GenericAlias from a
4967+
// types.GenericAlias, we first construct a temporary *unbound*
4968+
// typing._GenericAlias, on which we then call __getattr__. effectively doing
4969+
// typing._GenericAlias(list, T)[int], for which
4970+
// 'typing._GenericAlias(list, T)[int] is typing._GenericAlias(list, T)[int]'
4971+
// holds true
4972+
PyObject *params = PyObject_GetAttrString(origin, "__parameters__");
4973+
if (params == NULL) {
4974+
Py_DECREF(origin);
4975+
return NULL;
4976+
}
4977+
4978+
// create a new typing._GenericAlias with the unbound type params of the
4979+
// original types.GenericAlias.
4980+
// given a Mapping[str, int], this would produce a _GenericAlias(Mapping, (~K, ~V))
4981+
PyObject *new_alias = PyObject_CallFunctionObjArgs(state->mod->typing_generic_alias, origin, params, NULL);
4982+
if (new_alias == NULL) {
4983+
return NULL;
4984+
}
49624985

4963-
PyObject *newGenericAlias = PyObject_CallObject(state->mod->typing_generic_alias, genericAliasArgsList);
4964-
Py_DECREF(genericAliasArgsList);
4965-
return newGenericAlias;
4986+
// bind it to the concrete types.
4987+
// given a _GenericAlias(Mapping, (~K, ~V)), produce a Mapping[str, int] again
4988+
PyObject *result = PyObject_CallMethod(new_alias, "__getitem__", "O", args);
4989+
Py_DECREF(new_alias);
4990+
return result;
49664991
}
49674992
return obj;
49684993
}

tests/unit/test_common.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
T = TypeVar("T")
5757

5858

59+
def make_new_alias(alias):
60+
"""Replicate the logic to produce a types._GenericAlias from a typing.GenericAlias"""
61+
return typing._GenericAlias(
62+
alias.__origin__, alias.__origin__.__parameters__
63+
).__getitem__(*alias.__args__)
64+
65+
5966
def assert_eq(x, y):
6067
assert x == y
6168
assert type(x) is type(y)
@@ -1634,6 +1641,49 @@ def __iter__(self):
16341641
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
16351642
)
16361643

1644+
@pytest.mark.parametrize(
1645+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
1646+
)
1647+
@py312_plus
1648+
def test_inherited_builtin_generic_typevar_syntax_info_cached(
1649+
self, mapping_type: str
1650+
):
1651+
source = f"""
1652+
from msgspec import Struct, StructMeta
1653+
import collections
1654+
import abc
1655+
import typing
1656+
1657+
class CombinedMeta(StructMeta, abc.ABCMeta):
1658+
pass
1659+
1660+
class Foo[T]({mapping_type}[str, T], Struct, metaclass=CombinedMeta):
1661+
data: dict[str, T]
1662+
1663+
def __getitem__(self, x):
1664+
return self.data[x]
1665+
1666+
def __len__(self):
1667+
return len(self.data)
1668+
1669+
def __iter__(self):
1670+
return iter(self.data)
1671+
"""
1672+
1673+
with temp_module(source) as mod:
1674+
typ = mod.Foo[int]
1675+
dec = msgspec.json.Decoder(typ)
1676+
info = make_new_alias(typ).__msgspec_cache__
1677+
assert info is not None
1678+
assert sys.getrefcount(info) <= 4 # info + attr + decoder + func call
1679+
dec2 = msgspec.json.Decoder(typ)
1680+
assert make_new_alias(typ).__msgspec_cache__ is info
1681+
assert sys.getrefcount(info) <= 5
1682+
1683+
del dec
1684+
del dec2
1685+
assert sys.getrefcount(info) <= 3
1686+
16371687

16381688
class TestStructPostInit:
16391689
@pytest.mark.parametrize("array_like", [False, True])
@@ -1953,6 +2003,46 @@ def __iter__(self):
19532003
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
19542004
)
19552005

2006+
@pytest.mark.parametrize(
2007+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
2008+
)
2009+
@py312_plus
2010+
def test_inherited_builtin_generic_typevar_syntax_info_cached(
2011+
self, mapping_type: str
2012+
):
2013+
source = f"""
2014+
import dataclasses
2015+
import collections
2016+
import typing
2017+
2018+
@dataclasses.dataclass
2019+
class Foo[T]({mapping_type}[str, T]):
2020+
data: dict[str, T]
2021+
2022+
def __getitem__(self, x):
2023+
return self.data[x]
2024+
2025+
def __len__(self):
2026+
return len(self.data)
2027+
2028+
def __iter__(self):
2029+
return iter(self.data)
2030+
"""
2031+
2032+
with temp_module(source) as mod:
2033+
typ = mod.Foo[int]
2034+
dec = msgspec.json.Decoder(typ)
2035+
info = make_new_alias(typ).__msgspec_cache__
2036+
assert info is not None
2037+
assert sys.getrefcount(info) <= 4 # info + attr + decoder + func call
2038+
dec2 = msgspec.json.Decoder(typ)
2039+
assert make_new_alias(typ).__msgspec_cache__ is info
2040+
assert sys.getrefcount(info) <= 5
2041+
2042+
del dec
2043+
del dec2
2044+
assert sys.getrefcount(info) <= 3
2045+
19562046

19572047
class TestStructOmitDefaults:
19582048
def test_omit_defaults(self, proto):

0 commit comments

Comments
 (0)