File tree Expand file tree Collapse file tree
modelopt/torch/speculative/plugins Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -618,13 +618,32 @@ def modify(
618618 self .is_quantized = False
619619
620620 if self .eagle_use_torch_compile :
621+ self ._activate_torch_compile ()
622+
623+ self ._cached_attn_blk_masks = {}
624+
625+ def _activate_torch_compile (self ):
626+ import torch ._dynamo
627+
628+ torch ._dynamo .config .suppress_errors = True # Allow fallback to eager mode
629+
630+ # Individual try-catch for each function to maximize torch.compile usage
631+ try :
621632 self ._prepare_eagle_inputs = torch .compile (self ._prepare_eagle_inputs , dynamic = False )
633+ except Exception :
634+ print ("Disabling torch.compile for _prepare_eagle_inputs due to compilation error." )
635+
636+ try :
622637 self ._eagle_forward = torch .compile (
623638 self ._eagle_forward , dynamic = False , mode = "max-autotune"
624639 )
625- self ._eagle_loss = torch .compile (self ._eagle_loss , dynamic = False , fullgraph = True )
640+ except Exception :
641+ print ("Disabling torch.compile for _eagle_forward due to compilation error." )
626642
627- self ._cached_attn_blk_masks = {}
643+ try :
644+ self ._eagle_loss = torch .compile (self ._eagle_loss , dynamic = False , fullgraph = True )
645+ except Exception :
646+ print ("Disabling torch.compile for _eagle_loss due to compilation error." )
628647
629648 def _get_ttt_attention_mask (self , batch_size , seq_length , ttt_step ):
630649 # compile and cached flex attention masks in first call
You can’t perform that action at this time.
0 commit comments