@@ -148,9 +148,9 @@ def step_eval(self, pc, gt_flow, pc_dt0, gt_category, gt_instance, est_flow=None
148148
149149 self .frame_cnt += 1
150150
151- def print (self , flow_mode = "flow" , file_name = "result_av2.json" ):
151+ def print (self , res_name = "flow" , file_name = "result_av2.json" ):
152152 # --- Helper: Save detailed metrics to JSON (Preserves original structure) ---
153- def savejson (overall_data , vel_data , dis_data , flow_mode , category_name ):
153+ def savejson (overall_data , vel_data , dis_data , res_name , category_name ):
154154 if os .path .exists (file_name ):
155155 with open (file_name , "r" ) as f :
156156 try :
@@ -162,8 +162,8 @@ def savejson(overall_data, vel_data, dis_data, flow_mode, category_name):
162162
163163 if self .data_name not in data :
164164 data [self .data_name ] = {}
165- if flow_mode not in data [self .data_name ]:
166- data [self .data_name ][flow_mode ] = {}
165+ if res_name not in data [self .data_name ]:
166+ data [self .data_name ][res_name ] = {}
167167
168168 # Construct payload matching original format
169169 entry = {
@@ -188,7 +188,7 @@ def savejson(overall_data, vel_data, dis_data, flow_mode, category_name):
188188 "num_pts" : int (dis_data [i ][3 ]), "num_obj" : int (dis_data [i ][4 ])
189189 }
190190
191- data [self .data_name ][flow_mode ][category_name ] = entry
191+ data [self .data_name ][res_name ][category_name ] = entry
192192 with open (file_name , "w" ) as f :
193193 json .dump (data , f , indent = 4 )
194194 def safe_average (values , weights ):
@@ -205,7 +205,7 @@ def safe_std(values):
205205 total_data = {"mpe" : [], "cham" : [], "std_mpe" : [], "std_cham" : [], "num_pts" : []}
206206 table_rows = []
207207
208- print (f"\n HiMo refinement metrics for { flow_mode } in { self .data_name } :" )
208+ print (f"\n HiMo refinement metrics for { res_name } in { self .data_name } :" )
209209
210210 for cat in target_cats :
211211 if cat not in self .evaluate_data or len (self .evaluate_data [cat ]['mean' ]['num_pts' ]) == 0 :
@@ -238,7 +238,7 @@ def safe_std(values):
238238 d = raw ['dis' ][r ]
239239 dis_entries .append ([r , safe_average (d ['mpe' ], d ['num_pts' ]), safe_average (d ['cham' ], d ['num_pts' ]), np .sum (d ['num_pts' ]), len (d ['num_pts' ])])
240240
241- savejson (overall_entry , vel_entries , dis_entries , flow_mode , cat )
241+ savejson (overall_entry , vel_entries , dis_entries , res_name , cat )
242242
243243 table_rows .append ([
244244 cat_display_map .get (cat , cat ),
@@ -270,15 +270,15 @@ def safe_std(values):
270270def main (
271271 # data_dir: str = "/home/kin/data/Scania/preprocess/val",
272272 data_dir : str = "/home/kin/data/av2/h5py/sensor/himo" ,
273- flow_mode : str = "" ,
273+ res_name : str = "" ,
274274 comp_dis_zip : str = "" ,
275275):
276- data_name , EVAL_FLAG = check_valid (data_dir , flow_mode , comp_dis_zip )
276+ data_name , EVAL_FLAG = check_valid (data_dir , res_name , comp_dis_zip )
277277
278278 refinement_metrics = InstanceMetrics (data_name = data_name )
279- dataset = HDF5Dataset (data_dir , vis_name = flow_mode if EVAL_FLAG == 2 else '' , eval = True )
279+ dataset = HDF5Dataset (data_dir , vis_name = res_name if EVAL_FLAG == 2 else '' , eval = True )
280280
281- for data_id in tqdm (range (0 , len (dataset )), ncols = 80 , desc = f"Evaluating { flow_mode } on { data_name } " ):
281+ for data_id in tqdm (range (0 , len (dataset )), ncols = 80 , desc = f"Evaluating { res_name } on { data_name } " ):
282282 data = dataset [data_id ]
283283 pc0 , pose0 , pose1 = data ['pc0' ], data ['pose0' ], data ['pose1' ]
284284 ego_pose = np .linalg .inv (pose1 ) @ pose0
@@ -299,7 +299,7 @@ def main(
299299 dt0 = max (data ['lidar_dt' ]) - data ['lidar_dt' ]
300300
301301 if EVAL_FLAG == 2 :
302- est_flow = np .zeros_like (pose_flow ) if flow_mode == "raw" else (data [flow_mode ] - pose_flow )
302+ est_flow = np .zeros_like (pose_flow ) if res_name == "raw" else (data [res_name ] - pose_flow )
303303 refinement_metrics .step_eval (pc0 [mask_eval ,:], gt_flow [mask_eval ,:], \
304304 dt0 [mask_eval ], data ['flow_category_indices' ][mask_eval ], data ['flow_instance_id' ][mask_eval ],
305305 est_flow = est_flow [mask_eval ,:])
@@ -309,7 +309,7 @@ def main(
309309 dt0 [mask_eval ], data ['flow_category_indices' ][mask_eval ], data ['flow_instance_id' ][mask_eval ],
310310 est_dis = comp_dis [mask_eval ,:])
311311
312- refinement_metrics .print (flow_mode = flow_mode , file_name = f"res-{ data_name } .json" )
312+ refinement_metrics .print (res_name = res_name , file_name = f"res-{ data_name } .json" )
313313
314314if __name__ == '__main__' :
315315 start_time = time .time ()
0 commit comments