Skip to content

Commit 2307425

Browse files
committed
fix(attention): Another fix attempt for dynamo issues with torch compile and fa4.
1 parent 807930e commit 2307425

2 files changed

Lines changed: 26 additions & 11 deletions

File tree

src/modalities/models/model_factory.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,18 @@ class ModelFactory:
6363
"""Model factory class to create models."""
6464

6565
@staticmethod
66-
def _requires_graph_break_friendly_compile(module: nn.Module) -> bool:
66+
def _requires_eager_execution(module: nn.Module) -> bool:
6767
if isinstance(module, GPT2Block):
6868
return module.attn.attention_impl == AttentionImplementation.DAO_FLASH_V4
6969

7070
attention_impl = getattr(module, "attention_impl", None)
7171
return attention_impl == AttentionImplementation.DAO_FLASH_V4
7272

73+
# TODO remove?
74+
# @staticmethod
75+
# def _requires_graph_break_friendly_compile(module: nn.Module) -> bool:
76+
# return ModelFactory._requires_eager_execution(module)
77+
7378
@staticmethod
7479
def _is_model_on_meta_device(model: nn.Module) -> bool:
7580
"""
@@ -410,15 +415,24 @@ def get_parent_module_and_child_name(child_module: nn.Module, model: nn.Module)
410415

411416
for _, module in model.named_modules():
412417
if isinstance(module, block_types):
413-
options = {"trace.enabled": True} if debug else {}
414-
compiled_fullgraph = fullgraph
415-
if compiled_fullgraph and ModelFactory._requires_graph_break_friendly_compile(module):
416-
compiled_fullgraph = False
418+
if ModelFactory._requires_eager_execution(module):
417419
logger.warning(
418-
"Disabling `fullgraph=True` for `%s` because FlashAttention-4 currently graph-breaks under "
419-
"torch.compile when tracing into flash_attn.cute internals.",
420+
"Skipping `torch.compile` for `%s` because FlashAttention-4 currently graph-breaks under "
421+
"TorchDynamo when tracing into flash_attn.cute internals.",
420422
module.__class__.__name__,
421423
)
424+
continue
425+
426+
options = {"trace.enabled": True} if debug else {}
427+
compiled_fullgraph = fullgraph
428+
# TODO remove?
429+
# if compiled_fullgraph and ModelFactory._requires_graph_break_friendly_compile(module):
430+
# compiled_fullgraph = False
431+
# logger.warning(
432+
# "Disabling `fullgraph=True` for `%s` because FlashAttention-4 currently graph-breaks under "
433+
# "torch.compile when tracing into flash_attn.cute internals.",
434+
# module.__class__.__name__,
435+
# )
422436

423437
compiled_module = torch.compile(module, fullgraph=compiled_fullgraph, options=options)
424438
parent_module, child_name = get_parent_module_and_child_name(child_module=module, model=model)

tests/test_torch_compile.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,19 @@ def test_get_compiled_model_empty_block_names(gpt2_model: GPT2LLM) -> None:
106106

107107

108108
@pytest.mark.skipif(not is_flash_attn_v4_available(), reason="FA4 not installed")
109-
def test_get_compiled_model_disables_fullgraph_for_fa4(monkeypatch: MonkeyPatch, gpt2_model: GPT2LLM) -> None:
110-
recorded_fullgraph_values: list[bool] = []
109+
def test_get_compiled_model_skips_compile_for_fa4(monkeypatch: MonkeyPatch, gpt2_model: GPT2LLM) -> None:
110+
compile_call_count = 0
111111

112112
for block in gpt2_model.transformer.h.values():
113113
block.attn.attention_impl = AttentionImplementation.DAO_FLASH_V4
114114

115115
def fake_compile(module: nn.Module, fullgraph: bool, options: dict[str, object]) -> nn.Module:
116-
recorded_fullgraph_values.append(fullgraph)
116+
nonlocal compile_call_count
117+
compile_call_count += 1
117118
return module
118119

119120
monkeypatch.setattr(torch, "compile", fake_compile)
120121

121122
ModelFactory.get_compiled_model(gpt2_model, ["GPT2Block"], fullgraph=True)
122123

123-
assert recorded_fullgraph_values == [False] * len(gpt2_model.transformer.h)
124+
assert compile_call_count == 0

0 commit comments

Comments
 (0)