Skip to content

Commit 4c921e8

Browse files
committed
add flight recorder tutorial
Summary: Add a tutorial for debugging single-rank hangs in distributed PyTorch jobs using the TorchComms Flight Recorder, covering both aggregated text dump analysis and per-rank pickle-based CLI detection workflows.
1 parent 05c08f2 commit 4c921e8

1 file changed

Lines changed: 320 additions & 0 deletions

File tree

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
# Debugging a Rank Hang with Flight Recorder
2+
3+
This tutorial walks through a complete workflow for diagnosing a
4+
**single-rank hang** in a distributed PyTorch job using the
5+
**TorchComms Flight Recorder**. Two complementary approaches are shown:
6+
7+
| | Approach A | Approach B |
8+
|---|---|---|
9+
| **Source** | Debug server aggregated text dumps | Per-rank pickle files + FR CLI |
10+
| **Files** | `torchcomms_fr_trace_<ts>.txt` | `per_rank/rank_0`, `per_rank/rank_1` |
11+
| **Reader** | Human (open in editor / `cat`) | `python -m torch.distributed.flight_recorder.fr_trace` |
12+
| **Strength** | Single file, all ranks at a glance | Automated cross-rank mismatch detection |
13+
14+
---
15+
16+
## Table of Contents
17+
18+
1. [Background](#background)
19+
2. [The Scenario](#the-scenario)
20+
3. [Running the Demo](#running-the-demo)
21+
4. [Approach A — Aggregated Text Dumps](#approach-a--aggregated-text-dumps)
22+
5. [Approach B — Per-Rank Pickle Dumps + FR CLI](#approach-b--per-rank-pickle-dumps--fr-cli)
23+
6. [What to Look For](#what-to-look-for)
24+
7. [Reference](#reference)
25+
26+
---
27+
28+
## Background
29+
30+
The **Flight Recorder** is a ring-buffer that records every collective
31+
operation issued through a TorchComms communicator. Each entry captures:
32+
33+
| Field | Description |
34+
|---|---|
35+
| `collective_seq_id` | Monotonically increasing sequence number (same across all ranks for a given collective) |
36+
| `profiling_name` | e.g. `nccl:all_reduce`, `nccl:broadcast` |
37+
| `state` | `scheduled``started``completed` |
38+
| `retired` | `true` once the post-hook fires |
39+
| `input_dims` / `output_dims` | Tensor shapes |
40+
| `traceback` | Python stack trace at the call site |
41+
42+
The debug server provides **two dump mechanisms**:
43+
44+
* **Aggregated text dumps** — the periodic dumper on rank 0 fetches FR
45+
data from all ranks and writes a single human-readable text file
46+
every *N* seconds (`torchcomms_fr_trace_<ts>.txt`).
47+
* **Per-rank pickle dumps** — the periodic dumper also triggers each
48+
rank's worker server to write a pickle-format trace file (`per_rank/rank_<N>`).
49+
These are consumed by the **FR CLI** for automated cross-rank
50+
mismatch analysis.
51+
52+
---
53+
54+
## The Scenario
55+
56+
The demo script (`verify_flight_recorder.py`) creates a two-phase workload:
57+
58+
```
59+
Phase 1 (all ranks): 3× all_reduce + 1× broadcast ← completes normally
60+
Phase 2:
61+
• Hanging rank → enters time.sleep(∞)
62+
• Other ranks → issue another all_reduce → timeout after COMM_TIMEOUT seconds
63+
```
64+
65+
When the timeout fires, the dump directory contains:
66+
67+
```
68+
FR_DUMP_DIR/
69+
├── torchcomms_fr_trace_<ts>.txt ← Approach A: aggregated text
70+
├── stacks_<ts>.txt ← Python stack dumps
71+
└── per_rank/ ← Approach B: per-rank pickles
72+
├── rank_0
73+
└── rank_1
74+
```
75+
76+
---
77+
78+
## Running the Demo
79+
80+
### Prerequisites
81+
82+
* `torchcomms` and `torch.distributed.debug` installed (with deps:
83+
`tabulate`, `jinja2`, `aiohttp`).
84+
* Use `TEST_BACKEND=gloo TEST_DEVICE=cpu` for CPU-only testing, or
85+
a CUDA host with ≥ 2 GPUs.
86+
87+
### Launch
88+
89+
```bash
90+
FR_DUMP_DIR=/tmp/fr_hang_debug \
91+
FR_DUMP_INTERVAL=3 \
92+
COMM_TIMEOUT=15 \
93+
TEST_BACKEND=gloo \
94+
TEST_DEVICE=cpu \
95+
torchrun --nproc_per_node=2 verify_flight_recorder.py
96+
```
97+
98+
| Variable | Default | Description |
99+
|---|---|---|
100+
| `FR_DUMP_DIR` | `/tmp/fr_hang_debug` | Root dump directory |
101+
| `FR_DUMP_INTERVAL` | `5` | Seconds between periodic dumps |
102+
| `COMM_TIMEOUT` | `30` | Communicator timeout (seconds) |
103+
| `HANGING_RANK` | `-1` (last rank) | Which rank to hang |
104+
| `TEST_BACKEND` | `gloo` | Communication backend |
105+
| `TEST_DEVICE` | `cuda` | Tensor device |
106+
107+
### Expected Output
108+
109+
```
110+
[Rank 0/2] device=0, hanging_rank=1, timeout=15s
111+
[Rank 1/2] device=1, hanging_rank=1, timeout=15s
112+
[Rank 0] Debug server: http://localhost:25999
113+
[Rank 0] Periodic dumps every 3.0s → /tmp/fr_hang_debug
114+
[Rank 0] Per-rank pickles → /tmp/fr_hang_debug/per_rank
115+
[Rank 0] Phase 1: Running 3 all_reduce + 1 broadcast
116+
[Rank 0] Phase 1 complete
117+
[Rank 0] Phase 2: all_reduce (rank 1 will NOT participate)
118+
[Rank 0] Expecting timeout in ~15s ...
119+
[Rank 1] Phase 1 complete
120+
[Rank 1] >>> HANGING – entering infinite sleep <<<
121+
122+
... periodic mismatch warnings every 3 seconds ...
123+
124+
Not all ranks joining collective, sequence number: 4
125+
collective: nccl:all_reduce
126+
missing ranks: {1}
127+
collective state: scheduled
128+
129+
... ~15 seconds pass ...
130+
131+
[Rank 0] Caught timeout: RuntimeError: Timed out waiting 15000ms for recv operation
132+
[Rank 0] Pickle trace written to /tmp/fr_hang_debug/per_rank/rank_0
133+
```
134+
135+
---
136+
137+
## Approach A — Aggregated Text Dumps
138+
139+
The debug server writes periodic text snapshots aggregating data from
140+
all ranks:
141+
142+
```bash
143+
$ ls /tmp/fr_hang_debug/torchcomms_fr_trace_*.txt
144+
torchcomms_fr_trace_20260401_192058.txt
145+
torchcomms_fr_trace_20260401_192101.txt
146+
torchcomms_fr_trace_20260401_192104.txt
147+
...
148+
```
149+
150+
### Reading the dump
151+
152+
```bash
153+
cat /tmp/fr_hang_debug/torchcomms_fr_trace_20260401_192104.txt
154+
```
155+
156+
The **Collectives** table shows every recorded operation:
157+
158+
```
159+
--- Collectives ---
160+
id group_id pass_check collective_seq_id collective_name collective_state missing_ranks
161+
0 main_comm True 0 nccl:all_reduce scheduled
162+
1 main_comm True 1 nccl:all_reduce scheduled
163+
2 main_comm True 2 nccl:all_reduce scheduled
164+
3 main_comm True 3 nccl:broadcast scheduled
165+
4 main_comm True 4 nccl:all_reduce scheduled {1} ← MISMATCH
166+
```
167+
168+
The **NCCL Calls** table confirms which ranks participated:
169+
170+
```
171+
--- NCCL Calls ---
172+
id collective_id group_id global_rank collective_type
173+
0 0 main_comm 0 nccl:all_reduce
174+
1 0 main_comm 1 nccl:all_reduce
175+
...
176+
6 3 main_comm 0 nccl:broadcast
177+
7 3 main_comm 1 nccl:broadcast
178+
8 main_comm 0 nccl:all_reduce ← Only rank 0!
179+
```
180+
181+
The **Dump File** section confirms per-rank pickle files were written:
182+
183+
```
184+
=== TorchComms FR Dump File ===
185+
Rank 0: OK - Flight Recorder debug info written to /tmp/fr_hang_debug/per_rank/rank_0
186+
Rank 1: OK - Flight Recorder debug info written to /tmp/fr_hang_debug/per_rank/rank_1
187+
```
188+
189+
### Stacks dump
190+
191+
The `stacks_*.txt` files show Python tracebacks, pinpointing the
192+
exact line each rank is stuck at:
193+
194+
```bash
195+
$ cat /tmp/fr_hang_debug/stacks_20260401_192104.txt
196+
197+
=== Rank 0 ===
198+
File "verify_flight_recorder.py", line 148 in main ← all_reduce (waiting)
199+
200+
=== Rank 1 ===
201+
File "verify_flight_recorder.py", line 140 in main ← time.sleep (the hang!)
202+
```
203+
204+
**Interpretation:** Rank 1 never issued `collective_seq_id=4`.
205+
The stacks dump confirms it is stuck in `time.sleep`, not in a collective.
206+
207+
---
208+
209+
## Approach B — Per-Rank Pickle Dumps + FR CLI
210+
211+
The debug server's periodic dump also triggers each rank's worker
212+
server to write a pickle trace file. These land in the `per_rank/`
213+
subdirectory:
214+
215+
```bash
216+
$ ls /tmp/fr_hang_debug/per_rank/
217+
rank_0 rank_1
218+
```
219+
220+
Both files exist because each rank's worker server writes its own
221+
trace independently — even the hanging rank has a dump.
222+
223+
### Cross-rank mismatch analysis
224+
225+
```bash
226+
python -m torch.distributed.flight_recorder.fr_trace \
227+
/tmp/fr_hang_debug/per_rank -p rank_
228+
```
229+
230+
**Actual output:**
231+
232+
```
233+
Not all ranks joining collective, sequence number: 4
234+
internal record id: 4
235+
group info: main_comm:gloo
236+
collective: nccl:all_reduce
237+
missing ranks: {1}
238+
input sizes: [[1024]]
239+
output sizes: [[1024]]
240+
world size: 2
241+
expected ranks: {0, 1}
242+
collective state: scheduled
243+
```
244+
245+
The CLI automatically detected that **rank 1 never issued
246+
`collective_seq_id=4`** and flagged the mismatch.
247+
248+
### Side-by-side raw entry view
249+
250+
```bash
251+
python -m torch.distributed.flight_recorder.fr_trace \
252+
/tmp/fr_hang_debug/per_rank -p rank_ -j
253+
```
254+
255+
**Actual output:**
256+
257+
```
258+
Rank 0 Rank 1
259+
------------------------------------------------- -------------------------------------------------
260+
all_reduce(input_sizes=[[1024]], state=scheduled) all_reduce(input_sizes=[[1024]], state=scheduled)
261+
broadcast(input_sizes=[[1024]], state=scheduled) broadcast(input_sizes=[[1024]], state=scheduled)
262+
all_reduce(input_sizes=[[1024]], state=scheduled)
263+
```
264+
265+
Rank 0 has **5 entries** (3 `all_reduce` + 1 `broadcast` + the stuck
266+
`all_reduce`). Rank 1 has only **4 entries** — the 5th `all_reduce` is
267+
missing because rank 1 hung before issuing it.
268+
269+
### With stack traces
270+
271+
```bash
272+
python -m torch.distributed.flight_recorder.fr_trace \
273+
/tmp/fr_hang_debug/per_rank -p rank_ -j --print_stack_trace
274+
```
275+
276+
This adds Python stack traces to each entry, showing exactly where in
277+
user code each collective was called.
278+
279+
---
280+
281+
## What to Look For
282+
283+
| Symptom | Likely cause |
284+
|---|---|
285+
| `missing_ranks: {N}` in the Collectives table | Rank N hung or crashed before issuing the next collective |
286+
| Rank X's last entry is `state=started`, others are `completed` | Rank X issued the collective but is waiting for a peer that never joined |
287+
| Mismatched `collective_name` at the same `collective_seq_id` | Code-path divergence — ranks are calling different collectives |
288+
| Mismatched `input_sizes` / `output_sizes` | Tensor shape inconsistency across ranks |
289+
| Stacks dump shows `time.sleep` or user code (not a collective) | The rank is stuck in compute, not in a collective |
290+
291+
---
292+
293+
## Reference
294+
295+
For full API documentation on the tools used in this tutorial, see
296+
[`torch.distributed` — Debug HTTP Server](https://pytorch.org/docs/main/distributed.html#torch-distributed-debug-http-server)
297+
(source: `/home/tushar00jain/local/pytorch/docs/source/distributed.md`), which covers:
298+
299+
* `start_debug_server()` parameters (`dump_dir`, `dump_interval`, `enabled_dumps`, etc.)
300+
* Architecture (frontend server on rank 0, per-rank worker servers)
301+
* All frontend and worker-level HTTP endpoints
302+
* Periodic dumping configuration and supported handlers
303+
* Registering custom handlers (Python and C++)
304+
* Troubleshooting common issues
305+
306+
### FR CLI Quick Reference
307+
308+
```bash
309+
# Cross-rank mismatch analysis:
310+
python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix>
311+
312+
# Side-by-side raw entries per rank:
313+
python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> -j
314+
315+
# With stack traces:
316+
python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> -j --print_stack_trace
317+
318+
# Best-effort when some rank dumps are missing:
319+
python -m torch.distributed.flight_recorder.fr_trace <dir> -p <prefix> --allow-incomplete-ranks
320+
```

0 commit comments

Comments
 (0)