Skip to content

Commit 574f4b2

Browse files
committed
fix: more specific warning
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent 230f581 commit 574f4b2

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

tuning/sft_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,19 @@ def train(
171171
fast_moe_config = None
172172
if fast_moe_config is not None:
173173
# If LoRA with ScatterMoE detected, raise warning
174+
accepted_layers = ["all-linear"]
174175
if (
175176
peft_config is not None
176177
and hasattr(peft_config, "target_modules")
177178
and fast_moe_config.fast_moe is not None
179+
and peft_config.target_modules != accepted_layers
178180
):
179181
logger.warning(
180182
"You are running lora with the ScatterMoE plugin, please note that "
181183
"passing target modules that are part of the moe module can cause unexpected "
182184
"behaviors and unsuccessful tuning while LoRA tuning with ScatterMoE. "
183185
"For safe tuning, only pass linear modules such as those in the attn layer "
184-
"(i.e. ['q_proj', 'v_proj', 'o_proj', 'k_proj']) or pass 'all-linear'"
186+
"(i.e. ['q_proj', 'v_proj', 'o_proj', 'k_proj'])"
185187
)
186188

187189
task_type = "CAUSAL_LM"

0 commit comments

Comments
 (0)