Skip to content

Commit 543e6dd

Browse files
committed
feat: add noise scale
1 parent 412d4b7 commit 543e6dd

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

deepmd/pt/entrypoints/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,20 +659,29 @@ def grad_probe(FLAGS) -> None:
659659
norm_mean = 0.0
660660
norm_std = 0.0
661661

662+
G_sq_norm = float(np.dot(avg_grad_vec, avg_grad_vec))
663+
norms_arr = np.array(norms)
664+
mean_sq_norm = float(np.mean(norms_arr ** 2))
665+
tr_sigma = mean_sq_norm - G_sq_norm
666+
noise_scale = tr_sigma / G_sq_norm if G_sq_norm > 1e-30 else float("nan")
667+
662668
grads_per_task[task_key] = avg_grad_vec
663669
grads_per_task[f"{task_key}_norm_mean"] = norm_mean
664670
grads_per_task[f"{task_key}_norm_std"] = norm_std
671+
grads_per_task[f"{task_key}_noise_scale"] = noise_scale
665672

666673
log.info(
667674
"Task '%s': collected %d batches (groups=%d, k=%d). "
668-
"Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e",
675+
"Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e, "
676+
"Noise Scale(group)=%.4e",
669677
task_key,
670678
total_batches,
671679
total_batches // accumulate_k,
672680
accumulate_k,
673681
float(np.linalg.norm(avg_grad_vec)),
674682
float(norm_mean),
675683
float(norm_std),
684+
noise_scale,
676685
)
677686

678687
# Compute pairwise dot product and cosine similarity statistics

0 commit comments

Comments
 (0)