Skip to content

Commit bf7cf7d

Browse files
committed
init data loading refactor
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent c6c9905 commit bf7cf7d

File tree

2 files changed

+91
-38
lines changed

2 files changed

+91
-38
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 85 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from modelopt.torch.utils import print_rank_0
3232
from modelopt.torch.utils.distributed import is_master
33+
from modelopt.torch.utils.plugins.transformers_datasetse import LanguageDataCollator, ShardedDataset
3334

3435
try:
3536
import wandb
@@ -227,75 +228,122 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
227228
class OfflineSupervisedDataset(Dataset):
228229
"""Lazy offline dataset for supervised fine-tuning.
229230
230-
This dataset loads data on-the-fly from pre-processed .pt data files as well as
231-
input conversations in JSON format.
231+
This dataset loads data on-the-fly from pre-processed .pt data files.
232232
233233
Args:
234-
data_entries (list): A list of tuples (raw_data_example, file_path).
234+
dumped_files (list): A list of file paths to the dumped .pt files.
235235
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
236236
"""
237237

238238
def __init__(
239239
self,
240-
data_entries,
240+
dumped_files,
241241
tokenizer: transformers.PreTrainedTokenizer,
242242
vlm_processor=None,
243243
img_dir=None,
244244
):
245245
super().__init__()
246246
print_rank_0("Formatting inputs...Skip in offline mode")
247247
self.tokenizer = tokenizer
248-
self.data_entries = data_entries
249-
self.vlm_processor = vlm_processor
250-
self.img_dir = img_dir
251-
self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess
248+
self.dumped_files = dumped_files
249+
# self.vlm_processor = vlm_processor
250+
# self.img_dir = img_dir
251+
# self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess
252252

253253
# Does not cache the hidden states, as those have an extremely large memory footprint.
254254
self.cached_data_dict = {}
255255

256256
def __len__(self):
257-
return len(self.data_entries)
257+
return len(self.dumped_files)
258258

259259
def __getitem__(self, i) -> dict[str, torch.Tensor]:
260260
# Load the conversational data, using the cache
261-
raw_data, offline_file_path = self.data_entries[i]
262261
if i in self.cached_data_dict:
263-
preprocessed_base = self.cached_data_dict[i]
262+
ret = self.cached_data_dict[i]
264263
else:
265-
ret = self.preprocess_fn(
266-
[raw_data], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir
267-
)
268-
preprocessed_base = {k: ret[k][0] for k in ret}
269-
self.cached_data_dict[i] = preprocessed_base
270-
271-
# Extend the data sample with the hidden states from the .pt file
272-
max_length = self.tokenizer.model_max_length
273-
offline_data = torch.load(offline_file_path)
274-
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
275-
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
276-
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
277-
278-
# Make sure the input_ids have the same shape
279-
if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape:
280-
msg = f"""Input IDs from offline data do not match the preprocessed input IDs
281-
for offline data sample at {offline_file_path}."""
282-
raise ValueError(msg)
283-
284-
ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache
285-
ret["input_ids"] = offline_data["input_ids"]
286-
ret["kwargs"] = {
287-
"base_model_outputs": {
288-
"base_model_hidden_states": offline_data["hidden_states"],
289-
"aux_hidden_states": offline_data["aux_hidden_states"],
264+
offline_file_path = self.dumped_files[i]
265+
# Extend the data sample with the hidden states from the .pt file
266+
max_length = self.tokenizer.model_max_length
267+
offline_data = torch.load(offline_file_path)
268+
ret = {
269+
"input_ids": offline_data["input_ids"][:max_length],
270+
"kwargs": {
271+
"base_model_outputs": {
272+
"base_model_hidden_states": offline_data["hidden_states"][:max_length, :],
273+
"aux_hidden_states": offline_data["aux_hidden_states"][:max_length, :],
274+
}
275+
},
290276
}
291-
}
277+
self.cached_data_dict[i] = ret
292278
return ret
293279

294280

295281
def make_eagle_supervised_data_module(
296282
tokenizer: transformers.PreTrainedTokenizer,
297283
data_args,
298284
max_length=None,
285+
) -> dict:
286+
if data_args.offline_data_path is not None:
287+
print_rank_0("Loading pre-processed data for offline training...")
288+
289+
# Glob for all .pt files in the data_path directory
290+
assert data_args.offline_data_path is not None, (
291+
"offline_data_path must be provided for offline training."
292+
)
293+
offline_data_path = Path(data_args.offline_data_path)
294+
all_files = [str(p) for p in offline_data_path.glob("*.pt")]
295+
if not all_files:
296+
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")
297+
298+
# # Filter to conversations that exist in the offline data and in the provided json
299+
# valid_entries = []
300+
# for entry in train_dataset:
301+
# conv_id = entry.get("conversation_id")
302+
# if conv_id is None:
303+
# conv_id = entry.get("uuid")
304+
# if conv_id is None:
305+
# conv_id = entry.get("id")
306+
# if conv_id is None:
307+
# raise ValueError(f"Conversation ID required but not found for entry {entry}")
308+
# file_path = str(offline_data_path / f"{conv_id}.pt")
309+
# if file_path in all_files:
310+
# valid_entries.append((entry, file_path))
311+
312+
# if len(valid_entries) == 0:
313+
# msg = """No valid files found in the offline data path that match the conversation IDs
314+
# in the provided data json. Please ensure that the offline data path is correct and
315+
# contains .pt files named after the conversation IDs, and that the input conversations
316+
# json has the correct format (with 'conversation_id' or 'id' fields)."""
317+
# raise ValueError(msg)
318+
# elif len(valid_entries) < len(data_json):
319+
# print_rank_0(
320+
# f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations"
321+
# " have corresponding .pt files in the offline data path. Continuing..."
322+
# )
323+
324+
train_dataset = OfflineSupervisedDataset(
325+
all_files,
326+
tokenizer=tokenizer,
327+
)
328+
329+
data_collator = DataCollatorForOffline(max_length=max_length)
330+
else:
331+
train_dataset = ShardedDataset("nvidia/Daring-Anteater")
332+
data_collator = LanguageDataCollator(
333+
tokenizer=tokenizer,
334+
max_length=max_length,
335+
)
336+
337+
return {
338+
"train_dataset": train_dataset,
339+
"data_collator": data_collator,
340+
}
341+
342+
343+
def make_eagle_supervised_data_module_old(
344+
tokenizer: transformers.PreTrainedTokenizer,
345+
data_args,
346+
max_length=None,
299347
) -> dict:
300348
"""Make dataset and collator for supervised fine-tuning.
301349

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,12 @@ def forward(
767767
assert past_key_values is None, "past_key_values should be None in training"
768768

769769
if loss_mask is None:
770-
loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device)
770+
# By default, mask out padding tokens in loss computation
771+
loss_mask = (
772+
attention_mask.clone().detach()
773+
if attention_mask is not None
774+
else torch.ones_like(input_ids, dtype=torch.bool)
775+
)
771776

772777
# ====First, we run base model forward====
773778
if "base_model_outputs" in kwargs:

0 commit comments

Comments
 (0)