|
| 1 | +# Debugging a Rank Hang with Flight Recorder |
| 2 | + |
| 3 | +This tutorial walks through diagnosing a **single-rank hang** in a |
| 4 | +distributed PyTorch job using the **TorchComms Flight Recorder** and the |
| 5 | +**Debug Server**'s periodic dump. |
| 6 | + |
| 7 | +For full API documentation on the debug server, its |
| 8 | +endpoints and periodic dumping see the |
| 9 | +[torch.distributed debug HTTP server docs](https://pytorch.org/docs/main/distributed.html#torch-distributed-debug-http-server) |
| 10 | +(source: `/home/tushar00jain/local/pytorch/docs/source/distributed.md`). |
| 11 | + |
| 12 | + |
| 13 | +For a reference on the Flight Recorder, see [Flight Recorder Hook](https://meta-pytorch.org/torchcomms/main/hooks.html#flightrecorderhook) in torchcomms. |
| 14 | + |
| 15 | +--- |
| 16 | + |
| 17 | +## Table of Contents |
| 18 | + |
| 19 | +1. [Background](#background) |
| 20 | +2. [The Scenario](#the-scenario) |
| 21 | +3. [Running the Demo](#running-the-demo) |
| 22 | +4. [Reading the Aggregated Text Dumps](#reading-the-aggregated-text-dumps) |
| 23 | +5. [Running the FR CLI on Per-Rank Pickle Dumps](#running-the-fr-cli-on-per-rank-pickle-dumps) |
| 24 | +6. [What to Look For](#what-to-look-for) |
| 25 | +7. [Reference](#reference) |
| 26 | + |
| 27 | +--- |
| 28 | + |
| 29 | +## Background |
| 30 | + |
| 31 | +The **Flight Recorder** is a ring-buffer that records every collective |
| 32 | +operation issued through a TorchComms communicator. Each entry captures: |
| 33 | + |
| 34 | +| Field | Description | |
| 35 | +|---|---| |
| 36 | +| `collective_seq_id` | Monotonically increasing sequence number (same across all ranks for a given collective) | |
| 37 | +| `profiling_name` | e.g. `nccl:all_reduce`, `nccl:broadcast` | |
| 38 | +| `state` | `scheduled` → `started` → `completed` | |
| 39 | +| `input_dims` / `output_dims` | Tensor shapes | |
| 40 | +| `traceback` | Python stack trace at the call site | |
| 41 | + |
| 42 | +When periodic dumping is enabled on the debug server, each dump cycle |
| 43 | +produces two kinds of output: |
| 44 | + |
| 45 | +* **Aggregated text files** (`torchcomms_fr_trace_<ts>.txt`) — the |
| 46 | + frontend on rank 0 fetches FR data from all ranks and writes a |
| 47 | + human-readable table. |
| 48 | +* **Per-rank pickle files** (`per_rank/rank_<N>`) — each rank's worker |
| 49 | + server writes its own pickle trace. These can be fed to the |
| 50 | + **FR CLI** (`python -m torch.distributed.flight_recorder.fr_trace`) |
| 51 | + for automated cross-rank mismatch detection. |
| 52 | + |
| 53 | +--- |
| 54 | + |
| 55 | +## The Scenario |
| 56 | + |
| 57 | +The demo script (`verify_flight_recorder.py`) creates a two-phase workload: |
| 58 | + |
| 59 | +* Phase 1 (all ranks): 3 all_reduce 1 broadcast operations completes normally |
| 60 | +* Phase 2: |
| 61 | + * Hanging rank enters `time.sleep` |
| 62 | + * Other ranks issue another all_reduce that time out waiting for the hanging rank |
| 63 | + |
| 64 | +When the timeout fires, the dump directory contains: |
| 65 | + |
| 66 | +``` |
| 67 | +FR_DUMP_DIR/ |
| 68 | +├── torchcomms_fr_trace_<ts>.txt ← aggregated text |
| 69 | +└── per_rank/ ← per-rank pickle files |
| 70 | + ├── rank_0 |
| 71 | + └── rank_1 |
| 72 | +``` |
| 73 | + |
| 74 | +--- |
| 75 | + |
| 76 | +## Running the Demo |
| 77 | + |
| 78 | +### Prerequisites |
| 79 | + |
| 80 | +* `torchcomms` and `torch.distributed.debug` installed |
| 81 | +* Use `TEST_BACKEND=gloo TEST_DEVICE=cpu` for CPU-only testing, or |
| 82 | + a CUDA host with more than 2 GPUs. |
| 83 | + |
| 84 | +### Launch |
| 85 | + |
| 86 | +```bash |
| 87 | +FR_DUMP_DIR=/tmp/fr_hang_debug \ |
| 88 | +FR_DUMP_INTERVAL=3 \ |
| 89 | +COMM_TIMEOUT=15 \ |
| 90 | +TEST_BACKEND=gloo \ |
| 91 | +TEST_DEVICE=cpu \ |
| 92 | +torchrun --nproc_per_node=2 verify_flight_recorder.py |
| 93 | +``` |
| 94 | + |
| 95 | +| Variable | Default | Description | |
| 96 | +|---|---|---| |
| 97 | +| `FR_DUMP_DIR` | `/tmp/fr_hang_debug` | Root dump directory | |
| 98 | +| `FR_DUMP_INTERVAL` | `5` | Seconds between periodic dumps | |
| 99 | +| `COMM_TIMEOUT` | `30` | Communicator timeout (seconds) | |
| 100 | +| `HANGING_RANK` | `-1` (last rank) | Which rank to hang | |
| 101 | +| `TEST_BACKEND` | `gloo` | Communication backend | |
| 102 | +| `TEST_DEVICE` | `cuda` | Tensor device | |
| 103 | + |
| 104 | +### Expected Output |
| 105 | + |
| 106 | +``` |
| 107 | +[Rank 0/2] device=0, hanging_rank=1, timeout=15s |
| 108 | +[Rank 1/2] device=1, hanging_rank=1, timeout=15s |
| 109 | +[Rank 0] Debug server: http://localhost:25999 |
| 110 | +[Rank 0] Periodic dumps every 3.0s → /tmp/fr_hang_debug |
| 111 | +[Rank 0] Per-rank pickles → /tmp/fr_hang_debug/per_rank |
| 112 | +[Rank 0] Phase 1: Running 3 all_reduce + 1 broadcast |
| 113 | +[Rank 0] Phase 1 complete |
| 114 | +[Rank 0] Phase 2: all_reduce (rank 1 will NOT participate) |
| 115 | +[Rank 0] Expecting timeout in ~15s ... |
| 116 | +[Rank 1] Phase 1 complete |
| 117 | +[Rank 1] >>> HANGING – entering infinite sleep <<< |
| 118 | +
|
| 119 | +... periodic mismatch warnings every 3 seconds ... |
| 120 | +
|
| 121 | +Not all ranks joining collective, sequence number: 4 |
| 122 | +collective: nccl:all_reduce |
| 123 | +missing ranks: {1} |
| 124 | +collective state: scheduled |
| 125 | +
|
| 126 | +... ~15 seconds pass ... |
| 127 | +
|
| 128 | +[Rank 0] Caught timeout: RuntimeError: Timed out waiting 15000ms for recv operation |
| 129 | +[Rank 0] Pickle trace written to /tmp/fr_hang_debug/per_rank/rank_0 |
| 130 | +``` |
| 131 | + |
| 132 | +--- |
| 133 | + |
| 134 | +## Reading the Aggregated Text Dumps |
| 135 | + |
| 136 | +The debug server writes periodic text snapshots aggregating data from |
| 137 | +all ranks: |
| 138 | + |
| 139 | +```bash |
| 140 | +$ ls /tmp/fr_hang_debug/torchcomms_fr_trace_*.txt |
| 141 | +torchcomms_fr_trace_20260401_192058.txt |
| 142 | +torchcomms_fr_trace_20260401_192101.txt |
| 143 | +torchcomms_fr_trace_20260401_192104.txt |
| 144 | +... |
| 145 | +``` |
| 146 | + |
| 147 | +Open one of the snapshots written during the hang: |
| 148 | + |
| 149 | +```bash |
| 150 | +cat /tmp/fr_hang_debug/torchcomms_fr_trace_20260401_192104.txt |
| 151 | +``` |
| 152 | + |
| 153 | +The **Collectives** table shows every recorded operation: |
| 154 | + |
| 155 | +``` |
| 156 | +--- Collectives --- |
| 157 | + id group_id pass_check collective_seq_id collective_name collective_state missing_ranks |
| 158 | + 0 main_comm True 0 nccl:all_reduce scheduled |
| 159 | + 1 main_comm True 1 nccl:all_reduce scheduled |
| 160 | + 2 main_comm True 2 nccl:all_reduce scheduled |
| 161 | + 3 main_comm True 3 nccl:broadcast scheduled |
| 162 | + 4 main_comm True 4 nccl:all_reduce scheduled {1} ← MISMATCH |
| 163 | +``` |
| 164 | + |
| 165 | +The **NCCL Calls** table shows which ranks participated: |
| 166 | + |
| 167 | +``` |
| 168 | +--- NCCL Calls --- |
| 169 | + id collective_id group_id global_rank collective_type |
| 170 | + 0 0 main_comm 0 nccl:all_reduce |
| 171 | + 1 0 main_comm 1 nccl:all_reduce |
| 172 | + ... |
| 173 | + 6 3 main_comm 0 nccl:broadcast |
| 174 | + 7 3 main_comm 1 nccl:broadcast |
| 175 | + 8 main_comm 0 nccl:all_reduce ← Only rank 0! |
| 176 | +``` |
| 177 | + |
| 178 | +The **Dump File** section confirms per-rank pickle files were written: |
| 179 | + |
| 180 | +``` |
| 181 | +=== TorchComms FR Dump File === |
| 182 | +Rank 0: OK - Flight Recorder debug info written to /tmp/fr_hang_debug/per_rank/rank_0 |
| 183 | +Rank 1: OK - Flight Recorder debug info written to /tmp/fr_hang_debug/per_rank/rank_1 |
| 184 | +``` |
| 185 | + |
| 186 | +The `stacks_*.txt` files show Python tracebacks, pinpointing the |
| 187 | +exact line each rank is stuck at: |
| 188 | + |
| 189 | +```bash |
| 190 | +$ cat /tmp/fr_hang_debug/stacks_20260401_192104.txt |
| 191 | + |
| 192 | +=== Rank 0 === |
| 193 | + File "verify_flight_recorder.py", line 148 in main ← all_reduce (waiting) |
| 194 | + |
| 195 | +=== Rank 1 === |
| 196 | + File "verify_flight_recorder.py", line 140 in main ← time.sleep (the hang!) |
| 197 | +``` |
| 198 | + |
| 199 | +Rank 1 never issued `collective_seq_id=4`. The stacks dump confirms |
| 200 | +it is stuck in `time.sleep`, not in a collective. |
| 201 | + |
| 202 | +--- |
| 203 | + |
| 204 | +## Running the FR CLI on Per-Rank Pickle Dumps |
| 205 | + |
| 206 | +The periodic dump also triggers each rank's worker server to write a |
| 207 | +pickle trace file into the `per_rank/` subdirectory: |
| 208 | + |
| 209 | +```bash |
| 210 | +$ ls /tmp/fr_hang_debug/per_rank/ |
| 211 | +rank_0 rank_1 |
| 212 | +``` |
| 213 | + |
| 214 | +### Cross-rank mismatch analysis |
| 215 | + |
| 216 | +```bash |
| 217 | +python -m torch.distributed.flight_recorder.fr_trace \ |
| 218 | + /tmp/fr_hang_debug/per_rank -p rank_ |
| 219 | +``` |
| 220 | + |
| 221 | +Output: |
| 222 | + |
| 223 | +``` |
| 224 | +Not all ranks joining collective, sequence number: 4 |
| 225 | +internal record id: 4 |
| 226 | +group info: main_comm:gloo |
| 227 | +collective: nccl:all_reduce |
| 228 | +missing ranks: {1} |
| 229 | +input sizes: [[1024]] |
| 230 | +output sizes: [[1024]] |
| 231 | +world size: 2 |
| 232 | +expected ranks: {0, 1} |
| 233 | +collective state: scheduled |
| 234 | +``` |
| 235 | + |
| 236 | +The CLI detected that rank 1 never issued `collective_seq_id=4`. |
| 237 | + |
| 238 | +### Side-by-side raw entry view |
| 239 | + |
| 240 | +```bash |
| 241 | +python -m torch.distributed.flight_recorder.fr_trace \ |
| 242 | + /tmp/fr_hang_debug/per_rank -p rank_ -j |
| 243 | +``` |
| 244 | + |
| 245 | +Output: |
| 246 | + |
| 247 | +``` |
| 248 | +Rank 0 Rank 1 |
| 249 | +------------------------------------------------- ------------------------------------------------- |
| 250 | +all_reduce(input_sizes=[[1024]], state=scheduled) all_reduce(input_sizes=[[1024]], state=scheduled) |
| 251 | +broadcast(input_sizes=[[1024]], state=scheduled) broadcast(input_sizes=[[1024]], state=scheduled) |
| 252 | +all_reduce(input_sizes=[[1024]], state=scheduled) |
| 253 | +``` |
| 254 | + |
| 255 | +Rank 0 has 5 entries (3 `all_reduce` + 1 `broadcast` + the stuck |
| 256 | +`all_reduce`). Rank 1 has only 4 — the 5th `all_reduce` is missing |
| 257 | +because rank 1 hung before issuing it. |
| 258 | + |
| 259 | +### With stack traces |
| 260 | + |
| 261 | +```bash |
| 262 | +python -m torch.distributed.flight_recorder.fr_trace \ |
| 263 | + /tmp/fr_hang_debug/per_rank -p rank_ -j --print_stack_trace |
| 264 | +``` |
| 265 | + |
| 266 | +This adds Python stack traces to each entry, showing exactly where in |
| 267 | +user code each collective was called. |
| 268 | + |
| 269 | +--- |
| 270 | + |
| 271 | +## What to Look For |
| 272 | + |
| 273 | +| Symptom | Likely cause | |
| 274 | +|---|---| |
| 275 | +| `missing_ranks: {N}` in the Collectives table | Rank N hung or crashed before issuing the next collective | |
| 276 | +| Rank X's last entry is `state=started`, others are `completed` | Rank X issued the collective but is waiting for a peer that never joined | |
| 277 | +| Mismatched `collective_name` at the same `collective_seq_id` | Code-path divergence — ranks are calling different collectives | |
| 278 | +| Mismatched `input_sizes` / `output_sizes` | Tensor shape inconsistency across ranks | |
| 279 | +| Stacks dump shows `time.sleep` or user code (not a collective) | The rank is stuck in compute, not in a collective | |
| 280 | + |
| 281 | +--- |
| 282 | + |
| 283 | +### FR CLI Quick Reference |
| 284 | + |
| 285 | +```bash |
| 286 | +# Cross-rank mismatch analysis: |
| 287 | +python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> |
| 288 | + |
| 289 | +# Side-by-side raw entries per rank: |
| 290 | +python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> -j |
| 291 | + |
| 292 | +# With stack traces: |
| 293 | +python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> -j --print_stack_trace |
| 294 | + |
| 295 | +# Best-effort when some rank dumps are missing: |
| 296 | +python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> --allow-incomplete-ranks |
| 297 | +``` |
0 commit comments