3131from torch .distributed .tensor .experimental ._attention import _SDPAMerger
3232from torch .utils .data import Dataset
3333from transformers import Trainer , TrainerCallback
34+ from transformers .trainer_pt_utils import LabelSmoother
3435
3536import modelopt
3637from modelopt .torch .speculative .utils import get_ttt_msk_func
4344except ImportError :
4445 wandb = None
4546
47+ IGNORE_TOKEN_ID = LabelSmoother .ignore_index
48+
4649
4750class OfflineSupervisedDataset (Dataset ):
4851 """Offline dataset for supervised fine-tuning.
@@ -51,40 +54,84 @@ class OfflineSupervisedDataset(Dataset):
5154
5255 Args:
5356 dumped_files (list): A list of file paths to the dumped .pt files.
54- tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
5557 """
5658
5759 def __init__ (
5860 self ,
5961 dumped_files ,
60- tokenizer : transformers .PreTrainedTokenizer ,
6162 ):
6263 super ().__init__ ()
63- print_rank_0 ("Formatting inputs...Skip in offline mode" )
64- self .tokenizer = tokenizer
6564 self .dumped_files = dumped_files
6665
6766 def __len__ (self ):
6867 return len (self .dumped_files )
6968
7069 def __getitem__ (self , i ) -> dict [str , torch .Tensor ]:
71- # Load the conversational data, using the cache
72- offline_file_path = self . dumped_files [ i ]
73- # Extend the data sample with the hidden states from the .pt file
74- max_length = self . tokenizer . model_max_length
75- offline_data = torch . load ( offline_file_path )
70+ offline_data = torch . load ( self . dumped_files [ i ])
71+
72+ labels = torch . full_like ( offline_data [ "input_ids" ], IGNORE_TOKEN_ID )
73+ labels [..., : - 1 ] = offline_data [ "input_ids" ][..., 1 :]
74+
7675 ret = {
77- "input_ids" : offline_data ["input_ids" ][:max_length ],
78- "kwargs" : {
79- "base_model_outputs" : {
80- "base_model_hidden_states" : offline_data ["hidden_states" ][:max_length , :],
81- "aux_hidden_states" : offline_data ["aux_hidden_states" ][:max_length , :],
82- }
83- },
76+ "input_ids" : offline_data ["input_ids" ],
77+ "base_model_hidden_states" : offline_data ["hidden_states" ],
78+ "aux_hidden_states" : offline_data ["aux_hidden_states" ],
79+ "attention_mask" : torch .ones_like (offline_data ["input_ids" ]),
80+ "loss_mask" : torch .ones_like (offline_data ["input_ids" ]),
81+ "labels" : labels ,
8482 }
8583 return ret
8684
8785
86+ class EagleOfflineDataCollator :
87+ """Data collator that truncate or pads data for offline training."""
88+
89+ def __init__ (self , max_length ):
90+ self .max_length = max_length
91+
92+ def _pad_or_truncate (self , x : torch .Tensor , length : int , dim : int = 0 ):
93+ """Pad or truncate a tensor to length along a given dimension."""
94+ dim = dim % x .ndim # support negative dimension
95+
96+ # allocate output tensor
97+ out_shape = list (x .shape )
98+ out_shape [dim ] = length
99+ out = x .new_zeros (out_shape )
100+
101+ # consturct copy slice
102+ slc = [slice (None )] * x .ndim
103+ slc [dim ] = slice (0 , min (length , x .size (dim )))
104+
105+ # populate output tensor
106+ out [tuple (slc )] = x [tuple (slc )]
107+ return out
108+
109+ def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
110+ base_batch = {
111+ k : torch .stack ([self ._pad_or_truncate (item [k ], self .max_length ) for item in features ])
112+ for k in ["input_ids" , "attention_mask" , "loss_mask" , "labels" ]
113+ }
114+
115+ base_model_outputs = {
116+ k : torch .stack ([self ._pad_or_truncate (item [k ], self .max_length ) for item in features ])
117+ for k in ["base_model_hidden_states" , "aux_hidden_states" ]
118+ }
119+
120+ batch = {
121+ ** base_batch ,
122+ "base_model_outputs" : base_model_outputs ,
123+ }
124+
125+ # NOTE: vlm does not support offline data yet.
126+ # # Collate VLM data
127+ # if "pixel_values" in features[0]:
128+ # # pixel values and image flags should be flattened inside a batch
129+ # batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0)
130+ # batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0)
131+
132+ return batch
133+
134+
88135def make_eagle_supervised_data_module (
89136 tokenizer : transformers .PreTrainedTokenizer ,
90137 data_args ,
@@ -93,23 +140,18 @@ def make_eagle_supervised_data_module(
93140 if data_args .offline_data_path is not None :
94141 print_rank_0 ("Loading pre-processed data for offline training..." )
95142
96- # Glob for all .pt files in the data_path directory
97143 assert data_args .offline_data_path is not None , (
98144 "offline_data_path must be provided for offline training."
99145 )
100146 offline_data_path = Path (data_args .offline_data_path )
101- all_files = [str (p ) for p in offline_data_path .glob ("*.pt" )]
102- if not all_files :
147+ dumped_files = [str (p ) for p in offline_data_path .glob ("*.pt" )]
148+ if not dumped_files :
103149 raise ValueError (f"No .pt files found in { data_args .offline_data_path } " )
104150
105- train_dataset = OfflineSupervisedDataset (
106- all_files ,
107- tokenizer = tokenizer ,
108- )
109-
110- data_collator = DataCollatorForOffline (max_length = max_length )
151+ train_dataset = OfflineSupervisedDataset (dumped_files )
152+ data_collator = EagleOfflineDataCollator (max_length = max_length )
111153 else :
112- train_dataset = ShardedDataset ("nvidia/Daring-Anteater" )
154+ train_dataset = ShardedDataset ("json" , data_files = data_args . data_path )
113155 data_collator = LanguageDataCollator (
114156 tokenizer = tokenizer ,
115157 max_length = max_length ,
@@ -122,85 +164,6 @@ def make_eagle_supervised_data_module(
122164 }
123165
124166
125- class DataCollatorWithPadding :
126- def __init__ (self , max_length ):
127- self .max_length = max_length
128-
129- def paddingtensor2d (self , intensors , length ):
130- n , dim = intensors .shape
131- if n > length :
132- return intensors [:length , :]
133- padding_tensor = torch .zeros (length - n , dim , dtype = intensors .dtype )
134- outtensors = torch .cat ((intensors , padding_tensor ))
135- return outtensors
136-
137- def paddingtensor (self , intensors , length ):
138- if intensors .shape [0 ] > length :
139- return intensors [:length ]
140- padding_tensor = torch .zeros (length - intensors .shape [0 ], dtype = intensors .dtype )
141- outtensors = torch .cat ((intensors , padding_tensor ))
142- return outtensors
143-
144- def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
145- batch_input_ids = torch .stack (
146- [self .paddingtensor (item ["input_ids" ], self .max_length ) for item in features ]
147- )
148- batch_attention_mask = torch .stack (
149- [self .paddingtensor (item ["attention_mask" ], self .max_length ) for item in features ]
150- )
151- batch_loss_mask = torch .stack (
152- [self .paddingtensor (item ["loss_mask" ], self .max_length ) for item in features ]
153- )
154-
155- batch_labels = torch .stack (
156- [self .paddingtensor (item ["labels" ], self .max_length ) for item in features ]
157- )
158-
159- batch = {
160- "input_ids" : batch_input_ids ,
161- "attention_mask" : batch_attention_mask ,
162- "loss_mask" : batch_loss_mask ,
163- "labels" : batch_labels ,
164- }
165-
166- # Collate VLM data
167- if "pixel_values" in features [0 ]:
168- # pixel values and image flags should be flattened inside a batch
169- batch ["pixel_values" ] = torch .cat ([item ["pixel_values" ] for item in features ], dim = 0 )
170- batch ["image_flags" ] = torch .cat ([item ["image_flags" ] for item in features ], dim = 0 )
171-
172- return batch
173-
174-
175- class DataCollatorForOffline (DataCollatorWithPadding ):
176- def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
177- base_batch = super ().__call__ (features )
178- if "kwargs" not in features [0 ]:
179- raise ValueError ("No kwargs found in batch features. Offline data required." )
180-
181- features = [item ["kwargs" ]["base_model_outputs" ] for item in features ]
182-
183- batch_hidden_states = torch .stack (
184- [
185- self .paddingtensor2d (item ["base_model_hidden_states" ], self .max_length )
186- for item in features
187- ]
188- )
189- batch_aux_hidden_states = torch .stack (
190- [self .paddingtensor2d (item ["aux_hidden_states" ], self .max_length ) for item in features ]
191- )
192-
193- batch = {
194- ** base_batch ,
195- "base_model_outputs" : {
196- "base_model_hidden_states" : batch_hidden_states ,
197- "aux_hidden_states" : batch_aux_hidden_states ,
198- },
199- }
200-
201- return batch
202-
203-
204167class EagleTrainerWithAccLog (Trainer ):
205168 """Wrapper around Trainer that logs training accuracy."""
206169
0 commit comments