Skip to content

Commit 69a0925

Browse files
authored
Send serialized ASTs to parallel workers (#20991)
This way we can properly benefit from native parser in parallel type checking. Self-check is now ~2.2x faster with 5 workers compared to in-process checking (also with native parser). Also it uses less memory, but still with 5 workers, self-check takes ~twice more memory compared to in-process. Implementation is mostly straightforward. The GC freeze hack needed some tuning, as there is no single hot-spot in terms of allocations anymore. Note: do _not_ use `maturin develop` for any performance measurements, as this creates some very slow wheel.
1 parent e7d9f2d commit 69a0925

File tree

9 files changed

+302
-93
lines changed

9 files changed

+302
-93
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ jobs:
220220
# To speed-up process until ast_serialize is on PyPI.
221221
- name: Install pinned ast-serialize
222222
if: ${{ matrix.dev_ast_serialize }}
223-
run: pip install ast-serialize@git+https://github.com/mypyc/ast_serialize.git@d277690a078c7784667a640ed1045e725bc42c00
223+
run: pip install ast-serialize@git+https://github.com/mypyc/ast_serialize.git@052c5bfa3b2a5bf07c0b163ccbe2c5ccbfae9ac5
224224

225225
- name: Setup tox environment
226226
run: |

mypy/build.py

Lines changed: 104 additions & 37 deletions
Large diffs are not rendered by default.

mypy/build_worker/worker.py

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,24 @@
2727

2828
from mypy import util
2929
from mypy.build import (
30+
SCC,
3031
AckMessage,
3132
BuildManager,
33+
Graph,
3234
GraphMessage,
3335
SccRequestMessage,
3436
SccResponseMessage,
3537
SccsDataMessage,
3638
SourcesDataMessage,
37-
load_graph,
3839
load_plugins,
3940
process_stale_scc,
4041
)
4142
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT
42-
from mypy.errors import CompileError, Errors, report_internal_error
43+
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
4344
from mypy.fscache import FileSystemCache
4445
from mypy.ipc import IPCException, IPCServer, receive, send
4546
from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths
47+
from mypy.nodes import FileRawData
4648
from mypy.options import Options
4749
from mypy.util import read_py_file
4850
from mypy.version import __version__
@@ -123,42 +125,24 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
123125
if manager is None:
124126
return
125127

126-
# Mirror the GC freeze hack in the coordinator.
127-
if platform.python_implementation() == "CPython":
128-
gc.disable()
129-
try:
130-
graph = load_graph(sources, manager)
131-
except CompileError:
132-
# CompileError during loading will be reported by the coordinator.
133-
return
134-
if platform.python_implementation() == "CPython":
135-
gc.freeze()
136-
gc.unfreeze()
137-
gc.enable()
138-
for id in graph:
139-
manager.import_map[id] = graph[id].dependencies_set
140-
# Ignore errors during local graph loading to check that receiving
141-
# early errors from coordinator works correctly.
142-
manager.errors.reset()
143-
144-
# Notify worker we are done loading graph.
128+
# Notify coordinator we are done with setup.
145129
send(server, AckMessage())
146-
147-
# Compare worker graph and coordinator, with parallel parser we will only use the latter.
148130
graph_data = GraphMessage.read(receive(server), manager)
149-
assert set(manager.missing_modules) == graph_data.missing_modules
150-
coordinator_graph = graph_data.graph
151-
assert coordinator_graph.keys() == graph.keys()
131+
# Update some manager data in-place as it has been passed to semantic analyzer.
132+
manager.missing_modules |= graph_data.missing_modules
133+
graph = graph_data.graph
152134
for id in graph:
153-
assert graph[id].dependencies_set == coordinator_graph[id].dependencies_set
154-
assert graph[id].suppressed_set == coordinator_graph[id].suppressed_set
155-
send(server, AckMessage())
135+
manager.import_map[id] = graph[id].dependencies_set
136+
# Link modules dicts, so that plugins will get access to ASTs as we parse them.
137+
manager.plugin.set_modules(manager.modules)
156138

139+
# Notify coordinator we are ready to receive computed graph SCC structure.
140+
send(server, AckMessage())
157141
sccs = SccsDataMessage.read(receive(server)).sccs
158142
manager.scc_by_id = {scc.id: scc for scc in sccs}
159143
manager.top_order = [scc.id for scc in sccs]
160144

161-
# Notify coordinator we are ready to process SCCs.
145+
# Notify coordinator we are ready to start processing SCCs.
162146
send(server, AckMessage())
163147
while True:
164148
scc_message = SccRequestMessage.read(receive(server))
@@ -169,20 +153,17 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
169153
scc = manager.scc_by_id[scc_id]
170154
t0 = time.time()
171155
try:
172-
for id in scc.mod_ids:
173-
state = graph[id]
174-
# Extra if below is needed only because we are using local graph.
175-
# TODO: clone options when switching to coordinator graph.
176-
if state.tree is None:
177-
# Parse early to get errors related data, such as ignored
178-
# and skipped lines before replaying the errors.
179-
state.parse_file()
180-
else:
181-
state.setup_errors()
182-
if id in scc_message.import_errors:
183-
manager.errors.set_file(state.xpath, id, state.options)
184-
for err_info in scc_message.import_errors[id]:
185-
manager.errors.add_error_info(err_info)
156+
if platform.python_implementation() == "CPython":
157+
# Since we are splitting the GC freeze hack into multiple smaller freezes,
158+
# we should collect young generations to not accumulate accidental garbage.
159+
gc.collect(generation=1)
160+
gc.collect(generation=0)
161+
gc.disable()
162+
load_states(scc, graph, manager, scc_message.import_errors, scc_message.mod_data)
163+
if platform.python_implementation() == "CPython":
164+
gc.freeze()
165+
gc.unfreeze()
166+
gc.enable()
186167
result = process_stale_scc(graph, scc, manager, from_cache=graph_data.from_cache)
187168
# We must commit after each SCC, otherwise we break --sqlite-cache.
188169
manager.metastore.commit()
@@ -193,6 +174,34 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
193174
manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1)
194175

195176

177+
def load_states(
178+
scc: SCC,
179+
graph: Graph,
180+
manager: BuildManager,
181+
import_errors: dict[str, list[ErrorInfo]],
182+
mod_data: dict[str, tuple[bytes, FileRawData | None]],
183+
) -> None:
184+
"""Re-create full state of an SCC as it would have been in coordinator."""
185+
for id in scc.mod_ids:
186+
state = graph[id]
187+
# Re-clone options since we don't send them, it is usually faster than deserializing.
188+
state.options = state.options.clone_for_module(state.id)
189+
suppressed_deps_opts, raw_data = mod_data[id]
190+
state.parse_file(raw_data=raw_data)
191+
# Set data that is needed to be written to cache meta.
192+
state.known_suppressed_deps_opts = suppressed_deps_opts
193+
assert state.tree is not None
194+
import_lines = {imp.line for imp in state.tree.imports}
195+
state.imports_ignored = {
196+
line: codes for line, codes in state.tree.ignored_lines.items() if line in import_lines
197+
}
198+
# Replay original errors encountered during graph loading in coordinator.
199+
if id in import_errors:
200+
manager.errors.set_file(state.xpath, id, state.options)
201+
for err_info in import_errors[id]:
202+
manager.errors.add_error_info(err_info)
203+
204+
196205
def setup_worker_manager(sources: list[BuildSource], ctx: ServerContext) -> BuildManager | None:
197206
data_dir = os.path.dirname(os.path.dirname(__file__))
198207
# This is used for testing only now.

mypy/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
262262
LIST_BYTES: Final[Tag] = 23
263263
TUPLE_GEN: Final[Tag] = 24
264264
DICT_STR_GEN: Final[Tag] = 30
265+
DICT_INT_GEN: Final[Tag] = 31
265266

266267
# Misc classes.
267268
EXTRA_ATTRS: Final[Tag] = 150

mypy/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def main(
9797
stdout, stderr, options.hide_error_codes, hide_success=bool(options.output)
9898
)
9999

100+
if options.num_workers:
101+
# Supporting both parsers would be really tricky, so just support the new one.
102+
options.native_parser = True
103+
100104
if options.allow_redefinition_new and not options.local_partial_types:
101105
fail(
102106
"error: --local-partial-types must be enabled if using --allow-redefinition-new",

mypy/nativeparse.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
EllipsisExpr,
7777
Expression,
7878
ExpressionStmt,
79+
FileRawData,
7980
FloatExpr,
8081
ForStmt,
8182
FuncDef,
@@ -169,7 +170,6 @@ def __init__(self, options: Options) -> None:
169170
self.options = options
170171
self.errors: list[dict[str, Any]] = []
171172
self.num_funcs = 0
172-
self.uses_template_strings = False
173173

174174
def add_error(
175175
self,
@@ -195,7 +195,7 @@ def add_error(
195195

196196

197197
def native_parse(
198-
filename: str, options: Options, skip_function_bodies: bool = False
198+
filename: str, options: Options, skip_function_bodies: bool = False, imports_only: bool = False
199199
) -> tuple[MypyFile, list[dict[str, Any]], TypeIgnores]:
200200
"""Parse a Python file using the native Rust-based parser.
201201
@@ -208,6 +208,8 @@ def native_parse(
208208
skip_function_bodies: If True, many function and method bodies are omitted from
209209
the AST, useful for parsing stubs or extracting signatures without full
210210
implementation details
211+
imports_only: If True create an empty MypyFile with actual serialized defs
212+
stored in binary_data.
211213
212214
Returns:
213215
A tuple containing:
@@ -222,20 +224,27 @@ def native_parse(
222224
node.path = filename
223225
return node, [], []
224226

225-
b, errors, ignores, import_bytes, is_partial_package = parse_to_binary_ast(
226-
filename, options, skip_function_bodies
227+
b, errors, ignores, import_bytes, is_partial_package, uses_template_strings = (
228+
parse_to_binary_ast(filename, options, skip_function_bodies)
227229
)
228230
data = ReadBuffer(b)
229231
n = read_int(data)
230232
state = State(options)
231-
defs = read_statements(state, data, n)
233+
if imports_only:
234+
defs = []
235+
else:
236+
defs = read_statements(state, data, n)
232237

233238
imports = deserialize_imports(import_bytes)
234239

235240
node = MypyFile(defs, imports)
236241
node.path = filename
237242
node.is_partial_stub_package = is_partial_package
238-
node.uses_template_strings = state.uses_template_strings
243+
if imports_only:
244+
node.raw_data = FileRawData(
245+
b, import_bytes, errors, dict(ignores), is_partial_package, uses_template_strings
246+
)
247+
node.uses_template_strings = uses_template_strings
239248
# Merge deserialization errors with parsing errors
240249
all_errors = errors + state.errors
241250
return node, all_errors, ignores
@@ -263,7 +272,7 @@ def read_statements(state: State, data: ReadBuffer, n: int) -> list[Statement]:
263272

264273
def parse_to_binary_ast(
265274
filename: str, options: Options, skip_function_bodies: bool = False
266-
) -> tuple[bytes, list[dict[str, Any]], TypeIgnores, bytes, bool]:
275+
) -> tuple[bytes, list[dict[str, Any]], TypeIgnores, bytes, bool, bool]:
267276
ast_bytes, errors, ignores, import_bytes, ast_data = ast_serialize.parse(
268277
filename,
269278
skip_function_bodies=skip_function_bodies,
@@ -278,6 +287,7 @@ def parse_to_binary_ast(
278287
ignores,
279288
import_bytes,
280289
ast_data["is_partial_package"],
290+
ast_data["uses_template_strings"],
281291
)
282292

283293

@@ -1528,7 +1538,6 @@ def read_expression(state: State, data: ReadBuffer) -> Expression:
15281538
expect_end_tag(data)
15291539
return expr
15301540
elif tag == nodes.TSTRING_EXPR:
1531-
state.uses_template_strings = True
15321541
nparts = read_int(data)
15331542
titems: list[Expression | tuple[Expression, str, str | None, Expression | None]] = []
15341543
for _ in range(nparts):

mypy/nodes.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import mypy.strconv
3232
from mypy.cache import (
33+
DICT_INT_GEN,
3334
DICT_STR_GEN,
3435
DT_SPEC,
3536
END_TAG,
@@ -41,6 +42,7 @@
4142
Tag,
4243
WriteBuffer,
4344
read_bool,
45+
read_bytes,
4446
read_int,
4547
read_int_list,
4648
read_int_opt,
@@ -52,6 +54,7 @@
5254
read_str_opt_list,
5355
read_tag,
5456
write_bool,
57+
write_bytes,
5558
write_int,
5659
write_int_list,
5760
write_int_opt,
@@ -307,6 +310,69 @@ def read(cls, data: ReadBuffer) -> SymbolNode:
307310
Definition: _TypeAlias = tuple[str, "SymbolTableNode", Optional["TypeInfo"]]
308311

309312

313+
class FileRawData:
314+
"""Raw (binary) data representing parsed, but not deserialized file."""
315+
316+
__slots__ = (
317+
"defs",
318+
"imports",
319+
"raw_errors",
320+
"ignored_lines",
321+
"is_partial_stub_package",
322+
"uses_template_strings",
323+
)
324+
325+
defs: bytes
326+
imports: bytes
327+
raw_errors: list[dict[str, Any]] # TODO: switch to more precise type here.
328+
ignored_lines: dict[int, list[str]]
329+
is_partial_stub_package: bool
330+
uses_template_strings: bool
331+
332+
def __init__(
333+
self,
334+
defs: bytes,
335+
imports: bytes,
336+
raw_errors: list[dict[str, Any]],
337+
ignored_lines: dict[int, list[str]],
338+
is_partial_stub_package: bool,
339+
uses_template_strings: bool,
340+
) -> None:
341+
self.defs = defs
342+
self.imports = imports
343+
self.raw_errors = raw_errors
344+
self.ignored_lines = ignored_lines
345+
self.is_partial_stub_package = is_partial_stub_package
346+
self.uses_template_strings = uses_template_strings
347+
348+
def write(self, data: WriteBuffer) -> None:
349+
write_bytes(data, self.defs)
350+
write_bytes(data, self.imports)
351+
write_tag(data, LIST_GEN)
352+
write_int_bare(data, len(self.raw_errors))
353+
for err in self.raw_errors:
354+
write_json(data, err)
355+
write_tag(data, DICT_INT_GEN)
356+
write_int_bare(data, len(self.ignored_lines))
357+
for line, codes in self.ignored_lines.items():
358+
write_int(data, line)
359+
write_str_list(data, codes)
360+
write_bool(data, self.is_partial_stub_package)
361+
write_bool(data, self.uses_template_strings)
362+
363+
@classmethod
364+
def read(cls, data: ReadBuffer) -> FileRawData:
365+
defs = read_bytes(data)
366+
imports = read_bytes(data)
367+
assert read_tag(data) == LIST_GEN
368+
raw_errors = [read_json(data) for _ in range(read_int_bare(data))]
369+
assert read_tag(data) == DICT_INT_GEN
370+
ignored_lines = {read_int(data): read_str_list(data) for _ in range(read_int_bare(data))}
371+
return FileRawData(
372+
defs, imports, raw_errors, ignored_lines, read_bool(data), read_bool(data)
373+
)
374+
375+
310376
class MypyFile(SymbolNode):
311377
"""The abstract syntax tree of a single source file."""
312378

@@ -328,6 +394,7 @@ class MypyFile(SymbolNode):
328394
"plugin_deps",
329395
"future_import_flags",
330396
"_is_typeshed_file",
397+
"raw_data",
331398
)
332399

333400
__match_args__ = ("name", "path", "defs")
@@ -370,6 +437,8 @@ class MypyFile(SymbolNode):
370437
# Future imports defined in this file. Populated during semantic analysis.
371438
future_import_flags: set[str]
372439
_is_typeshed_file: bool | None
440+
# For native parser store actual serialized data here.
441+
raw_data: FileRawData | None
373442

374443
def __init__(
375444
self,
@@ -400,6 +469,7 @@ def __init__(
400469
self.uses_template_strings = False
401470
self.future_import_flags = set()
402471
self._is_typeshed_file = None
472+
self.raw_data = None
403473

404474
def local_definitions(self) -> Iterator[Definition]:
405475
"""Return all definitions within the module (including nested).

0 commit comments

Comments
 (0)