Skip to content

Commit 0f56a5f

Browse files
committed
fix: revert covariance
1 parent b29fe5f commit 0f56a5f

1 file changed

Lines changed: 22 additions & 59 deletions

File tree

deepmd/pt/entrypoints/main.py

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -656,102 +656,65 @@ def grad_probe(FLAGS) -> None:
656656
float(norm_std),
657657
)
658658

659-
# Compute pairwise dot product and Pearson correlation statistics
659+
# Compute pairwise dot product and cosine similarity statistics
660660
for (k1, k2), dots in pairwise_dots.items():
661661
if dots:
662-
dot_mean = np.mean(dots)
663-
dot_std = np.std(dots)
662+
dot_mean = float(np.mean(dots))
663+
dot_std = float(np.std(dots))
664664

665665
g1_avg = grads_per_task[k1]
666666
g2_avg = grads_per_task[k2]
667-
dot_accum = float(np.dot(g1_avg, g2_avg))
667+
dot_global = float(np.dot(g1_avg, g2_avg))
668668

669-
# Pearson correlation across groups:
670-
# cov_sum = Σ_j Cov(g1_j, g2_j) = E[dot(g1,g2)] - dot(ḡ1, ḡ2)
671-
# var_sum = Σ_j Var(g_j) = E[||g||²] - ||ḡ||²
672-
# pearson = cov_sum / sqrt(var1_sum * var2_sum)
673-
cov_sum = float(dot_mean - dot_accum)
669+
# Group-level cosine: per-group cos_i = dot_i / (||g1_i|| * ||g2_i||)
674670
norms_1 = np.array(task_batch_norms[k1])
675671
norms_2 = np.array(task_batch_norms[k2])
676-
var1_sum = float(np.mean(norms_1**2) - np.dot(g1_avg, g1_avg))
677-
var2_sum = float(np.mean(norms_2**2) - np.dot(g2_avg, g2_avg))
678-
denom = np.sqrt(max(var1_sum, 0.0) * max(var2_sum, 0.0))
679-
pearson = float(cov_sum / denom) if denom > 1e-12 else float("nan")
680-
681-
# Group-level weighted cosine similarity:
682-
# per-group cos_i = dot_i / (||g1_i|| * ||g2_i||)
683-
# weight w_i = ||g1_i|| * ||g2_i|| so high-norm groups dominate
684-
norm_products = norms_1 * norms_2 # (n_groups,) weight per group
672+
norm_products = norms_1 * norms_2
685673
valid = norm_products > 1e-12
686674
cos_per_group = np.where(
687675
valid,
688676
np.array(dots) / np.where(valid, norm_products, 1.0),
689-
0.0,
677+
float("nan"),
690678
)
691-
norm_product_sum = float(np.sum(norm_products))
692-
if norm_product_sum > 1e-12:
693-
cos_group_mean = float(
694-
np.sum(norm_products * cos_per_group) / norm_product_sum
695-
)
696-
cos_group_std = float(
697-
np.sqrt(
698-
np.sum(norm_products * (cos_per_group - cos_group_mean) ** 2)
699-
/ norm_product_sum
700-
)
701-
)
679+
valid_cos = cos_per_group[valid]
680+
if len(valid_cos) > 0:
681+
cos_group_mean = float(np.mean(valid_cos))
682+
cos_group_std = float(np.std(valid_cos))
702683
else:
703684
cos_group_mean = float("nan")
704685
cos_group_std = float("nan")
705686

706-
# Global-level cosine similarity:
707-
# cosine of accumulated average gradient vectors;
708-
# large-norm batches naturally dominate the direction of the average
687+
# Global-level cosine: cosine of accumulated average gradient vectors
709688
norm_g1 = float(np.linalg.norm(g1_avg))
710689
norm_g2 = float(np.linalg.norm(g2_avg))
711690
denom_global = norm_g1 * norm_g2
712691
cos_global = (
713-
float(dot_accum / denom_global)
692+
float(dot_global / denom_global)
714693
if denom_global > 1e-12
715694
else float("nan")
716695
)
717696

718697
grads_per_task[f"dot_{k1}_{k2}_mean"] = dot_mean
719698
grads_per_task[f"dot_{k1}_{k2}_std"] = dot_std
720-
grads_per_task[f"dot_{k1}_{k2}_accum"] = dot_accum
721-
grads_per_task[f"pearson_{k1}_{k2}"] = pearson
699+
grads_per_task[f"dot_{k1}_{k2}_global"] = dot_global
722700
grads_per_task[f"cos_group_mean_{k1}_{k2}"] = cos_group_mean
723701
grads_per_task[f"cos_group_std_{k1}_{k2}"] = cos_group_std
724702
grads_per_task[f"cos_global_{k1}_{k2}"] = cos_global
725703

726704
log.info(
727-
"Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e "
728-
"(SNR=%.2f, n_groups=%d)",
729-
k1,
730-
k2,
731-
float(dot_mean),
732-
float(dot_std),
733-
float(dot_accum),
734-
abs(dot_mean) / dot_std if dot_std > 0 else float("inf"),
735-
len(dots),
736-
)
737-
log.info(
738-
"Pearson correlation '%s' vs '%s': %.4f "
739-
"(cov_sum=%.4e, var1=%.4e, var2=%.4e)",
740-
k1,
741-
k2,
742-
pearson,
743-
cov_sum,
744-
var1_sum,
745-
var2_sum,
746-
)
747-
log.info(
748-
"Cosine similarity '%s' vs '%s': "
749-
"group_mean=%.4f, group_std=%.4f, global=%.4f",
705+
"Grad similarity '%s' vs '%s': "
706+
"dot_mean=%.4e, dot_std=%.4e, dot_global=%.4e, "
707+
"cos_group_mean=%.4f, cos_group_std=%.4f, cos_global=%.4f "
708+
"(n_groups=%d)",
750709
k1,
751710
k2,
711+
dot_mean,
712+
dot_std,
713+
dot_global,
752714
cos_group_mean,
753715
cos_group_std,
754716
cos_global,
717+
len(dots),
755718
)
756719

757720
save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()}

0 commit comments

Comments
 (0)