Skip to content

Commit 07ddfbe

Browse files
committed
refactor offline data loading
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 0cc817f commit 07ddfbe

File tree

7 files changed

+81
-110
lines changed

7 files changed

+81
-110
lines changed

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919

2020
python3 collect_hidden_states/compute_hidden_states_hf.py \
2121
--model meta-llama/Llama-3.2-1B-Instruct \
22-
--input-file synthetic_conversations/daring-anteater.jsonl \
22+
--input-data synthetic_conversations/daring-anteater.jsonl \
2323
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI
3030

3131
for i in $(seq 0 $((DP_SIZE-1)))
3232
do
33-
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
33+
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
3434
done
3535
wait
3636

examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
export TLLM_LOG_LEVEL="error";
2121
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
2222
--model meta-llama/Llama-3.2-1B-Instruct \
23-
--input-file synthetic_conversations/daring-anteater.jsonl \
23+
--input-data synthetic_conversations/daring-anteater.jsonl \
2424
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
2525

examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI
3333
for i in $(seq 0 $((DP_SIZE-1)))
3434
do
3535

36-
export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
36+
export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
3737

3838
done
3939
wait

examples/speculative_decoding/eagle_utils.py

Lines changed: 68 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.distributed.tensor.experimental._attention import _SDPAMerger
3232
from torch.utils.data import Dataset
3333
from transformers import Trainer, TrainerCallback
34+
from transformers.trainer_pt_utils import LabelSmoother
3435

3536
import modelopt
3637
from modelopt.torch.speculative.utils import get_ttt_msk_func
@@ -43,6 +44,8 @@
4344
except ImportError:
4445
wandb = None
4546

47+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
48+
4649

4750
class OfflineSupervisedDataset(Dataset):
4851
"""Offline dataset for supervised fine-tuning.
@@ -51,40 +54,84 @@ class OfflineSupervisedDataset(Dataset):
5154
5255
Args:
5356
dumped_files (list): A list of file paths to the dumped .pt files.
54-
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
5557
"""
5658

5759
def __init__(
5860
self,
5961
dumped_files,
60-
tokenizer: transformers.PreTrainedTokenizer,
6162
):
6263
super().__init__()
63-
print_rank_0("Formatting inputs...Skip in offline mode")
64-
self.tokenizer = tokenizer
6564
self.dumped_files = dumped_files
6665

6766
def __len__(self):
6867
return len(self.dumped_files)
6968

7069
def __getitem__(self, i) -> dict[str, torch.Tensor]:
71-
# Load the conversational data, using the cache
72-
offline_file_path = self.dumped_files[i]
73-
# Extend the data sample with the hidden states from the .pt file
74-
max_length = self.tokenizer.model_max_length
75-
offline_data = torch.load(offline_file_path)
70+
offline_data = torch.load(self.dumped_files[i])
71+
72+
labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID)
73+
labels[..., :-1] = offline_data["input_ids"][..., 1:]
74+
7675
ret = {
77-
"input_ids": offline_data["input_ids"][:max_length],
78-
"kwargs": {
79-
"base_model_outputs": {
80-
"base_model_hidden_states": offline_data["hidden_states"][:max_length, :],
81-
"aux_hidden_states": offline_data["aux_hidden_states"][:max_length, :],
82-
}
83-
},
76+
"input_ids": offline_data["input_ids"],
77+
"base_model_hidden_states": offline_data["hidden_states"],
78+
"aux_hidden_states": offline_data["aux_hidden_states"],
79+
"attention_mask": torch.ones_like(offline_data["input_ids"]),
80+
"loss_mask": torch.ones_like(offline_data["input_ids"]),
81+
"labels": labels,
8482
}
8583
return ret
8684

8785

86+
class EagleOfflineDataCollator:
87+
"""Data collator that truncate or pads data for offline training."""
88+
89+
def __init__(self, max_length):
90+
self.max_length = max_length
91+
92+
def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0):
93+
"""Pad or truncate a tensor to length along a given dimension."""
94+
dim = dim % x.ndim # support negative dimension
95+
96+
# allocate output tensor
97+
out_shape = list(x.shape)
98+
out_shape[dim] = length
99+
out = x.new_zeros(out_shape)
100+
101+
# consturct copy slice
102+
slc = [slice(None)] * x.ndim
103+
slc[dim] = slice(0, min(length, x.size(dim)))
104+
105+
# populate output tensor
106+
out[tuple(slc)] = x[tuple(slc)]
107+
return out
108+
109+
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
110+
base_batch = {
111+
k: torch.stack([self._pad_or_truncate(item[k], self.max_length) for item in features])
112+
for k in ["input_ids", "attention_mask", "loss_mask", "labels"]
113+
}
114+
115+
base_model_outputs = {
116+
k: torch.stack([self._pad_or_truncate(item[k], self.max_length) for item in features])
117+
for k in ["base_model_hidden_states", "aux_hidden_states"]
118+
}
119+
120+
batch = {
121+
**base_batch,
122+
"base_model_outputs": base_model_outputs,
123+
}
124+
125+
# NOTE: vlm does not support offline data yet.
126+
# # Collate VLM data
127+
# if "pixel_values" in features[0]:
128+
# # pixel values and image flags should be flattened inside a batch
129+
# batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0)
130+
# batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0)
131+
132+
return batch
133+
134+
88135
def make_eagle_supervised_data_module(
89136
tokenizer: transformers.PreTrainedTokenizer,
90137
data_args,
@@ -93,23 +140,18 @@ def make_eagle_supervised_data_module(
93140
if data_args.offline_data_path is not None:
94141
print_rank_0("Loading pre-processed data for offline training...")
95142

96-
# Glob for all .pt files in the data_path directory
97143
assert data_args.offline_data_path is not None, (
98144
"offline_data_path must be provided for offline training."
99145
)
100146
offline_data_path = Path(data_args.offline_data_path)
101-
all_files = [str(p) for p in offline_data_path.glob("*.pt")]
102-
if not all_files:
147+
dumped_files = [str(p) for p in offline_data_path.glob("*.pt")]
148+
if not dumped_files:
103149
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")
104150

105-
train_dataset = OfflineSupervisedDataset(
106-
all_files,
107-
tokenizer=tokenizer,
108-
)
109-
110-
data_collator = DataCollatorForOffline(max_length=max_length)
151+
train_dataset = OfflineSupervisedDataset(dumped_files)
152+
data_collator = EagleOfflineDataCollator(max_length=max_length)
111153
else:
112-
train_dataset = ShardedDataset("nvidia/Daring-Anteater")
154+
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
113155
data_collator = LanguageDataCollator(
114156
tokenizer=tokenizer,
115157
max_length=max_length,
@@ -122,85 +164,6 @@ def make_eagle_supervised_data_module(
122164
}
123165

124166

125-
class DataCollatorWithPadding:
126-
def __init__(self, max_length):
127-
self.max_length = max_length
128-
129-
def paddingtensor2d(self, intensors, length):
130-
n, dim = intensors.shape
131-
if n > length:
132-
return intensors[:length, :]
133-
padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
134-
outtensors = torch.cat((intensors, padding_tensor))
135-
return outtensors
136-
137-
def paddingtensor(self, intensors, length):
138-
if intensors.shape[0] > length:
139-
return intensors[:length]
140-
padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype)
141-
outtensors = torch.cat((intensors, padding_tensor))
142-
return outtensors
143-
144-
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
145-
batch_input_ids = torch.stack(
146-
[self.paddingtensor(item["input_ids"], self.max_length) for item in features]
147-
)
148-
batch_attention_mask = torch.stack(
149-
[self.paddingtensor(item["attention_mask"], self.max_length) for item in features]
150-
)
151-
batch_loss_mask = torch.stack(
152-
[self.paddingtensor(item["loss_mask"], self.max_length) for item in features]
153-
)
154-
155-
batch_labels = torch.stack(
156-
[self.paddingtensor(item["labels"], self.max_length) for item in features]
157-
)
158-
159-
batch = {
160-
"input_ids": batch_input_ids,
161-
"attention_mask": batch_attention_mask,
162-
"loss_mask": batch_loss_mask,
163-
"labels": batch_labels,
164-
}
165-
166-
# Collate VLM data
167-
if "pixel_values" in features[0]:
168-
# pixel values and image flags should be flattened inside a batch
169-
batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0)
170-
batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0)
171-
172-
return batch
173-
174-
175-
class DataCollatorForOffline(DataCollatorWithPadding):
176-
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
177-
base_batch = super().__call__(features)
178-
if "kwargs" not in features[0]:
179-
raise ValueError("No kwargs found in batch features. Offline data required.")
180-
181-
features = [item["kwargs"]["base_model_outputs"] for item in features]
182-
183-
batch_hidden_states = torch.stack(
184-
[
185-
self.paddingtensor2d(item["base_model_hidden_states"], self.max_length)
186-
for item in features
187-
]
188-
)
189-
batch_aux_hidden_states = torch.stack(
190-
[self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features]
191-
)
192-
193-
batch = {
194-
**base_batch,
195-
"base_model_outputs": {
196-
"base_model_hidden_states": batch_hidden_states,
197-
"aux_hidden_states": batch_aux_hidden_states,
198-
},
199-
}
200-
201-
return batch
202-
203-
204167
class EagleTrainerWithAccLog(Trainer):
205168
"""Wrapper around Trainer that logs training accuracy."""
206169

examples/speculative_decoding/launch_train.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ while [ $# -gt 0 ]; do
102102
if [[ "$1" != *=* ]]; then shift; fi
103103
DP_SHARD_SIZE="${1#*=}"
104104
;;
105+
--log_steps*)
106+
if [[ "$1" != *=* ]]; then shift; fi
107+
LOG_STEPS="${1#*=}"
108+
;;
105109
*)
106110
>&2 printf "Error: Invalid argument ${1#*=}\n"
107111
exit 1
@@ -138,6 +142,7 @@ AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
138142
ESTIMATE_AR=${ESTIMATE_AR:-False}
139143
CP_SIZE=${CP_SIZE:-1}
140144
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}
145+
LOG_STEPS=${LOG_STEPS:-100}
141146

142147
if [[ "$MODE" == "medusa" ]]; then
143148
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -201,7 +206,7 @@ CMD="accelerate launch --mixed_precision bf16 main.py \
201206
--weight_decay 0.0 \
202207
--warmup_steps 100 \
203208
--lr_scheduler_type linear \
204-
--logging_steps 100 \
209+
--logging_steps $LOG_STEPS \
205210
--tf32 True \
206211
--data_path $DATA \
207212
--disable_tqdm $DISABLE_TQDM \

modelopt/torch/utils/plugins/transformers_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
self,
5858
name: str,
5959
subset: str | None = None,
60+
data_files: str | None = None,
6061
split: str = "train",
6162
num_shards: int = 1,
6263
shard_index: int = 0,
@@ -66,6 +67,7 @@ def __init__(
6667
self.name = name
6768
self.subset = subset
6869
self.split = split
70+
self.data_files = data_files
6971
self.num_shards = num_shards
7072
self.shard_index = shard_index
7173
self.num_streaming_samples = num_streaming_samples
@@ -91,6 +93,7 @@ def _load_dataset(self):
9193
dataset = load_dataset(
9294
self.name,
9395
self.subset,
96+
data_files=self.data_files,
9497
split=self.split,
9598
# num_proc=4, # TODO: Make this configurable
9699
streaming=self.num_streaming_samples is not None,

0 commit comments

Comments
 (0)