Skip to content

Commit b1235c8

Browse files
authored
Fix Qwen3 SALM LoRA init (#15570)
1 parent 6a0ecb7 commit b1235c8

2 files changed

Lines changed: 67 additions & 1 deletion

File tree

nemo/collections/speechlm2/parts/lora.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from peft import LoraConfig, get_peft_model
1616
from transformers import PreTrainedModel
1717

18+
from nemo.collections.speechlm2.parts.pretrained import move_embedding
1819
from nemo.utils import logging
1920

2021

@@ -25,6 +26,9 @@ def maybe_install_lora(model):
2526
assert hasattr(model, "llm") and isinstance(model.llm, PreTrainedModel)
2627
assert "prevent_freeze_params" in model.cfg and isinstance(model.cfg.prevent_freeze_params, (list, ListConfig))
2728
model.lora_config = LoraConfig(**model.cfg.lora)
28-
model.llm = get_peft_model(model.llm, model.lora_config)
29+
# PEFT inspects get_input_embeddings() while wrapping the model, so temporarily
30+
# restore the embedding layer that SALM keeps outside the LLM for FSDP/TP.
31+
with move_embedding(model):
32+
model.llm = get_peft_model(model.llm, model.lora_config)
2933
model.cfg.prevent_freeze_params.append(r"^.+\.lora_.+$")
3034
logging.info(f"LoRA adapter installed: {model.lora_config}")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from types import SimpleNamespace
15+
16+
from omegaconf import DictConfig
17+
from peft import PeftModel
18+
from transformers import Qwen3Config, Qwen3ForCausalLM
19+
20+
from nemo.collections.speechlm2.parts.lora import maybe_install_lora
21+
22+
23+
def _make_qwen3_stub():
24+
llm = Qwen3ForCausalLM(
25+
Qwen3Config(
26+
vocab_size=32,
27+
hidden_size=16,
28+
intermediate_size=32,
29+
num_hidden_layers=1,
30+
num_attention_heads=2,
31+
num_key_value_heads=2,
32+
max_position_embeddings=32,
33+
)
34+
)
35+
return SimpleNamespace(
36+
cfg=DictConfig(
37+
{
38+
"prevent_freeze_params": [],
39+
"lora": {
40+
"r": 4,
41+
"lora_alpha": 8,
42+
"lora_dropout": 0.0,
43+
"target_modules": ["q_proj", "v_proj"],
44+
"task_type": "CAUSAL_LM",
45+
},
46+
}
47+
),
48+
llm=llm,
49+
embed_tokens=llm.model.embed_tokens,
50+
)
51+
52+
53+
def test_maybe_install_lora_restores_qwen3_input_embeddings_temporarily():
54+
model = _make_qwen3_stub()
55+
del model.llm.model.embed_tokens
56+
57+
maybe_install_lora(model)
58+
59+
assert isinstance(model.llm, PeftModel)
60+
assert hasattr(model.llm.base_model.model.model.layers[0].self_attn.q_proj, "lora_A")
61+
assert model.cfg.prevent_freeze_params == [r"^.+\.lora_.+$"]
62+
assert not hasattr(model.llm.base_model.model.model, "embed_tokens")

0 commit comments

Comments
 (0)