1010 DatasetConfig ,
1111 FTDPTokenizeFnConfig ,
1212 build_dataloader ,
13- build_datasets ,
14- get_dataloader_state ,
15- load_dataloader_state ,
1613)
1714from xtuner .v1 .train .toy_tokenizer import UTF8ByteTokenizer
1815from torch .multiprocessing import spawn
@@ -197,25 +194,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
197194 dataset_configs = [
198195 {
199196 "dataset" : DatasetConfig (anno_path = str (data_dir1 )),
200- "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 )
197+ "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 ),
201198 },
202199 ]
203200
204201 dataloader_config = DataloaderConfig (
202+ dataset_config_list = dataset_configs ,
205203 pack_max_length = 1024 ,
206204 pack_level = pack_level ,
207205 num_workers = num_workers ,
208206 group_by_length = group_by_length ,
209207 pack_workers = pack_workers ,
210208 )
211209
212- datasets = build_datasets (
213- dataset_config = dataset_configs ,
210+ dataloader1 = dataloader_config .build (
214211 tokenizer = tokenizer ,
215- )
216- dataloader1 = build_dataloader (
217- dataloader_config = dataloader_config ,
218- datasets = datasets ,
212+ dp_mesh = None ,
219213 global_batch_size = GLOBAL_BATCH_SIZE ,
220214 micro_batch_size = BATCH_SIZE ,
221215 seed = 10 ,
@@ -225,26 +219,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
225219 assert len (dataloader1 ) > 10
226220
227221 dataloader_iter = iter (dataloader1 )
228- consumed_sample = 0
229222 for _ in range (RESUME_ITER ):
230- batch = next (dataloader_iter )
231- consumed_sample += len (batch )
223+ next (dataloader_iter )
232224
233- dataloader_state = get_dataloader_state ( dataloader1 , consumed_sample )
225+ dataloader_state = dataloader1 . get_state_dict ( )
234226 expected_data = []
235227 for _ in range (AFTER_RESUME_ITER ):
236- batch = next (dataloader_iter )
237- consumed_sample += len (batch )
238- expected_data .append (batch )
228+ expected_data .append (next (dataloader_iter ))
239229
240- new_dataloader1 = build_dataloader (
241- dataloader_config = dataloader_config ,
242- datasets = datasets ,
230+ new_dataloader1 = dataloader_config . build (
231+ tokenizer = tokenizer ,
232+ dp_mesh = None ,
243233 global_batch_size = GLOBAL_BATCH_SIZE ,
244234 micro_batch_size = BATCH_SIZE ,
245235 seed = 10 ,
246236 )
247- load_dataloader_state ( new_dataloader1 , dataloader_state )
237+ new_dataloader1 . load_state_dict ( dataloader_state )
248238 new_dataloader_iter = iter (new_dataloader1 )
249239
250240 resume_data = []
@@ -257,32 +247,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
257247 # 2. Test resume after consuming multiple epochs
258248 while True :
259249 try :
260- batch = next (dataloader_iter )
261- consumed_sample += len (batch )
250+ next (dataloader_iter )
262251 except StopIteration :
263252 break
264253
265-
266254 dataloader_iter = iter (dataloader1 )
267255
268- for batch in range (RESUME_ITER ):
269- batch = next (dataloader_iter )
270- consumed_sample += len (batch )
256+ for _ in range (RESUME_ITER ):
257+ next (dataloader_iter )
271258
272- dataloader_state = get_dataloader_state ( dataloader1 , consumed_sample )
259+ dataloader_state = dataloader1 . get_state_dict ( )
273260
274261 expected_data = []
275262 for _ in range (AFTER_RESUME_ITER ):
276263 expected_data .append (next (dataloader_iter ))
277264
278- new_dataloader2 = build_dataloader (
279- dataloader_config = dataloader_config ,
280- datasets = datasets ,
265+ new_dataloader2 = dataloader_config . build (
266+ tokenizer = tokenizer ,
267+ dp_mesh = None ,
281268 global_batch_size = GLOBAL_BATCH_SIZE ,
282269 micro_batch_size = BATCH_SIZE ,
283270 seed = 10 ,
284271 )
285- load_dataloader_state ( new_dataloader2 , dataloader_state )
272+ new_dataloader2 . load_state_dict ( dataloader_state )
286273 new_dataloader_iter2 = iter (new_dataloader2 )
287274
288275 resume_data = []
0 commit comments