@@ -548,7 +548,7 @@ def grad_probe(FLAGS) -> None:
548548 param_shapes : list | None = None
549549
550550 accumulate_k = getattr (FLAGS , "accumulate_k" , 1 )
551- total_batches = FLAGS .nbatches * accumulate_k
551+ total_batches = FLAGS .nbatches
552552
553553 # Initialize trackers
554554 task_batch_norms = {k : [] for k in trainer .model_keys }
@@ -610,14 +610,19 @@ def grad_probe(FLAGS) -> None:
610610
611611 # At the end of each group of K batches, record norms and dot products then reset
612612 if (b + 1 ) % accumulate_k == 0 :
613+ # Normalize to per-batch mean so scale is independent of k
614+ group_means = {
615+ k : g / accumulate_k if g is not None else None
616+ for k , g in group_grads .items ()
617+ }
613618 for task_key in trainer .model_keys :
614- if group_grads [task_key ] is not None :
619+ if group_means [task_key ] is not None :
615620 task_batch_norms [task_key ].append (
616- np .linalg .norm (group_grads [task_key ])
621+ np .linalg .norm (group_means [task_key ])
617622 )
618623 for k1 , k2 in pairwise_dots .keys ():
619- g1 = group_grads .get (k1 )
620- g2 = group_grads .get (k2 )
624+ g1 = group_means .get (k1 )
625+ g2 = group_means .get (k2 )
621626 if g1 is not None and g2 is not None :
622627 pairwise_dots [(k1 , k2 )].append (float (np .dot (g1 , g2 )))
623628 group_grads = {k : None for k in trainer .model_keys }
@@ -641,10 +646,10 @@ def grad_probe(FLAGS) -> None:
641646
642647 log .info (
643648 "Task '%s': collected %d batches (groups=%d, k=%d). "
644- "Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e" ,
649+ "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e" ,
645650 task_key ,
646651 total_batches ,
647- FLAGS . nbatches ,
652+ total_batches // accumulate_k ,
648653 accumulate_k ,
649654 float (np .linalg .norm (avg_grad_vec )),
650655 float (norm_mean ),
0 commit comments