|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from types import SimpleNamespace |
| 15 | + |
| 16 | +from omegaconf import DictConfig |
| 17 | +from peft import PeftModel |
| 18 | +from transformers import Qwen3Config, Qwen3ForCausalLM |
| 19 | + |
| 20 | +from nemo.collections.speechlm2.parts.lora import maybe_install_lora |
| 21 | + |
| 22 | + |
| 23 | +def _make_qwen3_stub(): |
| 24 | + llm = Qwen3ForCausalLM( |
| 25 | + Qwen3Config( |
| 26 | + vocab_size=32, |
| 27 | + hidden_size=16, |
| 28 | + intermediate_size=32, |
| 29 | + num_hidden_layers=1, |
| 30 | + num_attention_heads=2, |
| 31 | + num_key_value_heads=2, |
| 32 | + max_position_embeddings=32, |
| 33 | + ) |
| 34 | + ) |
| 35 | + return SimpleNamespace( |
| 36 | + cfg=DictConfig( |
| 37 | + { |
| 38 | + "prevent_freeze_params": [], |
| 39 | + "lora": { |
| 40 | + "r": 4, |
| 41 | + "lora_alpha": 8, |
| 42 | + "lora_dropout": 0.0, |
| 43 | + "target_modules": ["q_proj", "v_proj"], |
| 44 | + "task_type": "CAUSAL_LM", |
| 45 | + }, |
| 46 | + } |
| 47 | + ), |
| 48 | + llm=llm, |
| 49 | + embed_tokens=llm.model.embed_tokens, |
| 50 | + ) |
| 51 | + |
| 52 | + |
| 53 | +def test_maybe_install_lora_restores_qwen3_input_embeddings_temporarily(): |
| 54 | + model = _make_qwen3_stub() |
| 55 | + del model.llm.model.embed_tokens |
| 56 | + |
| 57 | + maybe_install_lora(model) |
| 58 | + |
| 59 | + assert isinstance(model.llm, PeftModel) |
| 60 | + assert hasattr(model.llm.base_model.model.model.layers[0].self_attn.q_proj, "lora_A") |
| 61 | + assert model.cfg.prevent_freeze_params == [r"^.+\.lora_.+$"] |
| 62 | + assert not hasattr(model.llm.base_model.model.model, "embed_tokens") |
0 commit comments