Skip to content

Commit 7739972

Browse files
h-guo18danielkorzekwa
authored andcommitted
Feat: Context Parallel for Eagle3 Training (#745)
**Type of change:** New Feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - Supported Context Parallel by patching torch ring attention; - Require following libirary version for stable cp: - torch2.8.0 - transformers5.0.0 - accelrate1.12.0 - Move to FSDP2 - Removed unused arguments in training script (`--multi_gpu`, `fsdp_wrap_layer`) - Bump CI container to `nvcr.io/nvidia/pytorch:25.08-py3` <!-- You can potentially add a usage example below. --> ```bash ./launch_train.sh --model $MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ --num_epochs 0.1 \ --train_bs 1 \ --eagle_config eagle_config.json \ --training_seq_len 1024 \ --cp_size 2 #newly added ``` - SDPA level correctness: tested TTT attention with/without CP, diff < 1% ``` === Compare context-parallel (CP) outputs and grads with non-CP === Forward output comparison (CP vs Non-CP): Absolute diff (adiff) cp_out vs out: 0.001953125 Relative diff (rdiff) cp_out vs out: 0.00182342529296875 WQ (query proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125 Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625 WK (key proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125 Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125 WV (value proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25 Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125 ============================================================== ``` - E2E Training Acc (Llama3.1-8B, Unsynthesized magpie) <img width="911" height="630" alt="image" src="https://github.com/user-attachments/assets/1ecacc7f-c720-494c-9c1b-b60e7ced7baa" /> - Peak Mem Reserved (llama3.1-8B, 8xH100, train_length=4k) | cp_size | max_memory_allocated(MB) |max_memory_reserved (MB) | |----|--------------------------|--------------------------| | 1 | 65040.20 |79018.00 | 2 | 50409.17 |73098.00 | 4 | 45120.92 |72052.00 | 8 | 38882.12 |66484.00 - Max Training Length test (llama3.1-8B, H100) | cp_size | 6k | 12k | 24k | 48k | |--------------------|-----|-----|-----|-----| | 1 | ✅ | OOM | OOM | OOM | |2 | ✅ | ✅ | OOM | OOM | | 4 | ✅ | ✅ | ✅ | OOM | | 8 | ✅ | ✅ | ✅ | ✅ | <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * Added context parallelism (CP) and data parallelism shard size configuration parameters to training arguments. * **Enhancements** * Improved TTT attention masking support for speculative decoding workflows. * Enhanced training launch script with improved parallelism configuration handling. * **Chores** * Updated core dependencies: torch, transformers, accelerate, and wandb. * Added FSDP configuration file for distributed training setup. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 282675b commit 7739972

11 files changed

Lines changed: 279 additions & 77 deletions

File tree

examples/speculative_decoding/README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
3030

3131
### Docker
3232

33-
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
33+
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
3434

3535
Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.
3636

@@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst
5656
## Getting Started: Simplified Workflow
5757

5858
```bash
59-
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
59+
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct
6060
```
6161

6262
This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
@@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m
7474
./launch_train.sh --model $BASE_MODEL \
7575
--output_dir $OUTPUT_DIR \
7676
--data input_conversations/daring-anteater.jsonl \
77-
--num_gpu $NUM_GPU \
7877
--num_epochs $NUM_EPOCH \
7978
--eagle_config eagle_config.json
8079
```
8180

82-
This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
81+
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
8382
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
8483

8584
## Training Draft Model with Offline Base Model
@@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
118117
./launch_train.sh --model $BASE_MODEL \
119118
--output_dir $OUTPUT_DIR \
120119
--data $DATA \
121-
--num_gpu $NUM_GPU \
122120
--num_epochs $NUM_EPOCH \
123121
--eagle_config eagle_config.json \
124122
--offline-data $HIDDEN_STATES_DIR

examples/speculative_decoding/eagle_utils.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,31 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import json
1718
import os
19+
from collections.abc import Callable
1820
from pathlib import Path
21+
from typing import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from types import FrameType
1925
from typing import Any
2026

2127
import numpy as np
2228
import torch
2329
import transformers
2430
from datasets import load_dataset
31+
from packaging.version import Version
2532
from PIL import Image
2633
from scripts.ar_validate import validate_ar
34+
from torch.distributed.tensor.experimental._attention import _SDPAMerger
2735
from torch.utils.data import Dataset
2836
from transformers import AutoProcessor, Trainer, TrainerCallback
2937
from transformers.trainer_pt_utils import LabelSmoother
3038

39+
import modelopt
40+
from modelopt.torch.speculative.utils import get_ttt_msk_func
3141
from modelopt.torch.utils import print_rank_0
3242
from modelopt.torch.utils.distributed import is_master
3343

@@ -566,3 +576,142 @@ def on_step_end(self, args, state, control, **kwargs):
566576
except Exception:
567577
print_rank_0("AR validation not available.")
568578
return control
579+
580+
581+
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
582+
"""
583+
Return patched version of
584+
torch.distributed.tensor.experimental._attention._templated_ring_attention
585+
to support TTT.
586+
"""
587+
588+
def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
589+
"""Get chunk-interleaved TTT mask for current rank.
590+
e.g.:
591+
2 ranks, ttt_step=1;
592+
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
593+
[x, 0, 0, 0, 0, x, 0, 0],
594+
[x, x, 0, 0, 0, 0, x, 0],
595+
[x, x, x, 0, 0, 0, 0, x],
596+
597+
rank 0, step0: [[0, 0, x, 0],
598+
[x, 0, 0, x]]
599+
600+
rank 1, step0: [[0, 0, x, 0],
601+
[x, 0, 0, x]]
602+
603+
rank 0, step1: [[0, 0, 0, 0],
604+
[0, 0, 0, 0]]
605+
606+
rank 1, step1: [[x, x, 0, 0],
607+
[x, x, 0, 0]]
608+
609+
"""
610+
device = torch.cuda.current_device()
611+
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
612+
kv_indices = (
613+
torch.arange(q_len * size * (ttt_step + 1), device=device)
614+
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
615+
.reshape(-1)
616+
)
617+
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
618+
attn_mask = msk_func(
619+
None,
620+
None,
621+
q_indices.view(1, 1, -1, 1),
622+
kv_indices.view(1, 1, 1, -1),
623+
)
624+
attn_bias = torch.where(
625+
attn_mask,
626+
torch.zeros((), dtype=dtype, device=attn_mask.device),
627+
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
628+
)
629+
630+
return attn_bias
631+
632+
def patched_templated_attn(*args, **kwargs):
633+
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
634+
# Get original attention op
635+
# Sensitive to impl of _templated_ring_attention
636+
original_op = args[2]
637+
638+
# This patch is only enabled for eagle model by context manager, not base model.
639+
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
640+
641+
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
642+
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")
643+
644+
# Unset is_causal to use custom attn mask
645+
if patch_enbabled:
646+
kwargs["is_causal"] = False
647+
648+
def patched_op(*args, **kwargs):
649+
# Inspect the parent frame to get current shard info
650+
# This is sensitive to torch _templated_ring_attention impl
651+
try:
652+
frame: FrameType = inspect.currentframe()
653+
f_back: FrameType = frame.f_back
654+
rank = f_back.f_locals["rank"]
655+
size = f_back.f_locals["size"]
656+
query = f_back.f_locals["query"]
657+
key = f_back.f_locals["key"]
658+
i = f_back.f_locals["i"]
659+
ttt_step = (key.shape[2] // query.shape[2]) - 1
660+
except Exception as e:
661+
raise RuntimeError(
662+
f"Failed to capture loop variables in patched _templated_ring_attention: {e}"
663+
) from e
664+
# Set attn mask to permuted TTT mask
665+
if "attn_bias" in kwargs:
666+
kwargs["attn_bias"] = _get_sharded_ttt_msk(
667+
i, rank, size, query.shape[2], ttt_step, query.dtype
668+
)
669+
# Perform shard attention
670+
return original_op(*args, **kwargs)
671+
672+
return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)
673+
674+
return patched_templated_attn
675+
676+
677+
def patch_ring_attention_for_ttt():
678+
"""Patch torch ring attention to support context parallelism for TTT."""
679+
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
680+
681+
if not (
682+
Version(torch.__version__) > Version("2.7.1")
683+
and Version(torch.__version__) < Version("2.9.0")
684+
):
685+
raise RuntimeError(
686+
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
687+
f"Got {torch.__version__}. "
688+
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
689+
)
690+
691+
# 1. Disable load balance, which is designed for causal mask.
692+
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
693+
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
694+
695+
# 2. Patch templated ring attention for TTT mask.
696+
original_templated_ring_attention = (
697+
torch.distributed.tensor.experimental._attention._templated_ring_attention
698+
)
699+
original_templated_ring_attention_backward = (
700+
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
701+
)
702+
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
703+
get_patched_templated_ring_attn(original_templated_ring_attention)
704+
)
705+
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
706+
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
707+
)
708+
709+
# 3. Patch merger to skip the blank shard to avoid difference in output.
710+
original_sdpa_merger_step = _SDPAMerger.step
711+
712+
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
713+
if lse.sum() <= 0:
714+
return
715+
return original_sdpa_merger_step(self, out, lse, partial)
716+
717+
_SDPAMerger.step = patched_sdpa_merger_step
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"fsdp_version":2}

examples/speculative_decoding/launch_train.sh

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do
7474
if [[ "$1" != *=* ]]; then shift; fi
7575
EAGLE_CONFIG="${1#*=}"
7676
;;
77-
--fsdp_transformer_layer_cls_to_wrap*)
78-
if [[ "$1" != *=* ]]; then shift; fi
79-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
80-
;;
81-
--num_gpu*)
82-
if [[ "$1" != *=* ]]; then shift; fi
83-
NUM_GPU="${1#*=}"
84-
;;
8577
--disable_tqdm*)
8678
if [[ "$1" != *=* ]]; then shift; fi
8779
DISABLE_TQDM="${1#*=}"
@@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do
10294
if [[ "$1" != *=* ]]; then shift; fi
10395
AR_VALIDATE_STEPS="${1#*=}"
10496
;;
97+
--cp_size*)
98+
if [[ "$1" != *=* ]]; then shift; fi
99+
CP_SIZE="${1#*=}"
100+
;;
101+
--dp_size*)
102+
if [[ "$1" != *=* ]]; then shift; fi
103+
DP_SHARD_SIZE="${1#*=}"
104+
;;
105105
*)
106106
>&2 printf "Error: Invalid argument ${1#*=}\n"
107107
exit 1
@@ -129,15 +129,15 @@ LR=${LR:-"1e-4"}
129129
TRAIN_BS=${TRAIN_BS:-4}
130130
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
131131
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
132-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
133-
NUM_GPU=${NUM_GPU:-1}
134132
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
135133
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
136134
DISABLE_TQDM=${DISABLE_TQDM:-False}
137135
VLM_PROCESSOR=${VLM_PROCESSOR:-}
138136
VLM_IMG_DIR=${VLM_IMG_DIR:-}
139137
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
140138
ESTIMATE_AR=${ESTIMATE_AR:-False}
139+
CP_SIZE=${CP_SIZE:-1}
140+
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}
141141

142142
if [[ "$MODE" == "medusa" ]]; then
143143
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -163,21 +163,25 @@ else
163163
OFFLINE_TRAINING_ARGS=""
164164
fi
165165

166-
if [[ "$NUM_GPU" == 1 ]]; then
167-
MULTI_GPU=""
168-
else
169-
MULTI_GPU="--multi_gpu"
170-
fi
171166

172167
if [[ "$VLM_PROCESSOR" != "" ]]; then
173168
VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR"
174169
else
175170
VLM_ARGS=""
176171
fi
177172

173+
if [[ "$GPU_COUNT" -gt 1 ]]; then
174+
#Use FSDP2 when multi GPU available
175+
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json"
176+
else
177+
#Otherwise, single GPU training
178+
FSDP_ARGS=""
179+
fi
180+
181+
178182
# Disable tokenizers parallelism to avoid warning
179183
export TOKENIZERS_PARALLELISM=False
180-
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
184+
CMD="accelerate launch --mixed_precision bf16 main.py \
181185
--mode $MODE \
182186
--eagle_decoder_type $EAGLE_DECODER_TYPE \
183187
--model_name_or_path $MODEL \
@@ -206,6 +210,9 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
206210
$VLM_ARGS \
207211
$OFFLINE_TRAINING_ARGS \
208212
$SPECULATIVE_ARGS \
213+
$FSDP_ARGS \
214+
--cp_size $CP_SIZE \
215+
--dp_shard_size $DP_SHARD_SIZE \
209216
"
210217

211218
start_time=$(date +%s)

examples/speculative_decoding/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636

3737
import torch
3838
import transformers
39-
from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module
39+
from accelerate import ParallelismConfig
40+
from eagle_utils import (
41+
EagleTrainerWithAccLog,
42+
EagleTrainingPlot,
43+
make_eagle_supervised_data_module,
44+
patch_ring_attention_for_ttt,
45+
)
4046
from medusa_utils import make_medusa_supervised_data_module
4147
from transformers.trainer_utils import get_last_checkpoint
4248

@@ -100,6 +106,8 @@ class TrainingArguments(transformers.TrainingArguments):
100106
remove_unused_columns: bool = field(
101107
default=False, metadata={"help": "Set to False to keep extra args for VLM."}
102108
)
109+
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
110+
dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."})
103111

104112

105113
@dataclass
@@ -130,6 +138,13 @@ def train():
130138
model_args, data_args, training_args, medusa_args, eagle_args = (
131139
parser.parse_args_into_dataclasses()
132140
)
141+
training_args.parallelism_config = ParallelismConfig(
142+
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
143+
)
144+
if training_args.cp_size > 1:
145+
patch_ring_attention_for_ttt()
146+
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
147+
training_args.parallelism_config.sp_backend = None
133148
print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}")
134149

135150
# Detecting last checkpoint.
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
1-
flash-attn
2-
openai
3-
py7zr
4-
sentencepiece>=0.2.0
5-
tensorboardX
1+
accelerate==1.12.0
2+
transformers==5.0.0rc1

0 commit comments

Comments
 (0)