Skip to content

Commit b30bfcb

Browse files
wanghan-iapcmHan Wang
andauthored
feat(skill): add skill for debug gradient flow in the pt expt backend (#5280)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Documentation** * Added comprehensive debugging guide for diagnosing gradient flow issues during model training, including step-by-step diagnostic methodology and common root causes. * Added diagnostic script with practical code examples and walkthroughs for isolating gradient behavior across different training modes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 41757f2 commit b30bfcb

2 files changed

Lines changed: 485 additions & 0 deletions

File tree

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
---
2+
name: debug-gradient-flow
3+
description: Diagnose gradient flow issues in training, especially for compiled models (torch.compile/make_fx). Systematically isolates which loss components (energy, force, virial) contribute gradients to which parameters, and identifies where the gradient chain breaks.
4+
license: LGPL-3.0-or-later
5+
metadata:
6+
author: deepmd-kit
7+
version: '1.0'
8+
---
9+
10+
# Debugging Gradient Flow in Training
11+
12+
Use this method when a loss component (force, virial, energy) does not decrease during training, or when compiled model training diverges from uncompiled training.
13+
14+
## When to use
15+
16+
- A loss term (e.g. `rmse_f`, `rmse_v`) stays flat or NaN during training
17+
- Compiled training (`enable_compile=True`) behaves differently from uncompiled
18+
- After adding a new loss component or model output
19+
- After changes to `make_fx` tracing, `torch.compile`, or `autograd.grad` code paths
20+
21+
## Method: Per-component gradient isolation
22+
23+
The core technique: **zero out all loss terms except one**, run `loss.backward()`, and count which model parameters receive non-zero gradients. Compare across uncompiled and compiled paths to pinpoint where gradients are lost.
24+
25+
### Step 1: Write a gradient probe script
26+
27+
Create a script that constructs a trainer, injects labels if needed, and reports per-parameter gradient status:
28+
29+
```python
30+
def check_grad(trainer, label_overrides=None):
31+
trainer.wrapper.train()
32+
trainer.optimizer.zero_grad(set_to_none=True)
33+
inp, lab = trainer.get_data(is_train=True)
34+
lr = trainer.scheduler.get_last_lr()[0]
35+
36+
# Override labels to isolate a single loss component
37+
if label_overrides:
38+
lab.update(label_overrides)
39+
40+
_, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab)
41+
loss.backward()
42+
43+
status = {}
44+
for name, p in trainer.wrapper.named_parameters():
45+
if p.requires_grad:
46+
has_grad = p.grad is not None and p.grad.abs().sum() > 0
47+
status[name] = has_grad
48+
return status
49+
```
50+
51+
### Step 2: Run for each loss component in isolation
52+
53+
Test each loss component separately by zeroing out the others:
54+
55+
```python
56+
scenarios = {
57+
"energy only": {"find_force": 0.0, "find_virial": 0.0},
58+
"force only": {"find_energy": 0.0, "find_virial": 0.0},
59+
"virial only": {
60+
"find_energy": 0.0,
61+
"find_force": 0.0,
62+
"virial": torch.randn(nframes, 9, ...), # inject if data lacks virial
63+
"find_virial": 1.0,
64+
},
65+
"all losses": {
66+
"virial": torch.randn(nframes, 9, ...),
67+
"find_virial": 1.0,
68+
},
69+
}
70+
```
71+
72+
If training data lacks virial labels, inject synthetic ones — the numerical values don't matter, only gradient flow matters.
73+
74+
### Step 3: Compare compiled vs uncompiled
75+
76+
Run each scenario for both compiled and uncompiled trainers. Present results as a table:
77+
78+
```
79+
Uncompiled Compiled
80+
energy only: 22/22 22/22
81+
force only: 20/22 16/22 <-- problem
82+
virial only: 20/22 16/22 <-- problem
83+
all losses: 22/22 22/22 <-- OK in practice
84+
```
85+
86+
Key interpretations:
87+
88+
- **Same count, both paths**: gradient flow is correct
89+
- **Compiled < Uncompiled**: `make_fx` or `torch.compile` breaks some gradient paths
90+
- **0 grads in compiled**: catastrophic failure (e.g. wrong `create_graph`, wrong backend)
91+
- **"all losses" is OK but isolated isn't**: the missing grads are covered by other loss terms; may be acceptable
92+
93+
### Step 4: Identify affected parameters
94+
95+
When compiled has fewer grads, print the per-parameter diff:
96+
97+
```python
98+
print(f"{'Parameter':<60} {'Uncompiled':>10} {'Compiled':>10}")
99+
for name in sorted(status_uncompiled):
100+
uc = "GRAD" if status_uncompiled[name] else "-"
101+
cc = "GRAD" if status_compiled[name] else "-"
102+
marker = " <-- DIFF" if uc != cc else ""
103+
print(f"{name:<60} {uc:>10} {cc:>10}{marker}")
104+
```
105+
106+
This tells you exactly which layers lose gradients and helps locate the broken link in the computation graph.
107+
108+
### Step 5: Bisect the cause
109+
110+
If compiled has fewer grads, test these layers in order:
111+
112+
| Layer | What to try | What it tests |
113+
| ------------------------------------------------ | ------------------------------------------------------- | ------------------------------------------------------ |
114+
| `make_fx` only (no `torch.compile`) | Replace `torch.compile(traced, ...)` with just `traced` | Is `make_fx` the problem or `torch.compile`? |
115+
| Different `torch.compile` backends | Try `eager`, `aot_eager`, `inductor` | Which backend breaks gradients? |
116+
| `model.train()` vs `model.eval()` during tracing | Toggle training mode before `make_fx` | Does `create_graph=self.training` get the wrong value? |
117+
| `coord.requires_grad_(True)` placement | Check if coord has grad before entering compiled graph | Is the autograd entry point correct? |
118+
119+
```python
120+
# Test make_fx only (no torch.compile)
121+
traced = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)
122+
# Use traced directly instead of torch.compile(traced)
123+
124+
# Test different backends
125+
for backend in ["eager", "aot_eager", "inductor"]:
126+
compiled = torch.compile(traced, backend=backend, dynamic=False)
127+
# ... run gradient check
128+
```
129+
130+
## Common root causes
131+
132+
### 1. `create_graph=False` during tracing
133+
134+
**Symptom**: force/virial loss doesn't decrease; 0 params get grad from force/virial loss.
135+
136+
**Cause**: `model.eval()` before `make_fx` tracing makes `create_graph=self.training` evaluate to `False`. The `autograd.grad` that computes force is traced without graph creation, so the force tensor is detached from model parameters.
137+
138+
**Fix**: `model.train()` before `make_fx` tracing.
139+
140+
**Location**: `_trace_and_compile` in `deepmd/pt_expt/train/training.py`
141+
142+
### 2. `torch.compile` inductor backend kills second-order gradients
143+
144+
**Symptom**: force/virial loss doesn't decrease; 0 params get grad with inductor, but `eager`/`aot_eager` work fine.
145+
146+
**Cause**: The inductor backend's graph lowering doesn't support backward through `make_fx`-decomposed `autograd.grad` ops.
147+
148+
**Fix**: Default to `aot_eager` backend.
149+
150+
### 3. Ghost force contributions discarded
151+
152+
**Symptom**: force values differ between compiled and uncompiled models.
153+
154+
**Cause**: Using `extended_force[:, :nloc, :]` (slice) instead of scatter-summing ghost atom contributions back to local atoms via `mapping`.
155+
156+
**Fix**: `torch.zeros(...).scatter_add_(1, mapping_idx, extended_force[:, :actual_nall, :])`
157+
158+
### 4. Virial RMSE normalization mismatch
159+
160+
**Symptom**: `rmse_v` values differ between backends by a factor of `natoms`.
161+
162+
**Cause**: dpmodel `rmse_v = sqrt(l2_virial_loss)` missing `* atom_norm` normalization that other backends apply.
163+
164+
**Fix**: `rmse_v = sqrt(l2_virial_loss) * atom_norm`
165+
166+
## Verification
167+
168+
After fixing, always verify:
169+
170+
1. **Gradient count matches**: uncompiled and compiled should have the same number of params with grad for each isolated loss component
171+
1. **Numerical consistency**: compiled model energy/force/virial should match uncompiled to float precision (`atol=1e-10, rtol=1e-10`)
172+
1. **Loss decreases**: run a few training steps and verify `rmse_f` / `rmse_v` actually decrease
173+
1. **Regression test**: add a test that catches the bug by reverting the fix and confirming the test fails
174+
175+
```bash
176+
# Run compiled consistency test
177+
python -m pytest source/tests/pt_expt/test_training.py::TestCompiledConsistency -v
178+
# Run loss consistency test
179+
python -m pytest source/tests/consistent/loss/test_ener.py -v
180+
# Run full training smoke test
181+
python -m pytest source/tests/pt_expt/test_training.py -v
182+
```

0 commit comments

Comments
 (0)