@@ -547,20 +547,25 @@ def grad_probe(FLAGS) -> None:
547547 param_names : list | None = None
548548 param_shapes : list | None = None
549549
550+ accumulate_k = getattr (FLAGS , "accumulate_k" , 1 )
551+ total_batches = FLAGS .nbatches * accumulate_k
552+
550553 # Initialize trackers
551554 task_batch_norms = {k : [] for k in trainer .model_keys }
552555 task_accum_grads = {k : None for k in trainer .model_keys }
553- pairwise_sims = {
556+ pairwise_dots = {
554557 (k1 , k2 ): []
555558 for i , k1 in enumerate (trainer .model_keys )
556559 for j , k2 in enumerate (trainer .model_keys )
557560 if i < j
558561 }
559562
563+ # Accumulators within each group of K batches
564+ group_grads = {k : None for k in trainer .model_keys }
565+
560566 cur_lr = config ["learning_rate" ]["start_lr" ]
561567
562- for b in range (FLAGS .nbatches ):
563- batch_grads = {}
568+ for b in range (total_batches ):
564569 for task_key in trainer .model_keys :
565570 trainer .optimizer .zero_grad (set_to_none = True )
566571 input_dict , label_dict , _ = trainer .get_data (
@@ -574,7 +579,6 @@ def grad_probe(FLAGS) -> None:
574579 model = module .model [task_key ]
575580 descriptor = model .get_descriptor ()
576581
577- # Extract current batch gradient
578582 current_grads = []
579583 for name , param in descriptor .named_parameters ():
580584 if param .grad is not None :
@@ -584,15 +588,17 @@ def grad_probe(FLAGS) -> None:
584588
585589 if current_grads :
586590 grad_vec = np .concatenate (current_grads )
587- task_batch_norms [task_key ].append (np .linalg .norm (grad_vec ))
588591 if task_accum_grads [task_key ] is None :
589- task_accum_grads [task_key ] = grad_vec
592+ task_accum_grads [task_key ] = grad_vec . copy ()
590593 else :
591594 task_accum_grads [task_key ] += grad_vec
592- batch_grads [task_key ] = grad_vec
595+
596+ if group_grads [task_key ] is None :
597+ group_grads [task_key ] = grad_vec .copy ()
598+ else :
599+ group_grads [task_key ] += grad_vec
593600
594601 if param_names is None :
595- # get names and shapes from descriptor parameters
596602 param_names = [
597603 n for n , p in descriptor .named_parameters () if p .requires_grad
598604 ]
@@ -602,19 +608,26 @@ def grad_probe(FLAGS) -> None:
602608 if p .requires_grad
603609 ]
604610
605- # Compute pairwise dot products for this batch
606- for k1 , k2 in pairwise_sims .keys ():
607- if k1 in batch_grads and k2 in batch_grads :
608- g1 = batch_grads [k1 ]
609- g2 = batch_grads [k2 ]
610- pairwise_sims [(k1 , k2 )].append (float (np .dot (g1 , g2 )))
611+ # At the end of each group of K batches, record norms and dot products then reset
612+ if (b + 1 ) % accumulate_k == 0 :
613+ for task_key in trainer .model_keys :
614+ if group_grads [task_key ] is not None :
615+ task_batch_norms [task_key ].append (
616+ np .linalg .norm (group_grads [task_key ])
617+ )
618+ for k1 , k2 in pairwise_dots .keys ():
619+ g1 = group_grads .get (k1 )
620+ g2 = group_grads .get (k2 )
621+ if g1 is not None and g2 is not None :
622+ pairwise_dots [(k1 , k2 )].append (float (np .dot (g1 , g2 )))
623+ group_grads = {k : None for k in trainer .model_keys }
611624
612625 # Compute final statistics and log for each task
613626 for task_key in trainer .model_keys :
614627 accumulated_grads = task_accum_grads [task_key ]
615628 norms = task_batch_norms [task_key ]
616629 if accumulated_grads is not None :
617- avg_grad_vec = accumulated_grads / FLAGS . nbatches
630+ avg_grad_vec = accumulated_grads / total_batches
618631 norm_mean = np .mean (norms )
619632 norm_std = np .std (norms )
620633 else :
@@ -627,16 +640,19 @@ def grad_probe(FLAGS) -> None:
627640 grads_per_task [f"{ task_key } _norm_std" ] = norm_std
628641
629642 log .info (
630- "Task '%s': collected %d batches. Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e" ,
643+ "Task '%s': collected %d batches (groups=%d, k=%d). "
644+ "Avg Grad Norm=%.4e, Mean(Batch Norms)=%.4e, Std(Batch Norms)=%.4e" ,
631645 task_key ,
632- len (norms ),
646+ total_batches ,
647+ FLAGS .nbatches ,
648+ accumulate_k ,
633649 float (np .linalg .norm (avg_grad_vec )),
634650 float (norm_mean ),
635651 float (norm_std ),
636652 )
637653
638654 # Compute pairwise dot product statistics
639- for (k1 , k2 ), dots in pairwise_sims .items ():
655+ for (k1 , k2 ), dots in pairwise_dots .items ():
640656 if dots :
641657 dot_mean = np .mean (dots )
642658 dot_std = np .std (dots )
@@ -651,12 +667,15 @@ def grad_probe(FLAGS) -> None:
651667 grads_per_task [f"dot_{ k1 } _{ k2 } _accum" ] = dot_accum
652668
653669 log .info (
654- "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e" ,
670+ "Grad dot product '%s' vs '%s': Mean=%.4e, Std=%.4e, Accum=%.4e "
671+ "(SNR=%.2f, n_groups=%d)" ,
655672 k1 ,
656673 k2 ,
657674 float (dot_mean ),
658675 float (dot_std ),
659676 float (dot_accum ),
677+ abs (dot_mean ) / dot_std if dot_std > 0 else float ("inf" ),
678+ len (dots ),
660679 )
661680
662681 save_dict = {f"grads_{ k } " : v for k , v in grads_per_task .items ()}
0 commit comments