Skip to content

Commit f6e48b6

Browse files
Allow for >1 batch size in splatfacto
1 parent 4a3e3e6 commit f6e48b6

3 files changed

Lines changed: 90 additions & 51 deletions

File tree

nerfstudio/cameras/camera_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
152152
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
153153
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()
154154

155-
def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
155+
def apply_to_camera(self, camera: Cameras) -> Float[Tensor, "b 3 4"]:
156156
"""Apply the pose correction to the world-to-camera matrix in a Camera object"""
157157
if self.config.mode == "off":
158158
return camera.camera_to_worlds

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class FullImageDatamanagerConfig(DataManagerConfig):
8989
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
9090
cache_compressed_images: bool = False
9191
"""If True, cache raw image files as byte strings to RAM."""
92+
batch_size: int = 1
93+
"""The batch size for the dataloader."""
9294

9395

9496
class FullImageDatamanager(DataManager, Generic[TDataset]):
@@ -322,7 +324,7 @@ def setup_train(self):
322324
)
323325
self.train_image_dataloader = DataLoader(
324326
self.train_imagebatch_stream,
325-
batch_size=1,
327+
batch_size=self.config.batch_size,
326328
num_workers=self.config.dataloader_num_workers,
327329
collate_fn=identity_collate,
328330
)
@@ -385,28 +387,50 @@ def get_train_rays_per_batch(self) -> int:
385387
def next_train(self, step: int) -> Tuple[Cameras, Dict]:
386388
"""Returns the next training batch
387389
Returns a Camera instead of raybundle"""
390+
388391
self.train_count += 1
389392
if self.config.cache_images == "disk":
390-
camera, data = next(self.iter_train_image_dataloader)[0]
393+
output = next(self.iter_train_image_dataloader)
394+
print("Alex", output)
395+
camera, data = output[0]
391396
return camera, data
392397

393-
image_idx = self.train_unseen_cameras.pop(0)
394-
# Make sure to re-populate the unseen cameras list if we have exhausted it
395-
if len(self.train_unseen_cameras) == 0:
396-
self.train_unseen_cameras = self.sample_train_cameras()
397-
398-
data = self.cached_train[image_idx]
399-
# We're going to copy to make sure we don't mutate the cached dictionary.
400-
# This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335
401-
data = data.copy()
402-
data["image"] = data["image"].to(self.device)
403-
404-
assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension"
405-
camera = self.train_cameras[image_idx : image_idx + 1].to(self.device)
406-
if camera.metadata is None:
407-
camera.metadata = {}
408-
camera.metadata["cam_idx"] = image_idx
409-
return camera, data
398+
image_indices = []
399+
for _ in range(self.config.batch_size):
400+
# Make sure to re-populate the unseen cameras list if we have exhausted it
401+
if len(self.train_unseen_cameras) == 0:
402+
self.train_unseen_cameras = self.sample_train_cameras()
403+
image_indices.append(self.train_unseen_cameras.pop(0))
404+
405+
all_keys = self.cached_train[0].keys()
406+
407+
data = {}
408+
for key in all_keys:
409+
if key == "image":
410+
data[key] = torch.stack([self.cached_train[i][key] for i in image_indices]).to(self.device)
411+
else:
412+
data[key] = [self.cached_train[i][key] for i in image_indices]
413+
414+
cameras = Cameras(
415+
camera_to_worlds=self.train_cameras.camera_to_worlds[image_indices],
416+
fx=self.train_cameras.fx[image_indices],
417+
fy=self.train_cameras.fy[image_indices],
418+
cx=self.train_cameras.cx[image_indices],
419+
cy=self.train_cameras.cy[image_indices],
420+
width=self.train_cameras.width[image_indices],
421+
height=self.train_cameras.height[image_indices],
422+
camera_type=self.train_cameras.camera_type[image_indices],
423+
).to(self.device)
424+
425+
if self.train_cameras.distortion_params is not None:
426+
cameras.distortion_params = self.train_cameras.distortion_params[image_indices]
427+
428+
if cameras.metadata is None:
429+
cameras.metadata = {}
430+
431+
cameras.metadata["cam_idx"] = image_indices
432+
433+
return cameras, data
410434

411435
def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
412436
"""Returns the next evaluation batch

nerfstudio/models/splatfacto.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,28 @@
4646
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases
4747

4848

49-
def resize_image(image: torch.Tensor, d: int):
49+
def resize_image(image: torch.Tensor, d: int) -> torch.Tensor:
5050
"""
5151
Downscale images using the same 'area' method in opencv
5252
53-
:param image shape [H, W, C]
53+
:param image shape [B, H, W, C]
5454
:param d downscale factor (must be 2, 4, 8, etc.)
5555
56-
return downscaled image in shape [H//d, W//d, C]
56+
return downscaled image in shape [B, H//d, W//d, C]
5757
"""
5858
import torch.nn.functional as tf
5959

60-
image = image.to(torch.float32)
6160
weight = (1.0 / (d * d)) * torch.ones((1, 1, d, d), dtype=torch.float32, device=image.device)
62-
return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)
61+
62+
B, H, W, C = image.shape
63+
image = image.permute(0, 3, 1, 2) # [B, C, H, W]
64+
image = image.reshape(B * C, 1, H, W) # Combine batch and channel dimensions for Conv2D
65+
66+
downscaled = tf.conv2d(image, weight, stride=d)
67+
downscaled = downscaled.reshape(B, C, downscaled.shape[-2], downscaled.shape[-1])
68+
downscaled = downscaled.permute(0, 2, 3, 1) # [B, H//d, W//d, C]
69+
70+
return downscaled
6371

6472

6573
@torch_compile()
@@ -482,32 +490,31 @@ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int)
482490
)
483491
return out["rgb"]
484492

485-
def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
486-
"""Takes in a camera and returns a dictionary of outputs.
493+
def get_outputs(self, cameras: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
494+
"""Takes in cameras and returns a dictionary of outputs.
487495
488496
Args:
489-
camera: The camera(s) for which output images are rendered. It should have
497+
cameras: The camera(s) for which output images are rendered. It should have
490498
all the needed information to compute the outputs.
491499
492500
Returns:
493501
Outputs of model. (ie. rendered colors)
494502
"""
495-
if not isinstance(camera, Cameras):
503+
if not isinstance(cameras, Cameras):
496504
print("Called get_outputs with not a camera")
497505
return {}
498506

499507
if self.training:
500-
assert camera.shape[0] == 1, "Only one camera at a time"
501-
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)
508+
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(cameras)
502509
else:
503-
optimized_camera_to_world = camera.camera_to_worlds
510+
optimized_camera_to_world = cameras.camera_to_worlds
504511

505512
# cropping
506513
if self.crop_box is not None and not self.training:
507514
crop_ids = self.crop_box.within(self.means).squeeze()
508515
if crop_ids.sum() == 0:
509516
return self.get_empty_outputs(
510-
int(camera.width.item()), int(camera.height.item()), self.background_color
517+
int(cameras.width.item()), int(cameras.height.item()), self.background_color
511518
)
512519
else:
513520
crop_ids = None
@@ -530,12 +537,16 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
530537
colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)
531538

532539
camera_scale_fac = self._get_downscale_factor()
533-
camera.rescale_output_resolution(1 / camera_scale_fac)
534-
viewmat = get_viewmat(optimized_camera_to_world)
535-
K = camera.get_intrinsics_matrices().cuda()
536-
W, H = int(camera.width.item()), int(camera.height.item())
540+
cameras.rescale_output_resolution(1 / camera_scale_fac)
541+
viewmats = get_viewmat(optimized_camera_to_world)
542+
Ks = cameras.get_intrinsics_matrices().cuda()
543+
544+
W, H = (
545+
int(cameras.width[0]),
546+
int(cameras.height[0]),
547+
) # assume all cameras have the same resolution
537548
self.last_size = (H, W)
538-
camera.rescale_output_resolution(camera_scale_fac) # type: ignore
549+
cameras.rescale_output_resolution(camera_scale_fac) # type: ignore
539550

540551
# apply the compensation of screen space blurring to gaussians
541552
if self.config.rasterize_mode not in ["antialiased", "classic"]:
@@ -558,8 +569,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
558569
scales=torch.exp(scales_crop),
559570
opacities=torch.sigmoid(opacities_crop).squeeze(-1),
560571
colors=colors_crop,
561-
viewmats=viewmat, # [1, 4, 4]
562-
Ks=K, # [1, 3, 3]
572+
viewmats=viewmats, # [1, 4, 4]
573+
Ks=Ks, # [1, 3, 3]
563574
width=W,
564575
height=H,
565576
packed=False,
@@ -585,24 +596,28 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
585596

586597
# apply bilateral grid
587598
if self.config.use_bilateral_grid and self.training:
588-
if camera.metadata is not None and "cam_idx" in camera.metadata:
589-
rgb = self._apply_bilateral_grid(rgb, camera.metadata["cam_idx"], H, W)
599+
if cameras.metadata is not None and "cam_idx" in cameras.metadata:
600+
rgb = self._apply_bilateral_grid(rgb, cameras.metadata["cam_idx"], H, W)
590601

591602
if render_mode == "RGB+ED":
592603
depth_im = render[:, ..., 3:4]
593-
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0)
604+
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max())
594605
else:
595606
depth_im = None
596607

597608
if background.shape[0] == 3 and not self.training:
598609
background = background.expand(H, W, 3)
599610

600-
return {
601-
"rgb": rgb.squeeze(0), # type: ignore
602-
"depth": depth_im, # type: ignore
603-
"accumulation": alpha.squeeze(0), # type: ignore
604-
"background": background, # type: ignore
605-
} # type: ignore
611+
outputs = {
612+
"rgb": rgb,
613+
"depth": depth_im,
614+
"accumulation": alpha,
615+
"background": background,
616+
}
617+
618+
if self.training:
619+
return outputs
620+
return {k: v.squeeze(0) if k != "background" else v for k, v in outputs.items()}
606621

607622
def get_gt_img(self, image: torch.Tensor):
608623
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
@@ -622,7 +637,7 @@ def composite_with_background(self, image, background) -> torch.Tensor:
622637
image: the image to composite
623638
background: the background color
624639
"""
625-
if image.shape[2] == 4:
640+
if image.shape[-1] == 4:
626641
alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3))
627642
return alpha * image[..., :3] + (1 - alpha) * background
628643
else:
@@ -671,7 +686,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
671686
pred_img = pred_img * mask
672687

673688
Ll1 = torch.abs(gt_img - pred_img).mean()
674-
simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...])
689+
simloss = 1 - self.ssim(gt_img.permute(0, 3, 1, 2), pred_img.permute(0, 3, 1, 2))
675690
if self.config.use_scale_regularization and self.step % 10 == 0:
676691
scale_exp = torch.exp(self.scales)
677692
scale_reg = (

0 commit comments

Comments
 (0)