Skip to content

Commit 6ef654c

Browse files
committed
fix
1 parent d1c1840 commit 6ef654c

1 file changed

Lines changed: 32 additions & 17 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -990,9 +990,15 @@ def step(_step_id: int, task_key: str = "Default") -> None:
990990
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
991991
)
992992
loss.backward()
993+
# === Initialize gradient diagnostics variables ===
994+
total_norm: torch.Tensor | None = None
995+
pre_clip_named_norms: list[tuple[str, float]] = []
993996
if self.gradient_max_norm > 0.0:
994997
# Collect per-parameter gradient norms before clipping.
995-
if self.enable_tensorboard:
998+
# NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor,
999+
# so p.grad.norm() computes the shard-local L2 norm, not the full-parameter
1000+
# norm. Skip per-param collection in this case to avoid misleading values.
1001+
if self.enable_tensorboard and self.zero_stage < 2:
9961002
pre_clip_named_norms = [
9971003
(name, p.grad.detach().norm().item())
9981004
for name, p in self.wrapper.named_parameters()
@@ -1350,24 +1356,33 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13501356
f"{task_key}/{item}", more_loss[item], display_step_id
13511357
)
13521358
# === Gradient diagnostics (pre-clip) ===
1353-
if self.gradient_max_norm > 0.0:
1359+
# Only log if total_norm was computed (i.e., not LKF optimizer).
1360+
if self.gradient_max_norm > 0.0 and total_norm is not None:
13541361
writer.add_scalar(
1355-
"grad/total_norm", total_norm.item(), display_step_id
1356-
)
1357-
norms = torch.tensor(
1358-
[gn for _, gn in pre_clip_named_norms],
1359-
dtype=torch.float32,
1360-
device="cpu",
1362+
f"{task_key}/grad/total_norm",
1363+
total_norm.item(),
1364+
display_step_id,
13611365
)
1362-
writer.add_histogram("grad/param_norms", norms, display_step_id)
1363-
# Log top-10 largest per-parameter gradient norms.
1364-
# Shorten name: keep everything after "atomic_model.".
1365-
pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True)
1366-
for name, gn in pre_clip_named_norms[:10]:
1367-
idx = name.find("atomic_model.")
1368-
if idx >= 0:
1369-
name = name[idx + len("atomic_model.") :]
1370-
writer.add_scalar(f"grad_top10/{name}", gn, display_step_id)
1366+
# Only log per-parameter norms if list is non-empty.
1367+
if pre_clip_named_norms:
1368+
norms = torch.tensor(
1369+
[gn for _, gn in pre_clip_named_norms],
1370+
dtype=torch.float32,
1371+
device="cpu",
1372+
)
1373+
writer.add_histogram(
1374+
f"{task_key}/grad/param_norms", norms, display_step_id
1375+
)
1376+
# Log top-10 largest per-parameter gradient norms.
1377+
# Shorten name: keep everything after "atomic_model.".
1378+
pre_clip_named_norms.sort(key=lambda x: x[1], reverse=True)
1379+
for name, gn in pre_clip_named_norms[:10]:
1380+
idx = name.find("atomic_model.")
1381+
if idx >= 0:
1382+
name = name[idx + len("atomic_model.") :]
1383+
writer.add_scalar(
1384+
f"{task_key}/grad_top10/{name}", gn, display_step_id
1385+
)
13711386

13721387
self.wrapper.train()
13731388
self.t0 = time.time()

0 commit comments

Comments
 (0)