|
15 | 15 |
|
16 | 16 | import os |
17 | 17 | import socket |
| 18 | +import traceback |
18 | 19 |
|
19 | | -import pytest |
20 | 20 | import torch |
21 | 21 | import torch.distributed as dist |
22 | 22 | import torch.multiprocessing as mp |
@@ -65,12 +65,136 @@ def spawn_multiprocess_job(size, job, backend="gloo"): |
65 | 65 | assert not p.exitcode |
66 | 66 |
|
67 | 67 |
|
68 | | -def get_device_counts(): |
69 | | - num_gpus = torch.cuda.device_count() |
70 | | - return [ |
71 | | - 1, |
72 | | - pytest.param(2, marks=pytest.mark.skipif(num_gpus < 2, reason="need 2 GPUs!")), |
73 | | - ] |
| 68 | +def default_worker_teardown(rank, world_size): |
| 69 | + """Minimal cleanup between tests in persistent workers.""" |
| 70 | + try: |
| 71 | + from accelerate.state import AcceleratorState |
| 72 | + |
| 73 | + AcceleratorState._reset_state() |
| 74 | + except ImportError: |
| 75 | + pass |
| 76 | + except Exception as e: |
| 77 | + print(f"Error resetting AcceleratorState: {e}") |
| 78 | + torch.cuda.empty_cache() |
| 79 | + |
| 80 | + |
| 81 | +class DistributedWorkerPool: |
| 82 | + """Persistent worker pool that keeps distributed processes alive across multiple test dispatches. |
| 83 | +
|
| 84 | + Instead of spawning/destroying processes per test (which adds ~10s overhead each time), |
| 85 | + workers are spawned once and reuse the same ``torch.distributed`` process group. |
| 86 | + Use with a module-scoped pytest fixture to share workers across all tests in a file. |
| 87 | +
|
| 88 | + Usage:: |
| 89 | +
|
| 90 | + pool = DistributedWorkerPool( |
| 91 | + world_size=2, backend="nccl", teardown_fn=default_worker_teardown |
| 92 | + ) |
| 93 | +
|
| 94 | +
|
| 95 | + def _test_fn(rank, size): ... |
| 96 | +
|
| 97 | +
|
| 98 | + pool.run(_test_fn) |
| 99 | + pool.run(partial(other_fn, arg1)) |
| 100 | + pool.shutdown() |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__(self, world_size, backend="nccl", teardown_fn=default_worker_teardown): |
| 104 | + assert world_size > 0, "World size must be greater than 0" |
| 105 | + self.world_size = world_size |
| 106 | + ctx = mp.get_context("spawn") |
| 107 | + self._cmd_queues = [ctx.Queue() for _ in range(world_size)] |
| 108 | + self._result_queue = ctx.Queue() |
| 109 | + self._processes = [] |
| 110 | + |
| 111 | + port = get_free_port() |
| 112 | + for rank in range(world_size): |
| 113 | + p = ctx.Process( |
| 114 | + target=self._worker_loop, |
| 115 | + args=( |
| 116 | + rank, |
| 117 | + world_size, |
| 118 | + backend, |
| 119 | + port, |
| 120 | + self._cmd_queues[rank], |
| 121 | + self._result_queue, |
| 122 | + teardown_fn, |
| 123 | + ), |
| 124 | + ) |
| 125 | + p.start() |
| 126 | + self._processes.append(p) |
| 127 | + |
| 128 | + for _ in range(world_size): |
| 129 | + msg = self._result_queue.get(timeout=120) |
| 130 | + assert msg == "ready", f"Worker failed to initialize: {msg}" |
| 131 | + |
| 132 | + @staticmethod |
| 133 | + def _worker_loop(rank, world_size, backend, port, cmd_queue, result_queue, teardown_fn): |
| 134 | + os.environ["MASTER_ADDR"] = "localhost" |
| 135 | + os.environ["MASTER_PORT"] = str(port) |
| 136 | + os.environ["LOCAL_RANK"] = str(rank) |
| 137 | + os.environ["RANK"] = str(rank) |
| 138 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 139 | + dist.init_process_group(backend, rank=rank, world_size=world_size) |
| 140 | + if backend == "nccl" and torch.cuda.is_available(): |
| 141 | + torch.cuda.set_device(rank) |
| 142 | + torch.manual_seed(1234) |
| 143 | + result_queue.put("ready") |
| 144 | + |
| 145 | + while True: |
| 146 | + cmd = cmd_queue.get() |
| 147 | + if cmd is None: |
| 148 | + break |
| 149 | + fn, args, kwargs = cmd |
| 150 | + status = "ok" |
| 151 | + tb = None |
| 152 | + try: |
| 153 | + fn(rank, world_size, *args, **kwargs) |
| 154 | + except Exception: |
| 155 | + status = "error" |
| 156 | + tb = traceback.format_exc() |
| 157 | + finally: |
| 158 | + if teardown_fn is not None: |
| 159 | + try: |
| 160 | + teardown_fn(rank, world_size) |
| 161 | + except Exception as e: |
| 162 | + print(f"Error tearing down worker: {e}") |
| 163 | + status = "error" |
| 164 | + teardown_tb = traceback.format_exc() |
| 165 | + tb = (tb + "\n" if tb else "") + f"[teardown] {teardown_tb}" |
| 166 | + result_queue.put((status, rank, tb)) |
| 167 | + |
| 168 | + dist.destroy_process_group() |
| 169 | + |
| 170 | + def run(self, fn, *args, **kwargs): |
| 171 | + """Dispatch ``fn`` to all workers and block until completion. |
| 172 | +
|
| 173 | + ``fn`` is called as ``fn(rank, world_size, *args, **kwargs)`` and must be picklable |
| 174 | + (top-level function or ``functools.partial`` of one). |
| 175 | + """ |
| 176 | + for q in self._cmd_queues: |
| 177 | + q.put((fn, args, kwargs)) |
| 178 | + |
| 179 | + errors = [] |
| 180 | + for _ in range(self.world_size): |
| 181 | + status, rank, tb = self._result_queue.get(timeout=600) |
| 182 | + if status == "error": |
| 183 | + errors.append(f"--- Rank {rank} ---\n{tb}") |
| 184 | + |
| 185 | + if errors: |
| 186 | + raise RuntimeError("Worker(s) failed:\n" + "\n".join(errors)) |
| 187 | + |
| 188 | + def shutdown(self): |
| 189 | + """Signal all workers to exit and wait for them to finish.""" |
| 190 | + for q in self._cmd_queues: |
| 191 | + q.put(None) |
| 192 | + for p in self._processes: |
| 193 | + p.join(timeout=60) |
| 194 | + if p.is_alive(): |
| 195 | + p.terminate() |
| 196 | + # Ensure the terminated process is fully reaped to avoid zombies. |
| 197 | + p.join(timeout=10) |
74 | 198 |
|
75 | 199 |
|
76 | 200 | def synchronize_state_dict(model: nn.Module): |
|
0 commit comments