|
30 | 30 |
|
31 | 31 | from modelopt.torch.utils import print_rank_0 |
32 | 32 | from modelopt.torch.utils.distributed import is_master |
| 33 | +from modelopt.torch.utils.plugins.transformers_datasetse import LanguageDataCollator, ShardedDataset |
33 | 34 |
|
34 | 35 | try: |
35 | 36 | import wandb |
@@ -227,75 +228,122 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]: |
227 | 228 | class OfflineSupervisedDataset(Dataset): |
228 | 229 | """Lazy offline dataset for supervised fine-tuning. |
229 | 230 |
|
230 | | - This dataset loads data on-the-fly from pre-processed .pt data files as well as |
231 | | - input conversations in JSON format. |
| 231 | + This dataset loads data on-the-fly from pre-processed .pt data files. |
232 | 232 |
|
233 | 233 | Args: |
234 | | - data_entries (list): A list of tuples (raw_data_example, file_path). |
| 234 | + dumped_files (list): A list of file paths to the dumped .pt files. |
235 | 235 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. |
236 | 236 | """ |
237 | 237 |
|
238 | 238 | def __init__( |
239 | 239 | self, |
240 | | - data_entries, |
| 240 | + dumped_files, |
241 | 241 | tokenizer: transformers.PreTrainedTokenizer, |
242 | 242 | vlm_processor=None, |
243 | 243 | img_dir=None, |
244 | 244 | ): |
245 | 245 | super().__init__() |
246 | 246 | print_rank_0("Formatting inputs...Skip in offline mode") |
247 | 247 | self.tokenizer = tokenizer |
248 | | - self.data_entries = data_entries |
249 | | - self.vlm_processor = vlm_processor |
250 | | - self.img_dir = img_dir |
251 | | - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess |
| 248 | + self.dumped_files = dumped_files |
| 249 | + # self.vlm_processor = vlm_processor |
| 250 | + # self.img_dir = img_dir |
| 251 | + # self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess |
252 | 252 |
|
253 | 253 | # Does not cache the hidden states, as those have an extremely large memory footprint. |
254 | 254 | self.cached_data_dict = {} |
255 | 255 |
|
256 | 256 | def __len__(self): |
257 | | - return len(self.data_entries) |
| 257 | + return len(self.dumped_files) |
258 | 258 |
|
259 | 259 | def __getitem__(self, i) -> dict[str, torch.Tensor]: |
260 | 260 | # Load the conversational data, using the cache |
261 | | - raw_data, offline_file_path = self.data_entries[i] |
262 | 261 | if i in self.cached_data_dict: |
263 | | - preprocessed_base = self.cached_data_dict[i] |
| 262 | + ret = self.cached_data_dict[i] |
264 | 263 | else: |
265 | | - ret = self.preprocess_fn( |
266 | | - [raw_data], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir |
267 | | - ) |
268 | | - preprocessed_base = {k: ret[k][0] for k in ret} |
269 | | - self.cached_data_dict[i] = preprocessed_base |
270 | | - |
271 | | - # Extend the data sample with the hidden states from the .pt file |
272 | | - max_length = self.tokenizer.model_max_length |
273 | | - offline_data = torch.load(offline_file_path) |
274 | | - offline_data["input_ids"] = offline_data["input_ids"][:max_length] |
275 | | - offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] |
276 | | - offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] |
277 | | - |
278 | | - # Make sure the input_ids have the same shape |
279 | | - if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape: |
280 | | - msg = f"""Input IDs from offline data do not match the preprocessed input IDs |
281 | | - for offline data sample at {offline_file_path}.""" |
282 | | - raise ValueError(msg) |
283 | | - |
284 | | - ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache |
285 | | - ret["input_ids"] = offline_data["input_ids"] |
286 | | - ret["kwargs"] = { |
287 | | - "base_model_outputs": { |
288 | | - "base_model_hidden_states": offline_data["hidden_states"], |
289 | | - "aux_hidden_states": offline_data["aux_hidden_states"], |
| 264 | + offline_file_path = self.dumped_files[i] |
| 265 | + # Extend the data sample with the hidden states from the .pt file |
| 266 | + max_length = self.tokenizer.model_max_length |
| 267 | + offline_data = torch.load(offline_file_path) |
| 268 | + ret = { |
| 269 | + "input_ids": offline_data["input_ids"][:max_length], |
| 270 | + "kwargs": { |
| 271 | + "base_model_outputs": { |
| 272 | + "base_model_hidden_states": offline_data["hidden_states"][:max_length, :], |
| 273 | + "aux_hidden_states": offline_data["aux_hidden_states"][:max_length, :], |
| 274 | + } |
| 275 | + }, |
290 | 276 | } |
291 | | - } |
| 277 | + self.cached_data_dict[i] = ret |
292 | 278 | return ret |
293 | 279 |
|
294 | 280 |
|
295 | 281 | def make_eagle_supervised_data_module( |
296 | 282 | tokenizer: transformers.PreTrainedTokenizer, |
297 | 283 | data_args, |
298 | 284 | max_length=None, |
| 285 | +) -> dict: |
| 286 | + if data_args.offline_data_path is not None: |
| 287 | + print_rank_0("Loading pre-processed data for offline training...") |
| 288 | + |
| 289 | + # Glob for all .pt files in the data_path directory |
| 290 | + assert data_args.offline_data_path is not None, ( |
| 291 | + "offline_data_path must be provided for offline training." |
| 292 | + ) |
| 293 | + offline_data_path = Path(data_args.offline_data_path) |
| 294 | + all_files = [str(p) for p in offline_data_path.glob("*.pt")] |
| 295 | + if not all_files: |
| 296 | + raise ValueError(f"No .pt files found in {data_args.offline_data_path}") |
| 297 | + |
| 298 | + # # Filter to conversations that exist in the offline data and in the provided json |
| 299 | + # valid_entries = [] |
| 300 | + # for entry in train_dataset: |
| 301 | + # conv_id = entry.get("conversation_id") |
| 302 | + # if conv_id is None: |
| 303 | + # conv_id = entry.get("uuid") |
| 304 | + # if conv_id is None: |
| 305 | + # conv_id = entry.get("id") |
| 306 | + # if conv_id is None: |
| 307 | + # raise ValueError(f"Conversation ID required but not found for entry {entry}") |
| 308 | + # file_path = str(offline_data_path / f"{conv_id}.pt") |
| 309 | + # if file_path in all_files: |
| 310 | + # valid_entries.append((entry, file_path)) |
| 311 | + |
| 312 | + # if len(valid_entries) == 0: |
| 313 | + # msg = """No valid files found in the offline data path that match the conversation IDs |
| 314 | + # in the provided data json. Please ensure that the offline data path is correct and |
| 315 | + # contains .pt files named after the conversation IDs, and that the input conversations |
| 316 | + # json has the correct format (with 'conversation_id' or 'id' fields).""" |
| 317 | + # raise ValueError(msg) |
| 318 | + # elif len(valid_entries) < len(data_json): |
| 319 | + # print_rank_0( |
| 320 | + # f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations" |
| 321 | + # " have corresponding .pt files in the offline data path. Continuing..." |
| 322 | + # ) |
| 323 | + |
| 324 | + train_dataset = OfflineSupervisedDataset( |
| 325 | + all_files, |
| 326 | + tokenizer=tokenizer, |
| 327 | + ) |
| 328 | + |
| 329 | + data_collator = DataCollatorForOffline(max_length=max_length) |
| 330 | + else: |
| 331 | + train_dataset = ShardedDataset("nvidia/Daring-Anteater") |
| 332 | + data_collator = LanguageDataCollator( |
| 333 | + tokenizer=tokenizer, |
| 334 | + max_length=max_length, |
| 335 | + ) |
| 336 | + |
| 337 | + return { |
| 338 | + "train_dataset": train_dataset, |
| 339 | + "data_collator": data_collator, |
| 340 | + } |
| 341 | + |
| 342 | + |
| 343 | +def make_eagle_supervised_data_module_old( |
| 344 | + tokenizer: transformers.PreTrainedTokenizer, |
| 345 | + data_args, |
| 346 | + max_length=None, |
299 | 347 | ) -> dict: |
300 | 348 | """Make dataset and collator for supervised fine-tuning. |
301 | 349 |
|
|
0 commit comments