Skip to content

Commit f0a966b

Browse files
authored
feat(pt): log pre-clip gradient total_norm and per-param norm to tensorboard (#5252)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enhanced gradient diagnostics during training: when TensorBoard logging is enabled, training now records total gradient norm, per-parameter gradient norms, and a histogram of per-parameter norms. The system also highlights the top-10 largest parameter gradients as individual scalars and surfaces total-norm values in training logs to improve analysis and monitoring. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 622179e commit f0a966b

1 file changed

Lines changed: 47 additions & 1 deletion

File tree

deepmd/pt/train/training.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,7 @@ def run(self) -> None:
10171017
prof.start()
10181018

10191019
def step(_step_id: int, task_key: str = "Default") -> None:
1020+
display_step_id = _step_id + 1
10201021
if self.multi_task:
10211022
model_index = dp_random.choice(
10221023
np.arange(self.num_model, dtype=np.int_),
@@ -1047,7 +1048,27 @@ def step(_step_id: int, task_key: str = "Default") -> None:
10471048
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
10481049
)
10491050
loss.backward()
1051+
# === Initialize gradient diagnostics variables ===
1052+
total_norm: torch.Tensor | None = None
1053+
pre_clip_named_norms: list[tuple[str, float]] = []
10501054
if self.gradient_max_norm > 0.0:
1055+
# Collect per-parameter gradient norms before clipping.
1056+
# NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor,
1057+
# so p.grad.norm() computes the shard-local L2 norm, not the full-parameter
1058+
# norm. Skip per-param collection in this case to avoid misleading values.
1059+
if (
1060+
self.enable_tensorboard
1061+
and self.zero_stage < 2
1062+
and (
1063+
display_step_id % self.tensorboard_freq == 0
1064+
or display_step_id == 1
1065+
)
1066+
):
1067+
pre_clip_named_norms = [
1068+
(name, p.grad.detach().norm().item())
1069+
for name, p in self.wrapper.named_parameters()
1070+
if p.grad is not None
1071+
]
10511072
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
10521073
total_norm = torch.nn.utils.clip_grad_norm_(
10531074
self.wrapper.parameters(),
@@ -1172,7 +1193,6 @@ def fake_model() -> dict:
11721193
self.train_loss_accu[task_key][item] += more_loss[item]
11731194

11741195
# Log and persist
1175-
display_step_id = _step_id + 1
11761196
if self.display_in_training and (
11771197
display_step_id % self.disp_freq == 0 or display_step_id == 1
11781198
):
@@ -1399,6 +1419,32 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13991419
writer.add_scalar(
14001420
f"{task_key}/{item}", more_loss[item], display_step_id
14011421
)
1422+
# === Gradient diagnostics (pre-clip) ===
1423+
# Only log if total_norm was computed (i.e., not LKF optimizer).
1424+
if self.gradient_max_norm > 0.0 and total_norm is not None:
1425+
writer.add_scalar(
1426+
f"{task_key}/grad/total_norm",
1427+
total_norm.item(),
1428+
display_step_id,
1429+
)
1430+
# Only log per-parameter norms if list is non-empty.
1431+
if pre_clip_named_norms:
1432+
# Use float32 for histogram to ensure numerical stability
1433+
# when gradients are in lower precision (FP16/BF16).
1434+
norms = torch.tensor(
1435+
[gn for _, gn in pre_clip_named_norms],
1436+
dtype=torch.float32,
1437+
device="cpu",
1438+
)
1439+
writer.add_histogram(
1440+
f"{task_key}/grad/param_norms", norms, display_step_id
1441+
)
1442+
# Log top-10 largest per-parameter gradient norms.
1443+
pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True)
1444+
for name, gn in pre_clip_named_norms[:10]:
1445+
writer.add_scalar(
1446+
f"{task_key}/grad_top10/{name}", gn, display_step_id
1447+
)
14021448

14031449
self.wrapper.train()
14041450
self.t0 = time.time()

0 commit comments

Comments
 (0)