File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -659,20 +659,29 @@ def grad_probe(FLAGS) -> None:
659659 norm_mean = 0.0
660660 norm_std = 0.0
661661
662+ G_sq_norm = float (np .dot (avg_grad_vec , avg_grad_vec ))
663+ norms_arr = np .array (norms )
664+ mean_sq_norm = float (np .mean (norms_arr ** 2 ))
665+ tr_sigma = mean_sq_norm - G_sq_norm
666+ noise_scale = tr_sigma / G_sq_norm if G_sq_norm > 1e-30 else float ("nan" )
667+
662668 grads_per_task [task_key ] = avg_grad_vec
663669 grads_per_task [f"{ task_key } _norm_mean" ] = norm_mean
664670 grads_per_task [f"{ task_key } _norm_std" ] = norm_std
671+ grads_per_task [f"{ task_key } _noise_scale" ] = noise_scale
665672
666673 log .info (
667674 "Task '%s': collected %d batches (groups=%d, k=%d). "
668- "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e" ,
675+ "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e, "
676+ "Noise Scale(group)=%.4e" ,
669677 task_key ,
670678 total_batches ,
671679 total_batches // accumulate_k ,
672680 accumulate_k ,
673681 float (np .linalg .norm (avg_grad_vec )),
674682 float (norm_mean ),
675683 float (norm_std ),
684+ noise_scale ,
676685 )
677686
678687 # Compute pairwise dot product and cosine similarity statistics
You can’t perform that action at this time.
0 commit comments