Skip to content

Commit 6401aa9

Browse files
committed
Merge remote-tracking branch 'opensf/feature/deltaflow'
2 parents 0e432b5 + 6c8aafc commit 6401aa9

File tree

14 files changed

+195
-106
lines changed

14 files changed

+195
-106
lines changed

assets/slurm/dufolabel_sbatch.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

assets/slurm/ssl-process.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ cd /proj/berzelius-2023-154/users/x_qinzh/OpenSceneFlow
1818

1919

2020
# data directory containing the extracted h5py files
21-
DATA_DIR="/proj/berzelius-2023-364/data/truckscenes/h5py/val"
21+
DATA_DIR="/proj/berzelius-2023-364/data/av2/h5py/sensor/train"
2222

2323
TOTAL_SCENES=$(ls ${DATA_DIR}/*.h5 | wc -l)
2424
# Process every n-th frame into DUFOMap, no need to change at least for now.

conf/eval.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
dataset_path: /home/kin/data/av2/h5py/sensor
3-
checkpoint: /home/kin/model_zoo/deflow.ckpt
3+
checkpoint: /home/kin/data/model_zoo/deltaflow_public/deltaflow-av2.ckpt
44
data_mode: val # [val, test]
55
save_res: False # [True, False]
66

@@ -15,7 +15,7 @@ output: ${model.name}-${slurm_id}
1515
gpus: 1
1616
seed: 42069
1717
eval_only: True
18-
wandb_mode: offline # [offline, disabled, online]
18+
wandb_mode: disabled # [offline, disabled, online]
1919
defaults:
2020
- hydra: default
2121
- model: deflow

dataprocess/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,6 @@ Process train data for self-supervised learning. Only training data needs this s
247247
```bash
248248
python process.py --data_dir /home/kin/data/av2/h5py/sensor/train --scene_range 0,701
249249
```
250+
251+
As some users must have multi-nodes for running, here I provide an example SLURM script to run the data process in parallel.
252+
Check [assets/slurm/ssl-process.sh](../assets/slurm/ssl-process.sh) for more details.

eval.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
import torch
1414
from torch.utils.data import DataLoader
1515
import lightning.pytorch as pl
16-
from lightning.pytorch.loggers import WandbLogger
16+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
1717
from omegaconf import DictConfig
1818
import hydra, wandb, os, sys
1919
from hydra.core.hydra_config import HydraConfig
2020
from src.dataset import HDF5Dataset
2121
from src.trainer import ModelWrapper
22+
from src.utils import InlineTee
2223

2324
def precheck_cfg_valid(cfg):
2425
if os.path.exists(cfg.dataset_path + f"/{cfg.data_mode}") is False:
@@ -36,8 +37,8 @@ def main(cfg):
3637

3738
if 'iter_only' in cfg.model and cfg.model.iter_only:
3839
from src.runner import launch_runner
39-
print(f"---LOG[eval]: Run optmization-based method: {cfg.model.name}")
40-
launch_runner(cfg, cfg.data_mode)
40+
launch_runner(cfg, cfg.data_mode, output_dir)
41+
print(f"---LOG[eval]: Finished optimization-based evaluation. Logging saved to {output_dir}/output.log")
4142
return
4243

4344
if not os.path.exists(cfg.checkpoint):
@@ -47,27 +48,39 @@ def main(cfg):
4748
torch_load_ckpt = torch.load(cfg.checkpoint)
4849
checkpoint_params = DictConfig(torch_load_ckpt["hyper_parameters"])
4950
cfg.output = checkpoint_params.cfg.output + f"-e{torch_load_ckpt['epoch']}-{cfg.data_mode}-v{cfg.leaderboard_version}"
51+
# replace output_dir ${old_output_dir} with ${output_dir}
52+
output_dir = output_dir.replace(HydraConfig.get().runtime.output_dir.split('/')[-2], checkpoint_params.cfg.output.split('/')[-1])
5053
cfg.model.update(checkpoint_params.cfg.model)
5154
cfg.num_frames = cfg.model.target.get('num_frames', checkpoint_params.cfg.get('num_frames', cfg.get('num_frames', 2)))
5255

5356
mymodel = ModelWrapper.load_from_checkpoint(cfg.checkpoint, cfg=cfg, eval=True)
54-
print(f"\n---LOG[eval]: Loaded model from {cfg.checkpoint}. The backbone network is {checkpoint_params.cfg.model.name}.\n")
57+
os.makedirs(output_dir, exist_ok=True)
58+
sys.stdout = InlineTee(f"{output_dir}/output.log")
59+
print(f"---LOG[eval]: Loaded model from {cfg.checkpoint}. The backbone network is {checkpoint_params.cfg.model.name}.")
60+
print(f"---LOG[eval]: Evaluation data: {cfg.dataset_path}/{cfg.data_mode} set.\n")
5561

56-
wandb_logger = WandbLogger(save_dir=output_dir,
57-
entity="kth-rpl",
58-
project=f"opensf-eval",
59-
name=f"{cfg.output}",
60-
offline=(cfg.wandb_mode == "offline"))
62+
if cfg.wandb_mode != "disabled":
63+
logger = WandbLogger(save_dir=output_dir,
64+
entity="kth-rpl",
65+
project=f"opensf-eval",
66+
name=f"{cfg.output}",
67+
offline=(cfg.wandb_mode == "offline"))
68+
logger.watch(mymodel, log_graph=False)
69+
else:
70+
# check local tensorboard logging: tensorboard --logdir logs/jobs/{log folder}
71+
logger = TensorBoardLogger(save_dir=output_dir, name="logs")
6172

62-
trainer = pl.Trainer(logger=wandb_logger, devices=1)
73+
trainer = pl.Trainer(logger=logger, devices=1)
6374
# NOTE(Qingwen): search & check: def eval_only_step_(self, batch, res_dict)
6475
trainer.validate(model = mymodel, \
6576
dataloaders = DataLoader( \
6677
HDF5Dataset(cfg.dataset_path + f"/{cfg.data_mode}", \
6778
n_frames=cfg.num_frames, \
6879
eval=True, leaderboard_version=cfg.leaderboard_version), \
6980
batch_size=1, shuffle=False))
70-
wandb.finish()
81+
if cfg.wandb_mode != "disabled":
82+
wandb.finish()
83+
print(f"---LOG[eval]: Finished feed-forward evaluation. Logging saved to {output_dir}/output.log")
7184

7285
if __name__ == "__main__":
7386
main()

process.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def main(
186186
if not os.path.exists(gm_config_path) and run_gm:
187187
raise FileNotFoundError(f"Ground segmentation config file not found: {gm_config_path}. Please check folder")
188188

189-
190189
data_path = Path(data_dir)
191190
dataset = HDF5Data(data_path) # single frame reading.
192191
all_scene_ids = list(dataset.scene_id_bounds.keys())

src/dataset.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import h5py, pickle, argparse
2424
from tqdm import tqdm
2525
import numpy as np
26+
from torchvision import transforms
2627

2728
import os, sys
2829
BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' ))
@@ -185,8 +186,8 @@ def __call__(self, data_dict):
185186
class HDF5Dataset(Dataset):
186187
def __init__(self, directory, \
187188
transform=None, n_frames=2, ssl_label=None, \
188-
eval = False, eval_input_seq = False, leaderboard_version=1, \
189-
vis_name='', flow_num=1):
189+
eval = False, leaderboard_version=1, \
190+
vis_name=''):
190191
'''
191192
Args:
192193
directory: the directory of the dataset, the folder should contain some .h5 file and index_total.pkl.
@@ -196,10 +197,8 @@ def __init__(self, directory, \
196197
* n_frames: the number of frames we use, default is 2: current (pc0), next (pc1); if it's more than 2, then it read the history from current.
197198
* ssl_label: if attr, it will read the dynamic cluster label. Otherwise, no dynamic cluster label in data dict.
198199
* eval: if True, use the eval index (only used it for leaderboard evaluation)
199-
* eval_input_seq: I forgot what it is.... xox...
200200
* leaderboard_version: 1st or 2nd, default is 1. If '2', we will use the index_eval_v2.pkl from assets/docs.
201201
* vis_name: the data of the visualization, default is ''.
202-
* flow_num: the number of future frames we read, default is 1. (pc0->pc1 flow)
203202
'''
204203
super(HDF5Dataset, self).__init__()
205204
self.directory = directory
@@ -209,12 +208,10 @@ def __init__(self, directory, \
209208
self.data_index = pickle.load(f)
210209

211210
self.eval_index = False
212-
self.eval_input_seq = eval_input_seq
213211
self.ssl_label = import_func(f"src.autolabel.{ssl_label}") if ssl_label is not None else None
214212
self.history_frames = n_frames - 2
215213
self.vis_name = vis_name if isinstance(vis_name, list) else [vis_name]
216214
self.transform = transform
217-
self.flow_num = flow_num
218215

219216
if eval:
220217
eval_index_file = os.path.join(self.directory, 'index_eval.pkl')
@@ -267,7 +264,7 @@ def __init__(self, directory, \
267264

268265
def __len__(self):
269266
# return 100 # for testing
270-
if self.eval_index and not self.eval_input_seq:
267+
if self.eval_index:
271268
return len(self.eval_data_index)
272269
elif not self.eval_index and self.train_index is not None:
273270
return len(self.train_index)
@@ -278,25 +275,17 @@ def valid_index(self, index_):
278275
Check if the index is valid for the current mode and satisfy the constraints.
279276
"""
280277
eval_flag = False
281-
if self.eval_index and not self.eval_input_seq:
278+
if self.eval_index:
282279
eval_index_ = index_
283280
scene_id, timestamp = self.eval_data_index[eval_index_]
284281
index_ = self.data_index.index([scene_id, timestamp])
285282
max_idx = self.scene_id_bounds[scene_id]["max_index"]
286283
if index_ >= max_idx:
287284
_, index_ = self.valid_index(eval_index_ - 1)
288285
eval_flag = True
289-
elif self.eval_index and self.eval_input_seq:
290-
scene_id, timestamp = self.data_index[index_]
291-
# to make sure we have continuous frames
292-
if self.scene_id_bounds[scene_id]["max_index"] <= index_:
293-
index_ = index_ - 1
294-
scene_id, timestamp = self.data_index[index_]
295-
eval_flag = True if [scene_id, timestamp] in self.eval_data_index else False
296286
elif self.train_index is not None:
297287
train_index_ = index_
298288
scene_id, timestamp = self.train_index[train_index_]
299-
# FIXME: it works now, but self.flow_num is not possible in this case.
300289
max_idx = self.scene_id_bounds[scene_id]["max_index"]
301290
index_ = self.data_index.index([scene_id, timestamp])
302291
if index_ >= max_idx:
@@ -306,7 +295,7 @@ def valid_index(self, index_):
306295
max_idx = self.scene_id_bounds[scene_id]["max_index"]
307296
min_idx = self.scene_id_bounds[scene_id]["min_index"]
308297

309-
max_valid_index_for_flow = max_idx - self.flow_num
298+
max_valid_index_for_flow = max_idx - 1
310299
min_valid_index_for_flow = min_idx + self.history_frames
311300
index_ = max(min_valid_index_for_flow, min(max_valid_index_for_flow, index_))
312301
return eval_flag, index_

src/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,6 @@
4949
# * pip install pytorch3d assets/cuda/histlib
5050
try:
5151
from .icpflow import ICPFlow
52-
except ImportError:
52+
except ImportError as e:
5353
print("--- WARNING [model]: ICPFlow is not imported, as it requires pytorch3d lib which is not installed.")
5454
print(f"Detail error message\033[0m: {e}. Just ignore this warning if code runs without these models.")

src/runner.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#
1616
"""
1717

18-
import os
18+
import os, sys
1919
import torch
2020
import torch.distributed as dist
2121
import torch.multiprocessing as mp
@@ -34,7 +34,7 @@
3434
from .utils.eval_metric import OfficialMetrics, evaluate_leaderboard, evaluate_leaderboard_v2, evaluate_ssf
3535
from .utils.av2_eval import write_output_file
3636
from .utils.mics import zip_res
37-
37+
from .utils import InlineTee
3838
class SceneDistributedSampler(Sampler):
3939
"""
4040
A DistributedSampler that distributes data based on scene IDs, not individual indices.
@@ -101,11 +101,12 @@ def __init__(self, cfg, rank, world_size, mode):
101101
self.mode = mode
102102

103103
self.model.to(self.device)
104-
self.metrics = OfficialMetrics() if self.mode in ['val', 'eval'] else None
104+
self.metrics = OfficialMetrics() if self.mode in ['val', 'eval', 'valid'] else None
105+
self.res_name = cfg.get('res_name', cfg.model.name)
105106
self.save_res_path = cfg.get('save_res_path', None)
106107

107108
def _setup_dataloader(self):
108-
if self.mode in ['val', 'test', 'eval']:
109+
if self.mode in ['val', 'test', 'eval', 'valid']:
109110
dataset_path = self.cfg.dataset_path + f"/{self.cfg.data_mode}"
110111
is_eval_mode = True
111112
else: # 'save'
@@ -153,7 +154,7 @@ def _process_step(self, batch):
153154
final_flow = pose_flow.clone()
154155
final_flow[~batch['gm0']] = res_dict['flow'] + pose_flow[~batch['gm0']]
155156

156-
if self.mode in ['val', 'eval']:
157+
if self.mode in ['val', 'eval', 'valid']:
157158
eval_mask = batch['eval_mask'].squeeze()
158159
gt_flow = batch["flow"]
159160
v1_dict = evaluate_leaderboard(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \
@@ -257,7 +258,7 @@ def _run_process(cfg, mode):
257258
gathered_metrics_objects = [runner.metrics]
258259

259260
if rank == 0:
260-
if mode in ['val', 'eval']:
261+
if mode in ['val', 'eval', 'valid']:
261262
final_metrics = OfficialMetrics()
262263
print(f"\n--- [LOG] Finished processing. Aggregating results from {world_size} GPUs with {len(gathered_metrics_objects)} metrics objects...")
263264
for metrics_obj in gathered_metrics_objects:
@@ -299,17 +300,20 @@ def _run_process(cfg, mode):
299300

300301
runner.cleanup()
301302

302-
def _spawn_wrapper(rank, world_size, cfg, mode):
303+
def _spawn_wrapper(rank, world_size, cfg, mode, output_dir):
304+
log_filepath = f"{output_dir}/output.log" if output_dir else None
305+
if log_filepath and rank==0:
306+
sys.stdout = InlineTee(log_filepath, append=True)
307+
if rank == 0:
308+
print(f"---LOG[eval]: Run optimization-based method: {cfg.model.name} on {cfg.dataset_path}/{cfg.data_mode} set.\n")
303309
torch.cuda.set_device(rank)
304-
305-
# FIXME(Qingwen): better to set these through command, since we might have more nodes to connected.
306310
os.environ['RANK'] = str(rank)
307311
os.environ['WORLD_SIZE'] = str(world_size)
308312
os.environ['MASTER_ADDR'] = 'localhost'
309313
os.environ['MASTER_PORT'] = str(cfg.get('master_port', 12355))
310314
_run_process(cfg, mode)
311315

312-
def launch_runner(cfg, mode):
316+
def launch_runner(cfg, mode, output_dir):
313317
is_slurm_job = 'SLURM_PROCID' in os.environ
314318

315319
if not is_slurm_job and not dist.is_initialized():
@@ -321,7 +325,7 @@ def launch_runner(cfg, mode):
321325
cfg.save_res_path = Path(cfg.dataset_path).parent / "results" / cfg.output
322326

323327
mp.spawn(_spawn_wrapper,
324-
args=(world_size, cfg, mode),
328+
args=(world_size, cfg, mode, output_dir),
325329
nprocs=world_size,
326330
join=True)
327331

0 commit comments

Comments
 (0)