Skip to content

Commit cf84059

Browse files
committed
fix sampling method in grad-prob
1 parent 77e4419 commit cf84059

2 files changed

Lines changed: 48 additions & 19 deletions

File tree

deepmd/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,16 @@ def main_parser() -> argparse.ArgumentParser:
878878
default=1,
879879
help="Number of batches per task to average gradient over.",
880880
)
881+
parser_grad_probe.add_argument(
882+
"-k",
883+
"--accumulate-k",
884+
type=int,
885+
default=1,
886+
dest="accumulate_k",
887+
help="Accumulate K batches before computing each dot product. "
888+
"Total batches collected = nbatches * accumulate_k. "
889+
"Larger K reduces norm variance and improves SNR of mean/std.",
890+
)
881891
return parser
882892

883893

deepmd/pt/entrypoints/main.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -547,20 +547,25 @@ def grad_probe(FLAGS) -> None:
547547
param_names: list | None = None
548548
param_shapes: list | None = None
549549

550+
accumulate_k = getattr(FLAGS, "accumulate_k", 1)
551+
total_batches = FLAGS.nbatches * accumulate_k
552+
550553
# Initialize trackers
551554
task_batch_norms = {k: [] for k in trainer.model_keys}
552555
task_accum_grads = {k: None for k in trainer.model_keys}
553-
pairwise_sims = {
556+
pairwise_dots = {
554557
(k1, k2): []
555558
for i, k1 in enumerate(trainer.model_keys)
556559
for j, k2 in enumerate(trainer.model_keys)
557560
if i < j
558561
}
559562

563+
# Accumulators within each group of K batches
564+
group_grads = {k: None for k in trainer.model_keys}
565+
560566
cur_lr = config["learning_rate"]["start_lr"]
561567

562-
for b in range(FLAGS.nbatches):
563-
batch_grads = {}
568+
for b in range(total_batches):
564569
for task_key in trainer.model_keys:
565570
trainer.optimizer.zero_grad(set_to_none=True)
566571
input_dict, label_dict, _ = trainer.get_data(
@@ -574,7 +579,6 @@ def grad_probe(FLAGS) -> None:
574579
model = module.model[task_key]
575580
descriptor = model.get_descriptor()
576581

577-
# Extract current batch gradient
578582
current_grads = []
579583
for name, param in descriptor.named_parameters():
580584
if param.grad is not None:
@@ -584,15 +588,17 @@ def grad_probe(FLAGS) -> None:
584588

585589
if current_grads:
586590
grad_vec = np.concatenate(current_grads)
587-
task_batch_norms[task_key].append(np.linalg.norm(grad_vec))
588591
if task_accum_grads[task_key] is None:
589-
task_accum_grads[task_key] = grad_vec
592+
task_accum_grads[task_key] = grad_vec.copy()
590593
else:
591594
task_accum_grads[task_key] += grad_vec
592-
batch_grads[task_key] = grad_vec
595+
596+
if group_grads[task_key] is None:
597+
group_grads[task_key] = grad_vec.copy()
598+
else:
599+
group_grads[task_key] += grad_vec
593600

594601
if param_names is None:
595-
# get names and shapes from descriptor parameters
596602
param_names = [
597603
n for n, p in descriptor.named_parameters() if p.requires_grad
598604
]
@@ -602,19 +608,26 @@ def grad_probe(FLAGS) -> None:
602608
if p.requires_grad
603609
]
604610

605-
# Compute pairwise dot products for this batch
606-
for k1, k2 in pairwise_sims.keys():
607-
if k1 in batch_grads and k2 in batch_grads:
608-
g1 = batch_grads[k1]
609-
g2 = batch_grads[k2]
610-
pairwise_sims[(k1, k2)].append(float(np.dot(g1, g2)))
611+
# At the end of each group of K batches, record norms and dot products then reset
612+
if (b + 1) % accumulate_k == 0:
613+
for task_key in trainer.model_keys:
614+
if group_grads[task_key] is not None:
615+
task_batch_norms[task_key].append(
616+
np.linalg.norm(group_grads[task_key])
617+
)
618+
for k1, k2 in pairwise_dots.keys():
619+
g1 = group_grads.get(k1)
620+
g2 = group_grads.get(k2)
621+
if g1 is not None and g2 is not None:
622+
pairwise_dots[(k1, k2)].append(float(np.dot(g1, g2)))
623+
group_grads = {k: None for k in trainer.model_keys}
611624

612625
# Compute final statistics and log for each task
613626
for task_key in trainer.model_keys:
614627
accumulated_grads = task_accum_grads[task_key]
615628
norms = task_batch_norms[task_key]
616629
if accumulated_grads is not None:
617-
avg_grad_vec = accumulated_grads / FLAGS.nbatches
630+
avg_grad_vec = accumulated_grads / total_batches
618631
norm_mean = np.mean(norms)
619632
norm_std = np.std(norms)
620633
else:
@@ -627,16 +640,19 @@ def grad_probe(FLAGS) -> None:
627640
grads_per_task[f"{task_key}_norm_std"] = norm_std
628641

629642
log.info(
630-
"Task '%s': collected %d batches. Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e",
643+
"Task '%s': collected %d batches (groups=%d, k=%d). "
644+
"Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e",
631645
task_key,
632-
len(norms),
646+
total_batches,
647+
FLAGS.nbatches,
648+
accumulate_k,
633649
float(np.linalg.norm(avg_grad_vec)),
634650
float(norm_mean),
635651
float(norm_std),
636652
)
637653

638654
# Compute pairwise dot product statistics
639-
for (k1, k2), dots in pairwise_sims.items():
655+
for (k1, k2), dots in pairwise_dots.items():
640656
if dots:
641657
dot_mean = np.mean(dots)
642658
dot_std = np.std(dots)
@@ -651,12 +667,15 @@ def grad_probe(FLAGS) -> None:
651667
grads_per_task[f"dot_{k1}_{k2}_accum"] = dot_accum
652668

653669
log.info(
654-
"Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e",
670+
"Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e "
671+
"(SNR=%.2f, n_groups=%d)",
655672
k1,
656673
k2,
657674
float(dot_mean),
658675
float(dot_std),
659676
float(dot_accum),
677+
abs(dot_mean) / dot_std if dot_std > 0 else float("inf"),
678+
len(dots),
660679
)
661680

662681
save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()}

0 commit comments

Comments
 (0)