|
| 1 | +--- |
| 2 | +name: debug-gradient-flow |
| 3 | +description: Diagnose gradient flow issues in training, especially for compiled models (torch.compile/make_fx). Systematically isolates which loss components (energy, force, virial) contribute gradients to which parameters, and identifies where the gradient chain breaks. |
| 4 | +license: LGPL-3.0-or-later |
| 5 | +metadata: |
| 6 | + author: deepmd-kit |
| 7 | + version: '1.0' |
| 8 | +--- |
| 9 | + |
| 10 | +# Debugging Gradient Flow in Training |
| 11 | + |
| 12 | +Use this method when a loss component (force, virial, energy) does not decrease during training, or when compiled model training diverges from uncompiled training. |
| 13 | + |
| 14 | +## When to use |
| 15 | + |
| 16 | +- A loss term (e.g. `rmse_f`, `rmse_v`) stays flat or NaN during training |
| 17 | +- Compiled training (`enable_compile=True`) behaves differently from uncompiled |
| 18 | +- After adding a new loss component or model output |
| 19 | +- After changes to `make_fx` tracing, `torch.compile`, or `autograd.grad` code paths |
| 20 | + |
| 21 | +## Method: Per-component gradient isolation |
| 22 | + |
| 23 | +The core technique: **zero out all loss terms except one**, run `loss.backward()`, and count which model parameters receive non-zero gradients. Compare across uncompiled and compiled paths to pinpoint where gradients are lost. |
| 24 | + |
| 25 | +### Step 1: Write a gradient probe script |
| 26 | + |
| 27 | +Create a script that constructs a trainer, injects labels if needed, and reports per-parameter gradient status: |
| 28 | + |
| 29 | +```python |
| 30 | +def check_grad(trainer, label_overrides=None): |
| 31 | + trainer.wrapper.train() |
| 32 | + trainer.optimizer.zero_grad(set_to_none=True) |
| 33 | + inp, lab = trainer.get_data(is_train=True) |
| 34 | + lr = trainer.scheduler.get_last_lr()[0] |
| 35 | + |
| 36 | + # Override labels to isolate a single loss component |
| 37 | + if label_overrides: |
| 38 | + lab.update(label_overrides) |
| 39 | + |
| 40 | + _, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab) |
| 41 | + loss.backward() |
| 42 | + |
| 43 | + status = {} |
| 44 | + for name, p in trainer.wrapper.named_parameters(): |
| 45 | + if p.requires_grad: |
| 46 | + has_grad = p.grad is not None and p.grad.abs().sum() > 0 |
| 47 | + status[name] = has_grad |
| 48 | + return status |
| 49 | +``` |
| 50 | + |
| 51 | +### Step 2: Run for each loss component in isolation |
| 52 | + |
| 53 | +Test each loss component separately by zeroing out the others: |
| 54 | + |
| 55 | +```python |
| 56 | +scenarios = { |
| 57 | + "energy only": {"find_force": 0.0, "find_virial": 0.0}, |
| 58 | + "force only": {"find_energy": 0.0, "find_virial": 0.0}, |
| 59 | + "virial only": { |
| 60 | + "find_energy": 0.0, |
| 61 | + "find_force": 0.0, |
| 62 | + "virial": torch.randn(nframes, 9, ...), # inject if data lacks virial |
| 63 | + "find_virial": 1.0, |
| 64 | + }, |
| 65 | + "all losses": { |
| 66 | + "virial": torch.randn(nframes, 9, ...), |
| 67 | + "find_virial": 1.0, |
| 68 | + }, |
| 69 | +} |
| 70 | +``` |
| 71 | + |
| 72 | +If training data lacks virial labels, inject synthetic ones — the numerical values don't matter, only gradient flow matters. |
| 73 | + |
| 74 | +### Step 3: Compare compiled vs uncompiled |
| 75 | + |
| 76 | +Run each scenario for both compiled and uncompiled trainers. Present results as a table: |
| 77 | + |
| 78 | +``` |
| 79 | + Uncompiled Compiled |
| 80 | +energy only: 22/22 22/22 |
| 81 | +force only: 20/22 16/22 <-- problem |
| 82 | +virial only: 20/22 16/22 <-- problem |
| 83 | +all losses: 22/22 22/22 <-- OK in practice |
| 84 | +``` |
| 85 | + |
| 86 | +Key interpretations: |
| 87 | + |
| 88 | +- **Same count, both paths**: gradient flow is correct |
| 89 | +- **Compiled < Uncompiled**: `make_fx` or `torch.compile` breaks some gradient paths |
| 90 | +- **0 grads in compiled**: catastrophic failure (e.g. wrong `create_graph`, wrong backend) |
| 91 | +- **"all losses" is OK but isolated isn't**: the missing grads are covered by other loss terms; may be acceptable |
| 92 | + |
| 93 | +### Step 4: Identify affected parameters |
| 94 | + |
| 95 | +When compiled has fewer grads, print the per-parameter diff: |
| 96 | + |
| 97 | +```python |
| 98 | +print(f"{'Parameter':<60} {'Uncompiled':>10} {'Compiled':>10}") |
| 99 | +for name in sorted(status_uncompiled): |
| 100 | + uc = "GRAD" if status_uncompiled[name] else "-" |
| 101 | + cc = "GRAD" if status_compiled[name] else "-" |
| 102 | + marker = " <-- DIFF" if uc != cc else "" |
| 103 | + print(f"{name:<60} {uc:>10} {cc:>10}{marker}") |
| 104 | +``` |
| 105 | + |
| 106 | +This tells you exactly which layers lose gradients and helps locate the broken link in the computation graph. |
| 107 | + |
| 108 | +### Step 5: Bisect the cause |
| 109 | + |
| 110 | +If compiled has fewer grads, test these layers in order: |
| 111 | + |
| 112 | +| Layer | What to try | What it tests | |
| 113 | +| ------------------------------------------------ | ------------------------------------------------------- | ------------------------------------------------------ | |
| 114 | +| `make_fx` only (no `torch.compile`) | Replace `torch.compile(traced, ...)` with just `traced` | Is `make_fx` the problem or `torch.compile`? | |
| 115 | +| Different `torch.compile` backends | Try `eager`, `aot_eager`, `inductor` | Which backend breaks gradients? | |
| 116 | +| `model.train()` vs `model.eval()` during tracing | Toggle training mode before `make_fx` | Does `create_graph=self.training` get the wrong value? | |
| 117 | +| `coord.requires_grad_(True)` placement | Check if coord has grad before entering compiled graph | Is the autograd entry point correct? | |
| 118 | + |
| 119 | +```python |
| 120 | +# Test make_fx only (no torch.compile) |
| 121 | +traced = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam) |
| 122 | +# Use traced directly instead of torch.compile(traced) |
| 123 | + |
| 124 | +# Test different backends |
| 125 | +for backend in ["eager", "aot_eager", "inductor"]: |
| 126 | + compiled = torch.compile(traced, backend=backend, dynamic=False) |
| 127 | + # ... run gradient check |
| 128 | +``` |
| 129 | + |
| 130 | +## Common root causes |
| 131 | + |
| 132 | +### 1. `create_graph=False` during tracing |
| 133 | + |
| 134 | +**Symptom**: force/virial loss doesn't decrease; 0 params get grad from force/virial loss. |
| 135 | + |
| 136 | +**Cause**: `model.eval()` before `make_fx` tracing makes `create_graph=self.training` evaluate to `False`. The `autograd.grad` that computes force is traced without graph creation, so the force tensor is detached from model parameters. |
| 137 | + |
| 138 | +**Fix**: `model.train()` before `make_fx` tracing. |
| 139 | + |
| 140 | +**Location**: `_trace_and_compile` in `deepmd/pt_expt/train/training.py` |
| 141 | + |
| 142 | +### 2. `torch.compile` inductor backend kills second-order gradients |
| 143 | + |
| 144 | +**Symptom**: force/virial loss doesn't decrease; 0 params get grad with inductor, but `eager`/`aot_eager` work fine. |
| 145 | + |
| 146 | +**Cause**: The inductor backend's graph lowering doesn't support backward through `make_fx`-decomposed `autograd.grad` ops. |
| 147 | + |
| 148 | +**Fix**: Default to `aot_eager` backend. |
| 149 | + |
| 150 | +### 3. Ghost force contributions discarded |
| 151 | + |
| 152 | +**Symptom**: force values differ between compiled and uncompiled models. |
| 153 | + |
| 154 | +**Cause**: Using `extended_force[:, :nloc, :]` (slice) instead of scatter-summing ghost atom contributions back to local atoms via `mapping`. |
| 155 | + |
| 156 | +**Fix**: `torch.zeros(...).scatter_add_(1, mapping_idx, extended_force[:, :actual_nall, :])` |
| 157 | + |
| 158 | +### 4. Virial RMSE normalization mismatch |
| 159 | + |
| 160 | +**Symptom**: `rmse_v` values differ between backends by a factor of `natoms`. |
| 161 | + |
| 162 | +**Cause**: dpmodel `rmse_v = sqrt(l2_virial_loss)` missing `* atom_norm` normalization that other backends apply. |
| 163 | + |
| 164 | +**Fix**: `rmse_v = sqrt(l2_virial_loss) * atom_norm` |
| 165 | + |
| 166 | +## Verification |
| 167 | + |
| 168 | +After fixing, always verify: |
| 169 | + |
| 170 | +1. **Gradient count matches**: uncompiled and compiled should have the same number of params with grad for each isolated loss component |
| 171 | +1. **Numerical consistency**: compiled model energy/force/virial should match uncompiled to float precision (`atol=1e-10, rtol=1e-10`) |
| 172 | +1. **Loss decreases**: run a few training steps and verify `rmse_f` / `rmse_v` actually decrease |
| 173 | +1. **Regression test**: add a test that catches the bug by reverting the fix and confirming the test fails |
| 174 | + |
| 175 | +```bash |
| 176 | +# Run compiled consistency test |
| 177 | +python -m pytest source/tests/pt_expt/test_training.py::TestCompiledConsistency -v |
| 178 | +# Run loss consistency test |
| 179 | +python -m pytest source/tests/consistent/loss/test_ener.py -v |
| 180 | +# Run full training smoke test |
| 181 | +python -m pytest source/tests/pt_expt/test_training.py -v |
| 182 | +``` |
0 commit comments