|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import inspect |
16 | 17 | import json |
17 | 18 | import os |
| 19 | +from collections.abc import Callable |
18 | 20 | from pathlib import Path |
| 21 | +from typing import TYPE_CHECKING |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from types import FrameType |
19 | 25 | from typing import Any |
20 | 26 |
|
21 | 27 | import numpy as np |
22 | 28 | import torch |
23 | 29 | import transformers |
24 | 30 | from datasets import load_dataset |
| 31 | +from packaging.version import Version |
25 | 32 | from PIL import Image |
26 | 33 | from scripts.ar_validate import validate_ar |
| 34 | +from torch.distributed.tensor.experimental._attention import _SDPAMerger |
27 | 35 | from torch.utils.data import Dataset |
28 | 36 | from transformers import AutoProcessor, Trainer, TrainerCallback |
29 | 37 | from transformers.trainer_pt_utils import LabelSmoother |
30 | 38 |
|
| 39 | +import modelopt |
| 40 | +from modelopt.torch.speculative.utils import get_ttt_msk_func |
31 | 41 | from modelopt.torch.utils import print_rank_0 |
32 | 42 | from modelopt.torch.utils.distributed import is_master |
33 | 43 |
|
@@ -566,3 +576,142 @@ def on_step_end(self, args, state, control, **kwargs): |
566 | 576 | except Exception: |
567 | 577 | print_rank_0("AR validation not available.") |
568 | 578 | return control |
| 579 | + |
| 580 | + |
| 581 | +def get_patched_templated_ring_attn(orig_templated_attn: Callable): |
| 582 | + """ |
| 583 | + Return patched version of |
| 584 | + torch.distributed.tensor.experimental._attention._templated_ring_attention |
| 585 | + to support TTT. |
| 586 | + """ |
| 587 | + |
| 588 | + def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype): |
| 589 | + """Get chunk-interleaved TTT mask for current rank. |
| 590 | + e.g.: |
| 591 | + 2 ranks, ttt_step=1; |
| 592 | + full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0], |
| 593 | + [x, 0, 0, 0, 0, x, 0, 0], |
| 594 | + [x, x, 0, 0, 0, 0, x, 0], |
| 595 | + [x, x, x, 0, 0, 0, 0, x], |
| 596 | +
|
| 597 | + rank 0, step0: [[0, 0, x, 0], |
| 598 | + [x, 0, 0, x]] |
| 599 | +
|
| 600 | + rank 1, step0: [[0, 0, x, 0], |
| 601 | + [x, 0, 0, x]] |
| 602 | +
|
| 603 | + rank 0, step1: [[0, 0, 0, 0], |
| 604 | + [0, 0, 0, 0]] |
| 605 | +
|
| 606 | + rank 1, step1: [[x, x, 0, 0], |
| 607 | + [x, x, 0, 0]] |
| 608 | +
|
| 609 | + """ |
| 610 | + device = torch.cuda.current_device() |
| 611 | + q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device) |
| 612 | + kv_indices = ( |
| 613 | + torch.arange(q_len * size * (ttt_step + 1), device=device) |
| 614 | + .view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :] |
| 615 | + .reshape(-1) |
| 616 | + ) |
| 617 | + msk_func = get_ttt_msk_func(q_len * size, ttt_step) |
| 618 | + attn_mask = msk_func( |
| 619 | + None, |
| 620 | + None, |
| 621 | + q_indices.view(1, 1, -1, 1), |
| 622 | + kv_indices.view(1, 1, 1, -1), |
| 623 | + ) |
| 624 | + attn_bias = torch.where( |
| 625 | + attn_mask, |
| 626 | + torch.zeros((), dtype=dtype, device=attn_mask.device), |
| 627 | + torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device), |
| 628 | + ) |
| 629 | + |
| 630 | + return attn_bias |
| 631 | + |
| 632 | + def patched_templated_attn(*args, **kwargs): |
| 633 | + """Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention.""" |
| 634 | + # Get original attention op |
| 635 | + # Sensitive to impl of _templated_ring_attention |
| 636 | + original_op = args[2] |
| 637 | + |
| 638 | + # This patch is only enabled for eagle model by context manager, not base model. |
| 639 | + patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH |
| 640 | + |
| 641 | + if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention: |
| 642 | + raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}") |
| 643 | + |
| 644 | + # Unset is_causal to use custom attn mask |
| 645 | + if patch_enbabled: |
| 646 | + kwargs["is_causal"] = False |
| 647 | + |
| 648 | + def patched_op(*args, **kwargs): |
| 649 | + # Inspect the parent frame to get current shard info |
| 650 | + # This is sensitive to torch _templated_ring_attention impl |
| 651 | + try: |
| 652 | + frame: FrameType = inspect.currentframe() |
| 653 | + f_back: FrameType = frame.f_back |
| 654 | + rank = f_back.f_locals["rank"] |
| 655 | + size = f_back.f_locals["size"] |
| 656 | + query = f_back.f_locals["query"] |
| 657 | + key = f_back.f_locals["key"] |
| 658 | + i = f_back.f_locals["i"] |
| 659 | + ttt_step = (key.shape[2] // query.shape[2]) - 1 |
| 660 | + except Exception as e: |
| 661 | + raise RuntimeError( |
| 662 | + f"Failed to capture loop variables in patched _templated_ring_attention: {e}" |
| 663 | + ) from e |
| 664 | + # Set attn mask to permuted TTT mask |
| 665 | + if "attn_bias" in kwargs: |
| 666 | + kwargs["attn_bias"] = _get_sharded_ttt_msk( |
| 667 | + i, rank, size, query.shape[2], ttt_step, query.dtype |
| 668 | + ) |
| 669 | + # Perform shard attention |
| 670 | + return original_op(*args, **kwargs) |
| 671 | + |
| 672 | + return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs) |
| 673 | + |
| 674 | + return patched_templated_attn |
| 675 | + |
| 676 | + |
| 677 | +def patch_ring_attention_for_ttt(): |
| 678 | + """Patch torch ring attention to support context parallelism for TTT.""" |
| 679 | + # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. |
| 680 | + |
| 681 | + if not ( |
| 682 | + Version(torch.__version__) > Version("2.7.1") |
| 683 | + and Version(torch.__version__) < Version("2.9.0") |
| 684 | + ): |
| 685 | + raise RuntimeError( |
| 686 | + f"Context parallel TTT only supported for PyTorch 2.8.0 now. " |
| 687 | + f"Got {torch.__version__}. " |
| 688 | + f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1." |
| 689 | + ) |
| 690 | + |
| 691 | + # 1. Disable load balance, which is designed for causal mask. |
| 692 | + # This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init. |
| 693 | + torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False |
| 694 | + |
| 695 | + # 2. Patch templated ring attention for TTT mask. |
| 696 | + original_templated_ring_attention = ( |
| 697 | + torch.distributed.tensor.experimental._attention._templated_ring_attention |
| 698 | + ) |
| 699 | + original_templated_ring_attention_backward = ( |
| 700 | + torch.distributed.tensor.experimental._attention._templated_ring_attention_backward |
| 701 | + ) |
| 702 | + torch.distributed.tensor.experimental._attention._templated_ring_attention = ( |
| 703 | + get_patched_templated_ring_attn(original_templated_ring_attention) |
| 704 | + ) |
| 705 | + torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = ( |
| 706 | + get_patched_templated_ring_attn(original_templated_ring_attention_backward) |
| 707 | + ) |
| 708 | + |
| 709 | + # 3. Patch merger to skip the blank shard to avoid difference in output. |
| 710 | + original_sdpa_merger_step = _SDPAMerger.step |
| 711 | + |
| 712 | + def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool): |
| 713 | + if lse.sum() <= 0: |
| 714 | + return |
| 715 | + return original_sdpa_merger_step(self, out, lse, partial) |
| 716 | + |
| 717 | + _SDPAMerger.step = patched_sdpa_merger_step |
0 commit comments