22import functools
33import logging
44import time
5+ from collections .abc import (
6+ Iterable ,
7+ )
58from copy import (
69 deepcopy ,
710)
4750 dp_random ,
4851)
4952from deepmd .pt .utils .dataloader import (
50- BufferedIterator ,
5153 get_sampler_from_params ,
5254)
5355from deepmd .pt .utils .env import (
@@ -159,8 +161,24 @@ def get_opt_param(params):
159161 }
160162 return opt_type , opt_param
161163
164+ def cycle_iterator (iterable : Iterable ):
165+ """
166+ Produces an infinite iterator by repeatedly cycling through the given iterable.
167+
168+ Args:
169+ iterable (Iterable): The iterable to cycle through.
170+
171+ Yields
172+ ------
173+ Any: The next item from the iterable, cycling back to the beginning when the end is reached.
174+ """
175+ while True :
176+ with torch .device ("cpu" ):
177+ it = iter (iterable )
178+ yield from it
179+
162180 def get_data_loader (_training_data , _validation_data , _training_params ):
163- def get_dataloader_and_buffer (_data , _params ):
181+ def get_dataloader_and_iter (_data , _params ):
164182 _sampler = get_sampler_from_params (_data , _params )
165183 if _sampler is None :
166184 log .warning (
@@ -177,33 +195,32 @@ def get_dataloader_and_buffer(_data, _params):
177195 collate_fn = lambda batch : batch , # prevent extra conversion
178196 pin_memory = True ,
179197 )
180- with torch .device ("cpu" ):
181- _data_buffered = BufferedIterator (iter (_dataloader ))
182- return _dataloader , _data_buffered
198+ _data_iter = cycle_iterator (_dataloader )
199+ return _dataloader , _data_iter
183200
184- training_dataloader , training_data_buffered = get_dataloader_and_buffer (
201+ training_dataloader , training_data_iter = get_dataloader_and_iter (
185202 _training_data , _training_params ["training_data" ]
186203 )
187204
188205 if _validation_data is not None :
189206 (
190207 validation_dataloader ,
191- validation_data_buffered ,
192- ) = get_dataloader_and_buffer (
208+ validation_data_iter ,
209+ ) = get_dataloader_and_iter (
193210 _validation_data , _training_params ["validation_data" ]
194211 )
195212 valid_numb_batch = _training_params ["validation_data" ].get (
196213 "numb_btch" , 1
197214 )
198215 else :
199216 validation_dataloader = None
200- validation_data_buffered = None
217+ validation_data_iter = None
201218 valid_numb_batch = 1
202219 return (
203220 training_dataloader ,
204- training_data_buffered ,
221+ training_data_iter ,
205222 validation_dataloader ,
206- validation_data_buffered ,
223+ validation_data_iter ,
207224 valid_numb_batch ,
208225 )
209226
@@ -1064,48 +1081,15 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
10641081 checkpoint_files [0 ].unlink ()
10651082
10661083 def get_data (self , is_train = True , task_key = "Default" ):
1067- if not self .multi_task :
1068- if is_train :
1069- try :
1070- batch_data = next (iter (self .training_data ))
1071- except StopIteration :
1072- # Refresh the status of the dataloader to start from a new epoch
1073- with torch .device ("cpu" ):
1074- self .training_data = BufferedIterator (
1075- iter (self .training_dataloader )
1076- )
1077- batch_data = next (iter (self .training_data ))
1078- else :
1079- if self .validation_data is None :
1080- return {}, {}, {}
1081- try :
1082- batch_data = next (iter (self .validation_data ))
1083- except StopIteration :
1084- self .validation_data = BufferedIterator (
1085- iter (self .validation_dataloader )
1086- )
1087- batch_data = next (iter (self .validation_data ))
1084+ if is_train :
1085+ iterator = self .training_data
10881086 else :
1089- if is_train :
1090- try :
1091- batch_data = next (iter (self .training_data [task_key ]))
1092- except StopIteration :
1093- # Refresh the status of the dataloader to start from a new epoch
1094- self .training_data [task_key ] = BufferedIterator (
1095- iter (self .training_dataloader [task_key ])
1096- )
1097- batch_data = next (iter (self .training_data [task_key ]))
1098- else :
1099- if self .validation_data [task_key ] is None :
1100- return {}, {}, {}
1101- try :
1102- batch_data = next (iter (self .validation_data [task_key ]))
1103- except StopIteration :
1104- self .validation_data [task_key ] = BufferedIterator (
1105- iter (self .validation_dataloader [task_key ])
1106- )
1107- batch_data = next (iter (self .validation_data [task_key ]))
1108-
1087+ iterator = self .validation_data
1088+ if self .multi_task :
1089+ iterator = iterator [task_key ]
1090+ if iterator is None :
1091+ return {}, {}, {}
1092+ batch_data = next (iterator )
11091093 for key in batch_data .keys ():
11101094 if key == "sid" or key == "fid" or key == "box" or "find_" in key :
11111095 continue
0 commit comments