diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index f2bc6d96ef..d4ca82d030 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -18,6 +18,7 @@ import random import warnings +from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, Optional, Type, Union @@ -335,8 +336,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, # only sample within the mask, if the mask is in the batch all_indices = [] - all_images = [] - all_depth_images = [] + all_images = defaultdict(list) assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2" num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images) @@ -350,10 +350,11 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, ) indices[:, 0] = i all_indices.append(indices) - all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]]) - if "depth_image" in batch: - all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]]) + for key, value in batch.items(): + if key in ["image_idx", "mask"]: + continue + all_images[key].append(value[i][indices[:, 1], indices[:, 2]]) else: for i, num_rays in enumerate(num_rays_per_image): image_height, image_width, _ = batch["image"][i].shape @@ -363,26 +364,19 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, indices = self.sample_method(num_rays, 1, image_height, image_width, device=device) indices[:, 0] = i all_indices.append(indices) - all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]]) - if "depth_image" in batch: - all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]]) + for key, value in batch.items(): + if key in ["image_idx", "mask"]: + continue + all_images[key].append(value[i][indices[:, 1], indices[:, 2]]) indices = torch.cat(all_indices, dim=0) - c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1)) - collated_batch = { - key: value[c, y, x] - for key, value in batch.items() - if key not in ("image_idx", "image", "mask", "depth_image") and value is not None - } - - collated_batch["image"] = torch.cat(all_images, dim=0) - if "depth_image" in batch: - collated_batch["depth_image"] = torch.cat(all_depth_images, dim=0) + collated_batch = {key: torch.cat(all_images[key], dim=0) for key in all_images} assert collated_batch["image"].shape[0] == num_rays_per_batch # Needed to correct the random indices to their actual camera idx locations. + c = indices[..., 0].flatten() indices[:, 0] = batch["image_idx"][c] collated_batch["indices"] = indices # with the abs camera indices