Skip to content

Commit 331d4b7

Browse files
committed
Add librt functionality for lazy deserialization
1 parent d7e3268 commit 331d4b7

File tree

6 files changed

+299
-7
lines changed

6 files changed

+299
-7
lines changed

mypy/cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
239239
# Always use this type alias to refer to type tags.
240240
Tag = u8
241241

242+
# Note: all tags should be kept in sync with lib-rt/internal/librt_internal.c.
242243
# Primitives.
243244
LITERAL_FALSE: Final[Tag] = 0
244245
LITERAL_TRUE: Final[Tag] = 1
@@ -264,6 +265,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
264265
# Four integers representing source file (line, column) range.
265266
LOCATION: Final[Tag] = 152
266267

268+
RESERVED: Final[Tag] = 254
267269
END_TAG: Final[Tag] = 255
268270

269271

mypy/nodes.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4930,7 +4930,20 @@ def read(cls, data: ReadBuffer) -> SymbolTableNode:
49304930
sym.plugin_generated = read_bool(data)
49314931
cross_ref = read_str_opt(data)
49324932
if cross_ref is None:
4933-
sym.node = read_symbol(data)
4933+
tag = read_tag(data)
4934+
if tag == TYPE_INFO:
4935+
sym.node = TypeInfo.read(data)
4936+
else:
4937+
# This logic is temporary, to make sure we don't introduce
4938+
# regressions until we have proper lazy deserialization.
4939+
# It has negligible performance impact.
4940+
try:
4941+
from librt.internal import extract_symbol
4942+
except ImportError:
4943+
sym.node = read_symbol(data, tag)
4944+
else:
4945+
node_bytes = extract_symbol(data)
4946+
sym.node = read_symbol(ReadBuffer(node_bytes), tag)
49344947
else:
49354948
sym.cross_ref = cross_ref
49364949
assert read_tag(data) == END_TAG
@@ -5333,17 +5346,14 @@ def local_definitions(
53335346
TSTRING_EXPR: Final[Tag] = 229
53345347

53355348

5336-
def read_symbol(data: ReadBuffer) -> SymbolNode:
5337-
tag = read_tag(data)
5349+
def read_symbol(data: ReadBuffer, tag: Tag) -> SymbolNode:
53385350
# The branches here are ordered manually by type "popularity".
53395351
if tag == VAR:
53405352
return Var.read(data)
53415353
if tag == FUNC_DEF:
53425354
return FuncDef.read(data)
53435355
if tag == DECORATOR:
53445356
return Decorator.read(data)
5345-
if tag == TYPE_INFO:
5346-
return TypeInfo.read(data)
53475357
if tag == OVERLOADED_FUNC_DEF:
53485358
return OverloadedFuncDef.read(data)
53495359
if tag == TYPE_VAR_EXPR:

mypy/typeshed/stubs/librt/librt/internal.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ def read_int(data: ReadBuffer, /) -> int: ...
1919
def write_tag(data: WriteBuffer, value: u8, /) -> None: ...
2020
def read_tag(data: ReadBuffer, /) -> u8: ...
2121
def cache_version() -> u8: ...
22+
def extract_symbol(data: ReadBuffer, /) -> bytes: ...

mypyc/lib-rt/internal/librt_internal.c

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,273 @@ write_tag(PyObject *self, PyObject *const *args, size_t nargs) {
920920
return Py_None;
921921
}
922922

923+
// All tags must be kept in sync with cache.py, nodes.py, and types.py.
924+
// Primitive types.
925+
#define LITERAL_FALSE 0
926+
#define LITERAL_TRUE 1
927+
#define LITERAL_NONE 2
928+
#define LITERAL_INT 3
929+
#define LITERAL_STR 4
930+
#define LITERAL_BYTES 5
931+
#define LITERAL_FLOAT 6
932+
#define LITERAL_COMPLEX 7
933+
934+
// Supported builtin collections.
935+
#define LIST_GEN 20
936+
#define LIST_INT 21
937+
#define LIST_STR 22
938+
#define LIST_BYTES 23
939+
#define TUPLE_GEN 24
940+
#define DICT_STR_GEN 30
941+
942+
// This is the smallest custom class tag.
943+
#define MYPY_FILE 50
944+
945+
// Instance class has special formats.
946+
#define INSTANCE 80
947+
#define INSTANCE_SIMPLE 81
948+
#define INSTANCE_GENERIC 82
949+
#define INSTANCE_STR 83
950+
#define INSTANCE_FUNCTION 84
951+
#define INSTANCE_INT 85
952+
#define INSTANCE_BOOL 86
953+
#define INSTANCE_OBJECT 87
954+
955+
#define RESERVED 254
956+
#define END_TAG 255
957+
958+
// Forward declaration.
959+
static char _skip_object(PyObject *data, uint8_t tag);
960+
961+
static inline char
962+
_skip(PyObject *data, Py_ssize_t size) {
963+
// We are careful about error conditions, so all
964+
// _skip_xxx() functions can return an error value.
965+
_CHECK_READ(data, size, CPY_NONE_ERROR)
966+
((ReadBufferObject *)data)->ptr += size;
967+
return CPY_NONE;
968+
}
969+
970+
static inline char
971+
_skip_short_int(PyObject *data, uint8_t first) {
972+
if ((first & TWO_BYTES_INT_BIT) == 0)
973+
return CPY_NONE;
974+
if ((first & FOUR_BYTES_INT_BIT) == 0)
975+
return _skip(data, 1);
976+
return _skip(data, 3);
977+
}
978+
979+
static inline char
980+
_skip_int(PyObject *data) {
981+
_CHECK_READ(data, 1, CPY_NONE_ERROR)
982+
983+
uint8_t first;
984+
_READ(&first, data, uint8_t);
985+
if (likely(first != LONG_INT_TRAILER)) {
986+
return _skip_short_int(data, first);
987+
}
988+
989+
_CHECK_READ(data, 1, CPY_NONE_ERROR)
990+
_READ(&first, data, uint8_t);
991+
Py_ssize_t size_and_sign = _read_short_int(data, first);
992+
if (size_and_sign == CPY_INT_TAG)
993+
return CPY_NONE_ERROR;
994+
if ((Py_ssize_t)size_and_sign < 0) {
995+
PyErr_SetString(PyExc_ValueError, "invalid int data");
996+
return CPY_NONE_ERROR;
997+
}
998+
Py_ssize_t size = size_and_sign >> 2;
999+
return _skip(data, size);
1000+
}
1001+
1002+
// This is essentially a wrapper around _read_short_int() that makes
1003+
// sure the result is valid.
1004+
static inline Py_ssize_t
1005+
_read_size(PyObject *data) {
1006+
_CHECK_READ(data, 1, -1)
1007+
uint8_t first;
1008+
_READ(&first, data, uint8_t);
1009+
// We actually allow serializing lists/dicts with over 4 billion items,
1010+
// but we don't really need to, fail with ValueError just in case.
1011+
if (unlikely(first == LONG_INT_TRAILER)) {
1012+
PyErr_SetString(PyExc_ValueError, "unsupported size");
1013+
return -1;
1014+
}
1015+
CPyTagged tagged_size = _read_short_int(data, first);
1016+
if (tagged_size == CPY_INT_TAG)
1017+
return -1;
1018+
if ((Py_ssize_t)tagged_size < 0) {
1019+
PyErr_SetString(PyExc_ValueError, "invalid size");
1020+
return -1;
1021+
}
1022+
Py_ssize_t size = tagged_size >> 1;
1023+
return size;
1024+
}
1025+
1026+
static inline char
1027+
_skip_str_bytes(PyObject *data) {
1028+
Py_ssize_t size = _read_size(data);
1029+
if (size < 0)
1030+
return CPY_NONE_ERROR;
1031+
return _skip(data, size);
1032+
}
1033+
1034+
// List/dict logic should be kept in sync with mypy/cache.py
1035+
static inline char
1036+
_skip_list_gen(PyObject *data) {
1037+
Py_ssize_t size = _read_size(data);
1038+
if (size < 0)
1039+
return CPY_NONE_ERROR;
1040+
int i;
1041+
for (i = 0; i < size; i++) {
1042+
uint8_t tag = read_tag_internal(data);
1043+
if (unlikely(tag == CPY_LL_UINT_ERROR && PyErr_Occurred())) {
1044+
return CPY_NONE_ERROR;
1045+
}
1046+
if (unlikely(_skip_object(data, tag) == CPY_NONE_ERROR))
1047+
return CPY_NONE_ERROR;
1048+
}
1049+
return CPY_NONE;
1050+
}
1051+
1052+
static inline char
1053+
_skip_list_int(PyObject *data) {
1054+
Py_ssize_t size = _read_size(data);
1055+
if (size < 0)
1056+
return CPY_NONE_ERROR;
1057+
int i;
1058+
for (i = 0; i < size; i++) {
1059+
if (unlikely(_skip_int(data) == CPY_NONE_ERROR))
1060+
return CPY_NONE_ERROR;
1061+
}
1062+
return CPY_NONE;
1063+
}
1064+
1065+
static inline char
1066+
_skip_list_str_bytes(PyObject *data) {
1067+
Py_ssize_t size = _read_size(data);
1068+
if (size < 0)
1069+
return CPY_NONE_ERROR;
1070+
int i;
1071+
for (i = 0; i < size; i++) {
1072+
if (unlikely(_skip_str_bytes(data) == CPY_NONE_ERROR))
1073+
return CPY_NONE_ERROR;
1074+
}
1075+
return CPY_NONE;
1076+
}
1077+
1078+
static inline char
1079+
_skip_dict_str_gen(PyObject *data) {
1080+
Py_ssize_t size = _read_size(data);
1081+
if (size < 0)
1082+
return CPY_NONE_ERROR;
1083+
int i;
1084+
for (i = 0; i < size; i++) {
1085+
// Bare key followed by tagged value.
1086+
if (unlikely(_skip_str_bytes(data) == CPY_NONE_ERROR))
1087+
return CPY_NONE_ERROR;
1088+
uint8_t tag = read_tag_internal(data);
1089+
if (unlikely(tag == CPY_LL_UINT_ERROR && PyErr_Occurred())) {
1090+
return CPY_NONE_ERROR;
1091+
}
1092+
if (unlikely(_skip_object(data, tag) == CPY_NONE_ERROR))
1093+
return CPY_NONE_ERROR;
1094+
}
1095+
return CPY_NONE;
1096+
}
1097+
1098+
// Similar to mypy/cache.py, the convention is that the caller reads
1099+
// the opening tag for custom classes.
1100+
static inline char
1101+
_skip_class(PyObject *data) {
1102+
while (1) {
1103+
uint8_t tag = read_tag_internal(data);
1104+
if (unlikely(tag == CPY_LL_UINT_ERROR && PyErr_Occurred())) {
1105+
return CPY_NONE_ERROR;
1106+
}
1107+
if (tag == END_TAG) {
1108+
return CPY_NONE;
1109+
}
1110+
if (unlikely(_skip_object(data, tag) == CPY_NONE_ERROR)) {
1111+
return CPY_NONE_ERROR;
1112+
}
1113+
}
1114+
}
1115+
1116+
// Instance has special compact layout (as an important optimization).
1117+
static inline char
1118+
_skip_instance(PyObject *data) {
1119+
uint8_t second_tag = read_tag_internal(data);
1120+
if (unlikely(second_tag == CPY_LL_UINT_ERROR && PyErr_Occurred())) {
1121+
return CPY_NONE_ERROR;
1122+
}
1123+
if (second_tag >= INSTANCE_STR && second_tag <= INSTANCE_OBJECT) {
1124+
return CPY_NONE;
1125+
}
1126+
if (second_tag == INSTANCE_SIMPLE) {
1127+
return _skip_str_bytes(data);
1128+
}
1129+
if (second_tag == INSTANCE_GENERIC) {
1130+
return _skip_class(data);
1131+
}
1132+
PyErr_Format(PyExc_ValueError, "Unexpected instance tag: %d", second_tag);
1133+
return CPY_NONE_ERROR;
1134+
}
1135+
1136+
// This is the main dispatch point. Branches are ordered manually
1137+
// based roughly on frequency in self-check.
1138+
static char
1139+
_skip_object(PyObject *data, uint8_t tag) {
1140+
if (tag == LITERAL_STR || tag == LITERAL_BYTES)
1141+
return _skip_str_bytes(data);
1142+
if (tag == LITERAL_NONE || tag == LITERAL_FALSE || tag == LITERAL_TRUE)
1143+
return CPY_NONE;
1144+
if (tag == LIST_GEN || tag == TUPLE_GEN)
1145+
return _skip_list_gen(data);
1146+
if (tag == LITERAL_INT)
1147+
return _skip_int(data);
1148+
if (tag == INSTANCE)
1149+
return _skip_instance(data);
1150+
if (tag > MYPY_FILE && tag < RESERVED)
1151+
return _skip_class(data);
1152+
if (tag == LIST_INT)
1153+
return _skip_list_int(data);
1154+
if (tag == LIST_STR || tag == LIST_BYTES)
1155+
return _skip_list_str_bytes(data);
1156+
if (tag == DICT_STR_GEN)
1157+
return _skip_dict_str_gen(data);
1158+
if (tag == LITERAL_FLOAT)
1159+
return _skip(data, 8);
1160+
if (tag == LITERAL_COMPLEX)
1161+
return _skip(data, 16);
1162+
PyErr_Format(PyExc_ValueError, "Unsupported tag: %d", tag);
1163+
return CPY_NONE_ERROR;
1164+
}
1165+
1166+
static PyObject*
1167+
extract_symbol_internal(PyObject *data) {
1168+
char *ptr = ((ReadBufferObject *)data)->ptr;
1169+
if (unlikely(_skip_class(data) == CPY_NONE_ERROR))
1170+
return NULL;
1171+
Py_ssize_t size = ((ReadBufferObject *)data)->ptr - ptr;
1172+
PyObject *res = PyBytes_FromStringAndSize(ptr, size);
1173+
if (unlikely(res == NULL))
1174+
return NULL;
1175+
return res;
1176+
}
1177+
1178+
static PyObject*
1179+
extract_symbol(PyObject *self, PyObject *const *args, size_t nargs) {
1180+
if (unlikely(nargs != 1)) {
1181+
PyErr_Format(PyExc_TypeError,
1182+
"extract_symbol() takes exactly 1 argument (%zu given)", nargs);
1183+
return NULL;
1184+
}
1185+
PyObject *data = args[0];
1186+
_CHECK_READ_BUFFER(data, NULL)
1187+
return extract_symbol_internal(data);
1188+
}
1189+
9231190
static uint8_t
9241191
cache_version_internal(void) {
9251192
return 0;
@@ -954,6 +1221,7 @@ static PyMethodDef librt_internal_module_methods[] = {
9541221
{"write_tag", (PyCFunction)write_tag, METH_FASTCALL, PyDoc_STR("write a short int")},
9551222
{"read_tag", (PyCFunction)read_tag, METH_FASTCALL, PyDoc_STR("read a short int")},
9561223
{"cache_version", (PyCFunction)cache_version, METH_NOARGS, PyDoc_STR("cache format version")},
1224+
{"extract_symbol", (PyCFunction)extract_symbol, METH_FASTCALL, PyDoc_STR("extract bytes for a mypy symbol")},
9571225
{NULL, NULL, 0, NULL}
9581226
};
9591227

@@ -1005,6 +1273,7 @@ librt_internal_module_exec(PyObject *m)
10051273
(void *)ReadBuffer_type_internal,
10061274
(void *)WriteBuffer_type_internal,
10071275
(void *)NativeInternal_API_Version,
1276+
(void *)extract_symbol_internal
10081277
};
10091278
PyObject *c_api_object = PyCapsule_New((void *)NativeInternal_API, "librt.internal._C_API", NULL);
10101279
if (PyModule_Add(m, "_C_API", c_api_object) < 0) {

mypyc/lib-rt/internal/librt_internal.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
// API version -- more recent versions must maintain backward compatibility, i.e.
1212
// we can add new features but not remove or change existing features (unless
1313
// ABI version is changed, but see the comment above).
14-
#define LIBRT_INTERNAL_API_VERSION 0
14+
#define LIBRT_INTERNAL_API_VERSION 1
1515

1616
// Number of functions in the capsule API. If you add a new function, also increase
1717
// LIBRT_INTERNAL_API_VERSION.
18-
#define LIBRT_INTERNAL_API_LEN 20
18+
#define LIBRT_INTERNAL_API_LEN 21
1919

2020
#ifdef LIBRT_INTERNAL_MODULE
2121

@@ -41,6 +41,7 @@ static uint8_t cache_version_internal(void);
4141
static PyTypeObject *ReadBuffer_type_internal(void);
4242
static PyTypeObject *WriteBuffer_type_internal(void);
4343
static int NativeInternal_API_Version(void);
44+
static PyObject *extract_symbol_internal(PyObject *data);
4445

4546
#else
4647

@@ -66,6 +67,7 @@ static void *NativeInternal_API[LIBRT_INTERNAL_API_LEN];
6667
#define ReadBuffer_type_internal (*(PyTypeObject* (*)(void)) NativeInternal_API[17])
6768
#define WriteBuffer_type_internal (*(PyTypeObject* (*)(void)) NativeInternal_API[18])
6869
#define NativeInternal_API_Version (*(int (*)(void)) NativeInternal_API[19])
70+
#define extract_symbol_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[20])
6971

7072
static int
7173
import_librt_internal(void)

mypyc/primitives/misc_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,14 @@
503503
error_kind=ERR_NEVER,
504504
)
505505

506+
function_op(
507+
name="librt.internal.extract_symbol",
508+
arg_types=[object_rprimitive],
509+
return_type=bytes_rprimitive,
510+
c_function_name="extract_symbol_internal",
511+
error_kind=ERR_MAGIC,
512+
)
513+
506514
function_op(
507515
name="librt.base64.b64encode",
508516
arg_types=[bytes_rprimitive],

0 commit comments

Comments
 (0)