|
31 | 31 | from packaging.version import Version |
32 | 32 | from PIL import Image |
33 | 33 | from scripts.ar_validate import validate_ar |
34 | | -from torch.distributed.tensor.experimental._attention import _SDPAMerger |
35 | 34 | from torch.utils.data import Dataset |
36 | 35 | from transformers import AutoProcessor, Trainer, TrainerCallback |
37 | 36 | from transformers.trainer_pt_utils import LabelSmoother |
@@ -581,7 +580,7 @@ def on_step_end(self, args, state, control, **kwargs): |
581 | 580 | def get_patched_templated_ring_attn(orig_templated_attn: Callable): |
582 | 581 | """ |
583 | 582 | Return patched version of |
584 | | - torch.distributed.tensor.experimental._attention._templated_ring_attention |
| 583 | + torch.distributed.tensor.experimental._context_parallel._attention._templated_ring_attention |
585 | 584 | to support TTT. |
586 | 585 | """ |
587 | 586 |
|
@@ -630,7 +629,7 @@ def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype): |
630 | 629 | return attn_bias |
631 | 630 |
|
632 | 631 | def patched_templated_attn(*args, **kwargs): |
633 | | - """Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention.""" |
| 632 | + """Patched version of _templated_ring_attention.""" |
634 | 633 | # Get original attention op |
635 | 634 | # Sensitive to impl of _templated_ring_attention |
636 | 635 | original_op = args[2] |
@@ -678,40 +677,35 @@ def patch_ring_attention_for_ttt(): |
678 | 677 | """Patch torch ring attention to support context parallelism for TTT.""" |
679 | 678 | # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. |
680 | 679 |
|
681 | | - if not ( |
682 | | - Version(torch.__version__) > Version("2.7.1") |
683 | | - and Version(torch.__version__) < Version("2.9.0") |
684 | | - ): |
| 680 | + if Version(torch.__version__) < Version("2.10.0"): |
685 | 681 | raise RuntimeError( |
686 | | - f"Context parallel TTT only supported for PyTorch 2.8.0 now. " |
| 682 | + f"Context parallel TTT only supported for PyTorch >= 2.10.0. " |
687 | 683 | f"Got {torch.__version__}. " |
688 | | - f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1." |
| 684 | + f"Please use torch 2.10.0 or cp_size=1." |
689 | 685 | ) |
690 | 686 |
|
| 687 | + from torch.distributed.tensor.experimental._context_parallel import _attention |
| 688 | + |
691 | 689 | # 1. Disable load balance, which is designed for causal mask. |
692 | 690 | # This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init. |
693 | | - torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False |
| 691 | + _attention._cp_options.enable_load_balance = False |
694 | 692 |
|
695 | 693 | # 2. Patch templated ring attention for TTT mask. |
696 | | - original_templated_ring_attention = ( |
697 | | - torch.distributed.tensor.experimental._attention._templated_ring_attention |
698 | | - ) |
699 | | - original_templated_ring_attention_backward = ( |
700 | | - torch.distributed.tensor.experimental._attention._templated_ring_attention_backward |
701 | | - ) |
702 | | - torch.distributed.tensor.experimental._attention._templated_ring_attention = ( |
703 | | - get_patched_templated_ring_attn(original_templated_ring_attention) |
| 694 | + original_templated_ring_attention = _attention._templated_ring_attention |
| 695 | + original_templated_ring_attention_backward = _attention._templated_ring_attention_backward |
| 696 | + _attention._templated_ring_attention = get_patched_templated_ring_attn( |
| 697 | + original_templated_ring_attention |
704 | 698 | ) |
705 | | - torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = ( |
706 | | - get_patched_templated_ring_attn(original_templated_ring_attention_backward) |
| 699 | + _attention._templated_ring_attention_backward = get_patched_templated_ring_attn( |
| 700 | + original_templated_ring_attention_backward |
707 | 701 | ) |
708 | 702 |
|
709 | 703 | # 3. Patch merger to skip the blank shard to avoid difference in output. |
710 | | - original_sdpa_merger_step = _SDPAMerger.step |
| 704 | + original_sdpa_merger_step = _attention._SDPAMerger.step |
711 | 705 |
|
712 | 706 | def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool): |
713 | 707 | if lse.sum() <= 0: |
714 | 708 | return |
715 | 709 | return original_sdpa_merger_step(self, out, lse, partial) |
716 | 710 |
|
717 | | - _SDPAMerger.step = patched_sdpa_merger_step |
| 711 | + _attention._SDPAMerger.step = patched_sdpa_merger_step |
0 commit comments