Skip to content

Commit fed4664

Browse files
committed
(1) add support for ScanNet. (2) Visualize depths during testing.
1 parent 041e36f commit fed4664

8 files changed

Lines changed: 194 additions & 36 deletions

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ The code-base has additional support for:
2929
* Total Variation Loss for smoother embeddings (use `--tv-loss-weight` to enable)
3030
* Sparsity-inducing loss on the ray weights (use `--sparse-loss-weight` to enable)
3131

32+
## ScanNet dataset support
33+
The repo now supports training a NeRF model on a scene from the ScanNet dataset. I personally found setting up the ScanNet dataset to be a bit tricky. Please find some instructions/notes in [ScanNet.md](ScanNet.md).
34+
3235

3336
## TODO:
3437
* Voxel pruning during training and/or inference

ScanNet.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# ScanNet Instructions
2+
3+
I personally found it a bit tricky to setup the ScanNet dataset the first time I tried it. So, I am compiling some notes/instructions on how to do it in case someone finds it useful.
4+
5+
### 1. Dataset download
6+
7+
To download ScanNet data and its labels, follow the instructions [here](https://github.com/ScanNet/ScanNet). Basically, fill out the ScanNet Terms of Use agreement and email it to [scannet@googlegroups.com](mailto:scannet@googlegroups.com). You will receive a download link to the dataset. Download the dataset and unzip it.
8+
9+
### 2. Use [SensReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) to extract RGB-D and camera data
10+
Use the `reader.py` script as follows for each scene you want to work with:
11+
```
12+
python reader.py --filename [.sens file to export data from] --output_path [output directory to export data to]
13+
Options:
14+
--export_depth_images: export all depth frames as 16-bit pngs (depth shift 1000)
15+
--export_color_images: export all color frames as 8-bit rgb jpgs
16+
--export_poses: export all camera poses (4x4 matrix, camera to world)
17+
--export_intrinsics: export camera intrinsics (4x4 matrix)
18+
```
19+
20+
### 3. Then, use this [script](https://github.com/zju3dv/object_nerf/blob/main/data_preparation/scannet_sens_reader/convert_to_nerf_style_data.py) to convert the data to NeRF-style format. For instructions, see Step 1 [here](https://github.com/zju3dv/object_nerf/tree/main/data_preparation).
21+
1. The generated transforms_xxx.json comes with transformation matrix (from camera coordinate to world coordinate) in SLAM / OpenCV format (xyz -> right down forward). You need to change to NDC format (xyz -> right up back) in the dataloader for training with NeRF convention.
22+
2. For example, see the conversion done [here](https://github.com/cvg/nice-slam/blob/7af15cc33729aa5a8ca052908d96f495e34ab34c/src/utils/datasets.py#L205).

configs/scannet_scene0000.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
expname = scannet_scene0000_00
2+
basedir = ./logs
3+
datadir = /work/yashsb/datasets/ScanNet/
4+
dataset_type = scannet
5+
6+
no_batching = False
7+
8+
use_viewdirs = True
9+
white_bkgd = False
10+
lrate_decay = 500
11+
12+
N_samples = 64
13+
N_importance = 128
14+
N_rand = 1024

hash_encoding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, x):
6060
x_embedded_all = []
6161
for i in range(self.n_levels):
6262
resolution = torch.floor(self.base_resolution * self.b**i)
63-
voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices = get_voxel_vertices(\
63+
voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask = get_voxel_vertices(\
6464
x, self.bounding_box, \
6565
resolution, self.log2_hashmap_size)
6666

@@ -69,7 +69,8 @@ def forward(self, x):
6969
x_embedded = self.trilinear_interp(x, voxel_min_vertex, voxel_max_vertex, voxel_embedds)
7070
x_embedded_all.append(x_embedded)
7171

72-
return torch.cat(x_embedded_all, dim=-1)
72+
keep_mask = keep_mask.sum(dim=-1)==keep_mask.shape[-1]
73+
return torch.cat(x_embedded_all, dim=-1), keep_mask
7374

7475

7576
class SHEncoder(nn.Module):

load_blender.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,4 @@ def load_blender_data(basedir, half_res=False, testskip=1):
8888

8989
bounding_box = get_bbox3d_for_blenderobj(metas["train"], H, W, near=2.0, far=6.0)
9090

91-
return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box
92-
93-
91+
return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box

load_scannet.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import imageio
5+
import json
6+
import torch.nn.functional as F
7+
import cv2
8+
import pyvista as pv
9+
10+
trans_t = lambda t : torch.Tensor([
11+
[1,0,0,0],
12+
[0,1,0,0],
13+
[0,0,1,t],
14+
[0,0,0,1]]).float()
15+
16+
rot_phi = lambda phi : torch.Tensor([
17+
[1,0,0,0],
18+
[0,np.cos(phi),-np.sin(phi),0],
19+
[0,np.sin(phi), np.cos(phi),0],
20+
[0,0,0,1]]).float()
21+
22+
rot_theta = lambda th : torch.Tensor([
23+
[np.cos(th),0,-np.sin(th),0],
24+
[0,1,0,0],
25+
[np.sin(th),0, np.cos(th),0],
26+
[0,0,0,1]]).float()
27+
28+
29+
def pose_spherical(theta, phi, radius):
30+
c2w = trans_t(radius)
31+
c2w = rot_phi(phi/180.*np.pi) @ c2w
32+
c2w = rot_theta(theta/180.*np.pi) @ c2w
33+
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34+
return c2w
35+
36+
37+
def load_scannet_data(basedir, sceneID, half_res=False, trainskip=10, testskip=1):
38+
'''
39+
basedir is something like: "/work/yashsb/datasets/ScanNet/"
40+
'''
41+
scansdir = os.path.join(basedir, "scans")
42+
basedir = os.path.join(basedir, "nerfstyle_"+sceneID)
43+
44+
splits = ['train', 'val', 'test']
45+
metas = {}
46+
for s in splits:
47+
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
48+
metas[s] = json.load(fp)
49+
50+
all_imgs = []
51+
all_poses = []
52+
counts = [0]
53+
for s in splits:
54+
meta = metas[s]
55+
imgs = []
56+
poses = []
57+
if s=='train':
58+
skip = trainskip
59+
else:
60+
skip = testskip
61+
62+
for frame in meta['frames'][::skip]:
63+
fname = os.path.join(basedir, frame['file_path'] + '.png')
64+
imgs.append(imageio.imread(fname))
65+
pose = np.array(frame['transform_matrix'])
66+
67+
### NEED to do this because ScanNet uses OpenCV convention
68+
pose[:3, 1] *= -1
69+
pose[:3, 2] *= -1
70+
71+
poses.append(pose)
72+
73+
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
74+
poses = np.array(poses).astype(np.float32)
75+
counts.append(counts[-1] + imgs.shape[0])
76+
all_imgs.append(imgs)
77+
all_poses.append(poses)
78+
79+
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
80+
81+
imgs = np.concatenate(all_imgs, 0)
82+
poses = np.concatenate(all_poses, 0)
83+
84+
H, W = imgs[0].shape[:2]
85+
camera_angle_x = float(meta['camera_angle_x'])
86+
focal = .5 * W / np.tan(.5 * camera_angle_x)
87+
88+
render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
89+
90+
if half_res:
91+
H = H//2
92+
W = W//2
93+
focal = focal/2.
94+
95+
imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
96+
for i, img in enumerate(imgs):
97+
imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
98+
imgs = imgs_half_res
99+
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
100+
101+
## getting an approximate bounding box for the scene
102+
# load scene mesh
103+
mesh = pv.read(os.path.join(scansdir, sceneID, f"{sceneID}_vh_clean.ply"))
104+
# get the bounding box
105+
bounding_box = torch.tensor(mesh.bounds[::2]) - 1, torch.tensor(mesh.bounds[1::2]) + 1
106+
107+
return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box

run_nerf.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from load_llff import load_llff_data
2424
from load_deepvoxels import load_dv_data
2525
from load_blender import load_blender_data
26+
from load_scannet import load_scannet_data
2627
from load_LINEMOD import load_LINEMOD_data
2728

2829

@@ -45,7 +46,7 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
4546
"""Prepares inputs and applies network 'fn'.
4647
"""
4748
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
48-
embedded = embed_fn(inputs_flat)
49+
embedded, keep_mask = embed_fn(inputs_flat)
4950

5051
if viewdirs is not None:
5152
input_dirs = viewdirs[:,None].expand(inputs.shape)
@@ -54,6 +55,7 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
5455
embedded = torch.cat([embedded, embedded_dirs], -1)
5556

5657
outputs_flat = batchify(fn, netchunk)(embedded)
58+
outputs_flat[~keep_mask, -1] = 0 # set sigma to 0 for invalid points
5759
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
5860
return outputs
5961

@@ -135,7 +137,7 @@ def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
135137
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
136138
all_ret[k] = torch.reshape(all_ret[k], k_sh)
137139

138-
k_extract = ['rgb_map', 'disp_map', 'acc_map']
140+
k_extract = ['rgb_map', 'depth_map', 'acc_map']
139141
ret_list = [all_ret[k] for k in k_extract]
140142
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
141143
return ret_list + [ret_dict]
@@ -144,6 +146,7 @@ def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
144146
def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
145147

146148
H, W, focal = hwf
149+
near, far = render_kwargs['near'], render_kwargs['far']
147150

148151
if render_factor!=0:
149152
# Render downsampled for speed
@@ -152,18 +155,20 @@ def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedi
152155
focal = focal/render_factor
153156

154157
rgbs = []
155-
disps = []
158+
depths = []
156159
psnrs = []
157160

158161
t = time.time()
159162
for i, c2w in enumerate(tqdm(render_poses)):
160163
print(i, time.time() - t)
161164
t = time.time()
162-
rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
165+
rgb, depth, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
163166
rgbs.append(rgb.cpu().numpy())
164-
disps.append(disp.cpu().numpy())
167+
# normalize depth to [0,1]
168+
depth = (depth - near) / (far - near)
169+
depths.append(depth.cpu().numpy())
165170
if i==0:
166-
print(rgb.shape, disp.shape)
171+
print(rgb.shape, depth.shape)
167172

168173
if gt_imgs is not None and render_factor==0:
169174
try:
@@ -174,11 +179,21 @@ def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedi
174179
print(p)
175180
psnrs.append(p)
176181

177-
178182
if savedir is not None:
183+
# save rgb and depth as a figure
184+
fig = plt.figure(figsize=(25,15))
185+
ax = fig.add_subplot(1, 2, 1)
179186
rgb8 = to8b(rgbs[-1])
187+
ax.imshow(rgb8)
188+
ax.axis('off')
189+
ax = fig.add_subplot(1, 2, 2)
190+
ax.imshow(depths[-1], cmap='plasma', vmin=0, vmax=1)
191+
ax.axis('off')
180192
filename = os.path.join(savedir, '{:03d}.png'.format(i))
181-
imageio.imwrite(filename, rgb8)
193+
# save as png
194+
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
195+
plt.close(fig)
196+
# imageio.imwrite(filename, rgb8)
182197

183198

184199
rgbs = np.stack(rgbs, 0)
@@ -224,9 +239,6 @@ def create_nerf(args):
224239

225240
model_fine = None
226241

227-
# if args.i_embed==1:
228-
# args.N_importance = 0
229-
230242
if args.N_importance > 0:
231243
if args.i_embed==1:
232244
model_fine = NeRFSmall(num_layers=2,
@@ -248,9 +260,6 @@ def create_nerf(args):
248260

249261
# Create optimizer
250262
if args.i_embed==1:
251-
# sparse_opt = torch.optim.SparseAdam(embedding_params, lr=args.lrate, betas=(0.9, 0.99), eps=1e-15)
252-
# dense_opt = torch.optim.Adam(grad_vars, lr=args.lrate, betas=(0.9, 0.99), weight_decay=1e-6)
253-
# optimizer = MultiOptimizer(optimizers={"sparse_opt": sparse_opt, "dense_opt": dense_opt})
254263
optimizer = RAdam([
255264
{'params': grad_vars, 'weight_decay': 1e-6},
256265
{'params': embedding_params, 'eps': 1e-15}
@@ -352,8 +361,8 @@ def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=F
352361
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
353362
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
354363

355-
depth_map = torch.sum(weights * z_vals, -1)
356-
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
364+
depth_map = torch.sum(weights * z_vals, -1) / torch.sum(weights, -1)
365+
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map)
357366
acc_map = torch.sum(weights, -1)
358367

359368
if white_bkgd:
@@ -445,13 +454,12 @@ def render_rays(ray_batch,
445454

446455
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
447456

448-
# raw = run_network(pts)
449457
raw = network_query_fn(pts, viewdirs, network_fn)
450458
rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
451459

452460
if N_importance > 0:
453461

454-
rgb_map_0, disp_map_0, acc_map_0, sparsity_loss_0 = rgb_map, disp_map, acc_map, sparsity_loss
462+
rgb_map_0, depth_map_0, acc_map_0, sparsity_loss_0 = rgb_map, depth_map, acc_map, sparsity_loss
455463

456464
z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
457465
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
@@ -466,12 +474,12 @@ def render_rays(ray_batch,
466474

467475
rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
468476

469-
ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'sparsity_loss': sparsity_loss}
477+
ret = {'rgb_map' : rgb_map, 'depth_map' : depth_map, 'acc_map' : acc_map, 'sparsity_loss': sparsity_loss}
470478
if retraw:
471479
ret['raw'] = raw
472480
if N_importance > 0:
473481
ret['rgb0'] = rgb_map_0
474-
ret['disp0'] = disp_map_0
482+
ret['depth0'] = depth_map_0
475483
ret['acc0'] = acc_map_0
476484
ret['sparsity_loss0'] = sparsity_loss_0
477485
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
@@ -571,6 +579,10 @@ def config_parser():
571579
parser.add_argument("--half_res", action='store_true',
572580
help='load blender synthetic data at 400x400 instead of 800x800')
573581

582+
## scannet flags
583+
parser.add_argument("--scannet_sceneID", type=str, default='scene0000_00',
584+
help='sceneID to load from scannet')
585+
574586
## llff flags
575587
parser.add_argument("--factor", type=int, default=8,
576588
help='downsample factor for LLFF images')
@@ -658,6 +670,15 @@ def train():
658670
else:
659671
images = images[...,:3]
660672

673+
elif args.dataset_type == 'scannet':
674+
images, poses, render_poses, hwf, i_split, bounding_box = load_scannet_data(args.datadir, args.scannet_sceneID, args.half_res)
675+
args.bounding_box = bounding_box
676+
print('Loaded scannet', images.shape, render_poses.shape, hwf, args.datadir)
677+
i_train, i_val, i_test = i_split
678+
679+
near = 0.1
680+
far = 10.0
681+
661682
elif args.dataset_type == 'LINEMOD':
662683
images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
663684
print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
@@ -854,7 +875,7 @@ def train():
854875
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
855876

856877
##### Core optimization loop #####
857-
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
878+
rgb, depth, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
858879
verbose=i < 10, retraw=True,
859880
**render_kwargs_train)
860881

0 commit comments

Comments
 (0)