@@ -990,9 +990,15 @@ def step(_step_id: int, task_key: str = "Default") -> None:
990990 ** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
991991 )
992992 loss .backward ()
993+ # === Initialize gradient diagnostics variables ===
994+ total_norm : torch .Tensor | None = None
995+ pre_clip_named_norms : list [tuple [str , float ]] = []
993996 if self .gradient_max_norm > 0.0 :
994997 # Collect per-parameter gradient norms before clipping.
995- if self .enable_tensorboard :
998+ # NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor,
999+ # so p.grad.norm() computes the shard-local L2 norm, not the full-parameter
1000+ # norm. Skip per-param collection in this case to avoid misleading values.
1001+ if self .enable_tensorboard and self .zero_stage < 2 :
9961002 pre_clip_named_norms = [
9971003 (name , p .grad .detach ().norm ().item ())
9981004 for name , p in self .wrapper .named_parameters ()
@@ -1350,24 +1356,33 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13501356 f"{ task_key } /{ item } " , more_loss [item ], display_step_id
13511357 )
13521358 # === Gradient diagnostics (pre-clip) ===
1353- if self .gradient_max_norm > 0.0 :
1359+ # Only log if total_norm was computed (i.e., not LKF optimizer).
1360+ if self .gradient_max_norm > 0.0 and total_norm is not None :
13541361 writer .add_scalar (
1355- "grad/total_norm" , total_norm .item (), display_step_id
1356- )
1357- norms = torch .tensor (
1358- [gn for _ , gn in pre_clip_named_norms ],
1359- dtype = torch .float32 ,
1360- device = "cpu" ,
1362+ f"{ task_key } /grad/total_norm" ,
1363+ total_norm .item (),
1364+ display_step_id ,
13611365 )
1362- writer .add_histogram ("grad/param_norms" , norms , display_step_id )
1363- # Log top-10 largest per-parameter gradient norms.
1364- # Shorten name: keep everything after "atomic_model.".
1365- pre_clip_named_norms .sort (key = lambda x : x [1 ], reverse = True )
1366- for name , gn in pre_clip_named_norms [:10 ]:
1367- idx = name .find ("atomic_model." )
1368- if idx >= 0 :
1369- name = name [idx + len ("atomic_model." ) :]
1370- writer .add_scalar (f"grad_top10/{ name } " , gn , display_step_id )
1366+ # Only log per-parameter norms if list is non-empty.
1367+ if pre_clip_named_norms :
1368+ norms = torch .tensor (
1369+ [gn for _ , gn in pre_clip_named_norms ],
1370+ dtype = torch .float32 ,
1371+ device = "cpu" ,
1372+ )
1373+ writer .add_histogram (
1374+ f"{ task_key } /grad/param_norms" , norms , display_step_id
1375+ )
1376+ # Log top-10 largest per-parameter gradient norms.
1377+ # Shorten name: keep everything after "atomic_model.".
1378+ pre_clip_named_norms .sort (key = lambda x : x [1 ], reverse = True )
1379+ for name , gn in pre_clip_named_norms [:10 ]:
1380+ idx = name .find ("atomic_model." )
1381+ if idx >= 0 :
1382+ name = name [idx + len ("atomic_model." ) :]
1383+ writer .add_scalar (
1384+ f"{ task_key } /grad_top10/{ name } " , gn , display_step_id
1385+ )
13711386
13721387 self .wrapper .train ()
13731388 self .t0 = time .time ()
0 commit comments