Skip to content

Commit e35e80c

Browse files
authored
[Feature] Support share MTP weights. (#1672)
* Refactor MTP configuration to support weight sharing across layers. Updated MoE and MTPBlock classes to handle shared weights and adjusted layer initialization accordingly. Added share_weights parameter to MTPConfig for better control over layer behavior. * Updated the checkpointing mechanism to ensure shared MTP heads are recomputed as necessary. * resolve review comments
1 parent 714483a commit e35e80c

3 files changed

Lines changed: 27 additions & 8 deletions

File tree

xtuner/v1/model/moe/moe.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,8 @@ def build_mtp_block(self, config: MoEConfig) -> MTPBlock:
855855
else:
856856
raise ValueError(f"Unsupported layer type {layers_type_list[last_layer_idx]}")
857857

858-
for i in range(mtp_config.num_layers):
858+
num_physical_layer = 1 if mtp_config.share_weights else mtp_config.num_layers
859+
for i in range(num_physical_layer):
859860
# Build MoE decoder layer for MTP
860861
decoder_layer = MoEDecoderLayer(
861862
hidden_size=config.hidden_size,
@@ -894,7 +895,7 @@ def build_mtp_block(self, config: MoEConfig) -> MTPBlock:
894895
)
895896
mtp_layers.append(mtp_layer)
896897

897-
return MTPBlock(mtp_layers=mtp_layers)
898+
return MTPBlock(mtp_config=mtp_config, mtp_layers=mtp_layers)
898899

899900
@override
900901
def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple:
@@ -1015,7 +1016,9 @@ def fully_shard(
10151016
# Shard MTP block if it exists
10161017
if self.mtp_block is not None:
10171018
for mtp_idx, mtp_layer in enumerate(self.mtp_block.layers):
1018-
if self._should_recompute(None, mtp_idx=mtp_idx):
1019+
if self._should_recompute(None, mtp_idx=mtp_idx) or (
1020+
self.config.mtp_config is not None and self.config.mtp_config.share_weights
1021+
): # share mtp head must recompute
10191022
mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT)
10201023
self.mtp_block.layers[mtp_idx] = mtp_layer
10211024

@@ -1234,7 +1237,10 @@ def _should_recompute(
12341237
* Global 9 (MTP 2, last layer): no recompute (forced)
12351238
"""
12361239
num_layers = self.config.num_hidden_layers
1237-
mtp_layers = self.config.mtp_config.num_layers if self.config.mtp_config is not None else 0
1240+
if self.config.mtp_config is not None:
1241+
mtp_layers = 1 if self.config.mtp_config.share_weights else self.config.mtp_config.num_layers
1242+
else:
1243+
mtp_layers = 0
12381244
recompute_ratio = self.fsdp_config.recompute_ratio if self.fsdp_config is not None else 0.0
12391245

12401246
total_layers = num_layers + mtp_layers

xtuner/v1/module/mtp/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class MTPConfig(BaseModel):
2020
Args:
2121
num_layers (int): Number of MTP layers (prediction depths). Each layer
2222
predicts tokens at increasing future positions (i+1, i+2, ..., i+D).
23+
share_weights (bool): Whether to share the weights of the MTP layers.
24+
If True, the weights of the MTP layers are shared across all layers.
25+
Default: False.
2326
loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss
2427
is computed as the average of losses across all depths, multiplied by
2528
this factor. Default: 0.1.
@@ -30,6 +33,7 @@ class MTPConfig(BaseModel):
3033
... ...,
3134
... mtp_config=MTPConfig(
3235
... num_layers=2,
36+
... share_weights=True,
3337
... loss_scaling_factor=0.1,
3438
... ),
3539
... )
@@ -38,4 +42,5 @@ class MTPConfig(BaseModel):
3842
model_config = ConfigDict(extra="forbid")
3943

4044
num_layers: Annotated[int, Parameter(group="model")]
45+
share_weights: Annotated[bool, Parameter(group="model")] = False
4146
loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1

xtuner/v1/module/mtp/mtp_block.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from xtuner.v1.data_proto import SequenceContext
99

10+
from .config import MTPConfig
1011
from .mtp_layer import MTPLayer
1112
from .utils import roll_sequence_context
1213

@@ -25,6 +26,7 @@ class MTPBlock(nn.Module):
2526
the predictions of shallower layers.
2627
2728
Args:
29+
mtp_config (MTPConfig): MTP configuration.
2830
mtp_layers (list[MTPLayer]): List of MTP layers. Each layer should be a
2931
fully constructed MTPLayer instance. The number of layers determines
3032
the prediction depth (D).
@@ -43,7 +45,7 @@ class MTPBlock(nn.Module):
4345
... mtp_layers.append(mtp_layer)
4446
>>>
4547
>>> # Create MTP block
46-
>>> mtp_block = MTPBlock(mtp_layers=mtp_layers)
48+
>>> mtp_block = MTPBlock(mtp_config=config, mtp_layers=mtp_layers)
4749
>>>
4850
>>> # Forward pass
4951
>>> outputs = mtp_block(
@@ -58,13 +60,17 @@ class MTPBlock(nn.Module):
5860
>>> # outputs[1]: predictions for i+2
5961
"""
6062

61-
def __init__(self, *, mtp_layers: list[MTPLayer]):
63+
def __init__(self, *, mtp_config: MTPConfig, mtp_layers: list[MTPLayer]):
6264
super().__init__()
6365
if not mtp_layers:
6466
raise ValueError("mtp_layers cannot be empty")
6567

68+
if mtp_config.share_weights and len(mtp_layers) != 1:
69+
raise ValueError(f"share_weights mode requires exactly 1 MTP layer, got {len(mtp_layers)}")
70+
if not mtp_config.share_weights and len(mtp_layers) != mtp_config.num_layers:
71+
raise ValueError(f"Expected {mtp_config.num_layers} MTP layers, but got {len(mtp_layers)}")
72+
self.mtp_config = mtp_config
6673
self.layers = nn.ModuleList(mtp_layers)
67-
self.num_layers = len(mtp_layers)
6874

6975
def forward(
7076
self,
@@ -97,7 +103,9 @@ def forward(
97103
current_hidden_states = hidden_states
98104
current_seq_ctx = seq_ctx
99105

100-
for layer in self.layers:
106+
num_steps = self.mtp_config.num_layers
107+
for step in range(num_steps):
108+
layer = self.layers[0] if self.mtp_config.share_weights else self.layers[step]
101109
# Roll sequence context to get future tokens
102110
# This shifts each packed sequence independently, respecting boundaries
103111
current_seq_ctx = roll_sequence_context(current_seq_ctx, shifts=-1)

0 commit comments

Comments
 (0)