1+
2+ """
3+ # Created: 2025-03-24 11:39
4+ # Copyright (C) 2025-now, RPL, KTH Royal Institute of Technology
5+ # Author: Qingwen Zhang (https://kin-zhang.github.io/)
6+ #
7+ # This work is licensed under the terms of the MIT license.
8+ # For a copy, see <https://opensource.org/licenses/MIT>.
9+
10+ # Description: view scene flow dataset after preprocess.
11+ """
12+
13+ import os , sys
14+ BASE_DIR = os .path .abspath (os .path .join ( os .path .dirname ( __file__ ), '..' ))
15+ sys .path .append (BASE_DIR )
16+ import time , fire , h5py , pickle
17+ import numpy as np
18+
19+ from av2 .datasets .sensor .constants import AnnotationCategories
20+ from typing import Final
21+
22+ from tqdm import tqdm
23+
24+ CATEGORY_TO_INDEX : Final = {
25+ ** {"NONE" : 0 },
26+ ** {k .value : i + 1 for i , k in enumerate (AnnotationCategories )},
27+ }
28+ INDEX_TO_CATEGORY : Final = {v : k for k , v in CATEGORY_TO_INDEX .items ()}
29+ NAME_MAPPING_K2A = {
30+ 'outlier' : 'NONE' ,
31+ 'unlabeled' : 'NONE' ,
32+ 'car' : 'REGULAR_VEHICLE' ,
33+ 'bicycle' : 'BICYCLE' ,
34+ 'motorcycle' : 'MOTORCYCLE' ,
35+ 'truck' : 'TRUCK' ,
36+ 'other-vehicle' : 'LARGE_VEHICLE' ,
37+ 'person' : 'PEDESTRIAN' ,
38+ 'bicyclist' : 'BICYCLIST' ,
39+ 'motorcyclist' : 'MOTORCYCLIST' ,
40+ 'road' : 'NONE' ,
41+ 'parking' : 'NONE' ,
42+ 'sidewalk' : 'NONE' ,
43+ 'other-ground' : 'NONE' ,
44+ 'building' : 'NONE' ,
45+ 'fence' : 'NONE' ,
46+ 'vegetation' : 'NONE' ,
47+ 'trunk' : 'NONE' ,
48+ 'terrain' : 'NONE' ,
49+ 'pole' : 'NONE' ,
50+ 'traffic-sign' : 'SIGN' ,
51+ }
52+
53+
54+ NAME_MAPPING_N2A = {
55+ 'ignore' : 'NONE' ,
56+ 'barrier' : 'NONE' ,
57+ 'bicycle' : 'BICYCLE' ,
58+ 'bus' : 'BUS' ,
59+ 'car' : 'REGULAR_VEHICLE' ,
60+ 'construction_vehicle' : 'LARGE_VEHICLE' ,
61+ 'motorcycle' : 'MOTORCYCLE' ,
62+ 'pedestrian' : 'PEDESTRIAN' ,
63+ 'traffic_cone' : 'NONE' ,
64+ 'trailer' : 'VEHICULAR_TRAILER' ,
65+ 'truck' : 'TRUCK' ,
66+ 'driveable_surface' : 'NONE' ,
67+ 'other_flat' : 'NONE' ,
68+ 'sidewalk' : 'NONE' ,
69+ 'terrain' : 'NONE' ,
70+ 'manmade' : 'NONE' ,
71+ 'vegetation' : 'NONE' ,
72+ }
73+ PEDESTRIAN_CATEGORIES = ["PEDESTRIAN" , "STROLLER" , "WHEELCHAIR" , "OFFICIAL_SIGNALER" ]
74+ WHEELED_VRU = [
75+ "BICYCLE" ,
76+ "BICYCLIST" ,
77+ "MOTORCYCLE" ,
78+ "MOTORCYCLIST" ,
79+ "WHEELED_DEVICE" ,
80+ "WHEELED_RIDER" ,
81+ ]
82+ CAR = ["REGULAR_VEHICLE" ]
83+ OTHER_VEHICLES = [
84+ "BOX_TRUCK" ,
85+ "LARGE_VEHICLE" ,
86+ "RAILED_VEHICLE" ,
87+ "TRUCK" ,
88+ "TRUCK_CAB" ,
89+ "VEHICULAR_TRAILER" ,
90+ "ARTICULATED_BUS" ,
91+ "BUS" ,
92+ "SCHOOL_BUS" ,
93+ ]
94+ class iouEval :
95+ def __init__ (self , n_classes = 2 , ignore = None ):
96+ # classes
97+ self .n_classes = n_classes
98+ # What to include and ignore from the means
99+ self .ignore = np .array (ignore , dtype = np .int64 )
100+ self .include = np .array (
101+ [n for n in range (self .n_classes ) if n not in self .ignore ], dtype = np .int64 )
102+ # print("[IOU EVAL] IGNORE: ", self.ignore)
103+ # print("[IOU EVAL] INCLUDE: ", self.include)
104+ # reset the class counters
105+ self .reset ()
106+
107+ def num_classes (self ):
108+ return self .n_classes
109+
110+ def reset (self ):
111+ self .conf_matrix = np .zeros ((self .n_classes , self .n_classes ), dtype = np .int64 )
112+
113+ def addBatch (self , x , y ): # x=preds, y=targets
114+ # to tensor
115+ x_row = x .astype (np .int64 )
116+ y_row = y .astype (np .int64 )
117+
118+ # sizes should be matching
119+ x_row = x_row .reshape (- 1 ) # de-batchify
120+ y_row = y_row .reshape (- 1 ) # de-batchify
121+
122+ # check
123+ assert (x_row .shape == x_row .shape )
124+
125+ # idxs are labels and predictions
126+ idxs = np .stack ([x_row , y_row ], axis = 0 )
127+
128+ # ones is what I want to add to conf when I
129+ ones = np .ones ((idxs .shape [- 1 ]), dtype = np .int64 )
130+
131+ # make confusion matrix (cols = gt, rows = pred)
132+ # self.conf_matrix = self.conf_matrix.index_put_(
133+ # tuple(idxs), ones, accumulate=True)
134+ np .add .at (self .conf_matrix , tuple (idxs ), ones )
135+
136+ def getStats (self ):
137+ # remove fp from confusion on the ignore classes cols
138+ conf = self .conf_matrix .astype (np .float64 )
139+ conf [:, self .ignore ] = 0
140+
141+ # get the clean stats
142+ tp = np .diag (conf )
143+ fp = conf .sum (axis = 1 ) - tp
144+ fn = conf .sum (axis = 0 ) - tp
145+ return tp , fp , fn
146+
147+ def getIoU (self ):
148+ tp , fp , fn = self .getStats ()
149+ intersection = tp
150+ union = tp + fp + fn + 1e-15
151+ iou = intersection / union
152+ iou_mean = (intersection [self .include ] / union [self .include ]).mean ()
153+ return iou_mean , iou # returns "iou mean", "iou per class" ALL CLASSES
154+
155+ class HDF5Data :
156+ def __init__ (self , directory , flow_view = False , vis_name = ["flow" ], val = True ):
157+ '''
158+ directory: the directory of the dataset
159+ t_x: how many past frames we want to extract
160+ '''
161+ self .flow_view = flow_view
162+ self .vis_name = vis_name if isinstance (vis_name , list ) else [vis_name ]
163+ self .directory = directory
164+ self .phase = 'val' if val else 'test'
165+ if os .path .exists (os .path .join (self .directory , 'index_eval.pkl' )) or self .phase == 'val' :
166+ eval_index_file = os .path .join (self .directory , 'index_eval.pkl' )
167+ with open (eval_index_file , 'rb' ) as f :
168+ self .evalim_idx = pickle .load (f )
169+ else :
170+ eval_index_file = None
171+ with open (os .path .join (self .directory , 'index_total.pkl' ), 'rb' ) as f :
172+ self .data_index = pickle .load (f )
173+
174+ self .scene_id_bounds = {} # 存储每个scene_id的最大最小timestamp和位置
175+ for idx , (scene_id , timestamp ) in enumerate (self .data_index ):
176+ if scene_id not in self .scene_id_bounds :
177+ self .scene_id_bounds [scene_id ] = {
178+ "min_timestamp" : timestamp ,
179+ "max_timestamp" : timestamp ,
180+ "min_index" : idx ,
181+ "max_index" : idx
182+ }
183+ else :
184+ bounds = self .scene_id_bounds [scene_id ]
185+ # 更新最小timestamp和位置
186+ if timestamp < bounds ["min_timestamp" ]:
187+ bounds ["min_timestamp" ] = timestamp
188+ bounds ["min_index" ] = idx
189+ # 更新最大timestamp和位置
190+ if timestamp > bounds ["max_timestamp" ]:
191+ bounds ["max_timestamp" ] = timestamp
192+ bounds ["max_index" ] = idx
193+
194+ def __len__ (self ):
195+ if self .phase == 'val' :
196+ return len (self .evalim_idx )
197+ return len (self .data_index )
198+
199+ def __getitem__ (self , index ):
200+ if self .phase == 'val' :
201+ scene_id , timestamp = self .evalim_idx [index ]
202+ index = self .data_index .index ([scene_id , timestamp ])
203+ scene_id , timestamp = self .data_index [index ]
204+ # to make sure we have continuous frames for flow view
205+ # if self.flow_view and self.scene_id_bounds[scene_id]["max_index"] == index:
206+ # index = index - 1
207+ # scene_id, timestamp = self.data_index[index]
208+
209+ key = str (timestamp )
210+ data_dict = {
211+ 'scene_id' : scene_id ,
212+ 'timestamp' : timestamp ,
213+ }
214+ with h5py .File (os .path .join (self .directory , f'{ scene_id } .h5' ), 'r' ) as f :
215+ # original data
216+ data_dict ['pc0' ] = f [key ]['lidar' ][:]
217+ data_dict ['gm0' ] = f [key ]['ground_mask' ][:]
218+ data_dict ['pose0' ] = f [key ]['pose' ][:]
219+ for flow_key in ['seg_valid' , 'flow_category_indices' ] + self .vis_name :
220+ if flow_key in f [key ]:
221+ data_dict [flow_key ] = f [key ][flow_key ][:]
222+ else :
223+ print (f"[Warning]: No { flow_key } in { scene_id } at { timestamp } , check the data." )
224+ # if self.flow_view:
225+ # next_timestamp = str(self.data_index[index+1][1])
226+ # data_dict['pose1'] = f[next_timestamp]['pose'][:]
227+ # data_dict['pc1'] = f[next_timestamp]['lidar'][:]
228+ # data_dict['gm1'] = f[next_timestamp]['ground_mask'][:]
229+ # elif self.flow_view:
230+ # print(f"[Warning]: No {self.vis_name} in {scene_id} at {timestamp}, check the data.")
231+ return data_dict
232+
233+ valid_index_ = [CATEGORY_TO_INDEX [l ] for l in CAR + OTHER_VEHICLES ]
234+ def main (
235+ data_dir : str = "/home/kin/data/av2/h5py/sensor/himo" ,
236+ res_names : list = ["seg_raw" ,"seg_flow" ]
237+ ):
238+ dataset = HDF5Data (data_dir , flow_view = True , vis_name = res_names , val = True )
239+ evaluators = {name : iouEval (n_classes = 3 , ignore = []) for name in res_names }
240+ # print(f"Total {len(dataset)} scenes.")
241+ car_index = [CATEGORY_TO_INDEX [l ] for l in CAR ]
242+ other_index = [CATEGORY_TO_INDEX [l ] for l in OTHER_VEHICLES ]
243+ for data_id in tqdm (range (len (dataset )), desc = "Evaluating" , total = len (dataset ), ncols = 120 ):
244+ data = dataset [data_id ]
245+ if 'flow_category_indices' not in data :
246+ print (f"[Warning]: No flow_category_indices in { data ['scene_id' ]} at { data ['timestamp' ]} , check the data." )
247+ continue
248+ valid_mask = data ['seg_valid' ]
249+ # mask only or all points
250+ valid_mask = np .ones_like (valid_mask )
251+ seg_gt = data ['flow_category_indices' ][valid_mask ]
252+
253+ # re-assign the label needed on 3 classes only
254+ # if not car_index and other_index, then it is 0
255+ seg_gt [~ np .isin (seg_gt , valid_index_ )] = 0
256+ seg_gt [np .isin (seg_gt , car_index )] = 1
257+ seg_gt [np .isin (seg_gt , other_index )] = 2
258+ seg_pred = {}
259+ for name in res_names :
260+ seg_pred [name ] = data [name ][valid_mask ]
261+ seg_pred [name ][~ np .isin (seg_pred [name ], valid_index_ )] = 0
262+ seg_pred [name ][np .isin (seg_pred [name ], car_index )] = 1
263+ seg_pred [name ][np .isin (seg_pred [name ], other_index )] = 2
264+
265+ evaluators [name ].addBatch (seg_pred [name ], seg_gt )
266+
267+ # if data_id > 10:
268+ # break
269+
270+
271+ # evaluate miou on valid mask only and maybe highspeed?
272+ print ("\n ========================== RESULTS ========================== " )
273+ for name in res_names :
274+ _ , class_jaccard = evaluators [name ].getIoU ()
275+ m_jaccard = class_jaccard [1 :].mean ()
276+
277+ ignore = [0 ]
278+ class_strings = {0 :'ignore' , 1 : 'car' , 2 : 'other_vehicle' }
279+ print ('{name} 100 frames val:\n IoU avg {m_jaccard:.3f}' .format (name = name , m_jaccard = m_jaccard * 100 ))
280+ # print also classwise
281+ for i , jacc in enumerate (class_jaccard ):
282+ if i in ignore :
283+ continue
284+ print ('IoU class {i:} [{class_str:}] = {jacc:.3f}' .format (
285+ i = i , class_str = class_strings [i ], jacc = jacc * 100 ))
286+ print ('-' * 20 )
287+ if __name__ == '__main__' :
288+ start_time = time .time ()
289+ fire .Fire (main )
290+ print (f"Time used: { time .time () - start_time :.2f} s" )
0 commit comments