@@ -170,29 +170,14 @@ def train(
170170 if fast_moe_config is not None and fast_moe_config .fast_moe is None :
171171 fast_moe_config = None
172172 if fast_moe_config is not None :
173- # Checking for unsupported modules with Scatter MoE for LoRA
174- # Only raise an error for `all-linear`
175- restricted_modules = ["all-linear" ]
176- if (
177- peft_config is not None
178- and hasattr (peft_config , "target_modules" )
179- and any (
180- module in (peft_config .target_modules or [])
181- for module in restricted_modules
182- )
183- and fast_moe_config .fast_moe is not None
184- ):
185- raise ValueError (
186- "`--fast_moe` with LoRA does not currently support `all-linear`, as "
187- "target modules at this time. Please explicitly specify target "
188- "modules when using `--fast_moe` with LoRA."
189- )
190- # If other common non-linear modules, raise warning
173+ # If LoRA with ScatterMoE detected, raise warning
174+ # and exclude router module
191175 if (
192176 peft_config is not None
193177 and hasattr (peft_config , "target_modules" )
194178 and fast_moe_config .fast_moe is not None
195179 ):
180+ peft_config .exclude_modules = "router"
196181 logger .warning (
197182 "You are running lora with the ScatterMoE plugin, please note that "
198183 "passing target modules that are part of the moe module can cause unexpected "
0 commit comments