Skip to content

Commit 3036a9e

Browse files
authored
Feat: Context Parallel for Eagle3 Training (#745)
## What does this PR do? **Type of change:** New Feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - Supported Context Parallel by patching torch ring attention; - Require following libirary version for stable cp: - torch2.8.0 - transformers5.0.0 - accelrate1.12.0 - Move to FSDP2 - Removed unused arguments in training script (`--multi_gpu`, `fsdp_wrap_layer`) - Bump CI container to `nvcr.io/nvidia/pytorch:25.08-py3` ## Usage <!-- You can potentially add a usage example below. --> ```bash ./launch_train.sh --model $MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ --num_epochs 0.1 \ --train_bs 1 \ --eagle_config eagle_config.json \ --training_seq_len 1024 \ --cp_size 2 #newly added ``` ## Testing - SDPA level correctness: tested TTT attention with/without CP, diff < 1% ``` === Compare context-parallel (CP) outputs and grads with non-CP === Forward output comparison (CP vs Non-CP): Absolute diff (adiff) cp_out vs out: 0.001953125 Relative diff (rdiff) cp_out vs out: 0.00182342529296875 WQ (query proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125 Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625 WK (key proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125 Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125 WV (value proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25 Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125 ============================================================== ``` - E2E Training Acc (Llama3.1-8B, Unsynthesized magpie) <img width="911" height="630" alt="image" src="https://github.com/user-attachments/assets/1ecacc7f-c720-494c-9c1b-b60e7ced7baa" /> - Peak Mem Reserved (llama3.1-8B, 8xH100, train_length=4k) | cp_size | max_memory_allocated(MB) |max_memory_reserved (MB) | |----|--------------------------|--------------------------| | 1 | 65040.20 |79018.00 | 2 | 50409.17 |73098.00 | 4 | 45120.92 |72052.00 | 8 | 38882.12 |66484.00 - Max Training Length test (llama3.1-8B, H100) | cp_size | 6k | 12k | 24k | 48k | |--------------------|-----|-----|-----|-----| | 1 | ✅ | OOM | OOM | OOM | |2 | ✅ | ✅ | OOM | OOM | | 4 | ✅ | ✅ | ✅ | OOM | | 8 | ✅ | ✅ | ✅ | ✅ | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added context parallelism (CP) and data parallelism shard size configuration parameters to training arguments. * **Enhancements** * Improved TTT attention masking support for speculative decoding workflows. * Enhanced training launch script with improved parallelism configuration handling. * **Chores** * Updated core dependencies: torch, transformers, accelerate, and wandb. * Added FSDP configuration file for distributed training setup. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 04165ac commit 3036a9e

File tree

12 files changed

+305
-80
lines changed

12 files changed

+305
-80
lines changed

.github/workflows/example_tests.yml

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
strategy:
6464
fail-fast: false
6565
matrix:
66-
example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding]
66+
example: [llm_distill, llm_qat, llm_sparsity]
6767
uses: ./.github/workflows/_example_tests_runner.yml
6868
secrets: inherit
6969
with:
@@ -77,7 +77,7 @@ jobs:
7777
strategy:
7878
fail-fast: false
7979
matrix:
80-
example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding]
80+
example: [llm_distill, llm_qat, llm_sparsity]
8181
uses: ./.github/workflows/_example_tests_runner.yml
8282
secrets: inherit
8383
with:
@@ -86,6 +86,28 @@ jobs:
8686
pip_install_extras: "[hf,dev-test]"
8787
runner: linux-amd64-gpu-h100-latest-2
8888

89+
##### Speculative Decoding Example Tests (requires 25.08 image) #####
90+
speculative-decoding-pr:
91+
needs: [check-file-changes, wait-checks]
92+
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
93+
uses: ./.github/workflows/_example_tests_runner.yml
94+
secrets: inherit
95+
with:
96+
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
97+
example: speculative_decoding
98+
pip_install_extras: "[hf,dev-test]"
99+
runner: linux-amd64-gpu-l4-latest-1
100+
101+
speculative-decoding-non-pr:
102+
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
103+
uses: ./.github/workflows/_example_tests_runner.yml
104+
secrets: inherit
105+
with:
106+
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
107+
example: speculative_decoding
108+
pip_install_extras: "[hf,dev-test]"
109+
runner: linux-amd64-gpu-h100-latest-2
110+
89111
##### TensorRT-LLM Example Tests #####
90112
trtllm-pr:
91113
needs: [check-file-changes, wait-checks]
@@ -150,14 +172,15 @@ jobs:
150172
example-pr-required-check:
151173
# Run even if example tests are skipped
152174
if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }}
153-
needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr]
175+
needs: [check-file-changes, torch-pr, speculative-decoding-pr, trtllm-pr, onnx-pr]
154176
runs-on: ubuntu-latest
155177
steps:
156178
- name: Required GPU tests did not succeed
157179
if: |
158180
needs.check-file-changes.result != 'success' ||
159181
(needs.check-file-changes.outputs.any_changed == 'true' && (
160182
needs.torch-pr.result != 'success' ||
183+
needs.speculative-decoding-pr.result != 'success' ||
161184
needs.trtllm-pr.result != 'success' ||
162185
needs.onnx-pr.result != 'success'
163186
))

examples/speculative_decoding/README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
3030

3131
### Docker
3232

33-
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
33+
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
3434

3535
Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.
3636

@@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst
5656
## Getting Started: Simplified Workflow
5757

5858
```bash
59-
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
59+
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct
6060
```
6161

6262
This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
@@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m
7474
./launch_train.sh --model $BASE_MODEL \
7575
--output_dir $OUTPUT_DIR \
7676
--data input_conversations/daring-anteater.jsonl \
77-
--num_gpu $NUM_GPU \
7877
--num_epochs $NUM_EPOCH \
7978
--eagle_config eagle_config.json
8079
```
8180

82-
This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
81+
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
8382
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
8483

8584
## Training Draft Model with Offline Base Model
@@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
118117
./launch_train.sh --model $BASE_MODEL \
119118
--output_dir $OUTPUT_DIR \
120119
--data $DATA \
121-
--num_gpu $NUM_GPU \
122120
--num_epochs $NUM_EPOCH \
123121
--eagle_config eagle_config.json \
124122
--offline-data $HIDDEN_STATES_DIR

examples/speculative_decoding/eagle_utils.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,31 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import json
1718
import os
19+
from collections.abc import Callable
1820
from pathlib import Path
21+
from typing import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from types import FrameType
1925
from typing import Any
2026

2127
import numpy as np
2228
import torch
2329
import transformers
2430
from datasets import load_dataset
31+
from packaging.version import Version
2532
from PIL import Image
2633
from scripts.ar_validate import validate_ar
34+
from torch.distributed.tensor.experimental._attention import _SDPAMerger
2735
from torch.utils.data import Dataset
2836
from transformers import AutoProcessor, Trainer, TrainerCallback
2937
from transformers.trainer_pt_utils import LabelSmoother
3038

39+
import modelopt
40+
from modelopt.torch.speculative.utils import get_ttt_msk_func
3141
from modelopt.torch.utils import print_rank_0
3242
from modelopt.torch.utils.distributed import is_master
3343

@@ -566,3 +576,142 @@ def on_step_end(self, args, state, control, **kwargs):
566576
except Exception:
567577
print_rank_0("AR validation not available.")
568578
return control
579+
580+
581+
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
582+
"""
583+
Return patched version of
584+
torch.distributed.tensor.experimental._attention._templated_ring_attention
585+
to support TTT.
586+
"""
587+
588+
def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
589+
"""Get chunk-interleaved TTT mask for current rank.
590+
e.g.:
591+
2 ranks, ttt_step=1;
592+
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
593+
[x, 0, 0, 0, 0, x, 0, 0],
594+
[x, x, 0, 0, 0, 0, x, 0],
595+
[x, x, x, 0, 0, 0, 0, x],
596+
597+
rank 0, step0: [[0, 0, x, 0],
598+
[x, 0, 0, x]]
599+
600+
rank 1, step0: [[0, 0, x, 0],
601+
[x, 0, 0, x]]
602+
603+
rank 0, step1: [[0, 0, 0, 0],
604+
[0, 0, 0, 0]]
605+
606+
rank 1, step1: [[x, x, 0, 0],
607+
[x, x, 0, 0]]
608+
609+
"""
610+
device = torch.cuda.current_device()
611+
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
612+
kv_indices = (
613+
torch.arange(q_len * size * (ttt_step + 1), device=device)
614+
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
615+
.reshape(-1)
616+
)
617+
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
618+
attn_mask = msk_func(
619+
None,
620+
None,
621+
q_indices.view(1, 1, -1, 1),
622+
kv_indices.view(1, 1, 1, -1),
623+
)
624+
attn_bias = torch.where(
625+
attn_mask,
626+
torch.zeros((), dtype=dtype, device=attn_mask.device),
627+
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
628+
)
629+
630+
return attn_bias
631+
632+
def patched_templated_attn(*args, **kwargs):
633+
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
634+
# Get original attention op
635+
# Sensitive to impl of _templated_ring_attention
636+
original_op = args[2]
637+
638+
# This patch is only enabled for eagle model by context manager, not base model.
639+
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
640+
641+
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
642+
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")
643+
644+
# Unset is_causal to use custom attn mask
645+
if patch_enbabled:
646+
kwargs["is_causal"] = False
647+
648+
def patched_op(*args, **kwargs):
649+
# Inspect the parent frame to get current shard info
650+
# This is sensitive to torch _templated_ring_attention impl
651+
try:
652+
frame: FrameType = inspect.currentframe()
653+
f_back: FrameType = frame.f_back
654+
rank = f_back.f_locals["rank"]
655+
size = f_back.f_locals["size"]
656+
query = f_back.f_locals["query"]
657+
key = f_back.f_locals["key"]
658+
i = f_back.f_locals["i"]
659+
ttt_step = (key.shape[2] // query.shape[2]) - 1
660+
except Exception as e:
661+
raise RuntimeError(
662+
f"Failed to capture loop variables in patched _templated_ring_attention: {e}"
663+
) from e
664+
# Set attn mask to permuted TTT mask
665+
if "attn_bias" in kwargs:
666+
kwargs["attn_bias"] = _get_sharded_ttt_msk(
667+
i, rank, size, query.shape[2], ttt_step, query.dtype
668+
)
669+
# Perform shard attention
670+
return original_op(*args, **kwargs)
671+
672+
return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)
673+
674+
return patched_templated_attn
675+
676+
677+
def patch_ring_attention_for_ttt():
678+
"""Patch torch ring attention to support context parallelism for TTT."""
679+
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
680+
681+
if not (
682+
Version(torch.__version__) > Version("2.7.1")
683+
and Version(torch.__version__) < Version("2.9.0")
684+
):
685+
raise RuntimeError(
686+
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
687+
f"Got {torch.__version__}. "
688+
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
689+
)
690+
691+
# 1. Disable load balance, which is designed for causal mask.
692+
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
693+
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
694+
695+
# 2. Patch templated ring attention for TTT mask.
696+
original_templated_ring_attention = (
697+
torch.distributed.tensor.experimental._attention._templated_ring_attention
698+
)
699+
original_templated_ring_attention_backward = (
700+
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
701+
)
702+
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
703+
get_patched_templated_ring_attn(original_templated_ring_attention)
704+
)
705+
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
706+
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
707+
)
708+
709+
# 3. Patch merger to skip the blank shard to avoid difference in output.
710+
original_sdpa_merger_step = _SDPAMerger.step
711+
712+
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
713+
if lse.sum() <= 0:
714+
return
715+
return original_sdpa_merger_step(self, out, lse, partial)
716+
717+
_SDPAMerger.step = patched_sdpa_merger_step
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"fsdp_version":2}

examples/speculative_decoding/launch_train.sh

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do
7474
if [[ "$1" != *=* ]]; then shift; fi
7575
EAGLE_CONFIG="${1#*=}"
7676
;;
77-
--fsdp_transformer_layer_cls_to_wrap*)
78-
if [[ "$1" != *=* ]]; then shift; fi
79-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
80-
;;
81-
--num_gpu*)
82-
if [[ "$1" != *=* ]]; then shift; fi
83-
NUM_GPU="${1#*=}"
84-
;;
8577
--disable_tqdm*)
8678
if [[ "$1" != *=* ]]; then shift; fi
8779
DISABLE_TQDM="${1#*=}"
@@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do
10294
if [[ "$1" != *=* ]]; then shift; fi
10395
AR_VALIDATE_STEPS="${1#*=}"
10496
;;
97+
--cp_size*)
98+
if [[ "$1" != *=* ]]; then shift; fi
99+
CP_SIZE="${1#*=}"
100+
;;
101+
--dp_size*)
102+
if [[ "$1" != *=* ]]; then shift; fi
103+
DP_SHARD_SIZE="${1#*=}"
104+
;;
105105
*)
106106
>&2 printf "Error: Invalid argument ${1#*=}\n"
107107
exit 1
@@ -129,15 +129,15 @@ LR=${LR:-"1e-4"}
129129
TRAIN_BS=${TRAIN_BS:-4}
130130
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
131131
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
132-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
133-
NUM_GPU=${NUM_GPU:-1}
134132
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
135133
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
136134
DISABLE_TQDM=${DISABLE_TQDM:-False}
137135
VLM_PROCESSOR=${VLM_PROCESSOR:-}
138136
VLM_IMG_DIR=${VLM_IMG_DIR:-}
139137
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
140138
ESTIMATE_AR=${ESTIMATE_AR:-False}
139+
CP_SIZE=${CP_SIZE:-1}
140+
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}
141141

142142
if [[ "$MODE" == "medusa" ]]; then
143143
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -163,21 +163,25 @@ else
163163
OFFLINE_TRAINING_ARGS=""
164164
fi
165165

166-
if [[ "$NUM_GPU" == 1 ]]; then
167-
MULTI_GPU=""
168-
else
169-
MULTI_GPU="--multi_gpu"
170-
fi
171166

172167
if [[ "$VLM_PROCESSOR" != "" ]]; then
173168
VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR"
174169
else
175170
VLM_ARGS=""
176171
fi
177172

173+
if [[ "$GPU_COUNT" -gt 1 ]]; then
174+
#Use FSDP2 when multi GPU available
175+
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json"
176+
else
177+
#Otherwise, single GPU training
178+
FSDP_ARGS=""
179+
fi
180+
181+
178182
# Disable tokenizers parallelism to avoid warning
179183
export TOKENIZERS_PARALLELISM=False
180-
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
184+
CMD="accelerate launch --mixed_precision bf16 main.py \
181185
--mode $MODE \
182186
--eagle_decoder_type $EAGLE_DECODER_TYPE \
183187
--model_name_or_path $MODEL \
@@ -206,6 +210,9 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
206210
$VLM_ARGS \
207211
$OFFLINE_TRAINING_ARGS \
208212
$SPECULATIVE_ARGS \
213+
$FSDP_ARGS \
214+
--cp_size $CP_SIZE \
215+
--dp_shard_size $DP_SHARD_SIZE \
209216
"
210217

211218
start_time=$(date +%s)

0 commit comments

Comments
 (0)