@@ -991,6 +991,13 @@ def step(_step_id: int, task_key: str = "Default") -> None:
991991 )
992992 loss .backward ()
993993 if self .gradient_max_norm > 0.0 :
994+ # Collect per-parameter gradient norms before clipping.
995+ if self .enable_tensorboard :
996+ pre_clip_named_norms = [
997+ (name , p .grad .detach ().norm ().item ())
998+ for name , p in self .wrapper .named_parameters ()
999+ if p .grad is not None
1000+ ]
9941001 # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
9951002 total_norm = torch .nn .utils .clip_grad_norm_ (
9961003 self .wrapper .parameters (),
@@ -1342,6 +1349,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13421349 writer .add_scalar (
13431350 f"{ task_key } /{ item } " , more_loss [item ], display_step_id
13441351 )
1352+ # === Gradient diagnostics (pre-clip) ===
1353+ if self .gradient_max_norm > 0.0 :
1354+ writer .add_scalar (
1355+ "grad/total_norm" , total_norm .item (), display_step_id
1356+ )
1357+ norms = torch .tensor (
1358+ [gn for _ , gn in pre_clip_named_norms ],
1359+ dtype = torch .float32 ,
1360+ device = "cpu" ,
1361+ )
1362+ writer .add_histogram ("grad/param_norms" , norms , display_step_id )
1363+ # Log top-10 largest per-parameter gradient norms.
1364+ # Shorten name: keep everything after "atomic_model.".
1365+ pre_clip_named_norms .sort (key = lambda x : x [1 ], reverse = True )
1366+ for name , gn in pre_clip_named_norms [:10 ]:
1367+ idx = name .find ("atomic_model." )
1368+ if idx >= 0 :
1369+ name = name [idx + len ("atomic_model." ) :]
1370+ writer .add_scalar (f"grad_top10/{ name } " , gn , display_step_id )
13451371
13461372 self .wrapper .train ()
13471373 self .t0 = time .time ()
0 commit comments