66from mmengine .dist import get_world_size
77from pydantic import BaseModel , ConfigDict , TypeAdapter , model_validator
88from torch import nn
9+ from torch .distributed .tensor import DTensor
910from torch .utils .hooks import RemovableHandle
1011from typing_extensions import TypedDict
1112
1213from xtuner .v1 .model import MoE
1314from xtuner .v1 .model .base import BaseModel as XTunerBaseModel
1415from xtuner .v1 .model .base import ModelItem
1516from xtuner .v1 .module import LMHead , MHAConfig , MLAConfig , MultiHeadAttention , MultiLatentAttention
17+ from xtuner .v1 .module .attention .gated_deltanet import FusedRMSNormGated
1618from xtuner .v1 .module .decoder_layer .dense_decoder_layer import DenseDecoderLayer
1719from xtuner .v1 .module .decoder_layer .moe_decoder_layer import MoEDecoderLayer
1820from xtuner .v1 .utils .device import get_device
3739
3840class InternalMetrics (TypedDict , total = False ):
3941 weight_rms : dict [str , float ]
42+ weight_min : dict [str , float ]
43+ weight_max : dict [str , float ]
4044 maxvio : dict [str , float ]
4145 drop_ratio : dict [str , float ]
4246 router_logits_max : dict [str , float ]
@@ -53,13 +57,15 @@ class InternalMetricsConfig(BaseModel):
5357 internal_metrics_interval : int | None = None
5458 monitor_weights_rms_norm : bool = True
5559 monitor_attn_logits_stats : bool = True
60+ monitor_gdn_stats : bool | None = None # only applies to models w/ GDN
5661 monitor_moe_router_logits_stats : bool | None = None # only applies to MoE models
5762 monitor_moe_load_balance_stats : bool | None = None
5863
5964 @model_validator (mode = "after" )
6065 def post_init (self ):
6166 monitoring_fields = [
6267 self .monitor_weights_rms_norm ,
68+ self .monitor_gdn_stats ,
6369 self .monitor_attn_logits_stats ,
6470 self .monitor_moe_router_logits_stats ,
6571 self .monitor_moe_load_balance_stats ,
@@ -74,7 +80,7 @@ def post_init(self):
7480class InternalMetricsRecorder :
7581 def __init__ (self , internal_metrics_cfg : InternalMetricsConfig , model : XTunerBaseModel ):
7682 self .internal_metrics_cfg = internal_metrics_cfg
77- self .model = model
83+ self .model : XTunerBaseModel = model . language_model if hasattr ( model , "language_model" ) else model # type: ignore[assignment]
7884
7985 self .hooks : list [RemovableHandle ] = []
8086
@@ -91,6 +97,10 @@ def _init_metrics_dict(self) -> InternalMetrics:
9197 if self .internal_metrics_cfg .monitor_weights_rms_norm :
9298 metrics ["weight_rms" ] = {}
9399
100+ if self .internal_metrics_cfg .monitor_gdn_stats :
101+ metrics ["weight_min" ] = {}
102+ metrics ["weight_max" ] = {}
103+
94104 if self .internal_metrics_cfg .monitor_attn_logits_stats :
95105 attn_cfg : MHAConfig | MLAConfig = self .model .config .attention # type: ignore[attr-defined]
96106
@@ -153,6 +163,44 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
153163
154164 self .metrics ["weight_rms" ][layer_name ] = param_rms .item ()
155165
166+ @torch .no_grad ()
167+ def calculate_module_weight_min_max (self , module_or_param : nn .Module | torch .Tensor , layer_name : str ):
168+ """Calculate the min and max of the module's parameters."""
169+ self ._check_closed ()
170+
171+ if "weight_min" not in self .metrics or "weight_max" not in self .metrics :
172+ return
173+
174+ if isinstance (module_or_param , nn .Module ):
175+ all_params = [param .data for param in module_or_param .parameters () if param .requires_grad ]
176+ elif isinstance (module_or_param , torch .Tensor ):
177+ if module_or_param .requires_grad :
178+ all_params = [module_or_param .data ]
179+ else :
180+ all_params = []
181+ else :
182+ raise TypeError (f"module_or_param must be nn.Module or torch.Tensor, got { type (module_or_param )} " )
183+
184+ if not all_params :
185+ return
186+
187+ local_params = []
188+ for param in all_params :
189+ if isinstance (param , DTensor ):
190+ local_params .append (param .to_local ())
191+ else :
192+ local_params .append (param )
193+
194+ local_min = torch .min (torch .stack ([p .min () for p in local_params ]))
195+ local_max = torch .max (torch .stack ([p .max () for p in local_params ]))
196+
197+ if dist .is_initialized () and dist .get_world_size () > 1 :
198+ dist .all_reduce (local_min , op = dist .ReduceOp .MIN )
199+ dist .all_reduce (local_max , op = dist .ReduceOp .MAX )
200+
201+ self .metrics ["weight_min" ][layer_name ] = local_min .item ()
202+ self .metrics ["weight_max" ][layer_name ] = local_max .item ()
203+
156204 def register_attn_output_hook (self , module : nn .Module ):
157205 """Register attention output hook as a forward hook."""
158206 self ._check_closed ()
@@ -179,6 +227,16 @@ def pop_metrics(self, data_batches: list[ModelItem]):
179227 if self .internal_metrics_cfg .monitor_weights_rms_norm and isinstance (module , RMS_NORM_MONITOR_MODULES ):
180228 self .calculate_module_weight_rms (module , self ._clean_module_name (name ), dtype = torch .float32 )
181229
230+ if (
231+ self .internal_metrics_cfg .monitor_gdn_stats
232+ and FusedRMSNormGated is not None
233+ and isinstance (module , FusedRMSNormGated )
234+ ):
235+ self .calculate_module_weight_min_max (module , self ._clean_module_name (name ))
236+
237+ if self .internal_metrics_cfg .monitor_gdn_stats and hasattr (module , "A_log" ):
238+ self .calculate_module_weight_min_max (module .A_log , f"{ self ._clean_module_name (name )} .A_log" )
239+
182240 additional_kwargs = {}
183241 if self .internal_metrics_cfg .monitor_moe_router_logits_stats and isinstance (self .model , MoE ):
184242 # for MoE model, add additional kwargs to return necessary stats
0 commit comments