Skip to content

Commit 699444f

Browse files
committed
fix: exclude router module
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent ec1211e commit 699444f

1 file changed

Lines changed: 3 additions & 18 deletions

File tree

tuning/sft_trainer.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -170,29 +170,14 @@ def train(
170170
if fast_moe_config is not None and fast_moe_config.fast_moe is None:
171171
fast_moe_config = None
172172
if fast_moe_config is not None:
173-
# Checking for unsupported modules with Scatter MoE for LoRA
174-
# Only raise an error for `all-linear`
175-
restricted_modules = ["all-linear"]
176-
if (
177-
peft_config is not None
178-
and hasattr(peft_config, "target_modules")
179-
and any(
180-
module in (peft_config.target_modules or [])
181-
for module in restricted_modules
182-
)
183-
and fast_moe_config.fast_moe is not None
184-
):
185-
raise ValueError(
186-
"`--fast_moe` with LoRA does not currently support `all-linear`, as "
187-
"target modules at this time. Please explicitly specify target "
188-
"modules when using `--fast_moe` with LoRA."
189-
)
190-
# If other common non-linear modules, raise warning
173+
# If LoRA with ScatterMoE detected, raise warning
174+
# and exclude router module
191175
if (
192176
peft_config is not None
193177
and hasattr(peft_config, "target_modules")
194178
and fast_moe_config.fast_moe is not None
195179
):
180+
peft_config.exclude_modules = "router"
196181
logger.warning(
197182
"You are running lora with the ScatterMoE plugin, please note that "
198183
"passing target modules that are part of the moe module can cause unexpected "

0 commit comments

Comments
 (0)