We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d7c2d15 commit 0f7796eCopy full SHA for 0f7796e
1 file changed
tuning/sft_trainer.py
@@ -397,6 +397,7 @@ def train(
397
# For LoRa ScatterMoE, disable grad for ScatterMoE
398
if peft_config is not None:
399
for module in model.modules():
400
+ # Use string comparison to check if ScatterMoE module
401
if module.__class__.__name__ == "ScatterMoE":
402
for param in module.parameters():
403
param.requires_grad = False
0 commit comments