Skip to content

Commit 756c310

Browse files
committed
Optimize parallel workers start-up
1 parent fed3593 commit 756c310

1 file changed

Lines changed: 54 additions & 21 deletions

File tree

mypy/build.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from collections.abc import Callable, Iterator, Mapping, Sequence, Set as AbstractSet
3030
from heapq import heappop, heappush
3131
from textwrap import dedent
32+
from threading import Thread
3233
from typing import (
3334
TYPE_CHECKING,
3435
Any,
@@ -371,6 +372,7 @@ def default_flush_errors(
371372
extra_plugins = extra_plugins or []
372373

373374
workers = []
375+
connect_threads = []
374376
if options.num_workers > 0:
375377
# TODO: switch to something more efficient than pickle (also in the daemon).
376378
pickled_options = pickle.dumps(options.snapshot())
@@ -383,10 +385,17 @@ def default_flush_errors(
383385
buf = WriteBuffer()
384386
sources_message.write(buf)
385387
sources_data = buf.getvalue()
388+
389+
def connect(wc: WorkerClient, data: bytes) -> None:
390+
# Start loading sources in each worker as soon as it is up.
391+
wc.connect()
392+
wc.conn.write_bytes(data)
393+
394+
# We don't wait for workers to be ready until they are actually needed.
386395
for worker in workers:
387-
# Start loading graph in each worker as soon as it is up.
388-
worker.connect()
389-
worker.conn.write_bytes(sources_data)
396+
thread = Thread(target=connect, args=(worker, sources_data))
397+
thread.start()
398+
connect_threads.append(thread)
390399

391400
try:
392401
result = build_inner(
@@ -399,6 +408,7 @@ def default_flush_errors(
399408
stderr,
400409
extra_plugins,
401410
workers,
411+
connect_threads,
402412
)
403413
result.errors = messages
404414
return result
@@ -412,6 +422,10 @@ def default_flush_errors(
412422
e.messages = messages
413423
raise
414424
finally:
425+
# In case of an early crash it is better to wait for workers to become ready, and
426+
# shut them down cleanly. Otherwise, they will linger until connection timeout.
427+
for thread in connect_threads:
428+
thread.join()
415429
for worker in workers:
416430
try:
417431
send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={}))
@@ -431,6 +445,7 @@ def build_inner(
431445
stderr: TextIO,
432446
extra_plugins: Sequence[Plugin],
433447
workers: list[WorkerClient],
448+
connect_threads: list[Thread],
434449
) -> BuildResult:
435450
if platform.python_implementation() == "CPython":
436451
# Run gc less frequently, as otherwise we can spend a large fraction of
@@ -486,7 +501,7 @@ def build_inner(
486501

487502
reset_global_state()
488503
try:
489-
graph = dispatch(sources, manager, stdout)
504+
graph = dispatch(sources, manager, stdout, connect_threads)
490505
if not options.fine_grained_incremental:
491506
type_state.reset_all_subtype_caches()
492507
if options.timing_stats is not None:
@@ -1156,6 +1171,22 @@ def add_stats(self, **kwds: Any) -> None:
11561171
def stats_summary(self) -> Mapping[str, object]:
11571172
return self.stats
11581173

1174+
def broadcast(self, message: bytes) -> None:
1175+
"""Broadcast same message to all workers in parallel."""
1176+
threads = []
1177+
for worker in self.workers:
1178+
thread = Thread(target=worker.conn.write_bytes, args=(message,))
1179+
thread.start()
1180+
threads.append(thread)
1181+
for thread in threads:
1182+
thread.join()
1183+
1184+
def wait_ack(self) -> None:
1185+
"""Wait for an ack from all workers."""
1186+
for worker in self.workers:
1187+
buf = receive(worker.conn)
1188+
assert read_tag(buf) == ACK_MESSAGE
1189+
11591190
def submit(self, graph: Graph, sccs: list[SCC]) -> None:
11601191
"""Submit a stale SCC for processing in current process or parallel workers."""
11611192
if self.workers:
@@ -3685,7 +3716,12 @@ def log_configuration(manager: BuildManager, sources: list[BuildSource]) -> None
36853716
# The driver
36863717

36873718

3688-
def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO) -> Graph:
3719+
def dispatch(
3720+
sources: list[BuildSource],
3721+
manager: BuildManager,
3722+
stdout: TextIO,
3723+
connect_threads: list[Thread],
3724+
) -> Graph:
36893725
log_configuration(manager, sources)
36903726

36913727
t0 = time.time()
@@ -3742,7 +3778,7 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
37423778
dump_graph(graph, stdout)
37433779
return graph
37443780

3745-
# Fine grained dependencies that didn't have an associated module in the build
3781+
# Fine-grained dependencies that didn't have an associated module in the build
37463782
# are serialized separately, so we read them after we load the graph.
37473783
# We need to read them both for running in daemon mode and if we are generating
37483784
# a fine-grained cache (so that we can properly update them incrementally).
@@ -3755,25 +3791,28 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
37553791
if fg_deps_meta is not None:
37563792
manager.fg_deps_meta = fg_deps_meta
37573793
elif manager.stats.get("fresh_metas", 0) > 0:
3758-
# Clear the stats so we don't infinite loop because of positive fresh_metas
3794+
# Clear the stats, so we don't infinite loop because of positive fresh_metas
37593795
manager.stats.clear()
37603796
# There were some cache files read, but no fine-grained dependencies loaded.
37613797
manager.log("Error reading fine-grained dependencies cache -- aborting cache load")
37623798
manager.cache_enabled = False
37633799
manager.log("Falling back to full run -- reloading graph...")
3764-
return dispatch(sources, manager, stdout)
3800+
return dispatch(sources, manager, stdout, connect_threads)
37653801

37663802
# If we are loading a fine-grained incremental mode cache, we
37673803
# don't want to do a real incremental reprocess of the
37683804
# graph---we'll handle it all later.
37693805
if not manager.use_fine_grained_cache():
3806+
# Wait for workers since they may be needed at this point.
3807+
for thread in connect_threads:
3808+
thread.join()
37703809
process_graph(graph, manager)
37713810
# Update plugins snapshot.
37723811
write_plugins_snapshot(manager)
37733812
manager.old_plugins_snapshot = manager.plugins_snapshot
37743813
if manager.options.cache_fine_grained or manager.options.fine_grained_incremental:
3775-
# If we are running a daemon or are going to write cache for further fine grained use,
3776-
# then we need to collect fine grained protocol dependencies.
3814+
# If we are running a daemon or are going to write cache for further fine-grained use,
3815+
# then we need to collect fine-grained protocol dependencies.
37773816
# Since these are a global property of the program, they are calculated after we
37783817
# processed the whole graph.
37793818
type_state.add_all_protocol_deps(manager.fg_deps)
@@ -4166,10 +4205,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
41664205
buf = WriteBuffer()
41674206
graph_message.write(buf)
41684207
graph_data = buf.getvalue()
4169-
for worker in manager.workers:
4170-
buf = receive(worker.conn)
4171-
assert read_tag(buf) == ACK_MESSAGE
4172-
worker.conn.write_bytes(graph_data)
4208+
manager.wait_ack()
4209+
manager.broadcast(graph_data)
41734210

41744211
sccs = sorted_components(graph)
41754212
manager.log(
@@ -4187,13 +4224,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
41874224
buf = WriteBuffer()
41884225
sccs_message.write(buf)
41894226
sccs_data = buf.getvalue()
4190-
for worker in manager.workers:
4191-
buf = receive(worker.conn)
4192-
assert read_tag(buf) == ACK_MESSAGE
4193-
worker.conn.write_bytes(sccs_data)
4194-
for worker in manager.workers:
4195-
buf = receive(worker.conn)
4196-
assert read_tag(buf) == ACK_MESSAGE
4227+
manager.wait_ack()
4228+
manager.broadcast(sccs_data)
4229+
manager.wait_ack()
41974230

41984231
manager.free_workers = set(range(manager.options.num_workers))
41994232

0 commit comments

Comments
 (0)