@@ -63,13 +63,18 @@ class ModelFactory:
6363 """Model factory class to create models."""
6464
6565 @staticmethod
66- def _requires_graph_break_friendly_compile (module : nn .Module ) -> bool :
66+ def _requires_eager_execution (module : nn .Module ) -> bool :
6767 if isinstance (module , GPT2Block ):
6868 return module .attn .attention_impl == AttentionImplementation .DAO_FLASH_V4
6969
7070 attention_impl = getattr (module , "attention_impl" , None )
7171 return attention_impl == AttentionImplementation .DAO_FLASH_V4
7272
73+ # TODO remove?
74+ # @staticmethod
75+ # def _requires_graph_break_friendly_compile(module: nn.Module) -> bool:
76+ # return ModelFactory._requires_eager_execution(module)
77+
7378 @staticmethod
7479 def _is_model_on_meta_device (model : nn .Module ) -> bool :
7580 """
@@ -410,15 +415,24 @@ def get_parent_module_and_child_name(child_module: nn.Module, model: nn.Module)
410415
411416 for _ , module in model .named_modules ():
412417 if isinstance (module , block_types ):
413- options = {"trace.enabled" : True } if debug else {}
414- compiled_fullgraph = fullgraph
415- if compiled_fullgraph and ModelFactory ._requires_graph_break_friendly_compile (module ):
416- compiled_fullgraph = False
418+ if ModelFactory ._requires_eager_execution (module ):
417419 logger .warning (
418- "Disabling `fullgraph=True ` for `%s` because FlashAttention-4 currently graph-breaks under "
419- "torch.compile when tracing into flash_attn.cute internals." ,
420+ "Skipping `torch.compile ` for `%s` because FlashAttention-4 currently graph-breaks under "
421+ "TorchDynamo when tracing into flash_attn.cute internals." ,
420422 module .__class__ .__name__ ,
421423 )
424+ continue
425+
426+ options = {"trace.enabled" : True } if debug else {}
427+ compiled_fullgraph = fullgraph
428+ # TODO remove?
429+ # if compiled_fullgraph and ModelFactory._requires_graph_break_friendly_compile(module):
430+ # compiled_fullgraph = False
431+ # logger.warning(
432+ # "Disabling `fullgraph=True` for `%s` because FlashAttention-4 currently graph-breaks under "
433+ # "torch.compile when tracing into flash_attn.cute internals.",
434+ # module.__class__.__name__,
435+ # )
422436
423437 compiled_module = torch .compile (module , fullgraph = compiled_fullgraph , options = options )
424438 parent_module , child_name = get_parent_module_and_child_name (child_module = module , model = model )
0 commit comments