@@ -1058,7 +1058,24 @@ def step(_step_id: int, task_key: str = "Default") -> None:
10581058 ** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
10591059 )
10601060 loss .backward ()
1061+ # === Initialize gradient diagnostics variables ===
1062+ total_norm : torch .Tensor | None = None
1063+ pre_clip_named_norms : list [tuple [str , float ]] = []
10611064 if self .gradient_max_norm > 0.0 :
1065+ # Collect per-parameter gradient norms before clipping.
1066+ # NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor,
1067+ # so p.grad.norm() computes the shard-local L2 norm, not the full-parameter
1068+ # norm. Skip per-param collection in this case to avoid misleading values.
1069+ if (
1070+ self .enable_tensorboard
1071+ and self .zero_stage < 2
1072+ and (_step_id % self .tensorboard_freq == 0 or _step_id == 1 )
1073+ ):
1074+ pre_clip_named_norms = [
1075+ (name , p .grad .detach ().norm ().item ())
1076+ for name , p in self .wrapper .named_parameters ()
1077+ if p .grad is not None
1078+ ]
10621079 # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
10631080 total_norm = torch .nn .utils .clip_grad_norm_ (
10641081 self .wrapper .parameters (),
@@ -1410,6 +1427,32 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
14101427 writer .add_scalar (
14111428 f"{ task_key } /{ item } " , more_loss [item ], display_step_id
14121429 )
1430+ # === Gradient diagnostics (pre-clip) ===
1431+ # Only log if total_norm was computed (i.e., not LKF optimizer).
1432+ if self .gradient_max_norm > 0.0 and total_norm is not None :
1433+ writer .add_scalar (
1434+ f"{ task_key } /grad/total_norm" ,
1435+ total_norm .item (),
1436+ display_step_id ,
1437+ )
1438+ # Only log per-parameter norms if list is non-empty.
1439+ if pre_clip_named_norms :
1440+ # Use float32 for histogram to ensure numerical stability
1441+ # when gradients are in lower precision (FP16/BF16).
1442+ norms = torch .tensor (
1443+ [gn for _ , gn in pre_clip_named_norms ],
1444+ dtype = torch .float32 ,
1445+ device = "cpu" ,
1446+ )
1447+ writer .add_histogram (
1448+ f"{ task_key } /grad/param_norms" , norms , display_step_id
1449+ )
1450+ # Log top-10 largest per-parameter gradient norms.
1451+ pre_clip_named_norms .sort (key = lambda x : x [1 ], reverse = True )
1452+ for name , gn in pre_clip_named_norms [:10 ]:
1453+ writer .add_scalar (
1454+ f"{ task_key } /grad_top10/{ name } " , gn , display_step_id
1455+ )
14131456
14141457 self .wrapper .train ()
14151458 self .t0 = time .time ()
0 commit comments