Skip to content

Commit d5ce31d

Browse files
committed
torch.compile safely
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent 58dfce0 commit d5ce31d

1 file changed

Lines changed: 21 additions & 2 deletions

File tree

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,13 +618,32 @@ def modify(
618618
self.is_quantized = False
619619

620620
if self.eagle_use_torch_compile:
621+
self._activate_torch_compile()
622+
623+
self._cached_attn_blk_masks = {}
624+
625+
def _activate_torch_compile(self):
626+
import torch._dynamo
627+
628+
torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
629+
630+
# Individual try-catch for each function to maximize torch.compile usage
631+
try:
621632
self._prepare_eagle_inputs = torch.compile(self._prepare_eagle_inputs, dynamic=False)
633+
except Exception:
634+
print("Disabling torch.compile for _prepare_eagle_inputs due to compilation error.")
635+
636+
try:
622637
self._eagle_forward = torch.compile(
623638
self._eagle_forward, dynamic=False, mode="max-autotune"
624639
)
625-
self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True)
640+
except Exception:
641+
print("Disabling torch.compile for _eagle_forward due to compilation error.")
626642

627-
self._cached_attn_blk_masks = {}
643+
try:
644+
self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True)
645+
except Exception:
646+
print("Disabling torch.compile for _eagle_loss due to compilation error.")
628647

629648
def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
630649
# compile and cached flex attention masks in first call

0 commit comments

Comments
 (0)