Skip to content

Commit 046285f

Browse files
committed
Update the support to TE Grouped Linear
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 4d67c1f commit 046285f

1 file changed

Lines changed: 94 additions & 4 deletions

File tree

modelopt/torch/peft/lora/plugins/megatron.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@
3939

4040
try:
4141
from megatron.core.extensions.transformer_engine import (
42+
TEColumnParallelGroupedLinear,
4243
TEColumnParallelLinear,
44+
TERowParallelGroupedLinear,
4345
TERowParallelLinear,
4446
)
4547

48+
from modelopt.torch.quantization.plugins.megatron import (
49+
_QuantMegatronTEGroupedLinear as QuantTEGroupedLinear,
50+
)
4651
from modelopt.torch.quantization.plugins.megatron import (
4752
_QuantTEMCoreColumnParallelLinear as QuantTEColumnParallelLinear,
4853
)
@@ -51,8 +56,12 @@
5156
)
5257

5358
HAS_TE = True
59+
HAS_TE_GROUPED = (
60+
TEColumnParallelGroupedLinear is not None and TERowParallelGroupedLinear is not None
61+
)
5462
except ImportError:
5563
HAS_TE = False
64+
HAS_TE_GROUPED = False
5665

5766

5867
def megatron_replace_lora_module_hook(model: torch.nn.Module):
@@ -126,6 +135,44 @@ def _register_adapter_with_device(
126135

127136
super()._register_adapter(adapter_name, lora_a, lora_b, rank, scale, enable)
128137

138+
def _call_lora_module(
139+
self, module: nn.Module, x: torch.Tensor, *args, **kwargs
140+
) -> torch.Tensor:
141+
if HAS_TE_GROUPED and isinstance(
142+
module, (TEColumnParallelGroupedLinear, TERowParallelGroupedLinear)
143+
):
144+
output = module(x, *args, **kwargs)
145+
else:
146+
output = module(x)
147+
if isinstance(output, tuple):
148+
return output[0]
149+
return output
150+
151+
def forward(self, x: torch.Tensor, *args, **kwargs):
152+
# Bypass LoRAModule.forward to avoid double-applying adapters.
153+
output = super(LoRAModule, self).forward(x, *args, **kwargs)
154+
155+
if isinstance(output, tuple):
156+
result = output[0]
157+
other_outputs = output[1:]
158+
else:
159+
result = output
160+
other_outputs = ()
161+
162+
for adapter_name in self._lora_adapters:
163+
adapter = self._lora_adapters[adapter_name]
164+
if adapter["enable"]:
165+
lora_a = adapter["lora_a"]
166+
lora_b = adapter["lora_b"]
167+
lora_a_output = self._call_lora_module(lora_a, x, *args, **kwargs)
168+
lora_b_output = self._call_lora_module(lora_b, lora_a_output, *args, **kwargs)
169+
scale = adapter["scale"]
170+
result = result + scale * lora_b_output
171+
172+
if other_outputs:
173+
return (result, *other_outputs)
174+
return result
175+
129176

130177
@LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"})
131178
class _LoRAMegatronColumnParallelLinear(_MegatronParallelLoRABase):
@@ -157,14 +204,26 @@ def update_layer_lora(
157204
with torch.no_grad():
158205
attr_config.lora_a_init(lora_a.weight)
159206

160-
if HAS_TE and isinstance(self, TEColumnParallelLinear):
207+
if HAS_TE_GROUPED and isinstance(self, TEColumnParallelGroupedLinear):
208+
lora_b = TEColumnParallelGroupedLinear(
209+
self.num_gemms,
210+
attr_config.rank,
211+
output_size,
212+
config=self.config,
213+
bias=False,
214+
skip_bias_add=False,
215+
is_expert=True, # TODO (Jingyu)
216+
pg_collection=getattr(self, "_pg_collection", None),
217+
# tp_comm_buffer_name # TODO (Jingyu)
218+
)
219+
elif HAS_TE and isinstance(self, TEColumnParallelLinear):
161220
lora_b = TEColumnParallelLinear(
162221
attr_config.rank,
163222
output_size,
164223
config=self.config,
165224
bias=False,
166225
gather_output=False,
167-
is_expert=False, # TODO (Jingyu): Hard coded to False
226+
is_expert=getattr(self, "is_expert", False), # TODO (Jingyu)
168227
init_method=attr_config.lora_b_init,
169228
)
170229
else:
@@ -227,15 +286,28 @@ def update_layer_lora(
227286
input_size = getattr(self, "input_size", None) or self.in_features
228287
output_size = getattr(self, "output_size", None) or self.out_features
229288

230-
if HAS_TE and isinstance(self, TERowParallelLinear):
289+
if HAS_TE_GROUPED and isinstance(self, TERowParallelGroupedLinear):
290+
lora_a = TERowParallelGroupedLinear(
291+
self.num_gemms,
292+
input_size,
293+
attr_config.rank,
294+
config=self.config,
295+
init_method=attr_config.lora_a_init,
296+
bias=False,
297+
skip_bias_add=True,
298+
is_expert=True, # TODO (Jingyu)
299+
pg_collection=getattr(self, "_pg_collection", None),
300+
# tp_comm_buffer_name? # TODO (Jingyu)
301+
)
302+
elif HAS_TE and isinstance(self, TERowParallelLinear):
231303
lora_a = TERowParallelLinear(
232304
input_size,
233305
attr_config.rank,
234306
config=self.config,
235307
input_is_parallel=True,
236308
skip_bias_add=True,
237309
bias=False,
238-
is_expert=False, # TODO (Jingyu): Hard coded to False
310+
is_expert=getattr(self, "is_expert", False), # TODO (Jingyu)
239311
init_method=attr_config.lora_a_init,
240312
)
241313
else:
@@ -286,13 +358,26 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
286358

287359

288360
if HAS_TE:
361+
# Register TEColumnParallelLinear and TERowParallelLinear with the LoRAModuleRegistry.
362+
# This allows update_model to automatically replace these layers with their
363+
# corresponding LoRA-enabled modules when applying LoRA.
289364
LoRAModuleRegistry.register({TEColumnParallelLinear: "te_mcore_ColumnParallelLinear"})(
290365
_LoRAMegatronColumnParallelLinear
291366
)
292367
LoRAModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"})(
293368
_LoRAMegatronRowParallelLinear
294369
)
295370

371+
if HAS_TE_GROUPED:
372+
# Register TEColumnParallelGroupedLinear and TERowParallelGroupedLinear with the LoRAModuleRegistry.
373+
# This allows update_model to automatically replace these layers with their
374+
# corresponding LoRA-enabled modules when applying LoRA.
375+
LoRAModuleRegistry.register(
376+
{TEColumnParallelGroupedLinear: "te_mcore_ColumnParallelGroupedLinear"}
377+
)(_LoRAMegatronColumnParallelLinear)
378+
LoRAModuleRegistry.register({TERowParallelGroupedLinear: "te_mcore_RowParallelGroupedLinear"})(
379+
_LoRAMegatronRowParallelLinear
380+
)
296381

297382
# Register quantized versions if available
298383
LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})(
@@ -309,6 +394,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
309394
_LoRAMegatronRowParallelLinear
310395
)
311396

397+
if HAS_TE_GROUPED:
398+
LoRAModuleRegistry.register({QuantTEGroupedLinear: "quant_te_mcore_GroupedLinear"})(
399+
_LoRAMegatronColumnParallelLinear
400+
)
401+
312402

313403
class _QuantLoRAMegatronColumnParallelLinear(
314404
_LoRAMegatronColumnParallelLinear, QuantColumnParallelLinear

0 commit comments

Comments
 (0)