diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index a94fe4f2e2..cc44684d80 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1722,12 +1722,20 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave # For now, empty layers_pattern means any layer pattern is ok if layers_pattern is None or len(layers_pattern) == 0: - layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) + layer_index = re.search(r"(?:^|\.)[^.]*\.(\d+)\.", key) + # Avoid treating expert index as "layer index" in MoE module paths like "...experts....." + if layer_index is not None and key[layer_index.start() :].startswith(".experts."): + layer_index = None else: layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + layer_index = None for pattern in layers_pattern: - layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) - if layer_index is not None: + m = re.search(rf"(?:^|\.){pattern}\.(\d+)\.", key) + # Avoid treating expert index as "layer index" in MoE module paths like "...experts....." + if m is not None and key[m.start() :].startswith(".experts."): + continue + if m is not None: + layer_index = m break if layer_index is None: diff --git a/tests/test_layers_to_transform_moe.py b/tests/test_layers_to_transform_moe.py new file mode 100644 index 0000000000..2423ee0587 --- /dev/null +++ b/tests/test_layers_to_transform_moe.py @@ -0,0 +1,45 @@ +from torch import nn + +from peft import LoraConfig, get_peft_model + + +def test_layers_to_transform_filters_by_layer_not_expert_index(): + class ToyMoEBlock(nn.Module): + def __init__(self): + super().__init__() + self.self_attn = nn.Module() + self.self_attn.q_proj = nn.Linear(4, 4, bias=False) + + self.mlp = nn.Module() + self.mlp.experts = nn.ModuleList([nn.Module() for _ in range(2)]) + for e in range(2): + self.mlp.experts[e].up_proj = nn.Linear(4, 4, bias=False) + + def forward(self, x): + return x + + class ToyMoEModel(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Module() + self.model.layers = nn.ModuleList([ToyMoEBlock() for _ in range(4)]) + + def forward(self, x): + return x + + model = ToyMoEModel() + + config = LoraConfig( + target_modules=["q_proj", "up_proj"], + # layers_pattern="layers", + layers_to_transform=[1], + r=2, + lora_alpha=4, + ) + model = get_peft_model(model, config) + targeted = set(model.targeted_module_names) + + assert "model.layers.1.self_attn.q_proj" in targeted + assert "model.layers.1.mlp.experts.0.up_proj" in targeted + assert "model.layers.1.mlp.experts.1.up_proj" in targeted + assert "model.layers.2.mlp.experts.1.up_proj" not in targeted # must not match by expert index