Skip to content

Commit 4b7ae4e

Browse files
committed
feat(aero_realtime): TP for dense backbones via family TP plan
1 parent cbb157c commit 4b7ae4e

1 file changed

Lines changed: 93 additions & 19 deletions

File tree

src/lmms_engine/parallel/aero_realtime/parallelize.py

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from torch.distributed.device_mesh import DeviceMesh
2626
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
2727
from torch.distributed.tensor import Shard
28-
from torch.distributed.tensor.parallel import parallelize_module
28+
from torch.distributed.tensor.parallel import (
29+
ColwiseParallel,
30+
RowwiseParallel,
31+
parallelize_module,
32+
)
2933

3034
import lmms_engine.parallel.process_group_manager as pgm
3135
from lmms_engine.models.aero_realtime.backbone_registry import family_is_moe
@@ -48,32 +52,97 @@ def _ep_style_cls(family: str):
4852
raise ValueError(f"no EP ParallelStyle for backbone_family={family}")
4953

5054

55+
_QWEN3_VL_LIKE_TP_PLAN = {
56+
"self_attn.q_proj": ColwiseParallel(use_local_output=True),
57+
"self_attn.k_proj": ColwiseParallel(use_local_output=True),
58+
"self_attn.v_proj": ColwiseParallel(use_local_output=True),
59+
"self_attn.o_proj": RowwiseParallel(use_local_output=True),
60+
"mlp.gate_proj": ColwiseParallel(use_local_output=True),
61+
"mlp.up_proj": ColwiseParallel(use_local_output=True),
62+
"mlp.down_proj": RowwiseParallel(use_local_output=True),
63+
}
64+
65+
66+
def _tp_plan_for_family(family: str):
67+
"""Return the per-decoder-layer TP plan for dense backbone families.
68+
69+
MoE families are handled via EP, not TP, so this only covers dense
70+
families that ship a TP plan today (``qwen3_vl``). Add more dense
71+
families here as their TP plans land.
72+
"""
73+
if family == "qwen3_vl":
74+
return _QWEN3_VL_LIKE_TP_PLAN
75+
raise ValueError(f"no TP plan for backbone_family={family}")
76+
77+
78+
def _check_divisible(name: str, value: int, degree: int) -> None:
79+
if value % degree != 0:
80+
raise ValueError(f"{name} ({value}) must be divisible by tp_degree ({degree})")
81+
82+
83+
def _validate_aero_realtime_tp_config(model, tp_degree: int) -> None:
84+
if tp_degree <= 1:
85+
return
86+
87+
family = model.config.backbone_family
88+
if family_is_moe(family):
89+
raise ValueError(f"tp_degree>1 is not supported for MoE backbone_family={family}; use ep_degree instead")
90+
91+
# Dense families: validate text_config divisibility.
92+
text_config = model.config.text_config
93+
_check_divisible("hidden_size", text_config.hidden_size, tp_degree)
94+
_check_divisible("intermediate_size", text_config.intermediate_size, tp_degree)
95+
_check_divisible("num_attention_heads", text_config.num_attention_heads, tp_degree)
96+
_check_divisible("num_key_value_heads", text_config.num_key_value_heads, tp_degree)
97+
98+
sp_degree = pgm.process_group_manager.cp_world_size
99+
local_attention_heads = text_config.num_attention_heads // tp_degree
100+
if sp_degree > 1 and local_attention_heads % sp_degree != 0:
101+
raise ValueError(
102+
f"num_attention_heads / tp_degree ({local_attention_heads}) must be divisible by "
103+
f"sp_ulysses_degree ({sp_degree})"
104+
)
105+
106+
51107
def apply_aero_realtime_parallel(
52108
model,
53-
ep_mesh: DeviceMesh,
109+
ep_mesh: DeviceMesh = None,
54110
tp_mesh: DeviceMesh = None,
55111
**kwargs,
56112
):
57-
"""Apply EP ParallelStyle to each language_model decoder layer's
58-
``mlp.experts``. Only meaningful for MoE backbone families."""
59-
assert tp_mesh is None, "Tensor Parallelism is not supported yet for AeroRealtime"
113+
"""Apply expert / tensor parallelism to the aero language_model.
60114
115+
- MoE families (``ep_mesh`` required): wrap each decoder layer's
116+
``mlp.experts`` with the family's ParallelStyle.
117+
- Dense families (``tp_mesh`` required): apply the family's per-layer
118+
TP plan to each decoder layer.
119+
"""
61120
family = model.config.backbone_family
62-
if not family_is_moe(family):
63-
raise ValueError(f"ep_degree>1 requires an MoE backbone_family; got {family}")
121+
is_moe = family_is_moe(family)
64122

65-
style_cls = _ep_style_cls(family)
66-
num_moe_layers = 0
123+
if is_moe:
124+
assert tp_mesh is None, f"tp_mesh not supported for MoE backbone_family={family}"
125+
assert ep_mesh is not None, "ep_mesh required for MoE backbone family"
126+
127+
style_cls = _ep_style_cls(family)
128+
num_moe_layers = 0
129+
for decoder_layer in model.language_model.layers:
130+
parallelize_module(
131+
decoder_layer.mlp.experts,
132+
device_mesh=ep_mesh,
133+
parallelize_plan=style_cls(),
134+
)
135+
num_moe_layers += 1
136+
logger.info(f"Applied {style_cls.__name__} to {num_moe_layers} aero_realtime MoE layers")
137+
return
138+
139+
assert ep_mesh is None, f"ep_mesh not supported for dense backbone_family={family}"
140+
assert tp_mesh is not None, "tp_mesh required for dense backbone family"
141+
142+
tp_plan = _tp_plan_for_family(family)
67143
for decoder_layer in model.language_model.layers:
68-
module = decoder_layer.mlp
69-
parallelize_module(
70-
module.experts,
71-
device_mesh=ep_mesh,
72-
parallelize_plan=style_cls(),
73-
)
74-
num_moe_layers += 1
75-
76-
logger.info(f"Applied {style_cls.__name__} to {num_moe_layers} aero_realtime MoE layers")
144+
parallelize_module(decoder_layer, device_mesh=tp_mesh, parallelize_plan=tp_plan)
145+
logger.info(f"Applied {family} text TP to {len(model.language_model.layers)} aero_realtime decoder layers")
77146

78147

79148
def apply_aero_realtime_fsdp2(
@@ -163,16 +232,21 @@ def apply_aero_realtime_parallelize_fn(
163232
164233
Mirrors the qwen3_5_moe / qwen3_vl_moe two-stage flow:
165234
1. capture ``full_state_dict`` BEFORE parallelization
166-
2. apply EP (if ep_size>1; requires MoE family)
235+
2. apply EP (MoE families, ep_size>1) or TP (dense families, tp_size>1)
167236
3. apply FSDP2
168237
4. reload full state dict into the now-sharded model
169238
"""
170239
ep_size = pgm.process_group_manager.ep_size
240+
tp_size = pgm.process_group_manager.tp_world_size
241+
_validate_aero_realtime_tp_config(model, tp_size)
171242
full_state_dict = model.state_dict()
172243

173244
if ep_size > 1:
174245
ep_mesh = pgm.process_group_manager.device_mesh["ep"]
175246
apply_aero_realtime_parallel(model, ep_mesh=ep_mesh, **kwargs)
247+
elif tp_size > 1:
248+
tp_mesh = pgm.process_group_manager.device_mesh["tp"]
249+
apply_aero_realtime_parallel(model, tp_mesh=tp_mesh, **kwargs)
176250

177251
apply_aero_realtime_fsdp2(model, train_args, **kwargs)
178252
fsdp2_load_full_state_dict(model, full_state_dict)

0 commit comments

Comments
 (0)