@@ -656,102 +656,65 @@ def grad_probe(FLAGS) -> None:
656656 float (norm_std ),
657657 )
658658
659- # Compute pairwise dot product and Pearson correlation statistics
659+ # Compute pairwise dot product and cosine similarity statistics
660660 for (k1 , k2 ), dots in pairwise_dots .items ():
661661 if dots :
662- dot_mean = np .mean (dots )
663- dot_std = np .std (dots )
662+ dot_mean = float ( np .mean (dots ) )
663+ dot_std = float ( np .std (dots ) )
664664
665665 g1_avg = grads_per_task [k1 ]
666666 g2_avg = grads_per_task [k2 ]
667- dot_accum = float (np .dot (g1_avg , g2_avg ))
667+ dot_global = float (np .dot (g1_avg , g2_avg ))
668668
669- # Pearson correlation across groups:
670- # cov_sum = Σ_j Cov(g1_j, g2_j) = E[dot(g1,g2)] - dot(ḡ1, ḡ2)
671- # var_sum = Σ_j Var(g_j) = E[||g||²] - ||ḡ||²
672- # pearson = cov_sum / sqrt(var1_sum * var2_sum)
673- cov_sum = float (dot_mean - dot_accum )
669+ # Group-level cosine: per-group cos_i = dot_i / (||g1_i|| * ||g2_i||)
674670 norms_1 = np .array (task_batch_norms [k1 ])
675671 norms_2 = np .array (task_batch_norms [k2 ])
676- var1_sum = float (np .mean (norms_1 ** 2 ) - np .dot (g1_avg , g1_avg ))
677- var2_sum = float (np .mean (norms_2 ** 2 ) - np .dot (g2_avg , g2_avg ))
678- denom = np .sqrt (max (var1_sum , 0.0 ) * max (var2_sum , 0.0 ))
679- pearson = float (cov_sum / denom ) if denom > 1e-12 else float ("nan" )
680-
681- # Group-level weighted cosine similarity:
682- # per-group cos_i = dot_i / (||g1_i|| * ||g2_i||)
683- # weight w_i = ||g1_i|| * ||g2_i|| so high-norm groups dominate
684- norm_products = norms_1 * norms_2 # (n_groups,) weight per group
672+ norm_products = norms_1 * norms_2
685673 valid = norm_products > 1e-12
686674 cos_per_group = np .where (
687675 valid ,
688676 np .array (dots ) / np .where (valid , norm_products , 1.0 ),
689- 0.0 ,
677+ float ( "nan" ) ,
690678 )
691- norm_product_sum = float (np .sum (norm_products ))
692- if norm_product_sum > 1e-12 :
693- cos_group_mean = float (
694- np .sum (norm_products * cos_per_group ) / norm_product_sum
695- )
696- cos_group_std = float (
697- np .sqrt (
698- np .sum (norm_products * (cos_per_group - cos_group_mean ) ** 2 )
699- / norm_product_sum
700- )
701- )
679+ valid_cos = cos_per_group [valid ]
680+ if len (valid_cos ) > 0 :
681+ cos_group_mean = float (np .mean (valid_cos ))
682+ cos_group_std = float (np .std (valid_cos ))
702683 else :
703684 cos_group_mean = float ("nan" )
704685 cos_group_std = float ("nan" )
705686
706- # Global-level cosine similarity:
707- # cosine of accumulated average gradient vectors;
708- # large-norm batches naturally dominate the direction of the average
687+ # Global-level cosine: cosine of accumulated average gradient vectors
709688 norm_g1 = float (np .linalg .norm (g1_avg ))
710689 norm_g2 = float (np .linalg .norm (g2_avg ))
711690 denom_global = norm_g1 * norm_g2
712691 cos_global = (
713- float (dot_accum / denom_global )
692+ float (dot_global / denom_global )
714693 if denom_global > 1e-12
715694 else float ("nan" )
716695 )
717696
718697 grads_per_task [f"dot_{ k1 } _{ k2 } _mean" ] = dot_mean
719698 grads_per_task [f"dot_{ k1 } _{ k2 } _std" ] = dot_std
720- grads_per_task [f"dot_{ k1 } _{ k2 } _accum" ] = dot_accum
721- grads_per_task [f"pearson_{ k1 } _{ k2 } " ] = pearson
699+ grads_per_task [f"dot_{ k1 } _{ k2 } _global" ] = dot_global
722700 grads_per_task [f"cos_group_mean_{ k1 } _{ k2 } " ] = cos_group_mean
723701 grads_per_task [f"cos_group_std_{ k1 } _{ k2 } " ] = cos_group_std
724702 grads_per_task [f"cos_global_{ k1 } _{ k2 } " ] = cos_global
725703
726704 log .info (
727- "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e "
728- "(SNR=%.2f, n_groups=%d)" ,
729- k1 ,
730- k2 ,
731- float (dot_mean ),
732- float (dot_std ),
733- float (dot_accum ),
734- abs (dot_mean ) / dot_std if dot_std > 0 else float ("inf" ),
735- len (dots ),
736- )
737- log .info (
738- "Pearson correlation '%s' vs '%s': %.4f "
739- "(cov_sum=%.4e, var1=%.4e, var2=%.4e)" ,
740- k1 ,
741- k2 ,
742- pearson ,
743- cov_sum ,
744- var1_sum ,
745- var2_sum ,
746- )
747- log .info (
748- "Cosine similarity '%s' vs '%s': "
749- "group_mean=%.4f, group_std=%.4f, global=%.4f" ,
705+ "Grad similarity '%s' vs '%s': "
706+ "dot_mean=%.4e, dot_std=%.4e, dot_global=%.4e, "
707+ "cos_group_mean=%.4f, cos_group_std=%.4f, cos_global=%.4f "
708+ "(n_groups=%d)" ,
750709 k1 ,
751710 k2 ,
711+ dot_mean ,
712+ dot_std ,
713+ dot_global ,
752714 cos_group_mean ,
753715 cos_group_std ,
754716 cos_global ,
717+ len (dots ),
755718 )
756719
757720 save_dict = {f"grads_{ k } " : v for k , v in grads_per_task .items ()}
0 commit comments