Skip to content

Commit c9d6b27

Browse files
committed
fix dot_mean normalization
1 parent cf84059 commit c9d6b27

2 files changed

Lines changed: 14 additions & 9 deletions

File tree

deepmd/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,8 @@ def main_parser() -> argparse.ArgumentParser:
884884
type=int,
885885
default=1,
886886
dest="accumulate_k",
887-
help="Accumulate K batches before computing each dot product. "
888-
"Total batches collected = nbatches * accumulate_k. "
887+
help="Group size for gradient accumulation before computing each dot product. "
888+
"Total batches collected = nbatches; number of dot product samples = nbatches // k. "
889889
"Larger K reduces norm variance and improves SNR of mean/std.",
890890
)
891891
return parser

deepmd/pt/entrypoints/main.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def grad_probe(FLAGS) -> None:
548548
param_shapes: list | None = None
549549

550550
accumulate_k = getattr(FLAGS, "accumulate_k", 1)
551-
total_batches = FLAGS.nbatches * accumulate_k
551+
total_batches = FLAGS.nbatches
552552

553553
# Initialize trackers
554554
task_batch_norms = {k: [] for k in trainer.model_keys}
@@ -610,14 +610,19 @@ def grad_probe(FLAGS) -> None:
610610

611611
# At the end of each group of K batches, record norms and dot products then reset
612612
if (b + 1) % accumulate_k == 0:
613+
# Normalize to per-batch mean so scale is independent of k
614+
group_means = {
615+
k: g / accumulate_k if g is not None else None
616+
for k, g in group_grads.items()
617+
}
613618
for task_key in trainer.model_keys:
614-
if group_grads[task_key] is not None:
619+
if group_means[task_key] is not None:
615620
task_batch_norms[task_key].append(
616-
np.linalg.norm(group_grads[task_key])
621+
np.linalg.norm(group_means[task_key])
617622
)
618623
for k1, k2 in pairwise_dots.keys():
619-
g1 = group_grads.get(k1)
620-
g2 = group_grads.get(k2)
624+
g1 = group_means.get(k1)
625+
g2 = group_means.get(k2)
621626
if g1 is not None and g2 is not None:
622627
pairwise_dots[(k1, k2)].append(float(np.dot(g1, g2)))
623628
group_grads = {k: None for k in trainer.model_keys}
@@ -641,10 +646,10 @@ def grad_probe(FLAGS) -> None:
641646

642647
log.info(
643648
"Task '%s': collected %d batches (groups=%d, k=%d). "
644-
"Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e",
649+
"Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e",
645650
task_key,
646651
total_batches,
647-
FLAGS.nbatches,
652+
total_batches // accumulate_k,
648653
accumulate_k,
649654
float(np.linalg.norm(avg_grad_vec)),
650655
float(norm_mean),

0 commit comments

Comments
 (0)