Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/example_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ jobs:
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-h100-latest-2

##### Speculative Decoding Example Tests (requires 25.08 image) #####
##### Speculative Decoding Example Tests (requires 26.01 image) #####
speculative-decoding-pr:
needs: [check-file-changes, wait-checks]
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
example: speculative_decoding
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-l4-latest-1
Expand All @@ -103,7 +103,7 @@ jobs:
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
example: speculative_decoding
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-h100-latest-2
Expand Down
38 changes: 16 additions & 22 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from packaging.version import Version
from PIL import Image
from scripts.ar_validate import validate_ar
from torch.distributed.tensor.experimental._attention import _SDPAMerger
from torch.utils.data import Dataset
from transformers import AutoProcessor, Trainer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother
Expand Down Expand Up @@ -581,7 +580,7 @@ def on_step_end(self, args, state, control, **kwargs):
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
"""
Return patched version of
torch.distributed.tensor.experimental._attention._templated_ring_attention
torch.distributed.tensor.experimental._context_parallel._attention._templated_ring_attention
to support TTT.
"""

Expand Down Expand Up @@ -630,7 +629,7 @@ def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
return attn_bias

def patched_templated_attn(*args, **kwargs):
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
"""Patched version of _templated_ring_attention."""
# Get original attention op
# Sensitive to impl of _templated_ring_attention
original_op = args[2]
Expand Down Expand Up @@ -678,40 +677,35 @@ def patch_ring_attention_for_ttt():
"""Patch torch ring attention to support context parallelism for TTT."""
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.

if not (
Version(torch.__version__) > Version("2.7.1")
and Version(torch.__version__) < Version("2.9.0")
):
if Version(torch.__version__) < Version("2.10.0"):
raise RuntimeError(
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
f"Context parallel TTT only supported for PyTorch >= 2.10.0. "
f"Got {torch.__version__}. "
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
f"Please use torch 2.10.0 or cp_size=1."
)

from torch.distributed.tensor.experimental._context_parallel import _attention

# 1. Disable load balance, which is designed for causal mask.
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
_attention._cp_options.enable_load_balance = False

# 2. Patch templated ring attention for TTT mask.
original_templated_ring_attention = (
torch.distributed.tensor.experimental._attention._templated_ring_attention
)
original_templated_ring_attention_backward = (
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
)
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
get_patched_templated_ring_attn(original_templated_ring_attention)
original_templated_ring_attention = _attention._templated_ring_attention
original_templated_ring_attention_backward = _attention._templated_ring_attention_backward
_attention._templated_ring_attention = get_patched_templated_ring_attn(
original_templated_ring_attention
)
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
_attention._templated_ring_attention_backward = get_patched_templated_ring_attn(
original_templated_ring_attention_backward
)

# 3. Patch merger to skip the blank shard to avoid difference in output.
original_sdpa_merger_step = _SDPAMerger.step
original_sdpa_merger_step = _attention._SDPAMerger.step

def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
if lse.sum() <= 0:
return
return original_sdpa_merger_step(self, out, lse, partial)

_SDPAMerger.step = patched_sdpa_merger_step
_attention._SDPAMerger.step = patched_sdpa_merger_step
7 changes: 2 additions & 5 deletions tests/examples/speculative_decoding/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl
available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
if cp_size == 2 and available_gpus < 2:
pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus))
if cp_size == 2 and not (
Version(torch.__version__) > Version("2.7.0")
and Version(torch.__version__) < Version("2.9.0")
):
pytest.skip("cp_size=2 requires torch 2.8.0")
if cp_size == 2 and not Version(torch.__version__) >= Version("2.10.0"):
pytest.skip("cp_size=2 requires torch 2.10.0")
# Create an ultra-tiny EAGLE config for testing to reduce memory usage
tiny_eagle_config = {
"max_position_embeddings": 128,
Expand Down
Loading