2525from torch .distributed .device_mesh import DeviceMesh
2626from torch .distributed .fsdp import MixedPrecisionPolicy , fully_shard
2727from torch .distributed .tensor import Shard
28- from torch .distributed .tensor .parallel import parallelize_module
28+ from torch .distributed .tensor .parallel import (
29+ ColwiseParallel ,
30+ RowwiseParallel ,
31+ parallelize_module ,
32+ )
2933
3034import lmms_engine .parallel .process_group_manager as pgm
3135from lmms_engine .models .aero_realtime .backbone_registry import family_is_moe
@@ -48,32 +52,97 @@ def _ep_style_cls(family: str):
4852 raise ValueError (f"no EP ParallelStyle for backbone_family={ family } " )
4953
5054
55+ _QWEN3_VL_LIKE_TP_PLAN = {
56+ "self_attn.q_proj" : ColwiseParallel (use_local_output = True ),
57+ "self_attn.k_proj" : ColwiseParallel (use_local_output = True ),
58+ "self_attn.v_proj" : ColwiseParallel (use_local_output = True ),
59+ "self_attn.o_proj" : RowwiseParallel (use_local_output = True ),
60+ "mlp.gate_proj" : ColwiseParallel (use_local_output = True ),
61+ "mlp.up_proj" : ColwiseParallel (use_local_output = True ),
62+ "mlp.down_proj" : RowwiseParallel (use_local_output = True ),
63+ }
64+
65+
66+ def _tp_plan_for_family (family : str ):
67+ """Return the per-decoder-layer TP plan for dense backbone families.
68+
69+ MoE families are handled via EP, not TP, so this only covers dense
70+ families that ship a TP plan today (``qwen3_vl``). Add more dense
71+ families here as their TP plans land.
72+ """
73+ if family == "qwen3_vl" :
74+ return _QWEN3_VL_LIKE_TP_PLAN
75+ raise ValueError (f"no TP plan for backbone_family={ family } " )
76+
77+
78+ def _check_divisible (name : str , value : int , degree : int ) -> None :
79+ if value % degree != 0 :
80+ raise ValueError (f"{ name } ({ value } ) must be divisible by tp_degree ({ degree } )" )
81+
82+
83+ def _validate_aero_realtime_tp_config (model , tp_degree : int ) -> None :
84+ if tp_degree <= 1 :
85+ return
86+
87+ family = model .config .backbone_family
88+ if family_is_moe (family ):
89+ raise ValueError (f"tp_degree>1 is not supported for MoE backbone_family={ family } ; use ep_degree instead" )
90+
91+ # Dense families: validate text_config divisibility.
92+ text_config = model .config .text_config
93+ _check_divisible ("hidden_size" , text_config .hidden_size , tp_degree )
94+ _check_divisible ("intermediate_size" , text_config .intermediate_size , tp_degree )
95+ _check_divisible ("num_attention_heads" , text_config .num_attention_heads , tp_degree )
96+ _check_divisible ("num_key_value_heads" , text_config .num_key_value_heads , tp_degree )
97+
98+ sp_degree = pgm .process_group_manager .cp_world_size
99+ local_attention_heads = text_config .num_attention_heads // tp_degree
100+ if sp_degree > 1 and local_attention_heads % sp_degree != 0 :
101+ raise ValueError (
102+ f"num_attention_heads / tp_degree ({ local_attention_heads } ) must be divisible by "
103+ f"sp_ulysses_degree ({ sp_degree } )"
104+ )
105+
106+
51107def apply_aero_realtime_parallel (
52108 model ,
53- ep_mesh : DeviceMesh ,
109+ ep_mesh : DeviceMesh = None ,
54110 tp_mesh : DeviceMesh = None ,
55111 ** kwargs ,
56112):
57- """Apply EP ParallelStyle to each language_model decoder layer's
58- ``mlp.experts``. Only meaningful for MoE backbone families."""
59- assert tp_mesh is None , "Tensor Parallelism is not supported yet for AeroRealtime"
113+ """Apply expert / tensor parallelism to the aero language_model.
60114
115+ - MoE families (``ep_mesh`` required): wrap each decoder layer's
116+ ``mlp.experts`` with the family's ParallelStyle.
117+ - Dense families (``tp_mesh`` required): apply the family's per-layer
118+ TP plan to each decoder layer.
119+ """
61120 family = model .config .backbone_family
62- if not family_is_moe (family ):
63- raise ValueError (f"ep_degree>1 requires an MoE backbone_family; got { family } " )
121+ is_moe = family_is_moe (family )
64122
65- style_cls = _ep_style_cls (family )
66- num_moe_layers = 0
123+ if is_moe :
124+ assert tp_mesh is None , f"tp_mesh not supported for MoE backbone_family={ family } "
125+ assert ep_mesh is not None , "ep_mesh required for MoE backbone family"
126+
127+ style_cls = _ep_style_cls (family )
128+ num_moe_layers = 0
129+ for decoder_layer in model .language_model .layers :
130+ parallelize_module (
131+ decoder_layer .mlp .experts ,
132+ device_mesh = ep_mesh ,
133+ parallelize_plan = style_cls (),
134+ )
135+ num_moe_layers += 1
136+ logger .info (f"Applied { style_cls .__name__ } to { num_moe_layers } aero_realtime MoE layers" )
137+ return
138+
139+ assert ep_mesh is None , f"ep_mesh not supported for dense backbone_family={ family } "
140+ assert tp_mesh is not None , "tp_mesh required for dense backbone family"
141+
142+ tp_plan = _tp_plan_for_family (family )
67143 for decoder_layer in model .language_model .layers :
68- module = decoder_layer .mlp
69- parallelize_module (
70- module .experts ,
71- device_mesh = ep_mesh ,
72- parallelize_plan = style_cls (),
73- )
74- num_moe_layers += 1
75-
76- logger .info (f"Applied { style_cls .__name__ } to { num_moe_layers } aero_realtime MoE layers" )
144+ parallelize_module (decoder_layer , device_mesh = tp_mesh , parallelize_plan = tp_plan )
145+ logger .info (f"Applied { family } text TP to { len (model .language_model .layers )} aero_realtime decoder layers" )
77146
78147
79148def apply_aero_realtime_fsdp2 (
@@ -163,16 +232,21 @@ def apply_aero_realtime_parallelize_fn(
163232
164233 Mirrors the qwen3_5_moe / qwen3_vl_moe two-stage flow:
165234 1. capture ``full_state_dict`` BEFORE parallelization
166- 2. apply EP (if ep_size>1; requires MoE family )
235+ 2. apply EP (MoE families, ep_size>1) or TP (dense families, tp_size>1 )
167236 3. apply FSDP2
168237 4. reload full state dict into the now-sharded model
169238 """
170239 ep_size = pgm .process_group_manager .ep_size
240+ tp_size = pgm .process_group_manager .tp_world_size
241+ _validate_aero_realtime_tp_config (model , tp_size )
171242 full_state_dict = model .state_dict ()
172243
173244 if ep_size > 1 :
174245 ep_mesh = pgm .process_group_manager .device_mesh ["ep" ]
175246 apply_aero_realtime_parallel (model , ep_mesh = ep_mesh , ** kwargs )
247+ elif tp_size > 1 :
248+ tp_mesh = pgm .process_group_manager .device_mesh ["tp" ]
249+ apply_aero_realtime_parallel (model , tp_mesh = tp_mesh , ** kwargs )
176250
177251 apply_aero_realtime_fsdp2 (model , train_args , ** kwargs )
178252 fsdp2_load_full_state_dict (model , full_state_dict )
0 commit comments