Skip to content

Commit d1c1840

Browse files
committed
feat: log pre-clip gradient total_norm and per-param norm histogram to tensorboard
1 parent 8ed49be commit d1c1840

1 file changed

Lines changed: 26 additions & 0 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,13 @@ def step(_step_id: int, task_key: str = "Default") -> None:
991991
)
992992
loss.backward()
993993
if self.gradient_max_norm > 0.0:
994+
# Collect per-parameter gradient norms before clipping.
995+
if self.enable_tensorboard:
996+
pre_clip_named_norms = [
997+
(name, p.grad.detach().norm().item())
998+
for name, p in self.wrapper.named_parameters()
999+
if p.grad is not None
1000+
]
9941001
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
9951002
total_norm = torch.nn.utils.clip_grad_norm_(
9961003
self.wrapper.parameters(),
@@ -1342,6 +1349,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13421349
writer.add_scalar(
13431350
f"{task_key}/{item}", more_loss[item], display_step_id
13441351
)
1352+
# === Gradient diagnostics (pre-clip) ===
1353+
if self.gradient_max_norm > 0.0:
1354+
writer.add_scalar(
1355+
"grad/total_norm", total_norm.item(), display_step_id
1356+
)
1357+
norms = torch.tensor(
1358+
[gn for _, gn in pre_clip_named_norms],
1359+
dtype=torch.float32,
1360+
device="cpu",
1361+
)
1362+
writer.add_histogram("grad/param_norms", norms, display_step_id)
1363+
# Log top-10 largest per-parameter gradient norms.
1364+
# Shorten name: keep everything after "atomic_model.".
1365+
pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True)
1366+
for name, gn in pre_clip_named_norms[:10]:
1367+
idx = name.find("atomic_model.")
1368+
if idx >= 0:
1369+
name = name[idx + len("atomic_model.") :]
1370+
writer.add_scalar(f"grad_top10/{name}", gn, display_step_id)
13451371

13461372
self.wrapper.train()
13471373
self.t0 = time.time()

0 commit comments

Comments
 (0)