Skip to content

Commit 7b3c306

Browse files
committed
feat: log pre-clip gradient total_norm and per-param norm histogram to tensorboard
1 parent 65eea4b commit 7b3c306

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,24 @@ def step(_step_id: int, task_key: str = "Default") -> None:
10581058
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
10591059
)
10601060
loss.backward()
1061+
# === Initialize gradient diagnostics variables ===
1062+
total_norm: torch.Tensor | None = None
1063+
pre_clip_named_norms: list[tuple[str, float]] = []
10611064
if self.gradient_max_norm > 0.0:
1065+
# Collect per-parameter gradient norms before clipping.
1066+
# NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor,
1067+
# so p.grad.norm() computes the shard-local L2 norm, not the full-parameter
1068+
# norm. Skip per-param collection in this case to avoid misleading values.
1069+
if (
1070+
self.enable_tensorboard
1071+
and self.zero_stage < 2
1072+
and (_step_id % self.tensorboard_freq == 0 or _step_id == 1)
1073+
):
1074+
pre_clip_named_norms = [
1075+
(name, p.grad.detach().norm().item())
1076+
for name, p in self.wrapper.named_parameters()
1077+
if p.grad is not None
1078+
]
10621079
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
10631080
total_norm = torch.nn.utils.clip_grad_norm_(
10641081
self.wrapper.parameters(),
@@ -1410,6 +1427,32 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
14101427
writer.add_scalar(
14111428
f"{task_key}/{item}", more_loss[item], display_step_id
14121429
)
1430+
# === Gradient diagnostics (pre-clip) ===
1431+
# Only log if total_norm was computed (i.e., not LKF optimizer).
1432+
if self.gradient_max_norm > 0.0 and total_norm is not None:
1433+
writer.add_scalar(
1434+
f"{task_key}/grad/total_norm",
1435+
total_norm.item(),
1436+
display_step_id,
1437+
)
1438+
# Only log per-parameter norms if list is non-empty.
1439+
if pre_clip_named_norms:
1440+
# Use float32 for histogram to ensure numerical stability
1441+
# when gradients are in lower precision (FP16/BF16).
1442+
norms = torch.tensor(
1443+
[gn for _, gn in pre_clip_named_norms],
1444+
dtype=torch.float32,
1445+
device="cpu",
1446+
)
1447+
writer.add_histogram(
1448+
f"{task_key}/grad/param_norms", norms, display_step_id
1449+
)
1450+
# Log top-10 largest per-parameter gradient norms.
1451+
pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True)
1452+
for name, gn in pre_clip_named_norms[:10]:
1453+
writer.add_scalar(
1454+
f"{task_key}/grad_top10/{name}", gn, display_step_id
1455+
)
14131456

14141457
self.wrapper.train()
14151458
self.t0 = time.time()

0 commit comments

Comments
 (0)