2323import h5py , pickle , argparse
2424from tqdm import tqdm
2525import numpy as np
26+ from torchvision import transforms
2627
2728import os , sys
2829BASE_DIR = os .path .abspath (os .path .join ( os .path .dirname ( __file__ ), '..' ))
@@ -185,8 +186,8 @@ def __call__(self, data_dict):
185186class HDF5Dataset (Dataset ):
186187 def __init__ (self , directory , \
187188 transform = None , n_frames = 2 , ssl_label = None , \
188- eval = False , eval_input_seq = False , leaderboard_version = 1 , \
189- vis_name = '' , flow_num = 1 ):
189+ eval = False , leaderboard_version = 1 , \
190+ vis_name = '' ):
190191 '''
191192 Args:
192193 directory: the directory of the dataset, the folder should contain some .h5 file and index_total.pkl.
@@ -196,10 +197,8 @@ def __init__(self, directory, \
196197 * n_frames: the number of frames we use, default is 2: current (pc0), next (pc1); if it's more than 2, then it read the history from current.
197198 * ssl_label: if attr, it will read the dynamic cluster label. Otherwise, no dynamic cluster label in data dict.
198199 * eval: if True, use the eval index (only used it for leaderboard evaluation)
199- * eval_input_seq: I forgot what it is.... xox...
200200 * leaderboard_version: 1st or 2nd, default is 1. If '2', we will use the index_eval_v2.pkl from assets/docs.
201201 * vis_name: the data of the visualization, default is ''.
202- * flow_num: the number of future frames we read, default is 1. (pc0->pc1 flow)
203202 '''
204203 super (HDF5Dataset , self ).__init__ ()
205204 self .directory = directory
@@ -209,12 +208,10 @@ def __init__(self, directory, \
209208 self .data_index = pickle .load (f )
210209
211210 self .eval_index = False
212- self .eval_input_seq = eval_input_seq
213211 self .ssl_label = import_func (f"src.autolabel.{ ssl_label } " ) if ssl_label is not None else None
214212 self .history_frames = n_frames - 2
215213 self .vis_name = vis_name if isinstance (vis_name , list ) else [vis_name ]
216214 self .transform = transform
217- self .flow_num = flow_num
218215
219216 if eval :
220217 eval_index_file = os .path .join (self .directory , 'index_eval.pkl' )
@@ -267,7 +264,7 @@ def __init__(self, directory, \
267264
268265 def __len__ (self ):
269266 # return 100 # for testing
270- if self .eval_index and not self . eval_input_seq :
267+ if self .eval_index :
271268 return len (self .eval_data_index )
272269 elif not self .eval_index and self .train_index is not None :
273270 return len (self .train_index )
@@ -278,25 +275,17 @@ def valid_index(self, index_):
278275 Check if the index is valid for the current mode and satisfy the constraints.
279276 """
280277 eval_flag = False
281- if self .eval_index and not self . eval_input_seq :
278+ if self .eval_index :
282279 eval_index_ = index_
283280 scene_id , timestamp = self .eval_data_index [eval_index_ ]
284281 index_ = self .data_index .index ([scene_id , timestamp ])
285282 max_idx = self .scene_id_bounds [scene_id ]["max_index" ]
286283 if index_ >= max_idx :
287284 _ , index_ = self .valid_index (eval_index_ - 1 )
288285 eval_flag = True
289- elif self .eval_index and self .eval_input_seq :
290- scene_id , timestamp = self .data_index [index_ ]
291- # to make sure we have continuous frames
292- if self .scene_id_bounds [scene_id ]["max_index" ] <= index_ :
293- index_ = index_ - 1
294- scene_id , timestamp = self .data_index [index_ ]
295- eval_flag = True if [scene_id , timestamp ] in self .eval_data_index else False
296286 elif self .train_index is not None :
297287 train_index_ = index_
298288 scene_id , timestamp = self .train_index [train_index_ ]
299- # FIXME: it works now, but self.flow_num is not possible in this case.
300289 max_idx = self .scene_id_bounds [scene_id ]["max_index" ]
301290 index_ = self .data_index .index ([scene_id , timestamp ])
302291 if index_ >= max_idx :
@@ -306,7 +295,7 @@ def valid_index(self, index_):
306295 max_idx = self .scene_id_bounds [scene_id ]["max_index" ]
307296 min_idx = self .scene_id_bounds [scene_id ]["min_index" ]
308297
309- max_valid_index_for_flow = max_idx - self . flow_num
298+ max_valid_index_for_flow = max_idx - 1
310299 min_valid_index_for_flow = min_idx + self .history_frames
311300 index_ = max (min_valid_index_for_flow , min (max_valid_index_for_flow , index_ ))
312301 return eval_flag , index_
0 commit comments