Skip to content

Commit baa4675

Browse files
authored
[Enhance] Support internal metrics for gdn A_log and norm & vl model in general (#1615)
* [Enhance] Support internal metrics for gdn A_log and norm & vl model in general * [Enhance][internal metrics] address claude code review suggestions
1 parent 090dde0 commit baa4675

1 file changed

Lines changed: 59 additions & 1 deletion

File tree

xtuner/v1/utils/internal_metrics.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from mmengine.dist import get_world_size
77
from pydantic import BaseModel, ConfigDict, TypeAdapter, model_validator
88
from torch import nn
9+
from torch.distributed.tensor import DTensor
910
from torch.utils.hooks import RemovableHandle
1011
from typing_extensions import TypedDict
1112

1213
from xtuner.v1.model import MoE
1314
from xtuner.v1.model.base import BaseModel as XTunerBaseModel
1415
from xtuner.v1.model.base import ModelItem
1516
from xtuner.v1.module import LMHead, MHAConfig, MLAConfig, MultiHeadAttention, MultiLatentAttention
17+
from xtuner.v1.module.attention.gated_deltanet import FusedRMSNormGated
1618
from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer
1719
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEDecoderLayer
1820
from xtuner.v1.utils.device import get_device
@@ -37,6 +39,8 @@
3739

3840
class InternalMetrics(TypedDict, total=False):
3941
weight_rms: dict[str, float]
42+
weight_min: dict[str, float]
43+
weight_max: dict[str, float]
4044
maxvio: dict[str, float]
4145
drop_ratio: dict[str, float]
4246
router_logits_max: dict[str, float]
@@ -53,13 +57,15 @@ class InternalMetricsConfig(BaseModel):
5357
internal_metrics_interval: int | None = None
5458
monitor_weights_rms_norm: bool = True
5559
monitor_attn_logits_stats: bool = True
60+
monitor_gdn_stats: bool | None = None # only applies to models w/ GDN
5661
monitor_moe_router_logits_stats: bool | None = None # only applies to MoE models
5762
monitor_moe_load_balance_stats: bool | None = None
5863

5964
@model_validator(mode="after")
6065
def post_init(self):
6166
monitoring_fields = [
6267
self.monitor_weights_rms_norm,
68+
self.monitor_gdn_stats,
6369
self.monitor_attn_logits_stats,
6470
self.monitor_moe_router_logits_stats,
6571
self.monitor_moe_load_balance_stats,
@@ -74,7 +80,7 @@ def post_init(self):
7480
class InternalMetricsRecorder:
7581
def __init__(self, internal_metrics_cfg: InternalMetricsConfig, model: XTunerBaseModel):
7682
self.internal_metrics_cfg = internal_metrics_cfg
77-
self.model = model
83+
self.model: XTunerBaseModel = model.language_model if hasattr(model, "language_model") else model # type: ignore[assignment]
7884

7985
self.hooks: list[RemovableHandle] = []
8086

@@ -91,6 +97,10 @@ def _init_metrics_dict(self) -> InternalMetrics:
9197
if self.internal_metrics_cfg.monitor_weights_rms_norm:
9298
metrics["weight_rms"] = {}
9399

100+
if self.internal_metrics_cfg.monitor_gdn_stats:
101+
metrics["weight_min"] = {}
102+
metrics["weight_max"] = {}
103+
94104
if self.internal_metrics_cfg.monitor_attn_logits_stats:
95105
attn_cfg: MHAConfig | MLAConfig = self.model.config.attention # type: ignore[attr-defined]
96106

@@ -153,6 +163,44 @@ def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype:
153163

154164
self.metrics["weight_rms"][layer_name] = param_rms.item()
155165

166+
@torch.no_grad()
167+
def calculate_module_weight_min_max(self, module_or_param: nn.Module | torch.Tensor, layer_name: str):
168+
"""Calculate the min and max of the module's parameters."""
169+
self._check_closed()
170+
171+
if "weight_min" not in self.metrics or "weight_max" not in self.metrics:
172+
return
173+
174+
if isinstance(module_or_param, nn.Module):
175+
all_params = [param.data for param in module_or_param.parameters() if param.requires_grad]
176+
elif isinstance(module_or_param, torch.Tensor):
177+
if module_or_param.requires_grad:
178+
all_params = [module_or_param.data]
179+
else:
180+
all_params = []
181+
else:
182+
raise TypeError(f"module_or_param must be nn.Module or torch.Tensor, got {type(module_or_param)}")
183+
184+
if not all_params:
185+
return
186+
187+
local_params = []
188+
for param in all_params:
189+
if isinstance(param, DTensor):
190+
local_params.append(param.to_local())
191+
else:
192+
local_params.append(param)
193+
194+
local_min = torch.min(torch.stack([p.min() for p in local_params]))
195+
local_max = torch.max(torch.stack([p.max() for p in local_params]))
196+
197+
if dist.is_initialized() and dist.get_world_size() > 1:
198+
dist.all_reduce(local_min, op=dist.ReduceOp.MIN)
199+
dist.all_reduce(local_max, op=dist.ReduceOp.MAX)
200+
201+
self.metrics["weight_min"][layer_name] = local_min.item()
202+
self.metrics["weight_max"][layer_name] = local_max.item()
203+
156204
def register_attn_output_hook(self, module: nn.Module):
157205
"""Register attention output hook as a forward hook."""
158206
self._check_closed()
@@ -179,6 +227,16 @@ def pop_metrics(self, data_batches: list[ModelItem]):
179227
if self.internal_metrics_cfg.monitor_weights_rms_norm and isinstance(module, RMS_NORM_MONITOR_MODULES):
180228
self.calculate_module_weight_rms(module, self._clean_module_name(name), dtype=torch.float32)
181229

230+
if (
231+
self.internal_metrics_cfg.monitor_gdn_stats
232+
and FusedRMSNormGated is not None
233+
and isinstance(module, FusedRMSNormGated)
234+
):
235+
self.calculate_module_weight_min_max(module, self._clean_module_name(name))
236+
237+
if self.internal_metrics_cfg.monitor_gdn_stats and hasattr(module, "A_log"):
238+
self.calculate_module_weight_min_max(module.A_log, f"{self._clean_module_name(name)}.A_log")
239+
182240
additional_kwargs = {}
183241
if self.internal_metrics_cfg.monitor_moe_router_logits_stats and isinstance(self.model, MoE):
184242
# for MoE model, add additional kwargs to return necessary stats

0 commit comments

Comments
 (0)