Skip to content

Commit 59a0984

Browse files
committed
squash: cp ttt
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 04165ac commit 59a0984

7 files changed

Lines changed: 242 additions & 42 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import json
1718
import os
19+
from collections.abc import Callable
1820
from pathlib import Path
21+
from typing import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from types import FrameType
1925
from typing import Any
2026

2127
import numpy as np
@@ -24,10 +30,13 @@
2430
from datasets import load_dataset
2531
from PIL import Image
2632
from scripts.ar_validate import validate_ar
33+
from torch.distributed.tensor.experimental._attention import _SDPAMerger
2734
from torch.utils.data import Dataset
2835
from transformers import AutoProcessor, Trainer, TrainerCallback
2936
from transformers.trainer_pt_utils import LabelSmoother
3037

38+
import modelopt.torch.speculative.plugins.transformers
39+
from modelopt.torch.speculative.utils import get_ttt_msk_func
3140
from modelopt.torch.utils import print_rank_0
3241
from modelopt.torch.utils.distributed import is_master
3342

@@ -566,3 +575,153 @@ def on_step_end(self, args, state, control, **kwargs):
566575
except Exception:
567576
print_rank_0("AR validation not available.")
568577
return control
578+
579+
580+
def _compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) -> torch.Tensor:
581+
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
582+
583+
msk_func = get_ttt_msk_func(seq_length, ttt_step)
584+
585+
dtypemin = torch.finfo(dtype).min
586+
q_len = seq_length
587+
kv_len = seq_length * (1 + ttt_step)
588+
# Return tensor mask for non-flex attention
589+
tensor_mask = msk_func(
590+
None,
591+
None,
592+
torch.arange(q_len).view(1, 1, q_len, 1),
593+
torch.arange(kv_len).view(1, 1, 1, kv_len),
594+
).to(torch.cuda.current_device())
595+
tensor_mask = torch.full_like(
596+
tensor_mask, 0, dtype=dtype, device=torch.cuda.current_device()
597+
).masked_fill(~tensor_mask, dtypemin)
598+
return tensor_mask
599+
600+
601+
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
602+
"""
603+
Return patched version of
604+
torch.distributed.tensor.experimental._attention._templated_ring_attention
605+
to support TTT.
606+
"""
607+
608+
def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
609+
"""Get chunk-interleaved TTT mask for current rank.
610+
e.g.:
611+
2 ranks, ttt_step=1;
612+
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
613+
[x, 0, 0, 0, 0, x, 0, 0],
614+
[x, x, 0, 0, 0, 0, x, 0],
615+
[x, x, x, 0, 0, 0, 0, x],
616+
617+
rank 0, step0: [[0, 0, x, 0],
618+
[x, 0, 0, x]]
619+
620+
rank 1, step0: [[0, 0, x, 0],
621+
[x, 0, 0, x]]
622+
623+
rank 0, step1: [[0, 0, 0, 0],
624+
[0, 0, 0, 0]]
625+
626+
rank 1, step1: [[x, x, 0, 0],
627+
[x, x, 0, 0]]
628+
629+
"""
630+
device = torch.cuda.current_device()
631+
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
632+
kv_indices = (
633+
torch.arange(q_len * size * (ttt_step + 1), device=device)
634+
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
635+
.reshape(-1)
636+
)
637+
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
638+
attn_mask = msk_func(
639+
None,
640+
None,
641+
q_indices.view(1, 1, -1, 1),
642+
kv_indices.view(1, 1, 1, -1),
643+
)
644+
attn_bias = torch.where(
645+
attn_mask,
646+
torch.zeros((), dtype=dtype, device=attn_mask.device),
647+
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
648+
)
649+
650+
return attn_bias
651+
652+
def patched_templated_attn(*args, **kwargs):
653+
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
654+
# Get original attention op
655+
# Sensitive to impl of _templated_ring_attention
656+
original_op = args[2]
657+
658+
# This patch is only enabled for eagle model by context manager, not base model.
659+
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
660+
661+
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
662+
raise ValueError(f"CP TTT only supports cuddn attention now. Got: {original_op}")
663+
664+
# Unset is_causal to use custom attn mask
665+
if patch_enbabled:
666+
kwargs["is_causal"] = False
667+
668+
def patched_op(*args, **kwargs):
669+
# Inpect the parent frame to get current shard info
670+
# This is sensitive to torch _templated_ring_attention impl
671+
try:
672+
frame: FrameType = inspect.currentframe()
673+
f_back: FrameType = frame.f_back
674+
rank = f_back.f_locals["rank"]
675+
size = f_back.f_locals["size"]
676+
query = f_back.f_locals["query"]
677+
key = f_back.f_locals["key"]
678+
i = f_back.f_locals["i"]
679+
ttt_step = (key.shape[2] // query.shape[2]) - 1
680+
except Exception as e:
681+
print(f"Failed to capture loop variables in patched _templated_ring_attention: {e}")
682+
# Set attn mask to permuted TTT mask
683+
if "attn_bias" in kwargs:
684+
kwargs["attn_bias"] = _get_sharded_ttt_msk(
685+
i, rank, size, query.shape[2], ttt_step, query.dtype
686+
)
687+
# Perform shard attention
688+
return original_op(*args, **kwargs)
689+
690+
return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)
691+
692+
return patched_templated_attn
693+
694+
695+
def patch_ring_attention_for_ttt():
696+
"""Patch torch ring attention to support context parallelism for TTT."""
697+
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
698+
699+
# 1. Disable load balance, which is designed for causal mask.
700+
# This affect how buffers are sharded. So need to be done permenantly before accelerate/hf trainer init.
701+
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
702+
703+
# 2. Patch templated ring attention for TTT mask.
704+
original_templated_ring_attention = (
705+
torch.distributed.tensor.experimental._attention._templated_ring_attention
706+
)
707+
original_templated_ring_attention_backward = (
708+
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
709+
)
710+
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
711+
get_patched_templated_ring_attn(original_templated_ring_attention)
712+
)
713+
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
714+
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
715+
)
716+
717+
# 3. Patch merger to skip the blank shard to avoid difference in output.
718+
original_sdpa_merger_step = _SDPAMerger.step
719+
720+
def patched_sdpa_merger_step(
721+
self, out: torch.Tensor, lse: torch.Tensor, partial: bool
722+
) -> torch.Tensor:
723+
if lse.sum() <= 0:
724+
return
725+
return original_sdpa_merger_step(self, out, lse, partial)
726+
727+
_SDPAMerger.step = patched_sdpa_merger_step
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"fsdp_version":2}

examples/speculative_decoding/launch_train.sh

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do
7474
if [[ "$1" != *=* ]]; then shift; fi
7575
EAGLE_CONFIG="${1#*=}"
7676
;;
77-
--fsdp_transformer_layer_cls_to_wrap*)
78-
if [[ "$1" != *=* ]]; then shift; fi
79-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
80-
;;
81-
--num_gpu*)
82-
if [[ "$1" != *=* ]]; then shift; fi
83-
NUM_GPU="${1#*=}"
84-
;;
8577
--disable_tqdm*)
8678
if [[ "$1" != *=* ]]; then shift; fi
8779
DISABLE_TQDM="${1#*=}"
@@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do
10294
if [[ "$1" != *=* ]]; then shift; fi
10395
AR_VALIDATE_STEPS="${1#*=}"
10496
;;
97+
--cp_size*)
98+
if [[ "$1" != *=* ]]; then shift; fi
99+
CP_SIZE="${1#*=}"
100+
;;
101+
--dp_size*)
102+
if [[ "$1" != *=* ]]; then shift; fi
103+
DP_SHARD_SIZE="${1#*=}"
104+
;;
105105
*)
106106
>&2 printf "Error: Invalid argument ${1#*=}\n"
107107
exit 1
@@ -129,15 +129,15 @@ LR=${LR:-"1e-4"}
129129
TRAIN_BS=${TRAIN_BS:-4}
130130
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
131131
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
132-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
133-
NUM_GPU=${NUM_GPU:-1}
134132
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
135133
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
136134
DISABLE_TQDM=${DISABLE_TQDM:-False}
137135
VLM_PROCESSOR=${VLM_PROCESSOR:-}
138136
VLM_IMG_DIR=${VLM_IMG_DIR:-}
139137
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
140138
ESTIMATE_AR=${ESTIMATE_AR:-False}
139+
CP_SIZE=${CP_SIZE:-1}
140+
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}
141141

142142
if [[ "$MODE" == "medusa" ]]; then
143143
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -163,11 +163,6 @@ else
163163
OFFLINE_TRAINING_ARGS=""
164164
fi
165165

166-
if [[ "$NUM_GPU" == 1 ]]; then
167-
MULTI_GPU=""
168-
else
169-
MULTI_GPU="--multi_gpu"
170-
fi
171166

172167
if [[ "$VLM_PROCESSOR" != "" ]]; then
173168
VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR"
@@ -177,7 +172,7 @@ fi
177172

178173
# Disable tokenizers parallelism to avoid warning
179174
export TOKENIZERS_PARALLELISM=False
180-
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
175+
CMD="accelerate launch --mixed_precision bf16 main.py \
181176
--mode $MODE \
182177
--eagle_decoder_type $EAGLE_DECODER_TYPE \
183178
--model_name_or_path $MODEL \
@@ -206,6 +201,10 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
206201
$VLM_ARGS \
207202
$OFFLINE_TRAINING_ARGS \
208203
$SPECULATIVE_ARGS \
204+
--fsdp 'full_shard' \
205+
--fsdp_config fsdp_config.json \
206+
--cp_size $CP_SIZE \
207+
--dp_shard_size $DP_SHARD_SIZE \
209208
"
210209

211210
start_time=$(date +%s)

examples/speculative_decoding/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636

3737
import torch
3838
import transformers
39-
from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module
39+
from accelerate import ParallelismConfig
40+
from eagle_utils import (
41+
EagleTrainerWithAccLog,
42+
EagleTrainingPlot,
43+
make_eagle_supervised_data_module,
44+
patch_ring_attention_for_ttt,
45+
)
4046
from medusa_utils import make_medusa_supervised_data_module
4147
from transformers.trainer_utils import get_last_checkpoint
4248

@@ -100,6 +106,8 @@ class TrainingArguments(transformers.TrainingArguments):
100106
remove_unused_columns: bool = field(
101107
default=False, metadata={"help": "Set to False to keep extra args for VLM."}
102108
)
109+
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
110+
dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."})
103111

104112

105113
@dataclass
@@ -130,6 +138,13 @@ def train():
130138
model_args, data_args, training_args, medusa_args, eagle_args = (
131139
parser.parse_args_into_dataclasses()
132140
)
141+
training_args.parallelism_config = ParallelismConfig(
142+
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
143+
)
144+
if training_args.cp_size > 1:
145+
patch_ring_attention_for_ttt()
146+
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
147+
training_args.parallelism_config.sp_backend = None
133148
print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}")
134149

135150
# Detecting last checkpoint.
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
flash-attn
2-
openai
3-
py7zr
4-
sentencepiece>=0.2.0
5-
tensorboardX
1+
accelerate==1.12.0
2+
torch==2.8.0
3+
transformers==5.0.0rc1
4+
wandb

0 commit comments

Comments
 (0)