3939
4040try :
4141 from megatron .core .extensions .transformer_engine import (
42+ TEColumnParallelGroupedLinear ,
4243 TEColumnParallelLinear ,
44+ TERowParallelGroupedLinear ,
4345 TERowParallelLinear ,
4446 )
4547
48+ from modelopt .torch .quantization .plugins .megatron import (
49+ _QuantMegatronTEGroupedLinear as QuantTEGroupedLinear ,
50+ )
4651 from modelopt .torch .quantization .plugins .megatron import (
4752 _QuantTEMCoreColumnParallelLinear as QuantTEColumnParallelLinear ,
4853 )
5156 )
5257
5358 HAS_TE = True
59+ HAS_TE_GROUPED = (
60+ TEColumnParallelGroupedLinear is not None and TERowParallelGroupedLinear is not None
61+ )
5462except ImportError :
5563 HAS_TE = False
64+ HAS_TE_GROUPED = False
5665
5766
5867def megatron_replace_lora_module_hook (model : torch .nn .Module ):
@@ -126,6 +135,44 @@ def _register_adapter_with_device(
126135
127136 super ()._register_adapter (adapter_name , lora_a , lora_b , rank , scale , enable )
128137
138+ def _call_lora_module (
139+ self , module : nn .Module , x : torch .Tensor , * args , ** kwargs
140+ ) -> torch .Tensor :
141+ if HAS_TE_GROUPED and isinstance (
142+ module , (TEColumnParallelGroupedLinear , TERowParallelGroupedLinear )
143+ ):
144+ output = module (x , * args , ** kwargs )
145+ else :
146+ output = module (x )
147+ if isinstance (output , tuple ):
148+ return output [0 ]
149+ return output
150+
151+ def forward (self , x : torch .Tensor , * args , ** kwargs ):
152+ # Bypass LoRAModule.forward to avoid double-applying adapters.
153+ output = super (LoRAModule , self ).forward (x , * args , ** kwargs )
154+
155+ if isinstance (output , tuple ):
156+ result = output [0 ]
157+ other_outputs = output [1 :]
158+ else :
159+ result = output
160+ other_outputs = ()
161+
162+ for adapter_name in self ._lora_adapters :
163+ adapter = self ._lora_adapters [adapter_name ]
164+ if adapter ["enable" ]:
165+ lora_a = adapter ["lora_a" ]
166+ lora_b = adapter ["lora_b" ]
167+ lora_a_output = self ._call_lora_module (lora_a , x , * args , ** kwargs )
168+ lora_b_output = self ._call_lora_module (lora_b , lora_a_output , * args , ** kwargs )
169+ scale = adapter ["scale" ]
170+ result = result + scale * lora_b_output
171+
172+ if other_outputs :
173+ return (result , * other_outputs )
174+ return result
175+
129176
130177@LoRAModuleRegistry .register ({ColumnParallelLinear : "megatron_ColumnParallelLinear" })
131178class _LoRAMegatronColumnParallelLinear (_MegatronParallelLoRABase ):
@@ -157,14 +204,26 @@ def update_layer_lora(
157204 with torch .no_grad ():
158205 attr_config .lora_a_init (lora_a .weight )
159206
160- if HAS_TE and isinstance (self , TEColumnParallelLinear ):
207+ if HAS_TE_GROUPED and isinstance (self , TEColumnParallelGroupedLinear ):
208+ lora_b = TEColumnParallelGroupedLinear (
209+ self .num_gemms ,
210+ attr_config .rank ,
211+ output_size ,
212+ config = self .config ,
213+ bias = False ,
214+ skip_bias_add = False ,
215+ is_expert = True , # TODO (Jingyu)
216+ pg_collection = getattr (self , "_pg_collection" , None ),
217+ # tp_comm_buffer_name # TODO (Jingyu)
218+ )
219+ elif HAS_TE and isinstance (self , TEColumnParallelLinear ):
161220 lora_b = TEColumnParallelLinear (
162221 attr_config .rank ,
163222 output_size ,
164223 config = self .config ,
165224 bias = False ,
166225 gather_output = False ,
167- is_expert = False , # TODO (Jingyu): Hard coded to False
226+ is_expert = getattr ( self , "is_expert" , False ) , # TODO (Jingyu)
168227 init_method = attr_config .lora_b_init ,
169228 )
170229 else :
@@ -227,15 +286,28 @@ def update_layer_lora(
227286 input_size = getattr (self , "input_size" , None ) or self .in_features
228287 output_size = getattr (self , "output_size" , None ) or self .out_features
229288
230- if HAS_TE and isinstance (self , TERowParallelLinear ):
289+ if HAS_TE_GROUPED and isinstance (self , TERowParallelGroupedLinear ):
290+ lora_a = TERowParallelGroupedLinear (
291+ self .num_gemms ,
292+ input_size ,
293+ attr_config .rank ,
294+ config = self .config ,
295+ init_method = attr_config .lora_a_init ,
296+ bias = False ,
297+ skip_bias_add = True ,
298+ is_expert = True , # TODO (Jingyu)
299+ pg_collection = getattr (self , "_pg_collection" , None ),
300+ # tp_comm_buffer_name? # TODO (Jingyu)
301+ )
302+ elif HAS_TE and isinstance (self , TERowParallelLinear ):
231303 lora_a = TERowParallelLinear (
232304 input_size ,
233305 attr_config .rank ,
234306 config = self .config ,
235307 input_is_parallel = True ,
236308 skip_bias_add = True ,
237309 bias = False ,
238- is_expert = False , # TODO (Jingyu): Hard coded to False
310+ is_expert = getattr ( self , "is_expert" , False ) , # TODO (Jingyu)
239311 init_method = attr_config .lora_a_init ,
240312 )
241313 else :
@@ -286,13 +358,26 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
286358
287359
288360if HAS_TE :
361+ # Register TEColumnParallelLinear and TERowParallelLinear with the LoRAModuleRegistry.
362+ # This allows update_model to automatically replace these layers with their
363+ # corresponding LoRA-enabled modules when applying LoRA.
289364 LoRAModuleRegistry .register ({TEColumnParallelLinear : "te_mcore_ColumnParallelLinear" })(
290365 _LoRAMegatronColumnParallelLinear
291366 )
292367 LoRAModuleRegistry .register ({TERowParallelLinear : "te_mcore_RowParallelLinear" })(
293368 _LoRAMegatronRowParallelLinear
294369 )
295370
371+ if HAS_TE_GROUPED :
372+ # Register TEColumnParallelGroupedLinear and TERowParallelGroupedLinear with the LoRAModuleRegistry.
373+ # This allows update_model to automatically replace these layers with their
374+ # corresponding LoRA-enabled modules when applying LoRA.
375+ LoRAModuleRegistry .register (
376+ {TEColumnParallelGroupedLinear : "te_mcore_ColumnParallelGroupedLinear" }
377+ )(_LoRAMegatronColumnParallelLinear )
378+ LoRAModuleRegistry .register ({TERowParallelGroupedLinear : "te_mcore_RowParallelGroupedLinear" })(
379+ _LoRAMegatronRowParallelLinear
380+ )
296381
297382# Register quantized versions if available
298383LoRAModuleRegistry .register ({QuantColumnParallelLinear : "quant_megatron_ColumnParallelLinear" })(
@@ -309,6 +394,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
309394 _LoRAMegatronRowParallelLinear
310395 )
311396
397+ if HAS_TE_GROUPED :
398+ LoRAModuleRegistry .register ({QuantTEGroupedLinear : "quant_te_mcore_GroupedLinear" })(
399+ _LoRAMegatronColumnParallelLinear
400+ )
401+
312402
313403class _QuantLoRAMegatronColumnParallelLinear (
314404 _LoRAMegatronColumnParallelLinear , QuantColumnParallelLinear
0 commit comments