Skip to content

Commit ca7cba3

Browse files
committed
debug
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent bf7cf7d commit ca7cba3

File tree

2 files changed

+325
-1
lines changed

2 files changed

+325
-1
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from modelopt.torch.utils import print_rank_0
3232
from modelopt.torch.utils.distributed import is_master
33-
from modelopt.torch.utils.plugins.transformers_datasetse import LanguageDataCollator, ShardedDataset
33+
from modelopt.torch.utils.plugins.transformers_dataset import LanguageDataCollator, ShardedDataset
3434

3535
try:
3636
import wandb
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Processing large data to tokenize for pretraining."""
17+
18+
import copy
19+
import itertools
20+
21+
import torch
22+
import transformers
23+
from datasets import load_dataset
24+
from transformers.trainer_pt_utils import LabelSmoother
25+
26+
REMOVE_THINK_CHAT_TEMPLATE = (
27+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
28+
)
29+
30+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
31+
32+
33+
def _sharegpt_to_openai_messages(conversations: list[dict]):
34+
role_mapping = {
35+
"user": "user",
36+
"User": "user",
37+
"human": "user",
38+
"assistant": "assistant",
39+
"Assistant": "assistant",
40+
"gpt": "assistant",
41+
"system": "system",
42+
"System": "system",
43+
}
44+
messages = []
45+
for msg in conversations:
46+
role = role_mapping[msg["from"]]
47+
content = msg["value"]
48+
messages.append({"role": role, "content": content})
49+
return messages
50+
51+
52+
class ShardedDataset(torch.utils.data.Dataset):
53+
"""ShardedDataset is a subclass of torch.utils.data.Dataset that is used to load data from a dataset."""
54+
55+
def __init__(
56+
self,
57+
name: str,
58+
subset: str | None = None,
59+
split: str = "train",
60+
num_shards: int = 1,
61+
shard_index: int = 0,
62+
num_streaming_samples: int | None = None,
63+
):
64+
"""Initialize the ShardedDataset."""
65+
self.name = name
66+
self.subset = subset
67+
self.split = split
68+
self.num_shards = num_shards
69+
self.shard_index = shard_index
70+
self.num_streaming_samples = num_streaming_samples
71+
72+
self._load_dataset()
73+
74+
def __len__(self):
75+
if self.num_streaming_samples is not None:
76+
return self.num_streaming_samples
77+
else:
78+
return len(self._raw_samples)
79+
80+
def __getitem__(self, index):
81+
index = index // self.num_shards
82+
83+
if self.num_streaming_samples is not None:
84+
while index >= len(self._raw_samples):
85+
self._raw_samples.append(next(self._stream_iterator))
86+
87+
return self._raw_samples[index]
88+
89+
def _load_dataset(self):
90+
dataset = load_dataset(
91+
self.name,
92+
self.subset,
93+
split=self.split,
94+
# num_proc=4, # TODO: Make this configurable
95+
streaming=self.num_streaming_samples is not None,
96+
)
97+
98+
shard = dataset.shard(num_shards=self.num_shards, index=self.shard_index)
99+
100+
if self.num_streaming_samples is not None:
101+
self._raw_samples = []
102+
self._stream_samples = shard
103+
self._stream_iterator = itertools.cycle(self._stream_samples)
104+
else:
105+
self._raw_samples = shard
106+
107+
108+
class LanguageDataCollator:
109+
"""LanguageDataCollator is a class that is used to collate language data."""
110+
111+
def __init__(
112+
self,
113+
tokenizer: transformers.PreTrainedTokenizerBase,
114+
max_length: int = 4096,
115+
chat_template: str | None = None,
116+
add_generation_prompt: bool = False,
117+
answer_only_loss: bool = False,
118+
json_key: str = "text",
119+
):
120+
"""Initialize the LanguageDataset."""
121+
if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase):
122+
raise ValueError(
123+
"The tokenizer must be a transformers.PreTrainedTokenizerBase but got {}".format(
124+
type(tokenizer)
125+
)
126+
)
127+
self.tokenizer = tokenizer
128+
self.max_length = max_length
129+
self.add_generation_prompt = add_generation_prompt
130+
self.answer_only_loss = answer_only_loss
131+
self.json_key = json_key
132+
133+
if chat_template is not None:
134+
self.tokenizer.chat_template = chat_template
135+
else:
136+
self._post_process_chat_template()
137+
138+
if self.tokenizer.chat_template is None:
139+
raise ValueError("No valid chat template!")
140+
141+
def _post_process_tokenizer(self):
142+
if hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is None:
143+
if self.tokenizer.eos_token == "<|eot_id|>": # nosec
144+
self.tokenizer.pad_token = "<|end_of_text|>" # nosec
145+
else:
146+
raise ValueError("The tokenizer has no pad_token!")
147+
148+
def _post_process_chat_template(self):
149+
# [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the <think>
150+
# tokens are preserved for supervised learning.
151+
self.tokenizer.chat_template = self.tokenizer.chat_template.replace(
152+
REMOVE_THINK_CHAT_TEMPLATE, ""
153+
)
154+
155+
def _process_chat_sample(self, examples: list):
156+
tokenized_examples = self.tokenizer.apply_chat_template(
157+
examples,
158+
return_tensors="pt",
159+
return_dict=True,
160+
padding="max_length",
161+
truncation=True,
162+
max_length=self.max_length,
163+
add_generation_prompt=self.add_generation_prompt,
164+
return_assistant_tokens_mask=self.answer_only_loss,
165+
)
166+
return tokenized_examples
167+
168+
def _process_text_sample(self, examples: list):
169+
tokenized_examples = self.tokenizer(
170+
examples,
171+
return_tensors="pt",
172+
padding="max_length",
173+
truncation=True,
174+
max_length=self.max_length,
175+
)
176+
return tokenized_examples
177+
178+
def __call__(self, examples):
179+
"""Call the LanguageDataCollator."""
180+
batch = []
181+
182+
for example in examples:
183+
if not isinstance(example, dict):
184+
raise ValueError("The sample must be a Dict but got {}".format(type(example)))
185+
text = example.get(self.json_key, None)
186+
if isinstance(text, str):
187+
batch.append(text)
188+
else:
189+
messages = example.get("messages", None)
190+
if messages is None:
191+
conversations = example.get("conversations", None)
192+
if conversations is None:
193+
raise ValueError(
194+
"The sample must in either OpenAI messages format or ShareGPT conversations format."
195+
)
196+
else:
197+
messages = _sharegpt_to_openai_messages(conversations)
198+
batch.append(messages)
199+
200+
return self._process_chat_sample(batch)
201+
202+
203+
class LanguageDataset(ShardedDataset):
204+
"""LanguageDataset is a subclass of ShardedDataset that is used to load language data."""
205+
206+
def __init__(
207+
self,
208+
tokenizer: transformers.PreTrainedTokenizerBase,
209+
name: str,
210+
subset: str | None = None,
211+
split: str = "train",
212+
num_shards: int = 1,
213+
shard_index: int = 0,
214+
max_length: int = 4096,
215+
chat_template: str | None = None,
216+
add_generation_prompt: bool = False,
217+
answer_only_loss: bool = False,
218+
json_key: str = "text",
219+
):
220+
"""Initialize the LanguageDataset."""
221+
super().__init__(
222+
name=name,
223+
subset=subset,
224+
split=split,
225+
num_shards=num_shards,
226+
shard_index=shard_index,
227+
)
228+
self.collator = LanguageDataCollator(
229+
tokenizer=tokenizer,
230+
max_length=max_length,
231+
chat_template=chat_template,
232+
add_generation_prompt=add_generation_prompt,
233+
answer_only_loss=answer_only_loss,
234+
json_key=json_key,
235+
)
236+
237+
def __getitem__(self, index):
238+
"""Get the item at the given index."""
239+
index = index // self.num_shards
240+
241+
if self.num_streaming_samples is not None:
242+
while index >= len(self._raw_samples):
243+
self._raw_samples.append(next(self._stream_iterator))
244+
245+
return self.collator([self._raw_samples[index]])
246+
247+
248+
class VisionLanguageDataCollator(LanguageDataCollator):
249+
"""VisionLanguageDataCollator is a subclass of LanguageDataCollator that is used to collate vision-language data."""
250+
251+
def __init__(
252+
self,
253+
processor: transformers.ProcessorMixin,
254+
max_length: int = 8192,
255+
chat_template: str | None = None,
256+
add_generation_prompt: bool = False,
257+
answer_only_loss: bool = False,
258+
local_image_path: str | None = None,
259+
):
260+
"""Initialize the VisionLanguageDataset."""
261+
if not isinstance(processor, transformers.ProcessorMixin):
262+
raise ValueError(
263+
"The processor must be a transformers.ProcessorMixin but got {}".format(
264+
type(processor)
265+
)
266+
)
267+
268+
self.processor = processor
269+
self.max_length = max_length
270+
self.chat_template = chat_template
271+
self.add_generation_prompt = add_generation_prompt
272+
self.answer_only_loss = answer_only_loss
273+
self.local_image_path = local_image_path
274+
275+
super().__init__(
276+
tokenizer=self.processor.tokenizer,
277+
max_length=max_length,
278+
chat_template=chat_template,
279+
add_generation_prompt=add_generation_prompt,
280+
answer_only_loss=answer_only_loss,
281+
)
282+
283+
def _process_multimodal_sample(self, examples):
284+
tokenized_messages = self.processor.apply_chat_template(
285+
examples,
286+
tokenize=True,
287+
return_tensors="pt",
288+
return_dict=True,
289+
padding="max_length",
290+
truncation=True,
291+
max_length=self.max_length,
292+
add_generation_prompt=self.add_generation_prompt,
293+
return_assistant_tokens_mask=self.answer_only_loss,
294+
)
295+
return tokenized_messages
296+
297+
def __call__(self, examples):
298+
"""Call the VisionLanguageDataCollator."""
299+
batch = []
300+
301+
for example in examples:
302+
messages = example.get("messages", None)
303+
if messages is None:
304+
# print(example)
305+
conversations = example.get("conversations", None)
306+
if conversations is None:
307+
raise ValueError(
308+
"The sample must in either OpenAI messages format or ShareGPT conversations format."
309+
)
310+
else:
311+
messages = _sharegpt_to_openai_messages(conversations)
312+
313+
copy_messages = copy.deepcopy(messages)
314+
315+
for msg in copy_messages:
316+
if isinstance(msg["content"], str):
317+
msg["content"] = [{"type": "text", "text": msg["content"]}]
318+
for ctn in msg["content"]:
319+
if ctn["type"] == "image" and "path" in ctn:
320+
ctn["path"] = self.local_image_path + "/" + ctn["path"]
321+
322+
batch.append(copy_messages)
323+
324+
return self._process_multimodal_sample(batch)

0 commit comments

Comments
 (0)