diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index a64528dd57..8442125f38 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -86,14 +86,14 @@ jobs: pip_install_extras: "[hf,dev-test]" runner: linux-amd64-gpu-h100-latest-2 - ##### Speculative Decoding Example Tests (requires 25.08 image) ##### + ##### Speculative Decoding Example Tests (requires 26.01 image) ##### speculative-decoding-pr: needs: [check-file-changes, wait-checks] if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: - docker_image: "nvcr.io/nvidia/pytorch:25.08-py3" + docker_image: "nvcr.io/nvidia/pytorch:26.01-py3" example: speculative_decoding pip_install_extras: "[hf,dev-test]" runner: linux-amd64-gpu-l4-latest-1 @@ -103,7 +103,7 @@ jobs: uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: - docker_image: "nvcr.io/nvidia/pytorch:25.08-py3" + docker_image: "nvcr.io/nvidia/pytorch:26.01-py3" example: speculative_decoding pip_install_extras: "[hf,dev-test]" runner: linux-amd64-gpu-h100-latest-2 diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 05ec3f864c..3625072b1a 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -31,7 +31,6 @@ from packaging.version import Version from PIL import Image from scripts.ar_validate import validate_ar -from torch.distributed.tensor.experimental._attention import _SDPAMerger from torch.utils.data import Dataset from transformers import AutoProcessor, Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother @@ -581,7 +580,7 @@ def on_step_end(self, args, state, control, **kwargs): def get_patched_templated_ring_attn(orig_templated_attn: Callable): """ Return patched version of - torch.distributed.tensor.experimental._attention._templated_ring_attention + torch.distributed.tensor.experimental._context_parallel._attention._templated_ring_attention to support TTT. """ @@ -630,7 +629,7 @@ def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype): return attn_bias def patched_templated_attn(*args, **kwargs): - """Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention.""" + """Patched version of _templated_ring_attention.""" # Get original attention op # Sensitive to impl of _templated_ring_attention original_op = args[2] @@ -678,40 +677,35 @@ def patch_ring_attention_for_ttt(): """Patch torch ring attention to support context parallelism for TTT.""" # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. - if not ( - Version(torch.__version__) > Version("2.7.1") - and Version(torch.__version__) < Version("2.9.0") - ): + if Version(torch.__version__) < Version("2.10.0"): raise RuntimeError( - f"Context parallel TTT only supported for PyTorch 2.8.0 now. " + f"Context parallel TTT only supported for PyTorch >= 2.10.0. " f"Got {torch.__version__}. " - f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1." + f"Please use torch 2.10.0 or cp_size=1." ) + from torch.distributed.tensor.experimental._context_parallel import _attention + # 1. Disable load balance, which is designed for causal mask. # This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init. - torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False + _attention._cp_options.enable_load_balance = False # 2. Patch templated ring attention for TTT mask. - original_templated_ring_attention = ( - torch.distributed.tensor.experimental._attention._templated_ring_attention - ) - original_templated_ring_attention_backward = ( - torch.distributed.tensor.experimental._attention._templated_ring_attention_backward - ) - torch.distributed.tensor.experimental._attention._templated_ring_attention = ( - get_patched_templated_ring_attn(original_templated_ring_attention) + original_templated_ring_attention = _attention._templated_ring_attention + original_templated_ring_attention_backward = _attention._templated_ring_attention_backward + _attention._templated_ring_attention = get_patched_templated_ring_attn( + original_templated_ring_attention ) - torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = ( - get_patched_templated_ring_attn(original_templated_ring_attention_backward) + _attention._templated_ring_attention_backward = get_patched_templated_ring_attn( + original_templated_ring_attention_backward ) # 3. Patch merger to skip the blank shard to avoid difference in output. - original_sdpa_merger_step = _SDPAMerger.step + original_sdpa_merger_step = _attention._SDPAMerger.step def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool): if lse.sum() <= 0: return return original_sdpa_merger_step(self, out, lse, partial) - _SDPAMerger.step = patched_sdpa_merger_step + _attention._SDPAMerger.step = patched_sdpa_merger_step diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index deabba5702..3775b8a4c2 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -37,11 +37,8 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 if cp_size == 2 and available_gpus < 2: pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus)) - if cp_size == 2 and not ( - Version(torch.__version__) > Version("2.7.0") - and Version(torch.__version__) < Version("2.9.0") - ): - pytest.skip("cp_size=2 requires torch 2.8.0") + if cp_size == 2 and not Version(torch.__version__) >= Version("2.10.0"): + pytest.skip("cp_size=2 requires torch 2.10.0") # Create an ultra-tiny EAGLE config for testing to reduce memory usage tiny_eagle_config = { "max_position_embeddings": 128,