Skip to content

Commit 4f19d6b

Browse files
committed
style: clean up the developed code previously.
* add readh5 and create_eval_pkl
1 parent a8d54af commit 4f19d6b

File tree

10 files changed

+111
-82
lines changed

10 files changed

+111
-82
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ wget https://huggingface.co/kin-zhang/OpenSceneFlow/resolve/main/deflow_best.ckp
198198
### Feed-Forward Self-Supervised Model Training
199199

200200
Train Feed-forward SSL methods (e.g. SeFlow/SeFlow++/VoteFlow etc), we needed to:
201-
1) process auto-label process.
201+
1) process auto-label process for training. Check [dataprocess/README.md#self-supervised-process](dataprocess/README.md#self-supervised-process) for more details. We provide these inside the demo dataset already.
202202
2) specify the loss function, we set the config here for our best model in the leaderboard.
203203

204204
#### SeFlow
@@ -257,7 +257,8 @@ python save.py model=fastnsf
257257

258258
## 3. Evaluation
259259

260-
You can view Wandb dashboard for the training and evaluation results or upload result to online leaderboard.
260+
You can view Wandb dashboard for the training and evaluation results or upload result to online leaderboard.
261+
<!-- Three-way EPE and Dynamic Bucket-normalized are evaluated within a 70x70m range (followed Argoverse 2 online leaderboard). No ground points are considered in the evaluation. -->
261262

262263
Since in training, we save all hyper-parameters and model checkpoints, the only thing you need to do is to specify the checkpoint path. Remember to set the data path correctly also.
263264

assets/slurm/dufolabel_sbatch.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

assets/slurm/ssl-process.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ cd /proj/berzelius-2023-154/users/x_qinzh/OpenSceneFlow
1818

1919

2020
# data directory containing the extracted h5py files
21-
DATA_DIR="/proj/berzelius-2023-364/data/truckscenes/h5py/val"
21+
DATA_DIR="/proj/berzelius-2023-364/data/av2/h5py/sensor/train"
2222

2323
TOTAL_SCENES=$(ls ${DATA_DIR}/*.h5 | wc -l)
2424
# Process every n-th frame into DUFOMap, no need to change at least for now.

dataprocess/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,6 @@ Process train data for self-supervised learning. Only training data needs this s
247247
```bash
248248
python process.py --data_dir /home/kin/data/av2/h5py/sensor/train --scene_range 0,701
249249
```
250+
251+
As some users must have multi-nodes for running, here I provide an example SLURM script to run the data process in parallel.
252+
Check [assets/slurm/ssl-process.sh](../assets/slurm/ssl-process.sh) for more details.

process.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def main(
186186
if not os.path.exists(gm_config_path) and run_gm:
187187
raise FileNotFoundError(f"Ground segmentation config file not found: {gm_config_path}. Please check folder")
188188

189-
190189
data_path = Path(data_dir)
191190
dataset = HDF5Data(data_path) # single frame reading.
192191
all_scene_ids = list(dataset.scene_id_bounds.keys())

src/dataset.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import h5py, pickle, argparse
2424
from tqdm import tqdm
2525
import numpy as np
26+
from torchvision import transforms
2627

2728
import os, sys
2829
BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' ))
@@ -185,8 +186,8 @@ def __call__(self, data_dict):
185186
class HDF5Dataset(Dataset):
186187
def __init__(self, directory, \
187188
transform=None, n_frames=2, ssl_label=None, \
188-
eval = False, eval_input_seq = False, leaderboard_version=1, \
189-
vis_name='', flow_num=1):
189+
eval = False, leaderboard_version=1, \
190+
vis_name=''):
190191
'''
191192
Args:
192193
directory: the directory of the dataset, the folder should contain some .h5 file and index_total.pkl.
@@ -196,10 +197,8 @@ def __init__(self, directory, \
196197
* n_frames: the number of frames we use, default is 2: current (pc0), next (pc1); if it's more than 2, then it read the history from current.
197198
* ssl_label: if attr, it will read the dynamic cluster label. Otherwise, no dynamic cluster label in data dict.
198199
* eval: if True, use the eval index (only used it for leaderboard evaluation)
199-
* eval_input_seq: I forgot what it is.... xox...
200200
* leaderboard_version: 1st or 2nd, default is 1. If '2', we will use the index_eval_v2.pkl from assets/docs.
201201
* vis_name: the data of the visualization, default is ''.
202-
* flow_num: the number of future frames we read, default is 1. (pc0->pc1 flow)
203202
'''
204203
super(HDF5Dataset, self).__init__()
205204
self.directory = directory
@@ -209,12 +208,10 @@ def __init__(self, directory, \
209208
self.data_index = pickle.load(f)
210209

211210
self.eval_index = False
212-
self.eval_input_seq = eval_input_seq
213211
self.ssl_label = import_func(f"src.autolabel.{ssl_label}") if ssl_label is not None else None
214212
self.history_frames = n_frames - 2
215213
self.vis_name = vis_name if isinstance(vis_name, list) else [vis_name]
216214
self.transform = transform
217-
self.flow_num = flow_num
218215

219216
if eval:
220217
eval_index_file = os.path.join(self.directory, 'index_eval.pkl')
@@ -267,7 +264,7 @@ def __init__(self, directory, \
267264

268265
def __len__(self):
269266
# return 100 # for testing
270-
if self.eval_index and not self.eval_input_seq:
267+
if self.eval_index:
271268
return len(self.eval_data_index)
272269
elif not self.eval_index and self.train_index is not None:
273270
return len(self.train_index)
@@ -278,25 +275,17 @@ def valid_index(self, index_):
278275
Check if the index is valid for the current mode and satisfy the constraints.
279276
"""
280277
eval_flag = False
281-
if self.eval_index and not self.eval_input_seq:
278+
if self.eval_index:
282279
eval_index_ = index_
283280
scene_id, timestamp = self.eval_data_index[eval_index_]
284281
index_ = self.data_index.index([scene_id, timestamp])
285282
max_idx = self.scene_id_bounds[scene_id]["max_index"]
286283
if index_ >= max_idx:
287284
_, index_ = self.valid_index(eval_index_ - 1)
288285
eval_flag = True
289-
elif self.eval_index and self.eval_input_seq:
290-
scene_id, timestamp = self.data_index[index_]
291-
# to make sure we have continuous frames
292-
if self.scene_id_bounds[scene_id]["max_index"] <= index_:
293-
index_ = index_ - 1
294-
scene_id, timestamp = self.data_index[index_]
295-
eval_flag = True if [scene_id, timestamp] in self.eval_data_index else False
296286
elif self.train_index is not None:
297287
train_index_ = index_
298288
scene_id, timestamp = self.train_index[train_index_]
299-
# FIXME: it works now, but self.flow_num is not possible in this case.
300289
max_idx = self.scene_id_bounds[scene_id]["max_index"]
301290
index_ = self.data_index.index([scene_id, timestamp])
302291
if index_ >= max_idx:
@@ -306,7 +295,7 @@ def valid_index(self, index_):
306295
max_idx = self.scene_id_bounds[scene_id]["max_index"]
307296
min_idx = self.scene_id_bounds[scene_id]["min_index"]
308297

309-
max_valid_index_for_flow = max_idx - self.flow_num
298+
max_valid_index_for_flow = max_idx - 1
310299
min_valid_index_for_flow = min_idx + self.history_frames
311300
index_ = max(min_valid_index_for_flow, min(max_valid_index_for_flow, index_))
312301
return eval_flag, index_

tools/create_evalpkl.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
# Created: 2025-11-21 15:13
3+
# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
4+
# Author: Qingwen Zhang (https://kin-zhang.github.io/)
5+
6+
# Description:
7+
Create evaluation index pickle file from the total index pickle file.
8+
- need have enough non-ground points (as some of waymo frames have data quality issues)
9+
- sample every 5 frames for evaluation (followed the leaderboard setting) it can also save 5x validation time for optimization-based methods also.
10+
"""
11+
12+
import os, fire, pickle, time
13+
import h5py, torch
14+
from tqdm import tqdm
15+
16+
def create_evalpkl(
17+
data_dir: str = "/home/kin/data/waymo/valid",
18+
interval: int = 5,
19+
):
20+
with open(os.path.join(data_dir, "index_total.pkl"), 'rb') as f:
21+
total_index = pickle.load(f)
22+
23+
scene_id_bounds = {}
24+
for idx, (scene_id, timestamp) in enumerate(total_index):
25+
if scene_id not in scene_id_bounds:
26+
scene_id_bounds[scene_id] = {
27+
"min_timestamp": timestamp, "max_timestamp": timestamp,
28+
"min_index": idx, "max_index": idx
29+
}
30+
else:
31+
bounds = scene_id_bounds[scene_id]
32+
if timestamp < bounds["min_timestamp"]:
33+
bounds["min_timestamp"] = timestamp
34+
bounds["min_index"] = idx
35+
if timestamp > bounds["max_timestamp"]:
36+
bounds["max_timestamp"] = timestamp
37+
bounds["max_index"] = idx
38+
39+
# split the index by 5 - 5 frame, start with the fifth frame
40+
eval_data_index = []
41+
for scene_id, bounds in tqdm(scene_id_bounds.items(), desc="Creating eval index", total=len(scene_id_bounds), dynamic_ncols=True):
42+
with h5py.File(os.path.join(data_dir, f'{scene_id}.h5'), 'r') as f:
43+
for idx in range(bounds["min_index"] + interval*2, bounds["max_index"] - interval*2, interval):
44+
scene_id, timestamp = total_index[idx]
45+
key = str(timestamp)
46+
pc = torch.tensor(f[key]['lidar'][:][:,:3])
47+
gm = torch.tensor(f[key]['ground_mask'][:])
48+
if pc[~gm].shape[0] < 10000:
49+
continue
50+
eval_data_index.append(total_index[idx])
51+
52+
# print(f"Demo: {eval_data_index[:10]}")
53+
print(f"Total {len(eval_data_index)} frames for evaluation in {data_dir}.")
54+
with open(os.path.join(data_dir, "index_eval.pkl"), 'wb') as f:
55+
pickle.dump(eval_data_index, f)
56+
57+
if __name__ == '__main__':
58+
start_time = time.time()
59+
fire.Fire(create_evalpkl)
60+
print(f"Create reading index Successfully, cost: {time.time() - start_time:.2f} s")

tools/readh5.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
# Created: 2023-12-31 22:19
3+
# LastEdit: 2024-01-12 18:46
4+
# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
5+
# Author: Qingwen Zhang (https://kin-zhang.github.io/)
6+
7+
# Description:
8+
# Quick Read the keys in an h5 file, print out their shapes and data types etc.
9+
10+
# Example Running:
11+
python tools/readh5.py --scene_path /home/kin/data/av2/h5py/sensor/test/0c6e62d7-bdfa-3061-8d3d-03b13aa21f68.h5
12+
"""
13+
14+
import os
15+
os.environ["OMP_NUM_THREADS"] = "1"
16+
import fire, time, h5py
17+
18+
def readh5key(
19+
scene_path: str = "/home/kin/data/av2/h5py/sensor/test/0c6e62d7-bdfa-3061-8d3d-03b13aa21f68.h5"
20+
):
21+
with h5py.File(scene_path, 'r') as f:
22+
for cnt, k in enumerate(f.keys()):
23+
if cnt % 2 == 1:
24+
continue
25+
print(f"id: {cnt}; Key (TimeStamp): {k}")
26+
for sub_k in f[k].keys():
27+
print(f" Sub-Key: {sub_k}, Shape: {f[k][sub_k].shape}, Dtype: {f[k][sub_k].dtype}")
28+
if cnt >= 10:
29+
break
30+
print(f"\nTotal {len(f.keys())} timestamps in the file.")
31+
32+
if __name__ == '__main__':
33+
start_time = time.time()
34+
fire.Fire(readh5key)
35+
print(f"\nTime used: {(time.time() - start_time)/60:.2f} mins")

tools/visualization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def vis(
102102
if res_name == 'raw': # no result, only show **raw point cloud**
103103
pcd.points = o3d.utility.Vector3dVector(pc0[:, :3])
104104
pcd.paint_uniform_color([1.0, 1.0, 1.0])
105-
elif res_name in ['dufo_label', 'label']:
105+
elif res_name in ['dufo', 'label']:
106106
labels = data[res_name]
107107
pcd_i = o3d.geometry.PointCloud()
108108
for label_i in np.unique(labels):
@@ -169,7 +169,7 @@ def vis_multiple(
169169
pcd_list = []
170170
for mode in res_name:
171171
pcd = o3d.geometry.PointCloud()
172-
if mode in ['dufo_label', 'label']:
172+
if mode in ['dufo', 'label']:
173173
labels = data[mode]
174174
pcd_i = o3d.geometry.PointCloud()
175175
for label_i in np.unique(labels):

tools/visualization_rerun.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def main(
104104
flow_color = np.tile(pcd_color, (pc0.shape[0], 1))
105105
flow_color[gm0] = ground_color
106106

107-
if mode in ['dufo_label', 'label']:
107+
if mode in ['dufo', 'label']:
108108
if mode in data:
109109
labels = data[mode]
110110
for label_i in np.unique(labels):

0 commit comments

Comments
 (0)