Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,36 +2101,37 @@ def get_correct_experts_implementation(self, requested_experts: str | None) -> s

@classmethod
def _can_set_attn_implementation(cls) -> bool:
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
opening the file, but avoids maintaining yet another property flag.
"""Detect whether the class supports setting its attention implementation dynamically. Inspects the module source as a
heuristic, which avoids maintaining yet another property flag.
"""
class_module = sys.modules.get(cls.__module__)
# Missing module entry (e.g. cleared by a test) or custom model in a jupyter notebook / repl -> do not allow to set it
if class_module is None or not hasattr(class_module, "__file__"):
if class_module is None:
return False
class_file = class_module.__file__
with open(class_file, "r", encoding="utf-8") as f:
code = f.read()
# heuristic -> if we find those patterns, the model uses the correct interface
if re.search(r"class \w+Attention\(nn.Module\)", code):
return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
else:
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
return True
try:
code = inspect.getsource(class_module)
except (OSError, TypeError):
return False
# Heuristic: if we find an `*Attention*(nn.Module)` class, check whether the interface is used
if re.search(r"^class \w*Attention\w*\(nn\.Module\):", code, re.MULTILINE):
return "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
return True

@classmethod
def _can_set_experts_implementation(cls) -> bool:
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
opening the file, but avoids maintaining yet another property flag.
"""Detect whether the class supports setting its experts implementation dynamically. Inspects the module source as a
heuristic, which avoids maintaining yet another property flag.
"""
class_module = sys.modules.get(cls.__module__)
# Missing module entry (e.g. cleared by a test) or custom model in a jupyter notebook / repl -> do not allow to set it
if class_module is None or not hasattr(class_module, "__file__"):
if class_module is None:
return False
try:
code = inspect.getsource(class_module)
except (OSError, TypeError):
return False
class_file = class_module.__file__
with open(class_file, "r", encoding="utf-8") as f:
code = f.read()
# heuristic -> if we the use_experts_implementation decorator is used, then we can set it
# Heuristic: if the `@use_experts_implementation` decorator is used, then we can set it
return "@use_experts_implementation" in code

def set_attn_implementation(self, attn_implementation: str | dict, allow_all_kernels: bool = False):
Expand Down
54 changes: 54 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3035,6 +3035,60 @@ def test_attention_and_experts_modules_can_be_used_standalone(self):
with self.assertRaisesRegex(KeyError, "`foobar` is not a valid experts implementation registered"):
_ = experts_module(hidden_states, dummy_indices, dummy_scores)

def test_can_set_attn_returns_false_when_module_missing(self):
# Simulate the "module cleared from sys.modules" case (test cleanup, REPL).
from transformers.models.llama.modeling_llama import LlamaModel

original = sys.modules.pop(LlamaModel.__module__)
try:
self.assertFalse(LlamaModel._can_set_attn_implementation())
self.assertFalse(LlamaModel._can_set_experts_implementation())
finally:
sys.modules[LlamaModel.__module__] = original

def test_can_set_attn_skips_torch_dynamic_wrappers(self):
# Simulate FSDP2's `FSDP<ModelName>`: a dynamic subclass whose __module__ lives under torch.*
from transformers.models.llama.modeling_llama import LlamaModel

FSDPLlamaModel = type("FSDPLlamaModel", (LlamaModel,), {})
FSDPLlamaModel.__module__ = "torch.distributed.fsdp._fully_shard._fsdp_state"

# The MRO walk should skip past FSDPLlamaModel to LlamaModel and find the underlying answer.
self.assertTrue(FSDPLlamaModel._can_set_attn_implementation())

def test_can_set_attn_modern_vs_legacy(self):
# Modern interface model: True. Legacy model (T5 doesn't use ALL_ATTENTION_FUNCTIONS): False.
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.t5.modeling_t5 import T5Model

self.assertTrue(LlamaModel._can_set_attn_implementation())
self.assertFalse(T5Model._can_set_attn_implementation())

def test_can_set_attn_legacy_edge_cases(self):
# FSMT: bare `class Attention(nn.Module):` -- tightened regex catches this case.
from transformers.models.fsmt.modeling_fsmt import FSMTModel

self.assertFalse(FSMTModel._can_set_attn_implementation())

# SLANet: `class SLANetAttentionGRUCell(nn.Module):` -- "Attention" not at end of class name.
from transformers.models.slanet.modeling_slanet import SLANetBackbone

self.assertFalse(SLANetBackbone._can_set_attn_implementation())

# ShieldGemma2: output dataclass with "Attention" in the name but no `nn.Module` parent.
# No actual Attention class -> assume True (multimodal model or inherits from elsewhere).
from transformers.models.shieldgemma2.modeling_shieldgemma2 import ShieldGemma2ForImageClassification

self.assertTrue(ShieldGemma2ForImageClassification._can_set_attn_implementation())

def test_can_set_experts_moe_vs_dense(self):
# MoE model with @use_experts_implementation: True. Non-MoE model: False.
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.mixtral.modeling_mixtral import MixtralModel

self.assertTrue(MixtralModel._can_set_experts_implementation())
self.assertFalse(LlamaModel._can_set_experts_implementation())


@require_torch
class TestTensorSharing(TestCasePlus):
Expand Down
Loading