Skip to content

Commit dec7161

Browse files
yeyu-nvidiasugunav14
authored andcommitted
fix the path change in torch v2.10 for spec dec (#863)
## What does this PR do? **Type of change:** bug fix **Overview:** torch v2.10 changes the path for _SDPAMerger. will need to use the new path for import ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated internal import references to reflect organizational changes in dependencies. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent b1fefc3 commit dec7161

File tree

3 files changed

+21
-30
lines changed

3 files changed

+21
-30
lines changed

.github/workflows/example_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ jobs:
8686
pip_install_extras: "[hf,dev-test]"
8787
runner: linux-amd64-gpu-h100-latest-2
8888

89-
##### Speculative Decoding Example Tests (requires 25.08 image) #####
89+
##### Speculative Decoding Example Tests (requires 26.01 image) #####
9090
speculative-decoding-pr:
9191
needs: [check-file-changes, wait-checks]
9292
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
9393
uses: ./.github/workflows/_example_tests_runner.yml
9494
secrets: inherit
9595
with:
96-
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
96+
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
9797
example: speculative_decoding
9898
pip_install_extras: "[hf,dev-test]"
9999
runner: linux-amd64-gpu-l4-latest-1
@@ -103,7 +103,7 @@ jobs:
103103
uses: ./.github/workflows/_example_tests_runner.yml
104104
secrets: inherit
105105
with:
106-
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
106+
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
107107
example: speculative_decoding
108108
pip_install_extras: "[hf,dev-test]"
109109
runner: linux-amd64-gpu-h100-latest-2

examples/speculative_decoding/eagle_utils.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from packaging.version import Version
3232
from PIL import Image
3333
from scripts.ar_validate import validate_ar
34-
from torch.distributed.tensor.experimental._attention import _SDPAMerger
3534
from torch.utils.data import Dataset
3635
from transformers import AutoProcessor, Trainer, TrainerCallback
3736
from transformers.trainer_pt_utils import LabelSmoother
@@ -581,7 +580,7 @@ def on_step_end(self, args, state, control, **kwargs):
581580
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
582581
"""
583582
Return patched version of
584-
torch.distributed.tensor.experimental._attention._templated_ring_attention
583+
torch.distributed.tensor.experimental._context_parallel._attention._templated_ring_attention
585584
to support TTT.
586585
"""
587586

@@ -630,7 +629,7 @@ def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
630629
return attn_bias
631630

632631
def patched_templated_attn(*args, **kwargs):
633-
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
632+
"""Patched version of _templated_ring_attention."""
634633
# Get original attention op
635634
# Sensitive to impl of _templated_ring_attention
636635
original_op = args[2]
@@ -678,40 +677,35 @@ def patch_ring_attention_for_ttt():
678677
"""Patch torch ring attention to support context parallelism for TTT."""
679678
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
680679

681-
if not (
682-
Version(torch.__version__) > Version("2.7.1")
683-
and Version(torch.__version__) < Version("2.9.0")
684-
):
680+
if Version(torch.__version__) < Version("2.10.0"):
685681
raise RuntimeError(
686-
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
682+
f"Context parallel TTT only supported for PyTorch >= 2.10.0. "
687683
f"Got {torch.__version__}. "
688-
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
684+
f"Please use torch 2.10.0 or cp_size=1."
689685
)
690686

687+
from torch.distributed.tensor.experimental._context_parallel import _attention
688+
691689
# 1. Disable load balance, which is designed for causal mask.
692690
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
693-
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
691+
_attention._cp_options.enable_load_balance = False
694692

695693
# 2. Patch templated ring attention for TTT mask.
696-
original_templated_ring_attention = (
697-
torch.distributed.tensor.experimental._attention._templated_ring_attention
698-
)
699-
original_templated_ring_attention_backward = (
700-
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
701-
)
702-
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
703-
get_patched_templated_ring_attn(original_templated_ring_attention)
694+
original_templated_ring_attention = _attention._templated_ring_attention
695+
original_templated_ring_attention_backward = _attention._templated_ring_attention_backward
696+
_attention._templated_ring_attention = get_patched_templated_ring_attn(
697+
original_templated_ring_attention
704698
)
705-
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
706-
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
699+
_attention._templated_ring_attention_backward = get_patched_templated_ring_attn(
700+
original_templated_ring_attention_backward
707701
)
708702

709703
# 3. Patch merger to skip the blank shard to avoid difference in output.
710-
original_sdpa_merger_step = _SDPAMerger.step
704+
original_sdpa_merger_step = _attention._SDPAMerger.step
711705

712706
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
713707
if lse.sum() <= 0:
714708
return
715709
return original_sdpa_merger_step(self, out, lse, partial)
716710

717-
_SDPAMerger.step = patched_sdpa_merger_step
711+
_attention._SDPAMerger.step = patched_sdpa_merger_step

tests/examples/speculative_decoding/test_eagle.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl
3737
available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
3838
if cp_size == 2 and available_gpus < 2:
3939
pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus))
40-
if cp_size == 2 and not (
41-
Version(torch.__version__) > Version("2.7.0")
42-
and Version(torch.__version__) < Version("2.9.0")
43-
):
44-
pytest.skip("cp_size=2 requires torch 2.8.0")
40+
if cp_size == 2 and not Version(torch.__version__) >= Version("2.10.0"):
41+
pytest.skip("cp_size=2 requires torch 2.10.0")
4542
# Create an ultra-tiny EAGLE config for testing to reduce memory usage
4643
tiny_eagle_config = {
4744
"max_position_embeddings": 128,

0 commit comments

Comments
 (0)