diff --git a/examples/simple_trainer_3dcs.py b/examples/simple_trainer_3dcs.py new file mode 100644 index 000000000..d17aef2d0 --- /dev/null +++ b/examples/simple_trainer_3dcs.py @@ -0,0 +1,1267 @@ +import json +import math +import os +import time +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import yaml +from datasets.colmap import Dataset, Parser +from datasets.traj import ( + generate_ellipse_path_z, + generate_interpolated_path, + generate_spiral_path, +) +from fused_ssim import fused_ssim +from lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal, assert_never +from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed + +from gsplat import export_splats +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat.optimizers import SelectiveAdam +from gsplat.rendering import rasterization_3dcs +from gsplat.strategy import DefaultStrategy, ConvexSplattingStrategy +from gsplat.utils import save_ply +from gsplat_viewer import GsplatViewer, GsplatRenderTabState +from nerfview import CameraState, RenderTabState, apply_float_colormap +from gsplat.utils import fibonacci_sphere + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space. Don't do it for 3DCS as you need to scale everything properly (learning rates mostly) + normalize_world_space: bool = False + # Camera model + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" + + # Indoor or outdoor model default is indoor. + outdoor: bool = False + + # Light mode. Disabled by default. + light: bool = False + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Whether to save ply file (storage size can be large) + save_ply: bool = False + # Steps to save the model as ply + ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + + # 3DCS parameters + init_sigma = 0.00095 + if not light and outdoor: + init_sigma = 0.001 + elif not light and not outdoor: + init_sigma = 0.0009 + + init_delta = 0.1 + # Initial opacity of CS + init_opa: float = 0.1 + + # K. + num_points_per_convex = 6 + + # Default 3D convex size + convex_size = 1.2 + + init_scale: float = 1.2 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: Union[ConvexSplattingStrategy] = field( + default_factory=ConvexSplattingStrategy + ) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use visible adam from Taming 3DGS. (experimental) + visible_adam: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.0 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.ply_steps = [int(i * factor) for i in self.ply_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + # Strategy adjustements for outdoor + strategy.reset_opacity_until = 9000 + strategy.grow_grad_sigma = 0.000025 + if not self.light and self.outdoor: + strategy.reset_opacity_until = 18_000 + strategy.grow_grad_sigma = 0.000001 + strategy.scaling_cloning = 0.6 + elif not self.light and not self.outdoor: + strategy.refine_stop_iter = 9500 + strategy.grow_grad_sigma = 0.000004 + strategy.mask_threshold = 0.02 + strategy.sigma_scaling_cloning = 0.85 + + if isinstance(strategy, ConvexSplattingStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + visible_adam: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, + num_points_per_convex: int = 6, + init_sigma: float = 0.00095, + init_delta: float = 0.1, + outdoor: bool = False, + light: bool = False, +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + # Initialize the 3D convex size to be fibonnaci sphere scaled at the the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + + radii = init_scale * dist_avg.unsqueeze(1) + + x, y, z = points[:, 0], points[:, 1], points[:, 2] + points_per_convex = fibonacci_sphere(x, y, z, radii, num_points_per_convex) + + N = points.shape[0] + + num_points_per_convex_list = [] + for i in range(N): + num_points_per_convex_list.append(points_per_convex[i].shape[0]) + + tensor_num_points_per_convex = torch.tensor(num_points_per_convex_list, dtype=torch.int, device='cuda:0') + cumsum_of_points_per_convex = torch.cumsum(torch.nn.functional.pad(tensor_num_points_per_convex, (1,0), value=0), 0, dtype=torch.int)[:-1] + + mask = torch.ones((N,)) + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + deltas = torch.log(torch.ones((N,1), ) * init_delta * (1 / torch.reshape(dist_avg, (-1, 1)))) + sigmas = torch.log(torch.ones((N,1), ) * init_sigma) + + + # Distribute the GSs to different ranks (also works for single rank) + points_per_convex = points_per_convex[world_rank::world_size] + rgbs = rgbs[world_rank::world_size] + deltas = deltas[world_rank::world_size] + sigmas = sigmas[world_rank::world_size] + + lr_mask = 0.01 + lr_delta = 0.005 + lr_sigma = 0.0045 + lr_convex_points_init = 0.0005 + lr_opacities = 0.01 + + if not light and outdoor: + lr_sigma = 0.004 + elif not light and not outdoor: + lr_convex_points_init = 0.0004 + + params = [ + # name, value, lr + # xyz/means are changed into our convex points. + # FIXME: Better params as function inputs + ("convex_points", torch.nn.Parameter(points_per_convex), lr_convex_points_init), + ("opacities", torch.nn.Parameter(opacities), lr_opacities), + ("delta", torch.nn.Parameter(deltas), lr_delta), + ("sigma", torch.nn.Parameter(sigmas), lr_sigma), + ("mask", torch.nn.Parameter(mask), lr_mask), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size + optimizer_class = None + if sparse_grad: + optimizer_class = torch.optim.SparseAdam + elif visible_adam: + optimizer_class = SelectiveAdam + else: + optimizer_class = torch.optim.Adam + optimizers = { + name: optimizer_class( + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, _, lr in params + } + + return splats, optimizers, cumsum_of_points_per_convex + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + self.ply_dir = f"{cfg.result_dir}/ply" + os.makedirs(self.ply_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers, cumsum_of_points_per_convex = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + visible_adam=cfg.visible_adam, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + world_rank=world_rank, + world_size=world_size, + num_points_per_convex=cfg.num_points_per_convex, + init_sigma=cfg.init_sigma, + init_delta=cfg.init_delta, + outdoor=cfg.outdoor, + light=cfg.light, + ) + cfg.cumsum_of_points_per_convex = cumsum_of_points_per_convex + print("Model initialized. Number of Convex points:", len(self.splats["convex_points"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, ConvexSplattingStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + else: + assert_never(self.cfg.strategy) + + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = GsplatViewer( + server=self.server, + render_fn=self._viewer_render_fn, + output_dir=Path(cfg.result_dir), + mode="training", + ) + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + masks: Optional[Tensor] = None, + rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, + camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + convex_points = self.splats["convex_points"] # [N, 6, 3] + delta = torch.exp(self.splats["delta"]) + sigma = torch.exp(self.splats["sigma"]) + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + means3d = convex_points.mean(dim=1) + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means3d[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + # # Special treatment for the mask + mask = ((torch.sigmoid(self.splats["mask"]) > 0.01).float() - torch.sigmoid(self.splats["mask"])).detach() + torch.sigmoid(self.splats["mask"]) + opacities = opacities * mask + + if rasterize_mode is None: + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + if camera_model is None: + camera_model = self.cfg.camera_model + render_colors, render_alphas, info = rasterization_3dcs( + convex_points=convex_points, + delta=delta, + sigma=sigma, + num_points_per_convex=self.cfg.num_points_per_convex, + cumsum_of_points_per_convex=self.cfg.cumsum_of_points_per_convex, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, ConvexSplattingStrategy) + else False + ), + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + camera_model=self.cfg.camera_model, + **kwargs, + ) + + if masks is not None: + render_colors[~masks] = 0 + + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["convex_points"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + if cfg.use_bilateral_grid: + # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. + schedulers.append( + torch.optim.lr_scheduler.ChainedScheduler( + [ + torch.optim.lr_scheduler.LinearLR( + self.bil_grid_optimizers[0], + start_factor=0.01, + total_iters=1000, + ), + torch.optim.lr_scheduler.ExponentialLR( + self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + masks = data["mask"].to(device) if "mask" in data else None # [1, H, W] + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + masks=masks, + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.use_bilateral_grid: + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + 0.0005*torch.mean((torch.sigmoid(self.splats["mask"]))) + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + if cfg.use_bilateral_grid: + tvloss = 10 * total_variation_loss(self.bil_grids.grids) + loss += tvloss + + # regularizations + if cfg.opacity_reg > 0.0: + loss = ( + loss + + cfg.opacity_reg + * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + ) + if cfg.scale_reg > 0.0: + loss = ( + loss + + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() + ) + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_3D convexes", len(self.splats["convex_points"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.use_bilateral_grid: + self.writer.add_scalar("train/tvloss", tvloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["convex_points"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + if ( + step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 + ) and cfg.save_ply: + + if self.cfg.app_opt: + means3d = self.spalts["convex_points"].mean(dim=1) + # eval at origin to bake the appeareance into the colors + rgb = self.app_module( + features=self.splats["features"], + embed_ids=None, + dirs=torch.zeros_like(means3d[None, :, :]), + sh_degree=sh_degree_to_use, + ) + rgb = rgb + self.splats["colors"] + rgb = torch.sigmoid(rgb).squeeze(0) + sh0 = rgb_to_sh(rgb) + shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device) + else: + sh0 = self.splats["sh0"] + shN = self.splats["shN"] + + # export_splats( + # means=means, + # scales=scales, + # quats=quats, + # opacities=opacities, + # sh0=sh0, + # shN=shN, + # format="ply", + # save_to=f"{self.ply_dir}/point_cloud_{step}.ply", + # ) + + point_cloud_state_dict = {} + point_cloud_state_dict["convex_points"] = self.splats["convex_points"] + point_cloud_state_dict["delta"] = self.splats["delta"] + point_cloud_state_dict["sigma"] = self.splats["sigma"] + point_cloud_state_dict["active_sh_degree"] = cfg.sh_degree + point_cloud_state_dict["features_dc"] = self.splats["sh0"] + point_cloud_state_dict["features_rest"] = self.splats["shN"] + point_cloud_state_dict["opacity"] = self.splats["opacities"] + torch.save(point_cloud_state_dict, os.path.join(self.ply_dir, f"point_cloud_state_dict.pt")) + + hyperparameters = {} + hyperparameters["num_points_per_convex"] = cfg.num_points_per_convex + hyperparameters["cumsum_of_points_per_convex"] = cfg.cumsum_of_points_per_convex + torch.save(hyperparameters, os.path.join(self.ply_dir, 'hyperparameters.pt')) + #save_ply_convex(self.splats, cfg, f"{self.ply_dir}/point_cloud_{step}.ply", rgb) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + if cfg.visible_adam: + gaussian_cnt = self.splats.means.shape[0] + if cfg.packed: + visibility_mask = torch.zeros_like( + self.splats["opacities"], dtype=bool + ) + visibility_mask.scatter_(0, info["gaussian_ids"], 1) + else: + visibility_mask = (info["radii"] > 0).all(-1).any(0) + + # optimize + for optimizer in self.optimizers.values(): + if cfg.visible_adam: + optimizer.step(visibility_mask) + else: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.bil_grid_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # Run post-backward steps after backward and optimizer + if isinstance(self.cfg.strategy, ConvexSplattingStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed + ) + else: + assert_never(self.cfg.strategy) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + #self.render_traj(step) + + # run compression + if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: + self.run_compression(step=step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.render_tab_state.num_train_rays_per_sec = ( + num_train_rays_per_sec + ) + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + + metrics = defaultdict(list) + for i, data in enumerate(valloader): + + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + masks = data["mask"].to(device) if "mask" in data else None + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + masks=masks, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # # write images + # canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + # canvas = (canvas * 255).astype(np.uint8) + # imageio.imwrite( + # f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + # canvas, + # ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["convex_points"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + if self.cfg.disable_video: + return + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: CameraState, render_tab_state: RenderTabState + ): + assert isinstance(render_tab_state, GsplatRenderTabState) + if render_tab_state.preview_render: + width = render_tab_state.render_width + height = render_tab_state.render_height + else: + width = render_tab_state.viewer_width + height = render_tab_state.viewer_height + c2w = camera_state.c2w + K = camera_state.get_K((width, height)) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + RENDER_MODE_MAP = { + "rgb": "RGB", + "depth(accumulated)": "D", + "depth(expected)": "ED", + "alpha": "RGB", + } + + render_colors, render_alphas, info = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=width, + height=height, + sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree), + near_plane=render_tab_state.near_plane, + far_plane=render_tab_state.far_plane, + radius_clip=render_tab_state.radius_clip, + eps2d=render_tab_state.eps2d, + backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device) + / 255.0, + render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], + rasterize_mode=render_tab_state.rasterize_mode, + camera_model=render_tab_state.camera_model, + ) # [1, H, W, 3] + render_tab_state.total_gs_count = len(self.splats["convex_points"]) + render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() + + if render_tab_state.render_mode == "rgb": + # colors represented with sh are not guranteed to be in [0, 1] + render_colors = render_colors[0, ..., 0:3].clamp(0, 1) + renders = render_colors.cpu().numpy() + elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: + # normalize depth to [0, 1] + depth = render_colors[0, ..., 0:1] + if render_tab_state.normalize_nearfar: + near_plane = render_tab_state.near_plane + far_plane = render_tab_state.far_plane + else: + near_plane = depth.min() + far_plane = depth.max() + depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) + depth_norm = torch.clip(depth_norm, 0, 1) + if render_tab_state.inverse: + depth_norm = 1 - depth_norm + renders = ( + apply_float_colormap(depth_norm, render_tab_state.colormap) + .cpu() + .numpy() + ) + elif render_tab_state.render_mode == "alpha": + alpha = render_alphas[0, ..., 0:1] + if render_tab_state.inverse: + alpha = 1 - alpha + renders = ( + apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() + ) + return renders + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + for k in runner.splats.keys(): + runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) + step = ckpts[0]["step"] + runner.eval(step=step) + #runner.render_traj(step=step) + if cfg.compression is not None: + runner.run_compression(step=step) + else: + runner.train() + + runner.viewer.complete() + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 + + """ + + # Config objects we can choose between. + # Each is a tuple of (CLI description, config object). + configs = { + # FIXME: should it be 3dcs? + "3dcs": ( + "Gaussian splatting training using densification heuristics from the '3D Convex Splatting: Radiance Field Rendering with 3D Smooth Convexes'.", + Config( + strategy=ConvexSplattingStrategy(verbose=True), + ), + ), + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + Config( + strategy=DefaultStrategy(verbose=True), + ), + ), + } + cfg = tyro.extras.overridable_config_cli(configs) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True) diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 7f9bd8f8f..ab6218620 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -12,7 +12,7 @@ from pathlib import Path from gsplat._helper import load_test_data from gsplat.distributed import cli -from gsplat.rendering import rasterization +from gsplat.rendering import rasterization, rasterization_3dcs from nerfview import CameraState, RenderTabState, apply_float_colormap from gsplat_viewer import GsplatViewer, GsplatRenderTabState @@ -101,55 +101,81 @@ def main(local_rank: int, world_rank, world_size: int, args): ) else: means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] - for ckpt_path in args.ckpt: - ckpt = torch.load(ckpt_path, map_location=device)["splats"] - means.append(ckpt["means"]) - quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) - scales.append(torch.exp(ckpt["scales"])) - opacities.append(torch.sigmoid(ckpt["opacities"])) - sh0.append(ckpt["sh0"]) - shN.append(ckpt["shN"]) - means = torch.cat(means, dim=0) - quats = torch.cat(quats, dim=0) - scales = torch.cat(scales, dim=0) - opacities = torch.cat(opacities, dim=0) - sh0 = torch.cat(sh0, dim=0) - shN = torch.cat(shN, dim=0) - colors = torch.cat([sh0, shN], dim=-2) - sh_degree = int(math.sqrt(colors.shape[-2]) - 1) - # # crop - # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device) - # edges = aabb[3:] - aabb[:3] - # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1) - # sel = torch.where(sel)[0] - # means, quats, scales, colors, opacities = ( - # means[sel], - # quats[sel], - # scales[sel], - # colors[sel], - # opacities[sel], - # ) + convex_points, delta, sigma, num_points_per_convex, cumsum_of_points_per_convex = [], [], [], [], [] + if args.backend == "3dcs": + for ckpt_path in args.ckpt: + hyperparam = torch.load(os.path.join(ckpt_path, "hyperparameters.pt"), map_location=device, weights_only=False) + pc = torch.load(os.path.join(ckpt_path, "point_cloud_state_dict.pt"), map_location=device, weights_only=False) + convex_points.append(pc['convex_points']) + delta.append(torch.exp(pc['delta'])) + sigma.append(torch.exp(pc['sigma'])) + opacities.append(torch.sigmoid(pc["opacity"]).squeeze()) + num_points_per_convex.append(torch.tensor([6])) + cumsum_of_points_per_convex.append(hyperparam["cumsum_of_points_per_convex"]) + sh0.append(pc["features_dc"]) + shN.append(pc["features_rest"]) + convex_points = torch.cat(convex_points, dim=0) + delta = torch.cat(delta, dim=0) + sigma = torch.cat(sigma, dim=0) + num_points_per_convex = torch.cat(num_points_per_convex, dim=0) + cumsum_of_points_per_convex = torch.cat(cumsum_of_points_per_convex, dim=0) + opacities = torch.cat(opacities, dim=0) + sh0 = torch.cat(sh0, dim=0) + shN = torch.cat(shN, dim=0) + colors = torch.cat([sh0, shN], dim=-2) + sh_degree = int(pc["active_sh_degree"]) + print("Number of 3D convexes:", convex_points.shape[0]*convex_points.shape[1]) + else: + for ckpt_path in args.ckpt: + ckpt = torch.load(ckpt_path, map_location=device)["splats"] + means.append(ckpt["means"]) + quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) + scales.append(torch.exp(ckpt["scales"])) + opacities.append(torch.sigmoid(ckpt["opacities"])) + sh0.append(ckpt["sh0"]) + shN.append(ckpt["shN"]) + means = torch.cat(means, dim=0) + quats = torch.cat(quats, dim=0) + scales = torch.cat(scales, dim=0) + opacities = torch.cat(opacities, dim=0) + sh0 = torch.cat(sh0, dim=0) + shN = torch.cat(shN, dim=0) + colors = torch.cat([sh0, shN], dim=-2) + sh_degree = int(math.sqrt(colors.shape[-2]) - 1) + + # # crop + # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device) + # edges = aabb[3:] - aabb[:3] + # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1) + # sel = torch.where(sel)[0] + # means, quats, scales, colors, opacities = ( + # means[sel], + # quats[sel], + # scales[sel], + # colors[sel], + # opacities[sel], + # ) - # # repeat the scene into a grid (to mimic a large-scale setting) - # repeats = args.scene_grid - # gridx, gridy = torch.meshgrid( - # [ - # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), - # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), - # ], - # indexing="ij", - # ) - # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape( - # -1, 3 - # ) - # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :] - # means = means.reshape(-1, 3) - # quats = quats.repeat(repeats**2, 1) - # scales = scales.repeat(repeats**2, 1) - # colors = colors.repeat(repeats**2, 1, 1) - # opacities = opacities.repeat(repeats**2) - print("Number of Gaussians:", len(means)) + # # repeat the scene into a grid (to mimic a large-scale setting) + # repeats = args.scene_grid + # gridx, gridy = torch.meshgrid( + # [ + # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), + # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), + # ], + # indexing="ij", + # ) + # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape( + # -1, 3 + # ) + # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :] + # means = means.reshape(-1, 3) + # quats = quats.repeat(repeats**2, 1) + # scales = scales.repeat(repeats**2, 1) + # colors = colors.repeat(repeats**2, 1, 1) + # opacities = opacities.repeat(repeats**2) + print("Number of Gaussians:", len(means)) # register and open viewer @torch.no_grad() @@ -174,34 +200,74 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState "alpha": "RGB", } - render_colors, render_alphas, info = rasterization( - means, # [N, 3] - quats, # [N, 4] - scales, # [N, 3] - opacities, # [N] - colors, # [N, S, 3] - viewmat[None], # [1, 4, 4] - K[None], # [1, 3, 3] - width, - height, - sh_degree=( - min(render_tab_state.max_sh_degree, sh_degree) - if sh_degree is not None - else None - ), - near_plane=render_tab_state.near_plane, - far_plane=render_tab_state.far_plane, - radius_clip=render_tab_state.radius_clip, - eps2d=render_tab_state.eps2d, - backgrounds=torch.tensor([render_tab_state.backgrounds], device=device) - / 255.0, - render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], - rasterize_mode=render_tab_state.rasterize_mode, - camera_model=render_tab_state.camera_model, - packed=False, - ) - render_tab_state.total_gs_count = len(means) - render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() + if args.backend == "gsplat": + rasterization_fn = rasterization + elif args.backend == "3dcs": + rasterization_fn = rasterization_3dcs + elif args.backend == "inria": + from gsplat import rasterization_inria_wrapper + + rasterization_fn = rasterization_inria_wrapper + else: + raise ValueError + + if args.backend == "3dcs": + render_colors, render_alphas, info = rasterization_fn( + convex_points, + delta, + sigma, + num_points_per_convex, + cumsum_of_points_per_convex, + opacities, + colors, + viewmat[None], # [1, 4, 4] + K[None], # [1, 3, 3] + width, + height, + packed=False, + sh_degree=( + min(render_tab_state.max_sh_degree, sh_degree) + if sh_degree is not None + else None + ), + near_plane=render_tab_state.near_plane, + far_plane=render_tab_state.far_plane, + radius_clip=render_tab_state.radius_clip, + eps2d=render_tab_state.eps2d, + backgrounds=torch.tensor([render_tab_state.backgrounds], device=device) + / 255.0, + render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], + rasterize_mode=render_tab_state.rasterize_mode, + camera_model=render_tab_state.camera_model, + ) + else: + render_colors, render_alphas, info = rasterization( + means, # [N, 3] + quats, # [N, 4] + scales, # [N, 3] + opacities, # [N] + colors, # [N, S, 3] + viewmat[None], # [1, 4, 4] + K[None], # [1, 3, 3] + width, + height, + sh_degree=( + min(render_tab_state.max_sh_degree, sh_degree) + if sh_degree is not None + else None + ), + near_plane=render_tab_state.near_plane, + far_plane=render_tab_state.far_plane, + radius_clip=render_tab_state.radius_clip, + eps2d=render_tab_state.eps2d, + backgrounds=torch.tensor([render_tab_state.backgrounds], device=device) + / 255.0, + render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], + rasterize_mode=render_tab_state.rasterize_mode, + camera_model=render_tab_state.camera_model, + ) + render_tab_state.total_gs_count = len(means) + render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() if render_tab_state.render_mode == "rgb": # colors represented with sh are not guranteed to be in [0, 1] @@ -267,6 +333,12 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState parser.add_argument( "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file" ) + parser.add_argument( + "--ply", type=str, nargs="+", default=None, help="path to the .ply file" + ) + parser.add_argument( + "--backend", type=str, default="gsplat", choices=["gsplat", "3dcs", "inria"], help="backend to use for rendering", + ) parser.add_argument( "--port", type=int, default=8080, help="port for the viewer server" ) diff --git a/gsplat/_helper.py b/gsplat/_helper.py index 86bca0520..d543d294c 100644 --- a/gsplat/_helper.py +++ b/gsplat/_helper.py @@ -5,6 +5,13 @@ import torch import torch.nn.functional as F +def load_ply_data( + data_path: Optional[str] = None, + device="cuda", + scene_crop: Tuple[float, float, float, float, float, float] = (-2, -2, -2, 2, 2, 2), + scene_grid: int = 1, +): + assert True def load_test_data( data_path: Optional[str] = None, diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 6a5d1a59d..c91336e9f 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -1983,3 +1983,874 @@ def backward( None, None, ) + + +##### 3DCS #### +def spherical_harmonics_3dcs( + degrees_to_use: int, + convex_points: Tensor, # [N, 6, 3] + dirs: Tensor, # [..., 3] + coeffs: Tensor, # [..., K, 3] + masks: Optional[Tensor] = None, +) -> Tensor: + """Computes spherical harmonics. + + Args: + degrees_to_use: The degree to be used. + convex_points: the 3D convex points. [N, 6, 3] + dirs: Directions. [..., 3] + coeffs: Coefficients. [..., K, 3] + masks: Optional boolen masks to skip some computation. [...,] Default: None. + + Returns: + Spherical harmonics. [..., 3] + """ + assert (degrees_to_use + 1) ** 2 <= coeffs.shape[-2], coeffs.shape + assert dirs.shape[:-1] == coeffs.shape[:-2], (dirs.shape, coeffs.shape) + assert dirs.shape[-1] == 3, dirs.shape + assert coeffs.shape[-1] == 3, coeffs.shape + if masks is not None: + assert masks.shape == dirs.shape[:-1], masks.shape + masks = masks.contiguous() + return _SphericalHarmonics_3dcs.apply( + degrees_to_use, convex_points.contiguous(), dirs.contiguous(), coeffs.contiguous(), masks + ) + +class _SphericalHarmonics_3dcs(torch.autograd.Function): + """Spherical Harmonics version for 3DCS""" + + @staticmethod + def forward( + ctx, sh_degree: int, convex_points: Tensor, dirs: Tensor, coeffs: Tensor, masks: Tensor + ) -> Tensor: + colors = _make_lazy_cuda_func("spherical_harmonics_fwd_3dcs")(sh_degree, convex_points, dirs, coeffs, masks) + ctx.save_for_backward(convex_points, dirs, coeffs, masks) + ctx.sh_degree = sh_degree + ctx.num_bases = coeffs.shape[-2] + return colors + + @staticmethod + def backward(ctx, v_colors: Tensor): + convex_points, dirs, coeffs, masks = ctx.saved_tensors + sh_degree = ctx.sh_degree + num_bases = ctx.num_bases + compute_v_convex_points = ctx.needs_input_grad[1] + compute_v_dirs = ctx.needs_input_grad[2] + v_convex_points, v_coeffs, v_dirs = _make_lazy_cuda_func("spherical_harmonics_bwd_3dcs")( + num_bases, + sh_degree, + convex_points, + dirs, + coeffs, + masks, + v_colors.contiguous(), + compute_v_convex_points, + compute_v_dirs, + ) + if not compute_v_dirs: + v_dirs = None + if not compute_v_convex_points: + v_convex_points = None + return None, v_convex_points, v_dirs, v_coeffs, None + + +def fully_fused_projection_3dcs( + convex_points: Tensor, # [N, 6, 3] + cumsum_of_points_per_convex: Tensor, # [N] + delta: Tensor, # [N] + sigma: Tensor, # [N] + scaling: Tensor, # [N] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + eps2d: float = 0.3, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + packed: bool = False, + sparse_grad: bool = False, + calc_compensations: bool = False, + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + opacities: Optional[Tensor] = None, # [N] or None +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Projects Convex points to 2D convex hull. + + .. note:: + + During projection, we ignore the 3D convexes that are outside of the camera frustum. + So not all the elements in the output tensors are valid. The output `radii` could serve as + an indicator, in which zero radii means the corresponding elements are invalid in + the output tensors and will be ignored in the next rasterization process. If `packed=True`, + the output tensors will be packed into a flattened tensor, in which all elements are valid. + In this case, a `camera_ids` tensor and `gaussian_ids` tensor will be returned to indicate the + row (camera) and column (Gaussian) indices of the packed flattened tensor, which is essentially + following the COO sparse tensor format. + + .. note:: + + This functions supports projecting Gaussians with either covariances or {quaternions, scales}, + which will be converted to covariances internally in a fused CUDA kernel. Either `covars` or + {`quats`, `scales`} should be provided. + + Args: + convex_points: Gaussian means. [N, 6, 3] + cumsum_of_points_per_convex: Cumulative sum of points per convex. [N] + delta: Delta values. [N] + sigma: Sigma values. [N] + scaling: Scaling values. [N] + viewmats: Camera-to-world matrices. [C, 4, 4] + Ks: Camera intrinsics. [C, 3, 3] + width: Image width. + height: Image height. + eps2d: A epsilon added to the 2D covariance for numerical stability. Default: 0.3. + near_plane: Near plane distance. Default: 0.01. + far_plane: Far plane distance. Default: 1e10. + radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0. + packed: If True, the output tensors will be packed into a flattened tensor. Default: False. + sparse_grad: This is only effective when `packed` is True. If True, during backward the gradients + of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False. + calc_compensations: If True, a view-dependent opacity compensation factor will be computed, which + is useful for anti-aliasing. Default: False. + + Returns: + A tuple: + + If `packed` is True: + + - **camera_ids**. The row indices of the projected Gaussians. Int32 tensor of shape [nnz]. + - **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz]. + - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz]. + - **means**. Projected Gaussian means in 2D. [nnz, 2] + - **depths**. The z-depth of the projected Gaussians. [nnz] + - **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [nnz, 3] + - **compensations**. The view-dependent opacity compensation factor. [nnz] + + If `packed` is False: + + - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N]. + - **means**. Projected Gaussian means in 2D. [C, N, 2] + - **depths**. The z-depth of the projected Gaussians. [C, N] + - **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [C, N, 3] + - **compensations**. The view-dependent opacity compensation factor. [C, N] + """ + C = viewmats.size(0) + N = convex_points.size(0) + total_nb_points = 6 * N#//+ cumsum_of_points_per_convex[-1] + assert convex_points.size() == (N, 6, 3), convex_points.size() + assert viewmats.size() == (C, 4, 4), viewmats.size() + assert Ks.size() == (C, 3, 3), Ks.size() + convex_points = convex_points.contiguous() + if sparse_grad: + assert packed, "sparse_grad is only supported when packed is True" + if opacities is not None: + assert opacities.size() == (N,), opacities.size() + opacities = opacities.contiguous() + + viewmats = viewmats.contiguous() + Ks = Ks.contiguous() + # FIXME: Do packed later + #if packed: + # return _FullyFusedProjectionPacked_3dcs.apply( + # convex_points, + # viewmats, + # Ks, + # width, + # height, + # eps2d, + # near_plane, + # far_plane, + # radius_clip, + # sparse_grad, + # calc_compensations, + # camera_model, + # ) + #else: + return _FullyFusedProjection_3dcs.apply( + convex_points, + cumsum_of_points_per_convex, + delta, + sigma, + scaling, + viewmats, + Ks, + width, + height, + total_nb_points, + eps2d, + near_plane, + far_plane, + radius_clip, + calc_compensations, + camera_model, + opacities, + ) + +def rasterize_to_pixels_3dcs( + means2d: Tensor, # [C, N, 2] or [nnz, 2] + normals: Tensor, # [C, total_nb_points, 2] or [nnz, 3] + offsets: Tensor, # [C, total_nb_points] + num_points_per_convex_view: Tensor, # [C, N] + delta: Tensor, # [C, N] + sigma: Tensor, # [C, N] + num_points_per_convex: int, # 6 in practice + cumsum_of_points_per_convex: Tensor, # [N] + depths: Tensor, # [C, N] + conics: Tensor, # [C, N, 3] or [nnz, 3] + colors: Tensor, # [C, N, channels] or [nnz, channels] + opacities: Tensor, # [C, N] or [nnz] + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + backgrounds: Optional[Tensor] = None, # [C, channels] + masks: Optional[Tensor] = None, # [C, tile_height, tile_width] + packed: bool = False, + absgrad: bool = False, +) -> Tuple[Tensor, Tensor]: + """Rasterizes 3D convexes to pixels. + + Args: + means2d: Projected 3D convex centers. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + normals: The normals in camera space. [C, N, 3] + offsets: The offsets of the points in the convex. [C, N] + num_points_per_convex_view: The number of points per convex in the view. + delta: The delta of the points in the convex. [C, N] + sigma: The sigma of the points in the convex. [C, N] + num_points_per_convex: The number of points per convex. 6 in practice + cumsum_of_points_per_convex: The cumsum of the number of points per convex. [N] + depths: The z-depth of the projected 3D convexes center. [C, N] + conics: Inverse of the projected covariances with only upper triangle values. [C, N, 3] if packed is False, [nnz, 3] if packed is True. + colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True. + opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True. + image_width: Image width. + image_height: Image height. + tile_size: Tile size. + isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] + flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] + backgrounds: Background colors. [C, channels]. Default: None. + masks: Optional tile mask to skip rendering GS to masked tiles. [C, tile_height, tile_width]. Default: None. + packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. + absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. + + Returns: + A tuple: + + - **Rendered colors**. [C, image_height, image_width, channels] + - **Rendered alphas**. [C, image_height, image_width, 1] + """ + + C = isect_offsets.size(0) + device = means2d.device + # FIXME: Check new 3DCS arguments + if packed: + nnz = means2d.size(0) + assert means2d.shape == (nnz, 2), means2d.shape + assert conics.shape == (nnz, 3), conics.shape + assert colors.shape[0] == nnz, colors.shape + assert opacities.shape == (nnz,), opacities.shape + else: + N = means2d.size(1) + assert means2d.shape == (C, N, 2), means2d.shape + assert conics.shape == (C, N, 3), conics.shape + assert colors.shape[:2] == (C, N), colors.shape + assert opacities.shape == (C, N), opacities.shape + if backgrounds is not None: + assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape + backgrounds = backgrounds.contiguous() + if masks is not None: + assert masks.shape == isect_offsets.shape, masks.shape + masks = masks.contiguous() + + # Pad the channels to the nearest supported number if necessary + channels = colors.shape[-1] + if channels > 513 or channels == 0: + # TODO: maybe worth to support zero channels? + raise ValueError(f"Unsupported number of color channels: {channels}") + if channels not in ( + 1, + 2, + 3, + 4, + 5, + 8, + 9, + 16, + 17, + 32, + 33, + 64, + 65, + 128, + 129, + 256, + 257, + 512, + 513, + ): + padded_channels = (1 << (channels - 1).bit_length()) - channels + colors = torch.cat( + [ + colors, + torch.zeros(*colors.shape[:-1], padded_channels, device=device), + ], + dim=-1, + ) + if backgrounds is not None: + backgrounds = torch.cat( + [ + backgrounds, + torch.zeros( + *backgrounds.shape[:-1], padded_channels, device=device + ), + ], + dim=-1, + ) + else: + padded_channels = 0 + + tile_height, tile_width = isect_offsets.shape[1:3] + assert ( + tile_height * tile_size >= image_height + ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" + assert ( + tile_width * tile_size >= image_width + ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" + + render_colors, render_alphas = _RasterizeToPixels_3dcs.apply( + means2d.contiguous(), + normals.contiguous(), + offsets.contiguous(), + num_points_per_convex_view.contiguous(), + delta.contiguous(), + sigma.contiguous(), + num_points_per_convex, + cumsum_of_points_per_convex.contiguous(), + depths.contiguous(), + conics.contiguous(), + colors.contiguous(), + opacities.contiguous(), + backgrounds, + masks, + image_width, + image_height, + tile_size, + isect_offsets.contiguous(), + flatten_ids.contiguous(), + absgrad, + ) + + if padded_channels > 0: + render_colors = render_colors[..., :-padded_channels] + return render_colors, render_alphas + + +class _FullyFusedProjection_3dcs(torch.autograd.Function): + """Projects 3D Convex to 2D.""" + + @staticmethod + def forward( + ctx, + convex_points: Tensor, # [N, 6, 3] + cumsum_of_points_per_convex, # [N] + delta, # [N] + sigma, # [N] + scaling, # [N] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + total_nb_points: int, + eps2d: float, + near_plane: float, + far_plane: float, + radius_clip: float, + calc_compensations: bool, + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + opacities: Optional[Tensor] = None, # [N] or None + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + camera_model_type = _make_lazy_cuda_obj( + f"CameraModelType.{camera_model.upper()}" + ) + + normals, offsets, p_image, hull, num_points_per_convex_view, indices, radii, means2d, depths, conics, compensations = _make_lazy_cuda_func( + "projection_ewa_3dcs_fused_fwd" + )( + convex_points, + cumsum_of_points_per_convex, + delta, + sigma, + scaling, + opacities, + viewmats, + Ks, + width, + height, + total_nb_points, + eps2d, + near_plane, + far_plane, + radius_clip, + calc_compensations, + camera_model_type, + ) + if not calc_compensations: + compensations = None + ctx.save_for_backward( + convex_points, viewmats, Ks, normals, offsets, p_image, hull, num_points_per_convex_view, indices, radii, conics, compensations + ) + ctx.width = width + ctx.height = height + ctx.eps2d = eps2d + ctx.camera_model_type = camera_model_type + ctx.cumsum_of_points_per_convex = cumsum_of_points_per_convex + return normals, offsets, p_image, hull, num_points_per_convex_view, indices, radii, means2d, depths, conics, compensations + + @staticmethod + def backward(ctx, v_normals, v_offsets, v_p_image, v_hull, v_num_points_per_convex_view, v_indices, v_radii, v_means2d, v_depths, v_conics, v_compensations): + ( + convex_points, + viewmats, + Ks, + _, + offsets, + p_image, + hull, + num_points_per_convex_view, + indices, + radii, + conics, + compensations, + ) = ctx.saved_tensors + width = ctx.width + height = ctx.height + eps2d = ctx.eps2d + cumsum_of_points_per_convex = ctx.cumsum_of_points_per_convex + camera_model_type = ctx.camera_model_type + if v_compensations is not None: + v_compensations = v_compensations.contiguous() + v_convex_points, v_viewmats = _make_lazy_cuda_func( + "projection_ewa_3dcs_fused_bwd" + )( + convex_points, + cumsum_of_points_per_convex, + viewmats, + Ks, + width, + height, + eps2d, + camera_model_type, + radii, + hull, + num_points_per_convex_view, + #normals, + offsets, + p_image, + indices, + conics, + compensations, + v_normals.contiguous(), + v_offsets.contiguous(), + v_means2d.contiguous(), + v_depths.contiguous(), + v_conics.contiguous(), + v_compensations, + ctx.needs_input_grad[5], # viewmats_requires_grad + ) + + if not ctx.needs_input_grad[0]: + v_convex_points = None + # FIXME: check for delta, sigma, scaling? + if not ctx.needs_input_grad[5]: + v_viewmats = None + + return ( + v_convex_points, + None, + None,#v_delta, + None,#v_sigma, + None,#v_scaling, + v_viewmats, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + +class _RasterizeToPixels_3dcs(torch.autograd.Function): + """Rasterize gaussians""" + + @staticmethod + def forward( + ctx, + means2d: Tensor, # [C, N, 2] + normals: Tensor, # [C, N, 3] + offsets: Tensor, # [C, N] + num_points_per_convex_view: Tensor, # [C, N] + delta: Tensor, # [C, N] + sigma: Tensor, # [C, N] + num_points_per_convex: int, # 6 in practice + cumsum_of_points_per_convex: Tensor, # [N] + depths: Tensor, # [C, N] + conics: Tensor, # [C, N, 3] + colors: Tensor, # [C, N, D] + opacities: Tensor, # [C, N] + backgrounds: Tensor, # [C, D], Optional + masks: Tensor, # [C, tile_height, tile_width], Optional + width: int, + height: int, + tile_size: int, + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + absgrad: bool, + ) -> Tuple[Tensor, Tensor]: + render_colors, render_alphas, last_ids = _make_lazy_cuda_func( + "rasterize_to_pixels_3dcs_fwd" + )( + means2d, + normals, + offsets, + num_points_per_convex_view, + delta, + sigma, + num_points_per_convex, + cumsum_of_points_per_convex, + depths, + conics, + colors, + opacities, + backgrounds, + masks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + ) + + ctx.save_for_backward( + means2d, + normals, + offsets, + num_points_per_convex_view, + delta, + sigma, + depths, + conics, + colors, + opacities, + backgrounds, + masks, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + ) + ctx.width = width + ctx.height = height + ctx.tile_size = tile_size + ctx.absgrad = absgrad + ctx.cumsum_of_points_per_convex = cumsum_of_points_per_convex + ctx.num_points_per_convex = num_points_per_convex + + # double to float + render_alphas = render_alphas.float() + return render_colors, render_alphas + + @staticmethod + def backward( + ctx, + v_render_colors: Tensor, # [C, H, W, 3] + v_render_alphas: Tensor, # [C, H, W, 1] + ): + ( + means2d, + normals, + offsets, + num_points_per_convex_view, + delta, + sigma, + depths, + conics, + colors, + opacities, + backgrounds, + masks, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + ) = ctx.saved_tensors + + width = ctx.width + height = ctx.height + tile_size = ctx.tile_size + absgrad = ctx.absgrad + cumsum_of_points_per_convex = ctx.cumsum_of_points_per_convex + num_points_per_convex = ctx.num_points_per_convex + + ( + v_means2d_abs, + v_means2d, + v_normals, + v_offsets, + v_delta, + v_sigma, + v_conics, + v_colors, + v_opacities, + ) = _make_lazy_cuda_func("rasterize_to_pixels_3dcs_bwd")( + means2d, + normals, + offsets, + num_points_per_convex, + delta, + sigma, + num_points_per_convex_view, + cumsum_of_points_per_convex, + depths, + conics, + colors, + opacities, + backgrounds, + masks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + v_render_colors.contiguous(), + v_render_alphas.contiguous(), + absgrad, + ) + if absgrad: + means2d.absgrad = v_means2d_abs + + if ctx.needs_input_grad[12]: + v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum( + dim=(1, 2) + ) + else: + v_backgrounds = None + + return ( + v_means2d, + v_normals, + v_offsets, + None, + v_delta, + v_sigma, + None, + None, + None,#v_depths + v_conics, + v_colors, + v_opacities, + v_backgrounds, + None, + None, + None, + None, + None, + None, + None, + ) + +#class _FullyFusedProjectionPacked_3dcs(torch.autograd.Function): +# """Projects Gaussians to 2D. Return packed tensors.""" +# +# @staticmethod +# def forward( +# ctx, +# convex_points: Tensor, # [N, 6, 3] +# viewmats: Tensor, # [C, 4, 4] +# Ks: Tensor, # [C, 3, 3] +# width: int, +# height: int, +# eps2d: float, +# near_plane: float, +# far_plane: float, +# radius_clip: float, +# sparse_grad: bool, +# calc_compensations: bool, +# camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", +# ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +# camera_model_type = _make_lazy_cuda_obj( +# f"CameraModelType.{camera_model.upper()}" +# ) +# +# ( +# indptr, +# camera_ids, +# gaussian_ids, +# radii, +# means2d, +# depths, +# conics, +# compensations, +# ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd_3dcs")( +# convex_points, +# viewmats, +# Ks, +# width, +# height, +# eps2d, +# near_plane, +# far_plane, +# radius_clip, +# calc_compensations, +# camera_model_type, +# ) +# if not calc_compensations: +# compensations = None +# ctx.save_for_backward( +# camera_ids, +# gaussian_ids, +# means, +# covars, +# quats, +# scales, +# viewmats, +# Ks, +# conics, +# compensations, +# ) +# ctx.width = width +# ctx.height = height +# ctx.eps2d = eps2d +# ctx.sparse_grad = sparse_grad +# ctx.camera_model_type = camera_model_type +# +# return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations +# +# @staticmethod +# def backward( +# ctx, +# v_camera_ids, +# v_gaussian_ids, +# v_radii, +# v_means2d, +# v_depths, +# v_conics, +# v_compensations, +# ): +# ( +# camera_ids, +# gaussian_ids, +# means, +# covars, +# quats, +# scales, +# viewmats, +# Ks, +# conics, +# compensations, +# ) = ctx.saved_tensors +# width = ctx.width +# height = ctx.height +# eps2d = ctx.eps2d +# sparse_grad = ctx.sparse_grad +# camera_model_type = ctx.camera_model_type +# +# if v_compensations is not None: +# v_compensations = v_compensations.contiguous() +# v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( +# "fully_fused_projection_packed_bwd" +# )( +# means, +# covars, +# quats, +# scales, +# viewmats, +# Ks, +# width, +# height, +# eps2d, +# camera_model_type, +# camera_ids, +# gaussian_ids, +# conics, +# compensations, +# v_means2d.contiguous(), +# v_depths.contiguous(), +# v_conics.contiguous(), +# v_compensations, +# ctx.needs_input_grad[4], # viewmats_requires_grad +# sparse_grad, +# ) +# +# if not ctx.needs_input_grad[0]: +# v_means = None +# else: +# if sparse_grad: +# # TODO: gaussian_ids is duplicated so not ideal. +# # An idea is to directly set the attribute (e.g., .sparse_grad) of +# # the tensor but this requires the tensor to be leaf node only. And +# # a customized optimizer would be needed in this case. +# v_means = torch.sparse_coo_tensor( +# indices=gaussian_ids[None], # [1, nnz] +# values=v_means, # [nnz, 3] +# size=means.size(), # [N, 3] +# is_coalesced=len(viewmats) == 1, +# ) +# if not ctx.needs_input_grad[1]: +# v_covars = None +# else: +# if sparse_grad: +# v_covars = torch.sparse_coo_tensor( +# indices=gaussian_ids[None], # [1, nnz] +# values=v_covars, # [nnz, 6] +# size=covars.size(), # [N, 6] +# is_coalesced=len(viewmats) == 1, +# ) +# if not ctx.needs_input_grad[2]: +# v_quats = None +# else: +# if sparse_grad: +# v_quats = torch.sparse_coo_tensor( +# indices=gaussian_ids[None], # [1, nnz] +# values=v_quats, # [nnz, 4] +# size=quats.size(), # [N, 4] +# is_coalesced=len(viewmats) == 1, +# ) +# if not ctx.needs_input_grad[3]: +# v_scales = None +# else: +# if sparse_grad: +# v_scales = torch.sparse_coo_tensor( +# indices=gaussian_ids[None], # [1, nnz] +# values=v_scales, # [nnz, 3] +# size=scales.size(), # [N, 3] +# is_coalesced=len(viewmats) == 1, +# ) +# if not ctx.needs_input_grad[4]: +# v_viewmats = None +# +# return ( +# v_means, +# v_covars, +# v_quats, +# v_scales, +# v_viewmats, +# None, +# None, +# None, +# None, +# None, +# None, +# None, +# None, +# None, +# None, +# ) + diff --git a/gsplat/cuda/csrc/Intersect.cpp b/gsplat/cuda/csrc/Intersect.cpp index 838b8a328..baccbe8e2 100644 --- a/gsplat/cuda/csrc/Intersect.cpp +++ b/gsplat/cuda/csrc/Intersect.cpp @@ -144,73 +144,4 @@ at::Tensor intersect_offset( return offsets; } -// at::Tensor spherical_harmonics_fwd( -// const uint32_t degrees_to_use, -// const at::Tensor dirs, // [..., 3] -// const at::Tensor coeffs, // [..., K, 3] -// const at::optional masks // [...] -// ) { -// DEVICE_GUARD(dirs); -// CHECK_INPUT(dirs); -// CHECK_INPUT(coeffs); -// if (masks.has_value()) { -// CHECK_INPUT(masks.value()); -// } -// TORCH_CHECK(coeffs.size(-1) == 3, "coeffs must have last dimension 3"); -// TORCH_CHECK(dirs.size(-1) == 3, "dirs must have last dimension 3"); - -// at::Tensor colors = at::empty_like(dirs); // [..., 3] - -// launch_spherical_harmonics_fwd_kernel( -// degrees_to_use, -// dirs, -// coeffs, -// masks, -// colors -// ); -// return colors; // [..., 3] -// } - -// std::tuple spherical_harmonics_bwd( -// const uint32_t K, -// const uint32_t degrees_to_use, -// const at::Tensor dirs, // [..., 3] -// const at::Tensor coeffs, // [..., K, 3] -// const at::optional masks, // [...] -// const at::Tensor v_colors, // [..., 3] -// bool compute_v_dirs -// ) { -// DEVICE_GUARD(dirs); -// CHECK_INPUT(dirs); -// CHECK_INPUT(coeffs); -// CHECK_INPUT(v_colors); -// if (masks.has_value()) { -// CHECK_INPUT(masks.value()); -// } -// TORCH_CHECK(v_colors.size(-1) == 3, "v_colors must have last dimension -// 3"); TORCH_CHECK(coeffs.size(-1) == 3, "coeffs must have last dimension -// 3"); TORCH_CHECK(dirs.size(-1) == 3, "dirs must have last dimension 3"); -// const uint32_t N = dirs.numel() / 3; - -// at::Tensor v_coeffs = at::zeros_like(coeffs); -// at::Tensor v_dirs; -// if (compute_v_dirs) { -// v_dirs = at::zeros_like(dirs); -// } - -// at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); -// uint32_t n_elements = N; -// uint32_t shmem_size = 0; -// launch_spherical_harmonics_bwd_kernel( -// degrees_to_use, -// dirs, -// coeffs, -// masks, -// v_colors, -// v_coeffs, -// v_dirs.defined() ? at::optional(v_dirs) : c10::nullopt -// ); -// return std::make_tuple(v_coeffs, v_dirs); // [..., K, 3], [..., 3] -// } - } // namespace gsplat diff --git a/gsplat/cuda/csrc/Projection3DCS.cpp b/gsplat/cuda/csrc/Projection3DCS.cpp new file mode 100644 index 000000000..86ef77df0 --- /dev/null +++ b/gsplat/cuda/csrc/Projection3DCS.cpp @@ -0,0 +1,197 @@ +#include +#include +#include // for DEVICE_GUARD +#include + +#include +#include + +#include "Common.h" // where all the macros are defined +#include "Ops.h" // a collection of all gsplat operators +#include "Projection3DCS.h" // where the launch function is declared + +namespace gsplat { + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +projection_ewa_3dcs_fused_fwd( + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor delta, // [N] + const at::Tensor sigma, // [N] + at::Tensor scaling, // [N] + const at::optional opacities, // [N] optional + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t total_nb_points, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const bool calc_compensations, + const CameraModelType camera_model +) { + DEVICE_GUARD(convex_points); + CHECK_INPUT(convex_points); + CHECK_INPUT(viewmats); + CHECK_INPUT(Ks); + + uint32_t N = convex_points.size(0); // number of 3D convex shapes + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + at::Tensor radii = at::empty({C, N, 2}, convex_points.options().dtype(at::kInt)); + at::Tensor means2d = at::empty({C, N, 2}, convex_points.options()); + + at::Tensor normals = at::empty({C, total_nb_points, 2}, convex_points.options()); + at::Tensor offsets = at::empty({C, total_nb_points}, convex_points.options()); + at::Tensor p_image = at::empty({C, total_nb_points, 2}, convex_points.options()); + at::Tensor hull = at::empty({C, 2*total_nb_points}, convex_points.options().dtype(at::kInt)); + at::Tensor num_points_per_convex_view = at::empty({C, N}, convex_points.options().dtype(at::kInt)); + at::Tensor indices = at::empty({C, total_nb_points}, convex_points.options().dtype(at::kInt)); + + at::Tensor depths = at::empty({C, N}, convex_points.options()); + at::Tensor conics = at::empty({C, N, 3}, convex_points.options()); + at::Tensor compensations; + if (calc_compensations) { + // we dont want NaN to appear in this tensor, so we zero intialize it + compensations = at::zeros({C, N}, convex_points.options()); + } + + launch_projection_ewa_3dcs_fused_fwd_kernel( + // inputs + convex_points, + cumsum_of_points_per_convex, + delta, + sigma, + scaling, + opacities, + viewmats, + Ks, + image_width, + image_height, + eps2d, + near_plane, + far_plane, + radius_clip, + camera_model, + // outputs + radii, + normals, + offsets, + p_image, + hull, + num_points_per_convex_view, + indices, + means2d, + depths, + conics, + calc_compensations ? at::optional(compensations) + : c10::nullopt + ); + return std::make_tuple(normals, offsets, p_image, hull, num_points_per_convex_view, indices, radii, means2d, depths, conics, compensations); +} + +std::tuple +projection_ewa_3dcs_fused_bwd( + // fwd inputs + const at::Tensor convex_points, // [N, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const at::Tensor radii, // [C, N, 2] + const at::Tensor hull, // [C, 2*total_nb_points] + const at::Tensor num_points_per_convex_view, // [C, N] + //const at::Tensor normals, // [C, total_nb_points, 2] + const at::Tensor offsets, // [C, total_nb_points] + const at::Tensor p_image, // [C, total_nb_points, 2] + const at::Tensor indices, // [C, total_nb_points] + const at::Tensor conics, // [C, N, 3] + const at::optional compensations, // [C, N] optional + // grad inputs + const at::Tensor v_normals, // [C, total_nb_points, 2] + const at::Tensor v_offsets, // [C, total_nb_points] + // grad outputs + const at::Tensor v_means2d, // [C, N, 2] + const at::Tensor v_depths, // [C, N] + const at::Tensor v_conics, // [C, N, 3] + const at::optional v_compensations, // [C, N] optional + const bool viewmats_requires_grad +) { + DEVICE_GUARD(convex_points); + CHECK_INPUT(convex_points); + CHECK_INPUT(viewmats); + CHECK_INPUT(Ks); + CHECK_INPUT(radii); + CHECK_INPUT(conics); + CHECK_INPUT(v_normals); + CHECK_INPUT(v_offsets); + CHECK_INPUT(v_means2d); + CHECK_INPUT(v_depths); + CHECK_INPUT(v_conics); + if (compensations.has_value()) { + CHECK_INPUT(compensations.value()); + } + if (v_compensations.has_value()) { + CHECK_INPUT(v_compensations.value()); + assert(compensations.has_value()); + } + + at::Tensor v_convex_points = at::zeros_like(convex_points); + at::Tensor v_viewmats; + if (viewmats_requires_grad) { + v_viewmats = at::zeros_like(viewmats); + } + + launch_projection_ewa_3dcs_fused_bwd_kernel( + // inputs + convex_points, + cumsum_of_points_per_convex, + viewmats, + Ks, + image_width, + image_height, + eps2d, + camera_model, + radii, + hull, + num_points_per_convex_view, + //normals, + offsets, + p_image, + indices, + conics, + compensations, + v_normals, + v_offsets, + v_means2d, + v_depths, + v_conics, + v_compensations, + viewmats_requires_grad, + // outputs + v_convex_points, + v_viewmats + ); + + return std::make_tuple(v_convex_points, v_viewmats); +} + +} // namespace gsplat diff --git a/gsplat/cuda/csrc/Projection3DCS.h b/gsplat/cuda/csrc/Projection3DCS.h new file mode 100644 index 000000000..c07fe6a2c --- /dev/null +++ b/gsplat/cuda/csrc/Projection3DCS.h @@ -0,0 +1,136 @@ +#pragma once + +#include + +namespace at { +class Tensor; +} + +namespace gsplat { + +void launch_projection_ewa_3dcs_fused_fwd_kernel( + // inputs + const at::Tensor convex_points, // [N, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor delta, // [N] + const at::Tensor sigma, // [N] + at::Tensor scaling, // [N] + const at::optional opacities, // [N] optional + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const CameraModelType camera_model, + // outputs + at::Tensor radii, // [C, N, 2] + at::Tensor normals, // [C, total_nb_points, 2] + at::Tensor offsets, // [C, total_nb_points] + at::Tensor p_image, // [C, total_nb_points, 2] + at::Tensor hull, // [C, 2*total_nb_points] + at::Tensor num_points_per_convex_view, // [C, N] + at::Tensor indices, // [C, total_nb_points] + at::Tensor means2d, // [C, N, 2] + at::Tensor depths, // [C, N] + at::Tensor conics, // [C, N, 3] + at::optional compensations // [C, N] optional +); +void launch_projection_ewa_3dcs_fused_bwd_kernel( + // inputs + // fwd inputs + const at::Tensor convex_points, // [N, K, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const at::Tensor radii, // [C, N, 2] + const at::Tensor hull, // [C, 2*total_nb_points] + const at::Tensor num_points_per_convex_view, // [C, N] + //const at::Tensor normals, // [C, total_nb_points, 2] + const at::Tensor offsets, // [C, total_nb_points] + const at::Tensor p_image, // [C, total_nb_points, 2] + const at::Tensor indices, // [C, total_nb_points] + const at::Tensor conics, // [C, N, 3] + const at::optional compensations, // [C, N] optional + // grad inputs + const at::Tensor v_normals, // [C, total_nb_points, 2] + const at::Tensor v_offsets, // [C, total_nb_points] + // grad outputs + const at::Tensor v_means2d, // [C, N, 2] + const at::Tensor v_depths, // [C, N] + const at::Tensor v_conics, // [C, N, 3] + const at::optional v_compensations, // [C, N] optional + const bool viewmats_requires_grad, + // outputs + at::Tensor v_convex_points, // [N, K, 3] + at::Tensor v_viewmats // [C, 4, 4] +); + +// void launch_projection_ewa_3dcs_packed_fwd_kernel( +// // inputs +// const at::Tensor means, // [N, 3] +// const at::optional covars, // [N, 6] optional +// const at::optional quats, // [N, 4] optional +// const at::optional scales, // [N, 3] optional +// const at::optional opacities, // [N] optional +// const at::Tensor viewmats, // [C, 4, 4] +// const at::Tensor Ks, // [C, 3, 3] +// const uint32_t image_width, +// const uint32_t image_height, +// const float eps2d, +// const float near_plane, +// const float far_plane, +// const float radius_clip, +// const at::optional +// block_accum, // [C * blocks_per_row] packing helper +// const CameraModelType camera_model, +// // outputs +// at::optional block_cnts, // [C * blocks_per_row] packing helper +// at::optional indptr, // [C + 1] +// at::optional camera_ids, // [nnz] +// at::optional gaussian_ids, // [nnz] +// at::optional radii, // [nnz, 2] +// at::optional means2d, // [nnz, 2] +// at::optional depths, // [nnz] +// at::optional conics, // [nnz, 3] +// at::optional compensations // [nnz] optional +// ); +// void launch_projection_ewa_3dcs_packed_bwd_kernel( +// // fwd inputs +// const at::Tensor means, // [N, 3] +// const at::optional covars, // [N, 6] +// const at::optional quats, // [N, 4] +// const at::optional scales, // [N, 3] +// const at::Tensor viewmats, // [C, 4, 4] +// const at::Tensor Ks, // [C, 3, 3] +// const uint32_t image_width, +// const uint32_t image_height, +// const float eps2d, +// const CameraModelType camera_model, +// // fwd outputs +// const at::Tensor camera_ids, // [nnz] +// const at::Tensor gaussian_ids, // [nnz] +// const at::Tensor conics, // [nnz, 3] +// const at::optional compensations, // [nnz] optional +// // grad outputs +// const at::Tensor v_means2d, // [nnz, 2] +// const at::Tensor v_depths, // [nnz] +// const at::Tensor v_conics, // [nnz, 3] +// const at::optional v_compensations, // [nnz] optional +// const bool sparse_grad, +// // grad inputs +// at::Tensor v_means, // [N, 3] or [nnz, 3] +// at::optional v_covars, // [N, 6] or [nnz, 6] Optional +// at::optional v_quats, // [N, 4] or [nnz, 4] Optional +// at::optional v_scales, // [N, 3] or [nnz, 3] Optional +// at::optional v_viewmats // [C, 4, 4] Optional +// ); + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/ProjectionEWA3DCSFused.cu b/gsplat/cuda/csrc/ProjectionEWA3DCSFused.cu new file mode 100644 index 000000000..c38aee06d --- /dev/null +++ b/gsplat/cuda/csrc/ProjectionEWA3DCSFused.cu @@ -0,0 +1,852 @@ +#include +#include +#include +#include +#include + +#include "Common.h" +#include "Projection.h" +#include "Utils.cuh" + +#define MAX_NB_POINTS 8 +namespace gsplat { + +// Helper function to compute cross product of two vectors OA and OB +// A positive cross product indicates a counterclockwise turn, +// a negative cross product indicates a clockwise turn, +// and a zero cross product indicates the points are collinear. +__forceinline__ __device__ float crossProduct(const float* O, const float* A, const float* B) +{ + return (A[0] - O[0]) * (B[1] - O[1]) - (A[1] - O[1]) * (B[0] - O[0]); +} + +namespace cg = cooperative_groups; + +template +__global__ void projection_ewa_3dcs_fused_fwd_kernel( + const uint32_t C, + const uint32_t N, + const scalar_t *__restrict__ convex_points, // [N, 3] + const int32_t *__restrict__ cumsum_of_points_per_convex, // [N] + const scalar_t *__restrict__ delta, // [N] + const scalar_t *__restrict__ sigma, // [N] + scalar_t *__restrict__ scaling, // [N] FIXME move it to outputs + const scalar_t *__restrict__ opacities, // [N] optional + const scalar_t *__restrict__ viewmats, // [C, 4, 4] + const scalar_t *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, + const int32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const CameraModelType camera_model, + // outputs + int32_t *__restrict__ radii, // [C, N, 2] + scalar_t *__restrict__ normals, // [C, total_nb_points, 2] + scalar_t *__restrict__ offsets, // [C, total_nb_points] + scalar_t *__restrict__ p_image, // [C, total_nb_points, 2] + int32_t *__restrict__ hull, // [C, 2*total_nb_points] + int32_t *__restrict__ num_points_per_convex_view, // [C, N] + int32_t *__restrict__ indices, // [C, total_nb_points] + scalar_t *__restrict__ means2d, // [C, N, 2] + scalar_t *__restrict__ depths, // [C, N] + scalar_t *__restrict__ conics, // [C, N, 3] + scalar_t *__restrict__ compensations // [C, N] optional +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t conv_id = idx % N; // gaussian id + + //const int cumsum_for_convex = cumsum_of_points_per_convex[conv_id]; + + // shift pointers to the current camera and 3D convex + convex_points += conv_id * 6 * 3;//cumsum_for_convex * 3; + viewmats += cid * 16; + Ks += cid * 9; + + // FIXME: to add? + radii[idx * 2] = 0; + radii[idx * 2 + 1] = 0; + num_points_per_convex_view[idx] = 0; + + // glm is column-major but input is row-major + mat3 R = mat3( + viewmats[0], + viewmats[4], + viewmats[8], // 1st column + viewmats[1], + viewmats[5], + viewmats[9], // 2nd column + viewmats[2], + viewmats[6], + viewmats[10] // 3rd column + ); + vec3 translation = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // Compute 3D convex center projection (camera and screenspace). + vec3 center_convex{0.0f, 0.0f, 0.0f}; + + // FIXME: If the number of points per convex is expected to change, change this. + int num_points_per_convex = 6; + for (int i = 0; i < num_points_per_convex; i++) + { + indices[idx*6 + i] = i; + center_convex.x += convex_points[3 * i]; + center_convex.y += convex_points[3 * i + 1]; + center_convex.z += convex_points[3 * i + 2]; + } + + center_convex.x /= num_points_per_convex; + center_convex.y /= num_points_per_convex; + center_convex.z /= num_points_per_convex; + + // transform 3D Convex center to camera space + vec3 center_convex_c; + posW2C(R, translation, center_convex, center_convex_c); + if (center_convex_c.z < near_plane || center_convex_c.z > far_plane) + { + radii[idx * 2] = 0; + radii[idx * 2 + 1] = 0; + return; + } + + // perspective projection + vec2 center_convex_2D; + mat2 covar2d; + + // FIXME: Not used? + mat3 covar_c(0.f); + + switch (camera_model) { + case CameraModelType::PINHOLE: // perspective projection + persp_proj( + center_convex_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + covar2d, + center_convex_2D + ); + break; + case CameraModelType::ORTHO: // orthographic projection + ortho_proj( + center_convex_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + covar2d, + center_convex_2D + ); + break; + case CameraModelType::FISHEYE: // fisheye projection + fisheye_proj( + center_convex_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + covar2d, + center_convex_2D + ); + break; + } + + // Now project every points from the 3D convex + // TODO: Should be num_points_per_convex[idx]. See if it changes somewhere. + // TODO: Could be merge with previous loop? + + // Calculation of points in 2D image space. We need to compute the cov3D. + // float cov3D[6] = {0.0f}; + vec2 point_convex_2D[6]; + + for (int i = 0; i < num_points_per_convex; i++) + { + vec3 convex_point(convex_points[3 * i], convex_points[3 * i + 1], convex_points[3 * i + 2]); + vec3 convex_point_c{0.0f, 0.0f, 0.0f}; + posW2C(R, translation, convex_point, convex_point_c); + + // Now project in 2D and keep it. + mat2 point_covar2d; + + switch (camera_model) { + case CameraModelType::PINHOLE: // perspective projection + persp_proj( + convex_point_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + point_covar2d, + point_convex_2D[i] + ); + break; + case CameraModelType::ORTHO: // orthographic projection + ortho_proj( + convex_point_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + point_covar2d, + point_convex_2D[i] + ); + break; + case CameraModelType::FISHEYE: // fisheye projection + fisheye_proj( + convex_point_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + point_covar2d, + point_convex_2D[i] + ); + break; + } + // Now add the projection. + p_image[2*idx*6 + 2*i] = point_convex_2D[i].x; + p_image[2*idx*6 + 2*i + 1] = point_convex_2D[i].y; + } + + float max_distance = sqrtf( + (p_image[2*idx*6] - center_convex_2D.x) * (p_image[2*idx*6] - center_convex_2D.x) + + (p_image[2*idx*6 + 1] - center_convex_2D.y) * (p_image[2*idx*6 + 1] - center_convex_2D.y) + ); + float2 ref_point{p_image[2*idx*6], p_image[2*idx*6 + 1]}; + + // Find the furthest point from the center of the convex + for (int i = 1; i < num_points_per_convex; i++) + { + float distance = sqrtf( + (p_image[2*idx*6 + 2*i] - center_convex_2D.x) * (p_image[2*idx*6 + 2*i] - center_convex_2D.x) + + (p_image[2*idx*6 + 2*i + 1] - center_convex_2D.y) * (p_image[2*idx*6 + 2*i + 1] - center_convex_2D.y) + ); + + // Update max_distance if the current distance is greater + if (distance > max_distance) + { + max_distance = distance; + } + + if (p_image[2*idx*6 + 2*i + 1] < ref_point.y || (p_image[2*idx*6 + 2*i + 1] == ref_point.y && p_image[2*idx*6 + 2*i] < ref_point.x)) + { + ref_point = make_float2(p_image[2*idx*6 + 2*i], p_image[2*idx*6 + 2*i + 1]); + } + } + + if (max_distance > 10000.0f) + { + return; + } + + // Sort the points based on their polar angle with respect to ref_point. + // There exist definitely better sorting algos + for (int i = 0; i < num_points_per_convex - 1; i++) + { + for (int j = i + 1; j < num_points_per_convex; j++) + { + float angle1 = atan2f(p_image[2*idx*6 + 2*i + 1] - ref_point.y, p_image[2*idx*6 + 2*i] - ref_point.x); + float angle2 = atan2f(p_image[2*idx*6 + 2*j + 1] - ref_point.y, p_image[2*idx*6 + 2*j] - ref_point.x); + if (angle1 > angle2) + { + float2 temp{p_image[2*idx*6 + 2*i], p_image[2*idx*6 + 2*i + 1]}; + p_image[2*idx*6 + 2*i] = p_image[2*idx*6 + 2*j]; + p_image[2*idx*6 + 2*i + 1] = p_image[2*idx*6 + 2*j + 1]; + p_image[2*idx*6 + 2*j] = temp.x; + p_image[2*idx*6 + 2*j + 1] = temp.y; + + // Swap their corresponding indices + int temp_idx = indices[idx*6 + i]; + indices[idx*6 + i] = indices[idx*6 + j]; + indices[idx*6 + j] = temp_idx; + } + } + } + + // Now we apply the Graham scan algorithm to find the convex hull. + int hull_size = 0; + + // Lower hull + for (int i = 0; i < num_points_per_convex; i++) + { + while (hull_size >= 2 && crossProduct(&(p_image[2*idx*6 + 2*hull[2*idx*6 + hull_size - 2]]), &(p_image[2*idx*6 + 2*hull[2*idx*6 + hull_size - 1]]), &(p_image[2*idx*6 + 2*i])) <= 0) + hull_size--; + hull[2*idx*6 + hull_size] = i; + hull_size++; + } + + //Upper hull + int t = hull_size + 1; + for (int i = num_points_per_convex - 2; i >= 0; i--) + { + while (hull_size >= t && crossProduct(&(p_image[2*idx*6 + 2*hull[2*idx*6 + hull_size - 2]]), &(p_image[2*idx*6 + 2*hull[2*idx*6 + hull_size - 1]]), &(p_image[2*idx*6 + 2*i])) <= 0) + hull_size--; + hull[2*idx*6 + hull_size] = i; + hull_size++; + } + + float max_distance_off = 0.0f; + float previous_offset = 0.0f; + int counter = 0; + float max_distance_x = (2.1f / (center_convex_c.z * center_convex_c.z * delta[conv_id] * sigma[conv_id])); + + for (int i = 0; i < hull_size - 1; i++) + { + // Points forming the segment + float2 p1_conv{p_image[2*idx*6 + 2*hull[2*idx*6 + i]], p_image[2*idx*6 + 2*hull[2*idx*6 + i] + 1]}; + float2 p2_conv{p_image[2*idx*6 + 2*hull[2*idx*6 + (i + 1) % hull_size]], p_image[2*idx*6 + 2*hull[2*idx*6 + (i + 1) % hull_size] + 1]}; + + // Calculate the normal vector (90-degree counterclockwise rotation) + float2 normal = { p2_conv.y - p1_conv.y, -(p2_conv.x - p1_conv.x)}; + + // Calculate the offset (dot product of normal vector and point p1) + float offset = - (normal.x * p1_conv.x + normal.y * p1_conv.y); + + normals[2*idx*6 + 2*i] = normal.x; + normals[2*idx*6 + 2*i + 1] = normal.y; + offsets[idx*6 + i] = offset; + + if (normal.x * center_convex_2D.x + normal.y * center_convex_2D.y + offset < 0) + { + offset -= max_distance_x; + } + else + { + offset += max_distance_x; + } + + if (i != 0) + { + float denominator = normal.x * normals[2*idx*6 + 2*(i-1) + 1] - normal.y * normals[2*idx*6 + 2*(i-1)]; // to avoid division by small numbers + + // calculate the point of intersection between normals[i] and normals[i-1] + float2 intersection_point = { (-offset * normals[2*idx*6 + 2*(i-1) + 1] + previous_offset * normal.y) / denominator, (-previous_offset * normal.x + offset * normals[2*idx*6 + 2*(i-1)]) / denominator}; + + float angle = acosf( (normal.x * normals[2*idx*6 + 2*(i-1)] + normal.y * normals[2*idx*6 + 2*(i-1) + 1]) / (sqrtf(normal.x * normal.x + normal.y * normal.y) * sqrtf(normals[2*idx*6 + 2*(i-1)] * normals[2*idx*6 + 2*(i-1)] + normals[2*idx*6 + 2*(i-1) + 1] * normals[2*idx*6 + 2*(i-1) + 1]))); + + float distance = sqrtf((intersection_point.x - center_convex_2D.x) * (intersection_point.x - center_convex_2D.x) + (intersection_point.y - center_convex_2D.y) * (intersection_point.y - center_convex_2D.y)); + + if (angle > 0.1f && angle < 3.0f) + { + max_distance_off += distance; + counter++; + } + } + + previous_offset = offset; + num_points_per_convex_view[idx] = i + 1; + } + + if (num_points_per_convex_view[idx] < 3 || counter == 0 || num_points_per_convex_view[idx] > num_points_per_convex) + { + radii[idx * 2] = 0; + radii[idx * 2 + 1] = 0; + num_points_per_convex_view[idx] = 0; + return; + } + + max_distance_off = max_distance_off / counter; + max_distance = ceil(max(max_distance * 1.1f, max_distance_off)); + // // END 3DCS + + + // float extend = 3.33f; + // if (opacities != nullptr) { + // float opacity = opacities[conv_id]; + // // if (compensations != nullptr) + // // { + // // // we assume compensation term will be applied later on. + // // opacity *= compensation; + // // } + // if (opacity < ALPHA_THRESHOLD) + // { + // radii[idx * 2] = 0; + // radii[idx * 2 + 1] = 0; + // return; + // } + // // Compute opacity-aware bounding box. + // // https://arxiv.org/pdf/2402.00525 Section B.2 + // extend = min(extend, sqrt(2.0f * __logf(opacity / ALPHA_THRESHOLD))); + // } + + // // compute tight rectangular bounding box (non differentiable) + // // https://arxiv.org/pdf/2402.00525 + // float radius_x = ceilf(extend * max_distance); + // float radius_y = ceilf(extend * max_distance); + + if (max_distance <= radius_clip) + // if (radius_x <= radius_clip && radius_y <= radius_clip) + { + radii[idx * 2] = 0; + radii[idx * 2 + 1] = 0; + num_points_per_convex_view[idx] = 0; + return; + } + + // mask out gaussians outside the image region + if (center_convex_2D.x + max_distance <= 0 || center_convex_2D.x - max_distance >= image_width || + center_convex_2D.y + max_distance <= 0 || center_convex_2D.y - max_distance >= image_height) + // if (center_convex_2D.x + radius_x <= 0 || center_convex_2D.x - radius_x >= image_width || + // center_convex_2D.y + radius_y <= 0 || center_convex_2D.y - radius_y >= image_height) + { + radii[idx * 2] = 0; + radii[idx * 2 + 1] = 0; + num_points_per_convex_view[idx] = 0; + return; + } + + // write to outputs + radii[idx * 2] = (int32_t)max_distance; + radii[idx * 2 + 1] = (int32_t)max_distance; + means2d[idx * 2] = center_convex_2D.x; + means2d[idx * 2 + 1] = center_convex_2D.y; + depths[idx] = center_convex_c.z; + conics[idx * 3] = 0; + conics[idx * 3 + 1] = 0; + conics[idx * 3 + 2] = 0; +} + +void launch_projection_ewa_3dcs_fused_fwd_kernel( + // inputs + const at::Tensor convex_points, // [N, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor delta, // [N] + const at::Tensor sigma, // [N] + at::Tensor scaling, // [N] + const at::optional opacities, // [N] optional + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const CameraModelType camera_model, + // outputs + at::Tensor radii, // [C, N, 2] + at::Tensor normals, // [C, total_nb_points, 2] + at::Tensor offsets, // [C, total_nb_points] + at::Tensor p_image, // [C, total_nb_points, 2] + at::Tensor hull, // [C, 2*total_nb_points] + at::Tensor num_points_per_convex_view, // [C, N] + at::Tensor indices, // [C, total_nb_points] + at::Tensor means2d, // [C, N, 2] + at::Tensor depths, // [C, N] + at::Tensor conics, // [C, N, 3] + at::optional compensations // [C, N] optional +) { + uint32_t N = convex_points.size(0); // number of 3D convexes + uint32_t C = viewmats.size(0); // number of cameras + + int64_t n_elements = C * N; + dim3 threads(256); + dim3 grid((n_elements + threads.x - 1) / threads.x); + int64_t shmem_size = 0; // No shared memory used in this kernel + + if (n_elements == 0) { + // skip the kernel launch if there are no elements + return; + } + + AT_DISPATCH_FLOATING_TYPES( + convex_points.scalar_type(), + "projection_ewa_3dcs_fused_fwd_kernel", + [&]() { + projection_ewa_3dcs_fused_fwd_kernel + <<>>( + C, + N, + convex_points.data_ptr(), + cumsum_of_points_per_convex.data_ptr(), // FIXME: Good type? + delta.data_ptr(), + sigma.data_ptr(), + scaling.data_ptr(), + opacities.has_value() ? opacities.value().data_ptr() + : nullptr, + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + eps2d, + near_plane, + far_plane, + radius_clip, + camera_model, + radii.data_ptr(), + normals.data_ptr(), + offsets.data_ptr(), + p_image.data_ptr(), + hull.data_ptr(), + num_points_per_convex_view.data_ptr(), + indices.data_ptr(), + means2d.data_ptr(), + depths.data_ptr(), + conics.data_ptr(), + compensations.has_value() + ? compensations.value().data_ptr() + : nullptr + ); + } + ); +} + +template +__global__ void projection_ewa_3dcs_fused_bwd_kernel( + // fwd inputs + const uint32_t C, + const uint32_t N, + const scalar_t *__restrict__ convex_points, // [N, K, 3] + const int32_t *__restrict__ cumsum_of_points_per_convex, // [N] + // float* p_w, + const scalar_t *__restrict__ viewmats, // [C, 4, 4] + const scalar_t *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, + const int32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const int32_t *__restrict__ radii, // [C, N, 2] + + const int32_t *__restrict__ hull, // [C, 2*total_nb_points] + int32_t *__restrict__ num_points_per_convex_view, // [C, N] + //const scalar_t *__restrict__ normals, // [C, total_nb_points, 2] + const scalar_t *__restrict__ offsets, // [C, total_nb_points] + const scalar_t *__restrict__ p_image, // [C, total_nb_points, 2] + const int32_t *__restrict__ indices, // [C, total_nb_points] + + const scalar_t *__restrict__ conics, // [C, N, 3] + const scalar_t *__restrict__ compensations, // [C, N] optional + // grad outputs + const scalar_t *__restrict__ v_means2d, // [C, N, 2] + const scalar_t *__restrict__ v_depths, // [C, N] + const scalar_t *__restrict__ v_conics, // [C, N, 3] + const scalar_t *__restrict__ v_compensations, // [C, N] optional + // grad inputs + scalar_t *__restrict__ v_convex_points, // [N, K, 3] + scalar_t *__restrict__ v_viewmats, // [C, 4, 4] optional + const scalar_t *__restrict__ v_normals, // [C, total_nb_points, 2] + const scalar_t *__restrict__ v_offsets // [C, total_nb_points] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N || radii[idx * 2] <= 0 || radii[idx * 2 + 1] <= 0) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t conv_id = idx % N; // convex id + + // shift pointers to the current camera and gaussian + convex_points += conv_id * 6 * 3;//cumsum_for_convex * 3; + viewmats += cid * 16; + Ks += cid * 9; + + conics += idx * 3; + + v_means2d += idx * 2; + v_depths += idx; + v_conics += idx * 3; + + v_normals += idx * 6 * 2; + v_offsets += idx * 6; + + // vjp: compute the inverse of the 2d covariance + mat2 v_covar2d(0.f); + + // FIXME: If the number of points per convex is expected to change, change this. + int num_points_per_convex = 6; + + // Initialize loss accumulators for normals and offsets + float loss_points_x[MAX_NB_POINTS] = {0.0f}; + float loss_points_y[MAX_NB_POINTS] = {0.0f}; + + // Iterate over the convex hull. + for (int i = 0; i < num_points_per_convex_view[idx]; i++) + { + float v_normal_x = v_normals[2*i]; + float v_normal_y = v_normals[2*i + 1]; + float v_offset = v_offsets[i]; + + float2 p1_conv{p_image[2*idx*6 + 2*hull[2*idx*6 + i]], p_image[2*idx*6 + 2*hull[2*idx*6 + i] + 1]}; + float2 p2_conv{p_image[2*idx*6 + 2*hull[2*idx*6 + (i + 1) % (num_points_per_convex_view[idx]+1)]], p_image[2*idx*6 + 2*hull[2*idx*6 + (i + 1) % (num_points_per_convex_view[idx]+1)] + 1]}; + + // Calculate the normal vector (90-degree counterclockwise rotation) + float2 normal = { p2_conv.y - p1_conv.y, -(p2_conv.x - p1_conv.x)}; + + // Calculate the gradient of the loss with respect to the points p1 and p2 + // Gradient with respect to p1 (due to normal and offset) + float2 v_p1_conv = { (v_normal_y - v_offset * p2_conv.y), (-v_normal_x + v_offset * p2_conv.x)}; + + float2 v_p2_conv = {(-v_normal_y + v_offset * p1_conv.y), (v_normal_x - v_offset * p1_conv.x)}; + + loss_points_x[indices[idx*6 + hull[2*idx*6 + i]]] += v_p1_conv.x; + loss_points_y[indices[idx*6 + hull[2*idx*6 + i]]] += v_p1_conv.y; + + loss_points_x[indices[idx*6 + hull[2*idx*6 + (i + 1) % (num_points_per_convex_view[idx]+1)]]] += v_p2_conv.x; + loss_points_y[indices[idx*6 + hull[2*idx*6 + (i + 1) % (num_points_per_convex_view[idx]+1)]]] += v_p2_conv.y; + } + + vec2 v_means2d_tmp(0.f); + v_means2d_tmp.x = 0; + v_means2d_tmp.y = 0; + + // transform Gaussian to camera space + mat3 R = mat3( + viewmats[0], + viewmats[4], + viewmats[8], // 1st column + viewmats[1], + viewmats[5], + viewmats[9], // 2nd column + viewmats[2], + viewmats[6], + viewmats[10] // 3rd column + ); + vec3 translation = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // vjp: perspective projection + float fx = Ks[0], cx = Ks[2], fy = Ks[4], cy = Ks[5]; + + vec2 point_convex_2D[6]; + vec3 v_mean[6]; + for (int i = 0; i < num_points_per_convex; i++) + { + v_means2d_tmp.x += loss_points_x[i]; + v_means2d_tmp.y += loss_points_y[i]; + vec2 loss_point{loss_points_x[i], loss_points_y[i]}; + vec3 convex_point(convex_points[3 * i], convex_points[3 * i + 1], convex_points[3 * i + 2]); + vec3 convex_point_c{0.0f, 0.0f, 0.0f}; + posW2C(R, translation, convex_point, convex_point_c); + + mat3 covar_c(0.f); + + mat3 v_covar_c(0.f); + vec3 v_convex_point_c(0.f); + + switch (camera_model) { + case CameraModelType::PINHOLE: // perspective projection + persp_proj_vjp( + convex_point_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + loss_point, + v_convex_point_c, + v_covar_c + ); + break; + case CameraModelType::ORTHO: // orthographic projection + ortho_proj_vjp( + convex_point_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + loss_point, + v_convex_point_c, + v_covar_c + ); + break; + case CameraModelType::FISHEYE: // fisheye projection + fisheye_proj_vjp( + convex_point_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + loss_point, + v_convex_point_c, + v_covar_c + ); + break; + } + + // add contribution from v_depths --> Basically put the correct Z to the point. + v_convex_point_c.z += v_depths[0]; + + // vjp: transform Gaussian covariance to camera space + v_mean[i] = vec3(0.f); + mat3 v_covar(0.f); + mat3 v_R(0.f); + vec3 v_t(0.f); + posW2C_VJP(R, translation, convex_point, v_convex_point_c, v_R, v_t, v_mean[i]); + } + + // #if __CUDA_ARCH__ >= 700 + // write out results with warp-level reduction + auto warp = cg::tiled_partition<32>(cg::this_thread_block()); + auto warp_group_g = cg::labeled_partition(warp, conv_id); + if (v_convex_points != nullptr) + { + // FIXME: Reactivate this! + //warpSum(v_mean, warp_group_g); + warpSum(v_mean, warp_group_g); + if (warp_group_g.thread_rank() == 0) + { + float *v_convex_points_ptr = (float *)(v_convex_points) + 6*conv_id*3; +#pragma unroll + for (int i = 0; i < num_points_per_convex; i++) + { +#pragma unroll + for (uint32_t h = 0; h < 3; h++) + { + gpuAtomicAdd(v_convex_points_ptr + h + 3*i, v_mean[i][h]); + } + } + } + } + +// if (v_viewmats != nullptr) { +// auto warp_group_c = cg::labeled_partition(warp, cid); +// warpSum(v_R, warp_group_c); +// warpSum(v_t, warp_group_c); +// if (warp_group_c.thread_rank() == 0) { +// v_viewmats += cid * 16; +// #pragma unroll +// for (uint32_t i = 0; i < 3; i++) { // rows +// #pragma unroll +// for (uint32_t j = 0; j < 3; j++) { // cols +// gpuAtomicAdd(v_viewmats + i * 4 + j, v_R[j][i]); +// } +// gpuAtomicAdd(v_viewmats + i * 4 + 3, v_t[i]); +// } +// } +// } +} + +void launch_projection_ewa_3dcs_fused_bwd_kernel( + // inputs + // fwd inputs + const at::Tensor convex_points, // [N, K, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const at::Tensor radii, // [C, N, 2] + const at::Tensor hull, // [C, 2*total_nb_points] + const at::Tensor num_points_per_convex_view, // [C, N] + //const at::Tensor normals, // [C, total_nb_points, 2] + const at::Tensor offsets, // [C, total_nb_points] + const at::Tensor p_image, // [C, total_nb_points, 2] + const at::Tensor indices, // [C, total_nb_points] + const at::Tensor conics, // [C, N, 3] + const at::optional compensations, // [C, N] optional + // grad inputs + const at::Tensor v_normals, // [C, total_nb_points, 2] + const at::Tensor v_offsets, // [C, total_nb_points] + // grad outputs + const at::Tensor v_means2d, // [C, N, 2] + const at::Tensor v_depths, // [C, N] + const at::Tensor v_conics, // [C, N, 3] + const at::optional v_compensations, // [C, N] optional + const bool viewmats_requires_grad, + // outputs + at::Tensor v_convex_points, // [N, K, 3] + at::Tensor v_viewmats // [C, 4, 4] +) { + uint32_t N = convex_points.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + + int64_t n_elements = C * N; + dim3 threads(256); + dim3 grid((n_elements + threads.x - 1) / threads.x); + int64_t shmem_size = 0; // No shared memory used in this kernel + + if (n_elements == 0) { + // skip the kernel launch if there are no elements + return; + } + + // FIXME: Why does is it a double here? + AT_DISPATCH_FLOATING_TYPES( + convex_points.scalar_type(), + "projection_ewa_3dcs_fused_bwd_kernel", + [&]() { + projection_ewa_3dcs_fused_bwd_kernel + <<>>( + C, + N, + convex_points.data_ptr(), + cumsum_of_points_per_convex.data_ptr(), // FIXME: Good type? + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + eps2d, + camera_model, + radii.data_ptr(), + hull.data_ptr(), + num_points_per_convex_view.data_ptr(), + // normals.data_ptr(), + offsets.data_ptr(), + p_image.data_ptr(), + indices.data_ptr(), + conics.data_ptr(), + compensations.has_value() + ? compensations.value().data_ptr() + : nullptr, + v_means2d.data_ptr(), + v_depths.data_ptr(), + v_conics.data_ptr(), + v_compensations.has_value() + ? v_compensations.value().data_ptr() + : nullptr, + v_convex_points.data_ptr(), + viewmats_requires_grad ? v_viewmats.data_ptr() + : nullptr, + v_normals.data_ptr(), + v_offsets.data_ptr() + ); + } + ); +} + +} // namespace gsplat diff --git a/gsplat/cuda/csrc/Rasterization.cpp b/gsplat/cuda/csrc/Rasterization.cpp index 504e6a83f..2012e8e45 100644 --- a/gsplat/cuda/csrc/Rasterization.cpp +++ b/gsplat/cuda/csrc/Rasterization.cpp @@ -668,4 +668,276 @@ std::tuple rasterize_to_indices_2dgs( return std::make_tuple(gaussian_ids, pixel_ids); } +//////////////////////////////////////////////////// +// 3DCS +//////////////////////////////////////////////////// + +std::tuple rasterize_to_pixels_3dcs_fwd( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz, 2] + const at::Tensor offsets, // [C, N, 1] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const uint32_t num_points_per_convex, // 6 + const at::Tensor cumsum_of_points_per_convex, // [C] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, channels] or [nnz, channels] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, channels] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids // [n_isects] +) { + DEVICE_GUARD(means2d); + CHECK_INPUT(means2d); + CHECK_INPUT(normals); + CHECK_INPUT(offsets); + CHECK_INPUT(num_points_per_convex_view); + CHECK_INPUT(delta); + CHECK_INPUT(sigma); + CHECK_INPUT(cumsum_of_points_per_convex); + CHECK_INPUT(depths); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(tile_offsets); + CHECK_INPUT(flatten_ids); + if (backgrounds.has_value()) { + CHECK_INPUT(backgrounds.value()); + } + if (masks.has_value()) { + CHECK_INPUT(masks.value()); + } + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t channels = colors.size(-1); + + at::Tensor renders = + at::empty({C, image_height, image_width, channels}, means2d.options()); + at::Tensor alphas = + at::empty({C, image_height, image_width, 1}, means2d.options()); + at::Tensor last_ids = at::empty( + {C, image_height, image_width}, means2d.options().dtype(at::kInt) + ); + +#define __LAUNCH_KERNEL__(N) \ + case N: \ + launch_rasterize_to_pixels_3dcs_fwd_kernel( \ + means2d, \ + normals, \ + offsets, \ + num_points_per_convex_view, \ + delta, \ + sigma, \ + num_points_per_convex, \ + cumsum_of_points_per_convex, \ + depths, \ + conics, \ + colors, \ + opacities, \ + backgrounds, \ + masks, \ + image_width, \ + image_height, \ + tile_size, \ + tile_offsets, \ + flatten_ids, \ + renders, \ + alphas, \ + last_ids \ + ); \ + break; + + // TODO: an optimization can be done by passing the actual number of + // channels into the kernel functions and avoid necessary global memory + // writes. This requires moving the channel padding from python to C side. + switch (channels) { + __LAUNCH_KERNEL__(1) + __LAUNCH_KERNEL__(2) + __LAUNCH_KERNEL__(3) + __LAUNCH_KERNEL__(4) + __LAUNCH_KERNEL__(5) + __LAUNCH_KERNEL__(8) + __LAUNCH_KERNEL__(9) + __LAUNCH_KERNEL__(16) + __LAUNCH_KERNEL__(17) + __LAUNCH_KERNEL__(32) + __LAUNCH_KERNEL__(33) + __LAUNCH_KERNEL__(64) + __LAUNCH_KERNEL__(65) + __LAUNCH_KERNEL__(128) + __LAUNCH_KERNEL__(129) + __LAUNCH_KERNEL__(256) + __LAUNCH_KERNEL__(257) + __LAUNCH_KERNEL__(512) + __LAUNCH_KERNEL__(513) + default: + AT_ERROR("Unsupported number of channels: ", channels); + } +#undef __LAUNCH_KERNEL__ + + return std::make_tuple(renders, alphas, last_ids); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +rasterize_to_pixels_3dcs_bwd( + // Gaussian parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz*6, 2] + const at::Tensor offsets, // [C, K, 3] or [nnz, 3] + const uint32_t num_points_per_convex, // 6 + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor cumsum_of_points_per_convex, // [C, N] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, 3] or [nnz, 3] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, 3] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids, // [n_isects] + // forward outputs + const at::Tensor render_alphas, // [C, image_height, image_width, 1] + const at::Tensor last_ids, // [C, image_height, image_width] + // gradients of outputs + const at::Tensor v_render_colors, // [C, image_height, image_width, 3] + const at::Tensor v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad +) { + DEVICE_GUARD(means2d); + CHECK_INPUT(means2d); + CHECK_INPUT(normals); + CHECK_INPUT(offsets); + CHECK_INPUT(delta); + CHECK_INPUT(sigma); + CHECK_INPUT(cumsum_of_points_per_convex); + CHECK_INPUT(depths); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(tile_offsets); + CHECK_INPUT(flatten_ids); + CHECK_INPUT(render_alphas); + CHECK_INPUT(last_ids); + CHECK_INPUT(v_render_colors); + CHECK_INPUT(v_render_alphas); + if (backgrounds.has_value()) { + CHECK_INPUT(backgrounds.value()); + } + if (masks.has_value()) { + CHECK_INPUT(masks.value()); + } + + uint32_t channels = colors.size(-1); + + at::Tensor v_means2d = at::zeros_like(means2d); + at::Tensor v_normals = at::zeros_like(normals); + at::Tensor v_offsets = at::zeros_like(offsets); + at::Tensor v_delta = at::zeros_like(delta); + at::Tensor v_sigma = at::zeros_like(sigma); + at::Tensor v_conics = at::zeros_like(conics); + at::Tensor v_colors = at::zeros_like(colors); + at::Tensor v_opacities = at::zeros_like(opacities); + at::Tensor v_means2d_abs; + if (absgrad) { + v_means2d_abs = at::zeros_like(means2d); + } + +#define __LAUNCH_KERNEL__(N) \ + case N: \ + launch_rasterize_to_pixels_3dcs_bwd_kernel( \ + means2d, \ + normals, \ + offsets, \ + num_points_per_convex, \ + delta, \ + sigma, \ + num_points_per_convex_view, \ + cumsum_of_points_per_convex, \ + depths, \ + conics, \ + colors, \ + opacities, \ + backgrounds, \ + masks, \ + image_width, \ + image_height, \ + tile_size, \ + tile_offsets, \ + flatten_ids, \ + render_alphas, \ + last_ids, \ + v_render_colors, \ + v_render_alphas, \ + absgrad ? c10::optional(v_means2d_abs) : c10::nullopt, \ + v_means2d, \ + v_normals, \ + v_offsets, \ + v_delta, \ + v_sigma, \ + v_conics, \ + v_colors, \ + v_opacities \ + ); \ + break; + + // TODO: an optimization can be done by passing the actual number of + // channels into the kernel functions and avoid necessary global memory + // writes. This requires moving the channel padding from python to C side. + switch (channels) { + __LAUNCH_KERNEL__(1) + __LAUNCH_KERNEL__(2) + __LAUNCH_KERNEL__(3) + __LAUNCH_KERNEL__(4) + __LAUNCH_KERNEL__(5) + __LAUNCH_KERNEL__(8) + __LAUNCH_KERNEL__(9) + __LAUNCH_KERNEL__(16) + __LAUNCH_KERNEL__(17) + __LAUNCH_KERNEL__(32) + __LAUNCH_KERNEL__(33) + __LAUNCH_KERNEL__(64) + __LAUNCH_KERNEL__(65) + __LAUNCH_KERNEL__(128) + __LAUNCH_KERNEL__(129) + __LAUNCH_KERNEL__(256) + __LAUNCH_KERNEL__(257) + __LAUNCH_KERNEL__(512) + __LAUNCH_KERNEL__(513) + default: + AT_ERROR("Unsupported number of channels: ", channels); + } +#undef __LAUNCH_KERNEL__ + + return std::make_tuple( + v_means2d_abs, v_means2d, v_normals, v_offsets, v_delta, v_sigma, v_conics, v_colors, v_opacities + ); +} + + } // namespace gsplat diff --git a/gsplat/cuda/csrc/Rasterization.h b/gsplat/cuda/csrc/Rasterization.h index 82477863b..e18ae232d 100644 --- a/gsplat/cuda/csrc/Rasterization.h +++ b/gsplat/cuda/csrc/Rasterization.h @@ -191,4 +191,81 @@ void launch_rasterize_to_indices_2dgs_kernel( at::optional pixel_ids // [n_elems] ); +///////////////////////////////////////////////// +// rasterize_to_pixels 3DCS +///////////////////////////////////////////////// + +template +void launch_rasterize_to_pixels_3dcs_fwd_kernel( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz, 2] + const at::Tensor offsets, // [C, N, 1] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const uint32_t num_points_per_convex, // 6 + const at::Tensor cumsum_of_points_per_convex, // [C] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, channels] or [nnz, channels] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, channels] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids, // [n_isects] + // outputs + at::Tensor renders, // [C, image_height, image_width, channels] + at::Tensor alphas, // [C, image_height, image_width] + at::Tensor last_ids // [C, image_height, image_width] +); + +template +void launch_rasterize_to_pixels_3dcs_bwd_kernel( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz*6, 2] + const at::Tensor offsets, // [C, K, 3] or [nnz, 3] + const uint32_t num_points_per_convex, // 6 + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor cumsum_of_points_per_convex, // [C, N] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, 3] or [nnz, 3] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, 3] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids, // [n_isects] + // forward outputs + const at::Tensor render_alphas, // [C, image_height, image_width, 1] + const at::Tensor last_ids, // [C, image_height, image_width] + // gradients of outputs + const at::Tensor v_render_colors, // [C, image_height, image_width, 3] + const at::Tensor v_render_alphas, // [C, image_height, image_width, 1] + // outputs + at::optional v_means2d_abs, // [C, N, 2] or [nnz, 2] + at::Tensor v_means2d, // [C, N, 2] or [nnz, 2] + at::Tensor v_normals, + at::Tensor v_offsets, + at::Tensor v_delta, + at::Tensor v_sigma, + at::Tensor v_conics, // [C, N, 3] or [nnz, 3] + at::Tensor v_colors, // [C, N, 3] or [nnz, 3] + at::Tensor v_opacities // [C, N] or [nnz] +); + + } // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/RasterizeToPixels3DCSBwd.cu b/gsplat/cuda/csrc/RasterizeToPixels3DCSBwd.cu new file mode 100644 index 000000000..7e49bcb1e --- /dev/null +++ b/gsplat/cuda/csrc/RasterizeToPixels3DCSBwd.cu @@ -0,0 +1,593 @@ +#include +#include +#include +#include +#include + +#include "Common.h" +#include "Rasterization.h" +#include "Utils.cuh" + +#define MAX_NB_POINTS 8 + +namespace gsplat { + +namespace cg = cooperative_groups; + +template +__global__ void rasterize_to_pixels_3dcs_bwd_kernel( + const uint32_t C, + const uint32_t N, + const uint32_t n_isects, + const bool packed, + // fwd inputs + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const scalar_t *__restrict__ normals, // [C, total_nb_points, 2] or [nnz*6, 2] + const scalar_t *__restrict__ offsets, // [C, total_nb_points] or [nnz*6] + const uint32_t num_points_per_convex, // [K] + const scalar_t *__restrict__ delta, // [C, N] + const scalar_t *__restrict__ sigma, // [C, N] + const int32_t *__restrict__ num_points_per_convex_view, // [C, N] + const int32_t *__restrict__ cumsum_of_points_per_convex, // [C, N] + const scalar_t *__restrict__ depths, // [C, N] + const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] + const scalar_t *__restrict__ colors, // [C, N, CDIM] or [nnz, CDIM] + const scalar_t *__restrict__ opacities, // [C, N] or [nnz] + const scalar_t *__restrict__ backgrounds, // [C, CDIM] or [nnz, CDIM] + const bool *__restrict__ masks, // [C, tile_height, tile_width] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + const uint32_t tile_width, + const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + // fwd outputs + const scalar_t + *__restrict__ render_alphas, // [C, image_height, image_width, 1] + const int32_t *__restrict__ last_ids, // [C, image_height, image_width] + // grad outputs + const scalar_t *__restrict__ v_render_colors, // [C, image_height, + // image_width, CDIM] + const scalar_t + *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] + // grad inputs + vec2 *__restrict__ v_means2d_abs, // [C, N, 2] or [nnz, 2] + vec2 *__restrict__ v_means2d, // [C, N, 2] or [nnz, 2] + scalar_t *__restrict__ v_normals, // [C, total_nb_points, 2] or [nnz*6, 2] + scalar_t *__restrict__ v_offsets, // [C, k] or [nnz*6] + scalar_t *__restrict__ v_delta, // [C, N] + scalar_t *__restrict__ v_sigma, // [C, N] + vec3 *__restrict__ v_conics, // [C, N, 3] or [nnz, 3] + scalar_t *__restrict__ v_colors, // [C, N, CDIM] or [nnz, CDIM] + scalar_t *__restrict__ v_opacities // [C, N] or [nnz] +) { + auto block = cg::this_thread_block(); + uint32_t camera_id = block.group_index().x; + uint32_t tile_id = + block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width; + v_render_colors += camera_id * image_height * image_width * CDIM; + v_render_alphas += camera_id * image_height * image_width; + if (backgrounds != nullptr) { + backgrounds += camera_id * CDIM; + } + if (masks != nullptr) { + masks += camera_id * tile_height * tile_width; + } + + // when the mask is provided, do nothing and return if + // this tile is labeled as False + if (masks != nullptr && !masks[tile_id]) { + return; + } + + const float px = (float)j + 0.5f; + const float py = (float)i + 0.5f; + // clamp this value to the last pixel + const int32_t pix_id = + min(i * image_width + j, image_width * image_height - 1); + + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + const uint32_t num_batches = + (range_end - range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast(&id_batch[block_size]); // [block_size] + vec3 *conic_batch = + reinterpret_cast(&xy_opacity_batch[block_size]); // [block_size] + float *rgbs_batch = + (float *)&conic_batch[block_size]; // [block_size * CDIM] + + // 3DCS part + scalar_t *normals_batch = (scalar_t *)&rgbs_batch[block_size*CDIM]; + // FIXME: Should offset be a vec3? + scalar_t *offsets_batch = (scalar_t *)&normals_batch[block_size*MAX_NB_POINTS*2]; + int32_t *num_points_per_convex_view_batch = (int32_t *)&offsets_batch[block_size*MAX_NB_POINTS]; + //int32_t *cumsum_of_points_per_convex_batch = reinterpret_cast(&num_points_per_convex_view_batch[block_size]); + scalar_t *delta_batch = (scalar_t *)&num_points_per_convex_view_batch[block_size]; + scalar_t *sigma_batch = (scalar_t *)&delta_batch[block_size]; + scalar_t *depths_batch = (scalar_t *)&sigma_batch[block_size]; + + // this is the T AFTER the last gaussian in this pixel + float T_final = 1.0f - render_alphas[pix_id]; + float T = T_final; + // the contribution from gaussians behind the current one + float buffer[CDIM] = {0.f}; + // index of last gaussian to contribute to this pixel + const int32_t bin_final = inside ? last_ids[pix_id] : 0; + + // df/d_out for this pixel + float v_render_c[CDIM]; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + v_render_c[k] = v_render_colors[pix_id * CDIM + k]; + } + const float v_render_a = v_render_alphas[pix_id]; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing + const uint32_t tr = block.thread_rank(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + const int32_t warp_bin_final = + cg::reduce(warp, bin_final, cg::greater()); + for (uint32_t b = 0; b < num_batches; ++b) + { + // resync all threads before writing next batch of shared mem + block.sync(); + + // each thread fetch 1 gaussian from back to front + // 0 index will be furthest back in batch + // index of gaussian to load + // batch end is the index of the last gaussian in the batch + // These values can be negative so must be int32 instead of uint32 + const int32_t batch_end = range_end - 1 - block_size * b; + const int32_t batch_size = min(block_size, batch_end + 1 - range_start); + const int32_t idx = batch_end - tr; + if (idx >= range_start) + { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const float opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g]; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) + { + rgbs_batch[tr * CDIM + k] = colors[g * CDIM + k]; + } + + // 3DCS part + num_points_per_convex_view_batch[tr] = num_points_per_convex_view[g]; + delta_batch[tr] = delta[g]; + sigma_batch[tr] = sigma[g]; + depths_batch[tr] = depths[g]; +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view[g]; k++) + { + normals_batch[6*tr*2 + 2*k] = normals[6*g*2 + 2*k]; + } +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view[g]; k++) + { + normals_batch[6*tr*2 + 2*k + 1] = normals[6*g*2 + 2*k + 1]; + } +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view[g]; k++) + { + offsets_batch[6*tr + k] = offsets[6*g + k]; + } + + } + // wait for other threads to collect the gaussians in batch + block.sync(); + // process gaussians in the current batch for this pixel + // 0 index is the furthest back gaussian in the batch + for (uint32_t t = max(0, batch_end - warp_bin_final); t < batch_size; ++t) + { + bool valid = inside; + if (batch_end - t > bin_final) { + valid = 0; + } + float alpha; + float opac; + + // Propagate gradients to per-Convex colors and keep + // gradients w.r.t. alpha (blending factor for a Convex/pixel + // pair). + float distances[MAX_NB_POINTS]; + float max_val = -INFINITY; + float Cx = 0.0f; + float phi_x = 0.0f; + float sum_exp = 0.0f; + + if (valid) + { + vec3 xy_opac = xy_opacity_batch[t]; + opac = xy_opac.z; + + // 3DCS part. + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + distances[k] = normals_batch[6*t*2 + 2*k] * px + normals_batch[6*t*2 + 2*k + 1] * py + offsets_batch[6*t + k]; + + if (distances[k] > max_val) + { + max_val = distances[k]; + } + } + +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + sum_exp += __expf(depths_batch[t] * delta_batch[t] * (distances[k]-max_val)); + } + + phi_x = depths_batch[t] * delta_batch[t]*max_val + __logf(sum_exp); + + Cx = 1.0f / (1.0f + __expf(depths_batch[t] * sigma_batch[t] * phi_x)); + + alpha = min(0.999f, opac * Cx); + if (alpha < ALPHA_THRESHOLD) + { + valid = false; + } + } + + // if all threads are inactive in this warp, skip this loop + if (!warp.any(valid)) + { + continue; + } + float v_rgb_local[CDIM] = {0.f}; + vec2 v_xy_local = {0.f, 0.f}; + vec2 v_xy_abs_local = {0.f, 0.f}; + float v_opacity_local = 0.f; + + float v_delta_value_aux = 0.0f; + float v_sigma_value = 0.0f; + + float v_normal_local[MAX_NB_POINTS*2] = {0.f}; + float v_offset_local[MAX_NB_POINTS]= {0.f}; + + // initialize everything to 0, only set if the lane is valid + if (valid) + { + // compute the current T for this gaussian + float ra = 1.0f / (1.0f - alpha); + T *= ra; + // update v_rgb for this gaussian + const float fac = alpha * T; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + v_rgb_local[k] = fac * v_render_c[k]; + } + // contribution from this pixel + float v_alpha = 0.f; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + v_alpha += (rgbs_batch[t * CDIM + k] * T - buffer[k] * ra) * + v_render_c[k]; + } + + v_alpha += T_final * ra * v_render_a; + // contribution from background pixel + if (backgrounds != nullptr) { + float accum = 0.f; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + accum += backgrounds[k] * v_render_c[k]; + } + v_alpha += -T_final * ra * accum; + } + + if (opac * Cx <= 0.999f) + { + const float alpha = min(0.99f, opac * Cx); + v_opacity_local = Cx * v_alpha; + + // Helpful reusable temporary variables + float v_C = opac * v_alpha; + + // Calculate gradients w.r.t sigma + v_sigma_value = -depths_batch[t] * phi_x * Cx * (1.0f - Cx) * v_C; // remove depth here + + // Calculate gradient w.r.t phi_x + float v_phi_x = -sigma_batch[t] * depths_batch[t] * Cx * (1.0f - Cx) * v_C; // remove depth here + + // Calculate gradients with respect to distances + float v_distances[MAX_NB_POINTS]; + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + float exp_val = __expf(depths_batch[t] * delta_batch[t] * (distances[k]-max_val)); + v_distances[k] = (exp_val / sum_exp) * v_phi_x * delta_batch[t] * depths_batch[t]; + } + + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + v_normal_local[2*k] = v_distances[k] * px; + v_normal_local[2*k + 1] = v_distances[k] * py; + } + + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + v_offset_local[k] = v_distances[k]; + } + + // Gradient with respect to delta + float v_delta_value = 0.0f; + + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + float exp_val = __expf(delta_batch[t] * depths_batch[t] * (distances[k]-max_val)); + v_delta_value += depths_batch[t] * (distances[k]-max_val) * exp_val / sum_exp; + } + + // Multiply by the chain rule term v_phi_x + v_delta_value_aux = (depths_batch[t] * max_val + v_delta_value) * v_phi_x; + + } + +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) + { + buffer[k] += rgbs_batch[t * CDIM + k] * fac; + } + } + warpSum(v_rgb_local, warp); + warpSum(v_opacity_local, warp); + warpSum(v_sigma_value, warp); + warpSum(v_delta_value_aux, warp); + warpSum(v_normal_local, warp); + warpSum(v_offset_local, warp); + if (warp.thread_rank() == 0) + { + int32_t g = id_batch[t]; // flatten index in [C * N] or [nnz] + float *v_rgb_ptr = (float *)(v_colors) + CDIM * g; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) + { + gpuAtomicAdd(v_rgb_ptr + k, v_rgb_local[k]); + } + + gpuAtomicAdd(v_opacities + g, v_opacity_local); + + // Apply the gradient update to sigma + gpuAtomicAdd(v_sigma + g, v_sigma_value); + + // Apply the gradient update to delta + gpuAtomicAdd(v_delta + g, v_delta_value_aux); + + // Calculate gradients w.r.t normals and offsets + float *v_normals_ptr = (float *)(v_normals) + 2*g*6; +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + gpuAtomicAdd(v_normals_ptr + 2*k, v_normal_local[2*k]); + } + +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + gpuAtomicAdd(v_normals_ptr + 2*k + 1, v_normal_local[2*k+1]); + } + + float *v_offsets_ptr = (float *)(v_offsets) + g*6; +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + gpuAtomicAdd(v_offsets_ptr + k, v_offset_local[k]); + } + } + } + } +} + +template +void launch_rasterize_to_pixels_3dcs_bwd_kernel( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz, 2] + const at::Tensor offsets, // [C, total_nb_points] or [nnz] + const uint32_t num_points_per_convex, // 6 + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor cumsum_of_points_per_convex, // [C, N] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, 3] or [nnz, 3] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, 3] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids, // [n_isects] + // forward outputs + const at::Tensor render_alphas, // [C, image_height, image_width, 1] + const at::Tensor last_ids, // [C, image_height, image_width] + // gradients of outputs + const at::Tensor v_render_colors, // [C, image_height, image_width, 3] + const at::Tensor v_render_alphas, // [C, image_height, image_width, 1] + // outputs + at::optional v_means2d_abs, // [C, N, 2] or [nnz, 2] + at::Tensor v_means2d, // [C, N, 2] or [nnz, 2] + at::Tensor v_normals, + at::Tensor v_offsets, + at::Tensor v_delta, + at::Tensor v_sigma, + at::Tensor v_conics, // [C, N, 3] or [nnz, 3] + at::Tensor v_colors, // [C, N, 3] or [nnz, 3] + at::Tensor v_opacities // [C, N] or [nnz] +) { + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of 3D convexes + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + uint32_t n_isects = flatten_ids.size(0); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 grid = {C, tile_height, tile_width}; + + int64_t shmem_size = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + + sizeof(float) * CDIM + sizeof(float)*2*MAX_NB_POINTS + sizeof(float)*MAX_NB_POINTS + + sizeof(int32_t) /*+ sizeof(int32_t)*/ + sizeof(float) + sizeof(float) + sizeof(float)); + + if (n_isects == 0) { + // skip the kernel launch if there are no elements + return; + } + + // TODO: an optimization can be done by passing the actual number of + // channels into the kernel functions and avoid necessary global memory + // writes. This requires moving the channel padding from python to C side. + if (cudaFuncSetAttribute( + rasterize_to_pixels_3dcs_bwd_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_size + ) != cudaSuccess) { + AT_ERROR( + "Failed to set maximum shared memory size (requested ", + shmem_size, + " bytes), try lowering tile_size." + ); + } + + rasterize_to_pixels_3dcs_bwd_kernel + <<>>( + C, + N, + n_isects, + packed, + reinterpret_cast(means2d.data_ptr()), + normals.data_ptr(), + offsets.data_ptr(), + num_points_per_convex, + delta.data_ptr(), + sigma.data_ptr(), + num_points_per_convex_view.data_ptr(), + cumsum_of_points_per_convex.data_ptr(), + depths.data_ptr(), + reinterpret_cast(conics.data_ptr()), + colors.data_ptr(), + opacities.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() + : nullptr, + masks.has_value() ? masks.value().data_ptr() : nullptr, + image_width, + image_height, + tile_size, + tile_width, + tile_height, + tile_offsets.data_ptr(), + flatten_ids.data_ptr(), + render_alphas.data_ptr(), + last_ids.data_ptr(), + v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_means2d_abs.has_value() + ? reinterpret_cast( + v_means2d_abs.value().data_ptr() + ) + : nullptr, + reinterpret_cast(v_means2d.data_ptr()), + v_normals.data_ptr(), + v_offsets.data_ptr(), + v_delta.data_ptr(), + v_sigma.data_ptr(), + reinterpret_cast(v_conics.data_ptr()), + v_colors.data_ptr(), + v_opacities.data_ptr() + ); +} + +// Explicit Instantiation: this should match how it is being called in .cpp +// file. +// TODO: this is slow to compile, can we do something about it? +#define __INS__(CDIM) \ + template void launch_rasterize_to_pixels_3dcs_bwd_kernel( \ + const at::Tensor means2d, \ + const at::Tensor normals, \ + const at::Tensor offsets, \ + const uint32_t num_points_per_convex, \ + const at::Tensor delta, \ + const at::Tensor sigma, \ + const at::Tensor num_points_per_convex_view, \ + const at::Tensor cumsum_of_points_per_convex, \ + const at::Tensor depths, \ + const at::Tensor conics, \ + const at::Tensor colors, \ + const at::Tensor opacities, \ + const at::optional backgrounds, \ + const at::optional masks, \ + uint32_t image_width, \ + uint32_t image_height, \ + uint32_t tile_size, \ + const at::Tensor tile_offsets, \ + const at::Tensor flatten_ids, \ + const at::Tensor render_alphas, \ + const at::Tensor last_ids, \ + const at::Tensor v_render_colors, \ + const at::Tensor v_render_alphas, \ + at::optional v_means2d_abs, \ + at::Tensor v_means2d, \ + at::Tensor v_normals, \ + at::Tensor v_offsets, \ + at::Tensor v_delta, \ + at::Tensor v_sigma, \ + at::Tensor v_conics, \ + at::Tensor v_colors, \ + at::Tensor v_opacities \ + ); + +__INS__(1) +__INS__(2) +__INS__(3) +__INS__(4) +__INS__(5) +__INS__(8) +__INS__(9) +__INS__(16) +__INS__(17) +__INS__(32) +__INS__(33) +__INS__(64) +__INS__(65) +__INS__(128) +__INS__(129) +__INS__(256) +__INS__(257) +__INS__(512) +__INS__(513) +#undef __INS__ + +} // namespace gsplat diff --git a/gsplat/cuda/csrc/RasterizeToPixels3DCSFwd.cu b/gsplat/cuda/csrc/RasterizeToPixels3DCSFwd.cu new file mode 100644 index 000000000..001ac97fd --- /dev/null +++ b/gsplat/cuda/csrc/RasterizeToPixels3DCSFwd.cu @@ -0,0 +1,396 @@ +#include +#include +#include +#include + +#include "Common.h" +#include "Rasterization.h" + +#define MAX_NB_POINTS 8 + +namespace gsplat { + +namespace cg = cooperative_groups; + +//////////////////////////////////////////////////////////////// +// Forward +//////////////////////////////////////////////////////////////// + +template +__global__ void rasterize_to_pixels_3dcs_fwd_kernel( + const uint32_t C, + const uint32_t N, + const uint32_t n_isects, + const bool packed, + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const scalar_t *__restrict__ normals, // [C, total_nb_points, 2] or [nnz*6, 2] + const scalar_t *__restrict__ offsets, // [C, total_nb_points, 1] + const int32_t *__restrict__ num_points_per_convex_view, // [C, N] + const scalar_t *__restrict__ delta, // [C, N] + const scalar_t *__restrict__ sigma, // [C, N] + const uint32_t num_points_per_convex, // [C] + const int32_t *__restrict__ cumsum_of_points_per_convex, // [C] + const scalar_t *__restrict__ depths, // [C, N] + const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] + const scalar_t *__restrict__ colors, // [C, N, CDIM] or [nnz, CDIM] + const scalar_t *__restrict__ opacities, // [C, N] or [nnz] + const scalar_t *__restrict__ backgrounds, // [C, CDIM] + const bool *__restrict__ masks, // [C, tile_height, tile_width] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + const uint32_t tile_width, + const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + scalar_t *__restrict__ render_colors, // [C, image_height, image_width, CDIM] + scalar_t *__restrict__ render_alphas, // [C, image_height, image_width, 1] + int32_t *__restrict__ last_ids // [C, image_height, image_width] +) { + // each thread draws one pixel, but also timeshares caching gaussians in a + // shared tile + + auto block = cg::this_thread_block(); + int32_t camera_id = block.group_index().x; + int32_t tile_id = + block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_colors += camera_id * image_height * image_width * CDIM; + render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width; + if (backgrounds != nullptr) { + backgrounds += camera_id * CDIM; + } + if (masks != nullptr) { + masks += camera_id * tile_height * tile_width; + } + + float px = (float)j + 0.5f; + float py = (float)i + 0.5f; + int32_t pix_id = i * image_width + j; + + // return if out of bounds + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + bool done = !inside; + + // when the mask is provided, render the background color and return + // if this tile is labeled as False + if (masks != nullptr && inside && !masks[tile_id]) { +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + render_colors[pix_id * CDIM + k] = + backgrounds == nullptr ? 0.0f : backgrounds[k]; + } + return; + } + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + uint32_t num_batches = + (range_end - range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast(&id_batch[block_size]); // [block_size] + vec3 *conic_batch = + reinterpret_cast(&xy_opacity_batch[block_size]); // [block_size] + + scalar_t *normals_batch = (scalar_t *)&conic_batch[block_size]; + scalar_t *offsets_batch = (scalar_t *)&normals_batch[block_size*MAX_NB_POINTS*2]; + int32_t *num_points_per_convex_view_batch = (int32_t *)&offsets_batch[block_size*MAX_NB_POINTS]; + scalar_t *delta_batch = (scalar_t *)&num_points_per_convex_view_batch[block_size]; + scalar_t *sigma_batch = (scalar_t *)&delta_batch[block_size]; + scalar_t *depths_batch = (scalar_t *)&sigma_batch[block_size]; + + // current visibility left to render + // transmittance is gonna be used in the backward pass which requires a high + // numerical precision so we use double for it. However double make bwd 1.5x + // slower so we stick with float for now. + float T = 1.0f; + // index of most recent gaussian to write to this thread's pixel + uint32_t cur_idx = 0; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing its + // designated pixel + uint32_t tr = block.thread_rank(); + + float pix_out[CDIM] = {0.f}; + for (uint32_t b = 0; b < num_batches; ++b) + { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t batch_start = range_start + block_size * b; + uint32_t idx = batch_start + tr; + if (idx < range_end) + { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const float opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g]; + + // 3DCS batches. +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view[g]; k++) + { + normals_batch[6*tr*2 + 2*k] = normals[6*g*2 + 2*k]; + } + +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view[g]; k++) + { + normals_batch[6*tr*2 + 2*k + 1] = normals[6*g*2 + 2*k + 1]; + } +#pragma unroll + for (uint32_t k = 0; k < num_points_per_convex_view[g]; k++) + { + offsets_batch[6*tr + k] = offsets[6*g + k]; + } + delta_batch[tr] = delta[g]; + sigma_batch[tr] = sigma[g]; + depths_batch[tr] = depths[g]; + num_points_per_convex_view_batch[tr] = num_points_per_convex_view[g]; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, range_end - batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) + { + const vec3 xy_opac = xy_opacity_batch[t]; + const float opac = xy_opac.z; + + float distances[MAX_NB_POINTS]; // Max 4 distances as per collected_offsets Later, if we have more points per convex, this needs to be updated + float max_val = -INFINITY; + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + distances[k] = normals_batch[6*t*2 + 2*k] * px + normals_batch[6*t*2 + 2*k + 1] * py + offsets_batch[6*t + k]; + + if (distances[k] > max_val) + { + max_val = distances[k]; + } + } + + float sum_exp = 0.0f; +#pragma unrol + for (uint32_t k = 0; k < num_points_per_convex_view_batch[t]; k++) + { + sum_exp += __expf(depths_batch[t] * delta_batch[t] * (distances[k]-max_val)); + } + + float phi_x = depths_batch[t] * delta_batch[t]*max_val + __logf(sum_exp); + float Cx = 1.0f / (1.0f + __expf(depths_batch[t] * sigma_batch[t] * phi_x)); + + float alpha = min(0.999f, opac * Cx); + if (alpha < ALPHA_THRESHOLD) + { + continue; + } + + const float next_T = T * (1.0f - alpha); + if (next_T <= 1e-4f) { // this pixel is done: exclusive + done = true; + break; + } + + int32_t g = id_batch[t]; + const float vis = alpha * T; + const float *c_ptr = colors + g * CDIM; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + pix_out[k] += c_ptr[k] * vis; + } + cur_idx = batch_start + t; + + T = next_T; + } + } + + if (inside) { + // Here T is the transmittance AFTER the last gaussian in this pixel. + // We (should) store double precision as T would be used in backward + // pass and it can be very small and causing large diff in gradients + // with float32. However, double precision makes the backward pass 1.5x + // slower so we stick with float for now. + render_alphas[pix_id] = 1.0f - T; +#pragma unroll + for (uint32_t k = 0; k < CDIM; ++k) { + render_colors[pix_id * CDIM + k] = + backgrounds == nullptr ? pix_out[k] + : (pix_out[k] + T * backgrounds[k]); + } + // index in bin of last gaussian in this pixel + last_ids[pix_id] = static_cast(cur_idx); + } +} + +template +void launch_rasterize_to_pixels_3dcs_fwd_kernel( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz, 2] + const at::Tensor offsets, // [C, total_nb_points] or [nnz] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const uint32_t num_points_per_convex, // 6 + const at::Tensor cumsum_of_points_per_convex, // [C] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, channels] or [nnz, channels] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, channels] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids, // [n_isects] + // outputs + at::Tensor renders, // [C, image_height, image_width, channels] + at::Tensor alphas, // [C, image_height, image_width] + at::Tensor last_ids // [C, image_height, image_width] +) { + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + uint32_t n_isects = flatten_ids.size(0); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 grid = {C, tile_height, tile_width}; + + int64_t shmem_size = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + sizeof(float)*MAX_NB_POINTS*2 + + sizeof(float)*MAX_NB_POINTS + sizeof(int32_t) + sizeof(float) + sizeof(float) + sizeof(float)); + + // TODO: an optimization can be done by passing the actual number of + // channels into the kernel functions and avoid necessary global memory + // writes. This requires moving the channel padding from python to C side. + if (cudaFuncSetAttribute( + rasterize_to_pixels_3dcs_fwd_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_size + ) != cudaSuccess) { + AT_ERROR( + "Failed to set maximum shared memory size (requested ", + shmem_size, + " bytes), try lowering tile_size." + ); + } + + rasterize_to_pixels_3dcs_fwd_kernel + <<>>( + C, + N, + n_isects, + packed, + reinterpret_cast(means2d.data_ptr()), + normals.data_ptr(), + offsets.data_ptr(), + num_points_per_convex_view.data_ptr(), + delta.data_ptr(), + sigma.data_ptr(), + num_points_per_convex, + cumsum_of_points_per_convex.data_ptr(), + depths.data_ptr(), + reinterpret_cast(conics.data_ptr()), + colors.data_ptr(), + opacities.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() + : nullptr, + masks.has_value() ? masks.value().data_ptr() : nullptr, + image_width, + image_height, + tile_size, + tile_width, + tile_height, + tile_offsets.data_ptr(), + flatten_ids.data_ptr(), + renders.data_ptr(), + alphas.data_ptr(), + last_ids.data_ptr() + ); +} + +// Explicit Instantiation: this should match how it is being called in .cpp +// file. +// TODO: this is slow to compile, can we do something about it? +#define __INS__(CDIM) \ + template void launch_rasterize_to_pixels_3dcs_fwd_kernel( \ + const at::Tensor means2d, \ + const at::Tensor normals, \ + const at::Tensor offsets, \ + const at::Tensor num_points_per_convex_view, \ + const at::Tensor delta, \ + const at::Tensor sigma, \ + const uint32_t num_points_per_convex, \ + const at::Tensor cumsum_of_points_per_convex, \ + const at::Tensor depths, \ + const at::Tensor conics, \ + const at::Tensor colors, \ + const at::Tensor opacities, \ + const at::optional backgrounds, \ + const at::optional masks, \ + uint32_t image_width, \ + uint32_t image_height, \ + uint32_t tile_size, \ + const at::Tensor tile_offsets, \ + const at::Tensor flatten_ids, \ + at::Tensor renders, \ + at::Tensor alphas, \ + at::Tensor last_ids \ + ); + +__INS__(1) +__INS__(2) +__INS__(3) +__INS__(4) +__INS__(5) +__INS__(8) +__INS__(9) +__INS__(16) +__INS__(17) +__INS__(32) +__INS__(33) +__INS__(64) +__INS__(65) +__INS__(128) +__INS__(129) +__INS__(256) +__INS__(257) +__INS__(512) +__INS__(513) +#undef __INS__ + +} // namespace gsplat diff --git a/gsplat/cuda/csrc/SphericalHarmonics.cpp b/gsplat/cuda/csrc/SphericalHarmonics.cpp index 91e5a8a65..cda5424f0 100644 --- a/gsplat/cuda/csrc/SphericalHarmonics.cpp +++ b/gsplat/cuda/csrc/SphericalHarmonics.cpp @@ -77,4 +77,83 @@ std::tuple spherical_harmonics_bwd( return std::make_tuple(v_coeffs, v_dirs); // [..., K, 3], [..., 3] } +// 3DCS part +at::Tensor spherical_harmonics_fwd_3dcs( + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks // [...] +) { + DEVICE_GUARD(convex_points); + CHECK_INPUT(convex_points); + CHECK_INPUT(dirs); + CHECK_INPUT(coeffs); + if (masks.has_value()) { + CHECK_INPUT(masks.value()); + } + TORCH_CHECK(coeffs.size(-1) == 3, "coeffs must have last dimension 3"); + TORCH_CHECK(dirs.size(-1) == 3, "dirs must have last dimension 3"); + + at::Tensor colors = at::empty_like(dirs); // [..., 3] + + launch_spherical_harmonics_fwd_kernel_3dcs( + degrees_to_use, convex_points, dirs, coeffs, masks, colors + ); + return colors; // [..., 3] +} + +std::tuple spherical_harmonics_bwd_3dcs( + const uint32_t K, + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks, // [...] + const at::Tensor v_colors, // [..., 3] + bool compute_v_convex_points, + bool compute_v_dirs +) { + DEVICE_GUARD(convex_points); + CHECK_INPUT(convex_points); + CHECK_INPUT(dirs); + CHECK_INPUT(coeffs); + CHECK_INPUT(v_colors); + if (masks.has_value()) { + CHECK_INPUT(masks.value()); + } + TORCH_CHECK(v_colors.size(-1) == 3, "v_colors must have last dimension 3"); + TORCH_CHECK(coeffs.size(-1) == 3, "coeffs must have last dimension 3"); + TORCH_CHECK(dirs.size(-1) == 3, "dirs must have last dimension 3"); + const uint32_t N = dirs.numel() / 3; + + at::Tensor v_coeffs = at::zeros_like(coeffs); + at::Tensor v_convex_points; + at::Tensor v_dirs; + if (compute_v_convex_points) + { + v_convex_points = at::zeros_like(convex_points); + } + if (compute_v_dirs) + { + v_dirs = at::zeros_like(dirs); + } + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + uint32_t n_elements = N; + uint32_t shmem_size = 0; + launch_spherical_harmonics_bwd_kernel_3dcs( + degrees_to_use, + convex_points, + dirs, + coeffs, + masks, + v_colors, + v_convex_points, + v_coeffs, + v_dirs.defined() ? at::optional(v_dirs) : c10::nullopt + ); + return std::make_tuple(v_convex_points, v_coeffs, v_dirs); // [..., K, 3], [..., 3] +} + } // namespace gsplat diff --git a/gsplat/cuda/csrc/SphericalHarmonics.h b/gsplat/cuda/csrc/SphericalHarmonics.h index 664a6e31b..b14f10b52 100644 --- a/gsplat/cuda/csrc/SphericalHarmonics.h +++ b/gsplat/cuda/csrc/SphericalHarmonics.h @@ -30,4 +30,30 @@ void launch_spherical_harmonics_bwd_kernel( at::optional v_dirs // [..., 3] ); +// 3DCS +void launch_spherical_harmonics_fwd_kernel_3dcs( + // inputs + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks, // [...] + // outputs + at::Tensor colors // [..., 2] +); + +void launch_spherical_harmonics_bwd_kernel_3dcs( + // inputs + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks, // [...] + const at::Tensor v_colors, // [..., 3] + // outputs + at::Tensor v_convex_points, // [N, 6, 3] + at::Tensor v_coeffs, // [N, K, 3] + at::optional v_dirs // [N, 3] optional +); + } // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/SphericalHarmonicsCUDA.cu b/gsplat/cuda/csrc/SphericalHarmonicsCUDA.cu index 4fd67d69b..27c36b080 100644 --- a/gsplat/cuda/csrc/SphericalHarmonicsCUDA.cu +++ b/gsplat/cuda/csrc/SphericalHarmonicsCUDA.cu @@ -534,4 +534,178 @@ void launch_spherical_harmonics_bwd_kernel( ); } +// 3DCS. Could be moved. +template +__global__ void spherical_harmonics_fwd_kernel_3dcs( + const uint32_t N, + const uint32_t K, + const uint32_t degrees_to_use, + const scalar_t *__restrict__ convex_points, // [N, 6, 3] + const vec3 *__restrict__ dirs, // [N, 3] + const scalar_t *__restrict__ coeffs, // [N, K, 3] + const bool *__restrict__ masks, // [N] + scalar_t *__restrict__ colors // [..., 2] +) { + // parallelize over N * 3 + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= N * 3) { + return; + } + uint32_t elem_id = idx / 3; + uint32_t c = idx % 3; // color channel + if (masks != nullptr && !masks[elem_id]) { + return; + } + sh_coeffs_to_color_fast( + degrees_to_use, + c, + dirs[elem_id], + coeffs + elem_id * K * 3, + colors + elem_id * 3 + ); +} + +void launch_spherical_harmonics_fwd_kernel_3dcs( + // inputs + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks, // [...] + // outputs + at::Tensor colors // [..., 2] +) { + const uint32_t K = coeffs.size(-2); + const uint32_t N = dirs.numel() / 3; + + // parallelize over N * 3 + int64_t n_elements = N * 3; + dim3 threads(256); + dim3 grid((n_elements + threads.x - 1) / threads.x); + int64_t shmem_size = 0; // No shared memory used in this kernel + + if (n_elements == 0) { + // skip the kernel launch if there are no elements + return; + } + + AT_DISPATCH_FLOATING_TYPES( + dirs.scalar_type(), + "spherical_harmonics_fwd_kernel_3dcs", + [&]() { + spherical_harmonics_fwd_kernel_3dcs + <<>>( + N, + K, + degrees_to_use, + convex_points.data_ptr(), + reinterpret_cast(dirs.data_ptr()), + coeffs.data_ptr(), + masks.has_value() ? masks.value().data_ptr() + : nullptr, + colors.data_ptr() + ); + } + ); +} + +template +__global__ void spherical_harmonics_bwd_kernel_3dcs( + const uint32_t N, + const uint32_t K, + const uint32_t degrees_to_use, + const scalar_t *__restrict__ convex_points, // [N, 6, 3] + const vec3 *__restrict__ dirs, // [N, 3] + const scalar_t *__restrict__ coeffs, // [N, K, 3] + const bool *__restrict__ masks, // [N] + const scalar_t *__restrict__ v_colors, // [N, 3] + scalar_t *__restrict__ v_convex_points, // [N, 6, 3] + scalar_t *__restrict__ v_coeffs, // [N, K, 3] + scalar_t *__restrict__ v_dirs // [N, 3] optional +) { + // parallelize over N * 3 + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= N * 3) { + return; + } + uint32_t elem_id = idx / 3; + uint32_t c = idx % 3; // color channel + if (masks != nullptr && !masks[elem_id]) { + return; + } + + vec3 v_dir = {0.f, 0.f, 0.f}; + sh_coeffs_to_color_fast_vjp( + degrees_to_use, + c, + dirs[elem_id], + coeffs + elem_id * K * 3, + v_colors + elem_id * 3, + v_coeffs + elem_id * K * 3, + v_dirs == nullptr ? nullptr : &v_dir + ); + if (v_dirs != nullptr) { + gpuAtomicAdd(v_dirs + elem_id * 3, v_dir.x / 6.0f); + gpuAtomicAdd(v_dirs + elem_id * 3 + 1, v_dir.y / 6.0f); + gpuAtomicAdd(v_dirs + elem_id * 3 + 2, v_dir.z / 6.0f); + } +} + +void launch_spherical_harmonics_bwd_kernel_3dcs( + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks, // [...] + const at::Tensor v_colors, // [..., 3] + // outputs + at::Tensor v_convex_points, // [N, 6, 3] + at::Tensor v_coeffs, // [N, K, 3] + at::optional v_dirs // [N, 3] optional +) { + const uint32_t K = coeffs.size(-2); + const uint32_t N = dirs.numel() / 3; + + // parallelize over N * 3 + int64_t n_elements = N * 3; + dim3 threads(256); + dim3 grid((n_elements + threads.x - 1) / threads.x); + int64_t shmem_size = 0; // No shared memory used in this kernel + + if (n_elements == 0) { + // skip the kernel launch if there are no elements + return; + } + + AT_DISPATCH_FLOATING_TYPES( + dirs.scalar_type(), + "spherical_harmonics_bwd_kernel_3dcs", + [&]() { + spherical_harmonics_bwd_kernel_3dcs + <<>>( + N, + K, + degrees_to_use, + convex_points.data_ptr(), + reinterpret_cast(dirs.data_ptr()), + coeffs.data_ptr(), + masks.has_value() ? masks.value().data_ptr() + : nullptr, + v_colors.data_ptr(), + v_convex_points.data_ptr(), + v_coeffs.data_ptr(), + v_dirs.has_value() ? v_dirs.value().data_ptr() + : nullptr + ); + } + ); +} + + } // namespace gsplat diff --git a/gsplat/cuda/ext.cpp b/gsplat/cuda/ext.cpp index 90b93c2df..1bab51db9 100644 --- a/gsplat/cuda/ext.cpp +++ b/gsplat/cuda/ext.cpp @@ -65,4 +65,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "rasterize_to_pixels_2dgs_bwd", &gsplat::rasterize_to_pixels_2dgs_bwd ); m.def("rasterize_to_indices_2dgs", &gsplat::rasterize_to_indices_2dgs); -} \ No newline at end of file + + // 3DCS + m.def("projection_ewa_3dcs_fused_fwd", &gsplat::projection_ewa_3dcs_fused_fwd); + m.def("projection_ewa_3dcs_fused_bwd", &gsplat::projection_ewa_3dcs_fused_bwd); + + m.def("rasterize_to_pixels_3dcs_fwd", &gsplat::rasterize_to_pixels_3dcs_fwd); + m.def("rasterize_to_pixels_3dcs_bwd", &gsplat::rasterize_to_pixels_3dcs_bwd); + + m.def("spherical_harmonics_fwd_3dcs", &gsplat::spherical_harmonics_fwd_3dcs); + m.def("spherical_harmonics_bwd_3dcs", &gsplat::spherical_harmonics_bwd_3dcs); +} diff --git a/gsplat/cuda/include/Ops.h b/gsplat/cuda/include/Ops.h index a664b996c..e2e7be3e8 100644 --- a/gsplat/cuda/include/Ops.h +++ b/gsplat/cuda/include/Ops.h @@ -147,7 +147,7 @@ projection_ewa_3dgs_packed_bwd( const bool sparse_grad ); -// Sphereical harmonics +// Spherical harmonics at::Tensor spherical_harmonics_fwd( const uint32_t degrees_to_use, const at::Tensor dirs, // [..., 3] @@ -453,4 +453,161 @@ std::tuple rasterize_to_indices_2dgs( const at::Tensor flatten_ids // [n_isects] ); +//====== 3DCS ======// + +// Spherical harmonics +at::Tensor spherical_harmonics_fwd_3dcs( + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks // [...] +); +std::tuple spherical_harmonics_bwd_3dcs( + const uint32_t K, + const uint32_t degrees_to_use, + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor dirs, // [..., 3] + const at::Tensor coeffs, // [..., K, 3] + const at::optional masks, // [...] + const at::Tensor v_colors, // [..., 3] + bool compute_v_convex_points, + bool compute_v_dirs +); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +projection_ewa_3dcs_fused_fwd( + const at::Tensor convex_points, // [N, 6, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor delta, // [N] + const at::Tensor sigma, // [N] + at::Tensor scaling, // [N] + const at::optional opacities, // [N] optional + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t total_nb_points, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const bool calc_compensations, + const CameraModelType camera_model +); + +std::tuple< + at::Tensor, + at::Tensor> +projection_ewa_3dcs_fused_bwd( + // fwd inputs + const at::Tensor convex_points, // [N, 3] + const at::Tensor cumsum_of_points_per_convex, // [N] + const at::Tensor viewmats, // [C, 4, 4] + const at::Tensor Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const at::Tensor radii, // [C, N, 2] + const at::Tensor hull, // [C, 2*total_nb_points] + const at::Tensor num_points_per_convex_view, // [C, N] + //const at::Tensor normals, // [C, total_nb_points, 2] + const at::Tensor offsets, // [C, total_nb_points] + const at::Tensor p_image, // [C, total_nb_points, 2] + const at::Tensor indices, // [C, total_nb_points] + const at::Tensor conics, // [C, N, 3] + const at::optional compensations, // [C, N] optional + // grad inputs + const at::Tensor v_normals, // [C, total_nb_points, 2] + const at::Tensor v_offsets, // [C, total_nb_points] + // grad outputs + const at::Tensor v_means2d, // [C, N, 2] + const at::Tensor v_depths, // [C, N] + const at::Tensor v_conics, // [C, N, 3] + const at::optional v_compensations, // [C, N] optional + const bool viewmats_requires_grad +); + +std::tuple +rasterize_to_pixels_3dcs_fwd( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz, 2] + const at::Tensor offsets, // [C, total_nb_points] or [nnz] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const uint32_t num_points_per_convex, // 6 + const at::Tensor cumsum_of_points_per_convex, // [C] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, channels] or [nnz, channels] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, channels] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids // [n_isects] +); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +rasterize_to_pixels_3dcs_bwd( + // 3D convex parameters + const at::Tensor means2d, // [C, N, 2] or [nnz, 2] + const at::Tensor normals, // [C, total_nb_points, 2] or [nnz, 2] + const at::Tensor offsets, // [C, total_nb_points] or [nnz] + const uint32_t num_points_per_convex, // 6 + const at::Tensor delta, // [C, N] + const at::Tensor sigma, // [C, N] + const at::Tensor num_points_per_convex_view, // [C, N] + const at::Tensor cumsum_of_points_per_convex, // [C, N] + const at::Tensor depths, // [C, N] + const at::Tensor conics, // [C, N, 3] or [nnz, 3] + const at::Tensor colors, // [C, N, 3] or [nnz, 3] + const at::Tensor opacities, // [C, N] or [nnz] + const at::optional backgrounds, // [C, 3] + const at::optional masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const at::Tensor tile_offsets, // [C, tile_height, tile_width] + const at::Tensor flatten_ids, // [n_isects] + // forward outputs + const at::Tensor render_alphas, // [C, image_height, image_width, 1] + const at::Tensor last_ids, // [C, image_height, image_width] + // gradients of outputs + const at::Tensor v_render_colors, // [C, image_height, image_width, 3] + const at::Tensor v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad +); + } // namespace gsplat diff --git a/gsplat/cuda/include/Utils.cuh b/gsplat/cuda/include/Utils.cuh index 128959b57..565da1480 100644 --- a/gsplat/cuda/include/Utils.cuh +++ b/gsplat/cuda/include/Utils.cuh @@ -91,6 +91,16 @@ inline __device__ void warpSum(float *val, WarpT &warp) { } } +template +inline __device__ void warpSum(vec3 *val, WarpT &warp) { +#pragma unroll + for (uint32_t i = 0; i < DIM; i++) { + val[i].x = cg::reduce(warp, val[i].x, cg::plus()); + val[i].y = cg::reduce(warp, val[i].y, cg::plus()); + val[i].z = cg::reduce(warp, val[i].z, cg::plus()); + } +} + template inline __device__ void warpSum(float &val, WarpT &warp) { val = cg::reduce(warp, val, cg::plus()); } diff --git a/gsplat/cuda/include/bindings.h b/gsplat/cuda/include/bindings.h new file mode 100644 index 000000000..f0500dd39 --- /dev/null +++ b/gsplat/cuda/include/bindings.h @@ -0,0 +1,715 @@ +#ifndef GSPLAT_CUDA_BINDINGS_H +#define GSPLAT_CUDA_BINDINGS_H + +#include +#include +#include + +#define GSPLAT_N_THREADS 256 + +#define GSPLAT_CHECK_CUDA(x) \ + TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define GSPLAT_CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define GSPLAT_CHECK_INPUT(x) \ + GSPLAT_CHECK_CUDA(x); \ + GSPLAT_CHECK_CONTIGUOUS(x) +#define GSPLAT_DEVICE_GUARD(_ten) \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); + +#define GSPLAT_PRAGMA_UNROLL _Pragma("unroll") + +// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46 +#define GSPLAT_CUB_WRAPPER(func, ...) \ + do { \ + size_t temp_storage_bytes = 0; \ + func(nullptr, temp_storage_bytes, __VA_ARGS__); \ + auto &caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \ + auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ + func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ + } while (false) + +namespace gsplat { + +enum CameraModelType +{ + PINHOLE = 0, + ORTHO = 1, + FISHEYE = 2, +}; + +std::tuple quat_scale_to_covar_preci_fwd_tensor( + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const bool compute_covar, + const bool compute_preci, + const bool triu +); + +std::tuple quat_scale_to_covar_preci_bwd_tensor( + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const at::optional &v_covars, // [N, 3, 3] + const at::optional &v_precis, // [N, 3, 3] + const bool triu +); + +std::tuple proj_fwd_tensor( + const torch::Tensor &means, // [C, N, 3] + const torch::Tensor &covars, // [C, N, 3, 3] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t width, + const uint32_t height, + const CameraModelType camera_model +); + +std::tuple proj_bwd_tensor( + const torch::Tensor &means, // [C, N, 3] + const torch::Tensor &covars, // [C, N, 3, 3] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t width, + const uint32_t height, + const CameraModelType camera_model, + const torch::Tensor &v_means2d, // [C, N, 2] + const torch::Tensor &v_covars2d // [C, N, 2, 2] +); + +std::tuple world_to_cam_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &covars, // [N, 3, 3] + const torch::Tensor &viewmats // [C, 4, 4] +); + +std::tuple world_to_cam_bwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &covars, // [N, 3, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const at::optional &v_means_c, // [C, N, 3] + const at::optional &v_covars_c, // [C, N, 3, 3] + const bool means_requires_grad, + const bool covars_requires_grad, + const bool viewmats_requires_grad +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const at::optional &covars, // [N, 6] optional + const at::optional &quats, // [N, 4] optional + const at::optional &scales, // [N, 3] optional + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const bool calc_compensations, + const CameraModelType camera_model +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_bwd_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const at::optional &covars, // [N, 6] optional + const at::optional &quats, // [N, 4] optional + const at::optional &scales, // [N, 3] optional + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const torch::Tensor &radii, // [C, N] + const torch::Tensor &conics, // [C, N, 3] + const at::optional &compensations, // [C, N] optional + // grad outputs + const torch::Tensor &v_means2d, // [C, N, 2] + const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_conics, // [C, N, 3] + const at::optional &v_compensations, // [C, N] optional + const bool viewmats_requires_grad +); + +std::tuple isect_tiles_tensor( + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &radii, // [C, N] or [nnz] + const torch::Tensor &depths, // [C, N] or [nnz] + const at::optional &camera_ids, // [nnz] + const at::optional &gaussian_ids, // [nnz] + const uint32_t C, + const uint32_t tile_size, + const uint32_t tile_width, + const uint32_t tile_height, + const bool sort, + const bool double_buffer +); + +torch::Tensor isect_offset_encode_tensor( + const torch::Tensor &isect_ids, // [n_isects] + const uint32_t C, + const uint32_t tile_width, + const uint32_t tile_height +); + +std::tuple +rasterize_to_pixels_fwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &colors, // [C, N, D] + const torch::Tensor &opacities, // [N] + const at::optional &backgrounds, // [C, D] + const at::optional &mask, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &colors, // [C, N, 3] + const torch::Tensor &opacities, // [N] + const at::optional &backgrounds, // [C, 3] + const at::optional &mask, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad +); + +std::tuple rasterize_to_indices_in_range_tensor( + const uint32_t range_start, + const uint32_t range_end, // iteration steps + const torch::Tensor transmittances, // [C, image_height, image_width] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &opacities, // [N] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +torch::Tensor compute_sh_fwd_tensor( + const uint32_t degrees_to_use, + const torch::Tensor &dirs, // [..., 3] + const torch::Tensor &coeffs, // [..., K, 3] + const at::optional masks // [...] +); +std::tuple compute_sh_bwd_tensor( + const uint32_t K, + const uint32_t degrees_to_use, + const torch::Tensor &dirs, // [..., 3] + const torch::Tensor &coeffs, // [..., K, 3] + const at::optional masks, // [...] + const torch::Tensor &v_colors, // [..., 3] + bool compute_v_dirs +); + +/**************************************************************************************** + * Packed Version + ****************************************************************************************/ +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const at::optional &covars, // [N, 6] + const at::optional &quats, // [N, 3] + const at::optional &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const bool calc_compensations, + const CameraModelType camera_model +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_bwd_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const at::optional &covars, // [N, 6] + const at::optional &quats, // [N, 4] + const at::optional &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const torch::Tensor &camera_ids, // [nnz] + const torch::Tensor &gaussian_ids, // [nnz] + const torch::Tensor &conics, // [nnz, 3] + const at::optional &compensations, // [nnz] optional + // grad outputs + const torch::Tensor &v_means2d, // [nnz, 2] + const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_conics, // [nnz, 3] + const at::optional &v_compensations, // [nnz] optional + const bool viewmats_requires_grad, + const bool sparse_grad +); + +std::tuple compute_relocation_tensor( + torch::Tensor &opacities, + torch::Tensor &scales, + torch::Tensor &ratios, + torch::Tensor &binoms, + const int n_max +); + +//====== 2DGS ======// +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_fwd_2dgs_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip +); + +std::tuple +fully_fused_projection_bwd_2dgs_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + // fwd outputs + const torch::Tensor &radii, // [C, N] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] + // grad outputs + const torch::Tensor &v_means2d, // [C, N, 2] + const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] + const torch::Tensor &v_ray_transforms, // [C, N, 3, 3] + const bool viewmats_requires_grad +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_fwd_2dgs_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, K, 3] or [nnz, 3] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_bwd_2dgs_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3], + const torch::Tensor &densify, + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // ray_crossions + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor + &render_colors, // [C, image_height, image_width, COLOR_DIM] + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_normals, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_distort, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median, // [C, image_height, image_width, 1] + // options + bool absgrad +); + +std::tuple +rasterize_to_indices_in_range_2dgs_tensor( + const uint32_t range_start, + const uint32_t range_end, // iteration steps + const torch::Tensor transmittances, // [C, image_height, image_width] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] + const torch::Tensor &opacities, // [C, N] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_fwd_2dgs_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 3] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float near_plane, + const float far_plane, + const float radius_clip +); + +std::tuple +fully_fused_projection_packed_bwd_2dgs_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + // fwd outputs + const torch::Tensor &camera_ids, // [nnz] + const torch::Tensor &gaussian_ids, // [nnz] + const torch::Tensor &ray_transforms, // [nnz, 3, 3] + // grad outputs + const torch::Tensor &v_means2d, // [nnz, 2] + const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] + const torch::Tensor &v_ray_transforms, // [nnz, 3, 3] + const bool viewmats_requires_grad, + const bool sparse_grad +); + +void selective_adam_update( + torch::Tensor ¶m, + torch::Tensor ¶m_grad, + torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, + torch::Tensor &tiles_touched, + const float lr, + const float b1, + const float b2, + const float eps, + const uint32_t N, + const uint32_t M); + +//====== 3DCS ======// +/**************************************************************************************** + * Packed Version + ****************************************************************************************/ +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_fwd_tensor_3dcs( + const torch::Tensor &convex_points, // [N, 6, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const bool calc_compensations, + const CameraModelType camera_model +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_bwd_tensor_3dcs( + // fwd inputs + const torch::Tensor &convex_points, // [N, 6, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const torch::Tensor &camera_ids, // [nnz] + const torch::Tensor &gaussian_ids, // [nnz] + const torch::Tensor &conics, // [nnz, 3] + const at::optional &compensations, // [nnz] optional + // grad outputs + const torch::Tensor &v_means2d, // [nnz, 2] + const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_conics, // [nnz, 3] + const at::optional &v_compensations, // [nnz] optional + const bool viewmats_requires_grad, + const bool sparse_grad +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_fwd_tensor_3dcs( + const torch::Tensor &convex_points, // [N, K, 3] + const torch::Tensor &cumsum_of_points_per_convex, // [N] + const torch::Tensor &delta, // [N] + const torch::Tensor &sigma, // [N] + torch::Tensor &scaling, // [N] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t total_nb_points, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip, + const bool calc_compensations, + const CameraModelType camera_model +); + +std::tuple< + torch::Tensor, + torch::Tensor> +fully_fused_projection_bwd_tensor_3dcs( + // fwd inputs + const torch::Tensor &convex_points, // [N, K, 3] + const torch::Tensor &cumsum_of_points_per_convex, // [N] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const CameraModelType camera_model, + // fwd outputs + const torch::Tensor &radii, // [C, N] + const torch::Tensor &hull, // [C, 2*total_nb_points] + const torch::Tensor &num_points_per_convex_view, // [C, N] + //const torch::Tensor &normals, // [C, total_nb_points, 2] + const torch::Tensor &offsets, // [C, total_nb_points] + const torch::Tensor &p_image, // [C, total_nb_points, 2] + const torch::Tensor &indices, // [C, total_nb_points] + const torch::Tensor &conics, // [C, N, 3] + const at::optional &compensations, // [C, N] optional + // grad inputs + const torch::Tensor &v_normals, // [C, total_nb_points, 2] + const torch::Tensor &v_offsets, // [C, total_nb_points] + // grad outputs + torch::Tensor &v_means2d, // [C, N, 2] + const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_conics, // [C, N, 3] + const at::optional &v_compensations, // [C, N] optional + const bool viewmats_requires_grad +); + +std::tuple +rasterize_to_pixels_fwd_tensor_3dcs( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, total_nb_points, 2] or [nnz, 2] + const torch::Tensor &offsets, // [C, total_nb_points] or [nnz] + const torch::Tensor &num_points_per_convex_view, // [C, N] + const torch::Tensor &delta, // [C, N] + const torch::Tensor &sigma, // [C, N] + const uint32_t num_points_per_convex, // 6 + const torch::Tensor &cumsum_of_points_per_convex, // [C] + const torch::Tensor &depths, // [C, N] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> + rasterize_to_pixels_bwd_tensor_3dcs( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, total_nb_points, 2] or [nnz, 2] + const torch::Tensor &offsets, // [C, total_nb_points] or [nnz] + const uint32_t num_points_per_convex, // 6 + const torch::Tensor &delta, // [C, N] + const torch::Tensor &sigma, // [C, N] + const torch::Tensor &num_points_per_convex_view, // [C, N] + const torch::Tensor &cumsum_of_points_per_convex, // [C, N] + const torch::Tensor &depths, // [C, N] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad +); + +torch::Tensor compute_sh_fwd_tensor_3dcs( + const uint32_t degrees_to_use, + const torch::Tensor &convex_points, // [N, 6, 3] + const torch::Tensor &dirs, // [..., 3] + const torch::Tensor &coeffs, // [..., K, 3] + const at::optional masks // [...] +); +std::tuple compute_sh_bwd_tensor_3dcs( + const uint32_t K, + const uint32_t degrees_to_use, + const torch::Tensor &convex_points, // [N, 6, 3] + const torch::Tensor &dirs, // [..., 3] + const torch::Tensor &coeffs, // [..., K, 3] + const at::optional masks, // [...] + const torch::Tensor &v_colors, // [..., 3] + bool compute_v_dirs, + bool compute_v_convex_points +); + +} // namespace gsplat + +#endif // GSPLAT_CUDA_BINDINGS_H diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 8b8b70a82..2bfa9987d 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -10,11 +10,14 @@ from .cuda._wrapper import ( fully_fused_projection, fully_fused_projection_2dgs, + fully_fused_projection_3dcs, isect_offset_encode, isect_tiles, rasterize_to_pixels, rasterize_to_pixels_2dgs, + rasterize_to_pixels_3dcs, spherical_harmonics, + spherical_harmonics_3dcs, ) from .distributed import ( all_gather_int32, @@ -510,7 +513,6 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso camera_ids=camera_ids, gaussian_ids=gaussian_ids, ) - # print("rank", world_rank, "Before isect_offset_encode") isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) meta.update( @@ -1474,3 +1476,796 @@ def rasterization_2dgs_inria_wrapper( "gaussian_ids": None, } return (render_colors, render_alphas), meta + +### 3DCS ### +def rasterization_3dcs( + convex_points: Tensor, # [N, K, 3] + delta: Tensor, # [N] + sigma: Tensor, # [N] + num_points_per_convex: int, # K, default is 6 + cumsum_of_points_per_convex: Tensor, + opacities: Tensor, # [N] + colors: Tensor, # [(C,) N, D] or [(C,) N, K, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = True, + tile_size: int = 16, + backgrounds: Optional[Tensor] = None, + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + sparse_grad: bool = False, + absgrad: bool = False, + rasterize_mode: Literal["classic", "antialiased"] = "classic", + channel_chunk: int = 32, + distributed: bool = False, + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", +) -> Tuple[Tensor, Tensor, Dict]: + + """Rasterize a set of 3D Convex Points (N) to a batch of image planes (C). + + This function provides a handful features for 3D Convex rasterization, which + we detail in the following notes. A complete profiling of the these features + can be found in the :ref:`profiling` page. + + .. note:: + **Multi-GPU Distributed Rasterization**: This function can be used in a multi-GPU + distributed scenario by setting `distributed` to True. When `distributed` is True, + a subset of total Gaussians could be passed into this function in each rank, and + the function will collaboratively render a set of images using Gaussians from all ranks. Note + to achieve balanced computation, it is recommended (not enforced) to have similar number of + Gaussians in each rank. But we do enforce that the number of cameras to be rendered + in each rank is the same. The function will return the rendered images + corresponds to the input cameras in each rank, and allows for gradients to flow back to the + Gaussians living in other ranks. For the details, please refer to the paper + `On Scaling Up 3D Gaussian Splatting Training `_. + + .. note:: + **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians + to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. + + .. note:: + **Support N-D Features**: If `sh_degree` is None, + the `colors` is expected to be with shape [N, D] or [C, N, D], in which D is the channel of + the features to be rendered. The computation is slow when D > 32 at the moment. + If `sh_degree` is set, the `colors` is expected to be the SH coefficients with + shape [N, K, 3] or [C, N, K, 3], where K is the number of SH bases. In this case, it is expected + that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the + activated bases in the SH coefficients. + + .. note:: + **Depth Rendering**: This function supports colors or/and depths via `render_mode`. + The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the + colored image that respects the `colors` argument. "D" renders the accumulated z-depth + :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth + :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both + the colored image and the depth, in which the depth is the last channel of the output. + + .. note:: + **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between + memory footprint and runtime. If `packed` is True, the intermediate results are + packed into sparse tensors, which is more memory efficient but might be slightly + slower. This is especially helpful when the scene is large and each camera sees only + a small portion of the scene. If `packed` is False, the intermediate results are + with shape [C, N, ...], which is faster but might consume more memory. + + .. note:: + **Sparse Gradients**: If `sparse_grad` is True, the gradients for {convex_points, quats, scales} + will be stored in a `COO sparse layout `_. + This can be helpful for saving memory + for training when the scene is large and each iteration only activates a small portion + of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, + such as `torch.optim.SparseAdam `_. + This argument is only effective when `packed` is True. + + .. note:: + **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for + speeding up large scale scenes or scenes with large depth of fields. Gaussians with + 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. + This will skip all the far-away Gaussians that are too small to be seen in the image. + But be warned that if there are close-up Gaussians that are also below this threshold, they will + also get skipped (which is rarely happened in practice). This is by default disabled by setting + `radius_clip` to 0.0. + + .. note:: + **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will + apply a view-dependent compensation factor + :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian + opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` + is the `eps2d`. This will make the rendered image more antialiased, as proposed in + the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting `_. + + .. note:: + **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected + 2D means will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. This is an implementation of the paper + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, + which is shown to be more effective for splitting Gaussians during training. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + convex_points: The 3D Points representing the 3D convex structure. [N, K, 3] + delta: The delta values of the 3D convex. [N] + sigma: The sigma values of the 3D convex. [N] + num_points_per_convex: The number of points per convex. Default is 6. + cumsum_of_points_per_convex: The cumsum of points per convex. [N] + opacities: The opacities of the Gaussians. [N] + colors: The colors of the Gaussians. [(C,) N, D] or [(C,) N, K, 3] for SH coefficients. + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad: If true, the gradients for {convex_points, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + rasterize_mode: The rasterization mode. Supported modes are "classic" and + "antialiased". Default is "classic". + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + distributed: Whether to use distributed rendering. Default is False. If True, + The input Gaussians are expected to be a subset of scene in each rank, and + the function will collaboratively render the images for all ranks. + camera_model: The camera model to use. Supported models are "pinhole", "ortho", + and "fisheye". Default is "pinhole". + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + Examples: + + .. code-block:: python + + >>> # define Gaussians + >>> convex_points = torch.randn((100, 3), device=device) + >>> colors = torch.rand((100, 3), device=device) + >>> opacities = torch.rand((100,), device=device) + >>> # define cameras + >>> viewmats = torch.eye(4, device=device)[None, :, :] + >>> Ks = torch.tensor([ + >>> [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :] + >>> width, height = 300, 200 + >>> # render + >>> colors, alphas, meta = rasterization( + >>> convex_points, quats, scales, opacities, colors, viewmats, Ks, width, height + >>> ) + >>> print (colors.shape, alphas.shape) + torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 1]) + >>> print (meta.keys()) + dict_keys(['camera_ids', 'gaussian_ids', 'radii', 'means2d', 'depths', 'conics', + 'opacities', 'tile_width', 'tile_height', 'tiles_per_gauss', 'isect_ids', + 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size']) + + """ + meta = {} + + N = convex_points.shape[0] + C = viewmats.shape[0] + device = convex_points.device + assert convex_points.shape == (N, num_points_per_convex, 3), convex_points.shape + assert opacities.shape == (N,), opacities.shape + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + + def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor: + view_list = list( + map( + lambda x: x.split(int(x.shape[0] / C), dim=0), + world_view.split([C * N_i for N_i in N_world], dim=0), + ) + ) + return torch.stack([torch.cat(l, dim=0) for l in zip(*view_list)], dim=0) + + if sh_degree is None: + # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] + assert (colors.dim() == 2 and colors.shape[0] == N) or ( + colors.dim() == 3 and colors.shape[:2] == (C, N) + ), colors.shape + if distributed: + assert ( + colors.dim() == 2 + ), "Distributed mode only supports per-Gaussian colors." + else: + # treat colors as SH coefficients, should be in shape [N, K, 3] or [C, N, K, 3] + # Allowing for activating partial SH bands + assert ( + colors.dim() == 3 and colors.shape[0] == N and colors.shape[2] == 3 + ) or ( + colors.dim() == 4 and colors.shape[:2] == (C, N) and colors.shape[3] == 3 + ), colors.shape + assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape + if distributed: + assert ( + colors.dim() == 3 + ), "Distributed mode only supports per-Gaussian colors." + + if absgrad: + assert not distributed, "AbsGrad is not supported in distributed mode." + + # Special for 3DCS + scaling = torch.zeros_like(convex_points[:,0,0].squeeze(), dtype=convex_points.dtype, requires_grad=True, device="cuda").detach() + + # If in distributed mode, we distribute the projection computation over Gaussians + # and the rasterize computation over cameras. So first we gather the cameras + # from all ranks for projection. + # FIXME: Probably rework this part. + if distributed: + world_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Gather the number of Gaussians in each rank. + N_world = all_gather_int32(world_size, N, device=device) + + # Enforce that the number of cameras is the same across all ranks. + C_world = [C] * world_size + viewmats, Ks = all_gather_tensor_list(world_size, [viewmats, Ks]) + + # Silently change C from local #Cameras to global #Cameras. + C = len(viewmats) + + # Project 3D convex points to 2D convex hull. + proj_results = fully_fused_projection_3dcs( + convex_points, + cumsum_of_points_per_convex, + delta, + sigma, + scaling, + viewmats, + Ks, + width, + height, + eps2d=eps2d, + packed=packed, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=radius_clip, + sparse_grad=sparse_grad, + calc_compensations=(rasterize_mode == "antialiased"), + camera_model=camera_model, + opacities=opacities, # use opacities to compute a tigher bound for radii. + ) + if packed: + # The results are packed into shape [nnz, ...]. All elements are valid. + ( + normals, + offsets, + _, + _, + num_points_per_convex_view, + _, + radii, + means2d, + depths, + conics, + compensations + ) = proj_results + opacities = opacities[gaussian_ids] # [nnz] + else: + # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. + normals, offsets, _, _, num_points_per_convex_view, _, radii, means2d, depths, conics, compensations = proj_results + opacities = opacities.repeat(C, 1) # [C, N] + camera_ids, gaussian_ids = None, None + + if compensations is not None: + opacities = opacities * compensations + + meta.update( + { + # global camera_ids + "camera_ids": camera_ids, + # local gaussian_ids + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "opacities": opacities, + "sigma": sigma, + "delta": delta, + "scaling": scaling, + } + ) + + # Turn colors into [C, N, D] or [nnz, D] to pass into rasterize_to_pixels() + if sh_degree is None: + # Colors are post-activation values, with shape [N, D] or [C, N, D] + if packed: + if colors.dim() == 2: + # Turn [N, D] into [nnz, D] + colors = colors[gaussian_ids] + else: + # Turn [C, N, D] into [nnz, D] + colors = colors[camera_ids, gaussian_ids] + else: + if colors.dim() == 2: + # Turn [N, D] into [C, N, D] + colors = colors.expand(C, -1, -1) + else: + # colors is already [C, N, D] + pass + else: + # Colors are SH coefficients, with shape [N, K, 3] or [C, N, K, 3] + camtoworlds = torch.inverse(viewmats) # [C, 4, 4] + if packed: + dirs = convex_points[None, gaussian_ids, :] - camtoworlds[camera_ids, None, :3, 3] # [nnz, 3] + masks = (radii > 0).all(dim=-1) # [nnz] + if colors.dim() == 3: + # Turn [N, K, 3] into [nnz, 3] + shs = colors[gaussian_ids, :, :] # [nnz, K, 3] + else: + # Turn [C, N, K, 3] into [nnz, 3] + shs = colors[camera_ids, gaussian_ids, :, :] # [nnz, K, 3] + colors = spherical_harmonics_3dcs(sh_degree, convex_points, dirs, shs, masks=masks) # [nnz, 3] + else: + # Compute direction from the center of the 3D convex i.e means3d + # FIXME: Not the most efficient? Should we do that in the cuda call? + means3d = convex_points.mean(dim=1) + dirs = means3d[None, :, :] - camtoworlds[:, None, :3, 3] # [C, N, 3] + masks = (radii > 0).all(dim=-1) # [C, N] + if colors.dim() == 3: + # Turn [N, K, 3] into [C, N, K, 3] + shs = colors.expand(C, -1, -1, -1) # [C, N, K, 3] + else: + # colors is already [C, N, K, 3] + shs = colors + colors = spherical_harmonics_3dcs(sh_degree, convex_points, dirs, shs, masks=masks) # [C, N, 3] + + # make it apple-to-apple with Inria's CUDA Backend. + colors = torch.clamp_min(colors + 0.5, 0.0) + + # If in distributed mode, we need to scatter the GSs to the destination ranks, based + # on which cameras they are visible to, which we already figured out in the projection + # stage. + # if distributed: + # if packed: + # # count how many elements need to be sent to each rank + # cnts = torch.bincount(camera_ids, minlength=C) # all cameras + # cnts = cnts.split(C_world, dim=0) + # cnts = [cuts.sum() for cuts in cnts] + + # # all to all communication across all ranks. After this step, each rank + # # would have all the necessary GSs to render its own images. + # collected_splits = all_to_all_int32(world_size, cnts, device=device) + # (radii,) = all_to_all_tensor_list( + # world_size, [radii], cnts, output_splits=collected_splits + # ) + # (means2d, depths, conics, opacities, colors) = all_to_all_tensor_list( + # world_size, + # [means2d, depths, conics, opacities, colors], + # cnts, + # output_splits=collected_splits, + # ) + + # # before sending the data, we should turn the camera_ids from global to local. + # # i.e. the camera_ids produced by the projection stage are over all cameras world-wide, + # # so we need to turn them into camera_ids that are local to each rank. + # offsets = torch.tensor( + # [0] + C_world[:-1], device=camera_ids.device, dtype=camera_ids.dtype + # ) + # offsets = torch.cumsum(offsets, dim=0) + # offsets = offsets.repeat_interleave(torch.stack(cnts)) + # camera_ids = camera_ids - offsets + + # # and turn gaussian ids from local to global. + # offsets = torch.tensor( + # [0] + N_world[:-1], + # device=gaussian_ids.device, + # dtype=gaussian_ids.dtype, + # ) + # offsets = torch.cumsum(offsets, dim=0) + # offsets = offsets.repeat_interleave(torch.stack(cnts)) + # gaussian_ids = gaussian_ids + offsets + + # # all to all communication across all ranks. + # (camera_ids, gaussian_ids) = all_to_all_tensor_list( + # world_size, + # [camera_ids, gaussian_ids], + # cnts, + # output_splits=collected_splits, + # ) + + # # Silently change C from global #Cameras to local #Cameras. + # C = C_world[world_rank] + + # else: + # # Silently change C from global #Cameras to local #Cameras. + # C = C_world[world_rank] + + # # all to all communication across all ranks. After this step, each rank + # # would have all the necessary GSs to render its own images. + # (radii,) = all_to_all_tensor_list( + # world_size, + # [radii.flatten(0, 1)], + # splits=[C_i * N for C_i in C_world], + # output_splits=[C * N_i for N_i in N_world], + # ) + # radii = reshape_view(C, radii, N_world) + + # (means2d, depths, conics, opacities, colors) = all_to_all_tensor_list( + # world_size, + # [ + # means2d.flatten(0, 1), + # depths.flatten(0, 1), + # conics.flatten(0, 1), + # opacities.flatten(0, 1), + # colors.flatten(0, 1), + # ], + # splits=[C_i * N for C_i in C_world], + # output_splits=[C * N_i for N_i in N_world], + # ) + # means2d = reshape_view(C, means2d, N_world) + # depths = reshape_view(C, depths, N_world) + # conics = reshape_view(C, conics, N_world) + # opacities = reshape_view(C, opacities, N_world) + # colors = reshape_view(C, colors, N_world) + + # Rasterize to pixels + if render_mode in ["RGB+D", "RGB+ED"]: + colors = torch.cat((colors, depths[..., None]), dim=-1) + if backgrounds is not None: + backgrounds = torch.cat( + [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 + ) + elif render_mode in ["D", "ED"]: + colors = depths[..., None] + if backgrounds is not None: + backgrounds = torch.zeros(C, 1, device=backgrounds.device) + else: # RGB + pass + + # Identify intersecting tiles + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, + radii, + depths, + tile_size, + tile_width, + tile_height, + packed=packed, + n_cameras=C, + camera_ids=camera_ids, + gaussian_ids=gaussian_ids, + ) + isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + + meta.update( + { + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + "n_cameras": C, + } + ) + + if colors.shape[-1] > channel_chunk: + # slice into chunks + n_chunks = (colors.shape[-1] + channel_chunk - 1) // channel_chunk + render_colors, render_alphas = [], [] + for i in range(n_chunks): + colors_chunk = colors[..., i * channel_chunk : (i + 1) * channel_chunk] + backgrounds_chunk = ( + backgrounds[..., i * channel_chunk : (i + 1) * channel_chunk] + if backgrounds is not None + else None + ) + render_colors_, render_alphas_ = rasterize_to_pixels_3dcs( + means2d, + normals, + offsets, + num_points_per_convex_view, + delta, + sigma, + num_points_per_convex, + cumsum_of_points_per_convex, + depths, + conics, + colors_chunk, + opacities, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds_chunk, + packed=packed, + absgrad=absgrad, + ) + render_colors.append(render_colors_) + render_alphas.append(render_alphas_) + render_colors = torch.cat(render_colors, dim=-1) + render_alphas = render_alphas[0] # discard the rest + else: + render_colors, render_alphas = rasterize_to_pixels_3dcs( + means2d, + normals, + offsets, + num_points_per_convex_view, + delta, + sigma, + num_points_per_convex, + cumsum_of_points_per_convex, + depths, + conics, + colors, + opacities, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds, + packed=packed, + absgrad=absgrad, + ) + if render_mode in ["ED", "RGB+ED"]: + # normalize the accumulated depth to get the expected depth + render_colors = torch.cat( + [ + render_colors[..., :-1], + render_colors[..., -1:] / render_alphas.clamp(min=1e-10), + ], + dim=-1, + ) + + return render_colors, render_alphas, meta + + +#def _rasterization_3dcs( +# convex_points: Tensor, # [N, 3] +# opacities: Tensor, # [N] +# colors: Tensor, # [(C,) N, D] or [(C,) N, K, 3] +# viewmats: Tensor, # [C, 4, 4] +# Ks: Tensor, # [C, 3, 3] +# width: int, +# height: int, +# near_plane: float = 0.01, +# far_plane: float = 1e10, +# eps2d: float = 0.3, +# sh_degree: Optional[int] = None, +# tile_size: int = 16, +# backgrounds: Optional[Tensor] = None, +# render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", +# rasterize_mode: Literal["classic", "antialiased"] = "classic", +# channel_chunk: int = 32, +# batch_per_iter: int = 100, +#) -> Tuple[Tensor, Tensor, Dict]: +# """A version of rasterization() that utilies on PyTorch's autograd. +# +# .. note:: +# This function still relies on gsplat's CUDA backend for some computation, but the +# entire differentiable graph is on of PyTorch (and nerfacc) so could use Pytorch's +# autograd for backpropagation. +# +# .. note:: +# This function relies on installing latest nerfacc, via: +# pip install git+https://github.com/nerfstudio-project/nerfacc +# +# .. note:: +# Compared to rasterization(), this function does not support some arguments such as +# `packed`, `sparse_grad` and `absgrad`. +# """ +# from gsplat.cuda._torch_impl import ( +# _fully_fused_projection, +# _quat_scale_to_covar_preci, +# _rasterize_to_pixels, +# ) +# +# N = convex_points.shape[0] +# C = viewmats.shape[0] +# # FIXME: Replace 6 with the num_points_per_convex +# assert convex_points.shape == (N, 6, 3), convex_points.shape +# assert opacities.shape == (N,), opacities.shape +# assert viewmats.shape == (C, 4, 4), viewmats.shape +# assert Ks.shape == (C, 3, 3), Ks.shape +# assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode +# +# if sh_degree is None: +# # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] +# assert (colors.dim() == 2 and colors.shape[0] == N) or ( +# colors.dim() == 3 and colors.shape[:2] == (C, N) +# ), colors.shape +# else: +# # treat colors as SH coefficients, should be in shape [N, K, 3] or [C, N, K, 3] +# # Allowing for activating partial SH bands +# assert ( +# colors.dim() == 3 and colors.shape[0] == N and colors.shape[2] == 3 +# ) or ( +# colors.dim() == 4 and colors.shape[:2] == (C, N) and colors.shape[3] == 3 +# ), colors.shape +# assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape +# +# # Project Gaussians to 2D. +# # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. +# covars, _ = _quat_scale_to_covar_preci(quats, scales, True, False, triu=False) +# radii, means2d, depths, conics, compensations = _fully_fused_projection_3dcs( +# convex_points, +# covars, +# viewmats, +# Ks, +# width, +# height, +# eps2d=eps2d, +# near_plane=near_plane, +# far_plane=far_plane, +# calc_compensations=(rasterize_mode == "antialiased"), +# ) +# opacities = opacities.repeat(C, 1) # [C, N] +# camera_ids, gaussian_ids = None, None +# +# if compensations is not None: +# opacities = opacities * compensations +# +# # Identify intersecting tiles +# tile_width = math.ceil(width / float(tile_size)) +# tile_height = math.ceil(height / float(tile_size)) +# tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( +# means2d, +# radii, +# depths, +# tile_size, +# tile_width, +# tile_height, +# packed=False, +# n_cameras=C, +# camera_ids=camera_ids, +# gaussian_ids=gaussian_ids, +# ) +# isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) +# +# # Turn colors into [C, N, D] or [nnz, D] to pass into rasterize_to_pixels() +# if sh_degree is None: +# # Colors are post-activation values, with shape [N, D] or [C, N, D] +# if colors.dim() == 2: +# # Turn [N, D] into [C, N, D] +# colors = colors.expand(C, -1, -1) +# else: +# # colors is already [C, N, D] +# pass +# else: +# # Colors are SH coefficients, with shape [N, K, 3] or [C, N, K, 3] +# camtoworlds = torch.inverse(viewmats) # [C, 4, 4] +# dirs = convex_points[None, :, :] - camtoworlds[:, None, :3, 3] # [C, N, 3] +# masks = radii > 0 # [C, N] +# if colors.dim() == 3: +# # Turn [N, K, 3] into [C, N, 3] +# shs = colors.expand(C, -1, -1, -1) # [C, N, K, 3] +# else: +# # colors is already [C, N, K, 3] +# shs = colors +# colors = spherical_harmonics(sh_degree, dirs, shs, masks=masks) # [C, N, 3] +# # make it apple-to-apple with Inria's CUDA Backend. +# colors = torch.clamp_min(colors + 0.5, 0.0) +# +# # Rasterize to pixels +# if render_mode in ["RGB+D", "RGB+ED"]: +# colors = torch.cat((colors, depths[..., None]), dim=-1) +# if backgrounds is not None: +# backgrounds = torch.cat( +# [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 +# ) +# elif render_mode in ["D", "ED"]: +# colors = depths[..., None] +# if backgrounds is not None: +# backgrounds = torch.zeros(C, 1, device=backgrounds.device) +# else: # RGB +# pass +# if colors.shape[-1] > channel_chunk: +# # slice into chunks +# n_chunks = (colors.shape[-1] + channel_chunk - 1) // channel_chunk +# render_colors, render_alphas = [], [] +# for i in range(n_chunks): +# colors_chunk = colors[..., i * channel_chunk : (i + 1) * channel_chunk] +# backgrounds_chunk = ( +# backgrounds[..., i * channel_chunk : (i + 1) * channel_chunk] +# if backgrounds is not None +# else None +# ) +# render_colors_, render_alphas_ = _rasterize_to_pixels( +# means2d, +# conics, +# colors_chunk, +# opacities, +# width, +# height, +# tile_size, +# isect_offsets, +# flatten_ids, +# backgrounds=backgrounds_chunk, +# batch_per_iter=batch_per_iter, +# ) +# render_colors.append(render_colors_) +# render_alphas.append(render_alphas_) +# render_colors = torch.cat(render_colors, dim=-1) +# render_alphas = render_alphas[0] # discard the rest +# else: +# render_colors, render_alphas = _rasterize_to_pixels( +# means2d, +# conics, +# colors, +# opacities, +# width, +# height, +# tile_size, +# isect_offsets, +# flatten_ids, +# backgrounds=backgrounds, +# batch_per_iter=batch_per_iter, +# ) +# if render_mode in ["ED", "RGB+ED"]: +# # normalize the accumulated depth to get the expected depth +# render_colors = torch.cat( +# [ +# render_colors[..., :-1], +# render_colors[..., -1:] / render_alphas.clamp(min=1e-10), +# ], +# dim=-1, +# ) +# +# meta = { +# "camera_ids": camera_ids, +# "gaussian_ids": gaussian_ids, +# "radii": radii, +# "means2d": means2d, +# "depths": depths, +# "conics": conics, +# "opacities": opacities, +# "tile_width": tile_width, +# "tile_height": tile_height, +# "tiles_per_gauss": tiles_per_gauss, +# "isect_ids": isect_ids, +# "flatten_ids": flatten_ids, +# "isect_offsets": isect_offsets, +# "width": width, +# "height": height, +# "tile_size": tile_size, +# "n_cameras": C, +# } +# return render_colors, render_alphas, meta + + diff --git a/gsplat/strategy/__init__.py b/gsplat/strategy/__init__.py index 305dc8129..82a71b880 100644 --- a/gsplat/strategy/__init__.py +++ b/gsplat/strategy/__init__.py @@ -1,3 +1,4 @@ from .base import Strategy from .default import DefaultStrategy +from .convex_splatting import ConvexSplattingStrategy from .mcmc import MCMCStrategy diff --git a/gsplat/strategy/convex_splatting.py b/gsplat/strategy/convex_splatting.py new file mode 100644 index 000000000..c9bf776bc --- /dev/null +++ b/gsplat/strategy/convex_splatting.py @@ -0,0 +1,329 @@ +from dataclasses import dataclass +from typing import Any, Dict, Tuple, Union + +import torch + +from .base import Strategy +from .ops_convex import remove, reset_opa, split +from typing_extensions import Literal + + +@dataclass +class ConvexSplattingStrategy(Strategy): + """A strategy that follows the convex splatting strategy of the paper: + + `3D Convex Splatting: Radiance Field Rendering with 3D Smooth Convexes `_ + + The strategy will: + + - Periodically duplicate GSs with high image plane gradients and small scales. + - Periodically split 3D convexes with high sigma gradients and large scales. + - Periodically prune GSs with low opacity. + - Periodically reset GSs to a lower opacity. + + If `absgrad=True`, it will use the absolute gradients instead of average gradients + for GS duplicating & splitting, following the AbsGS paper: + + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_ + + Which typically leads to better results but requires to set the `grow_grad2d` to a + higher value, e.g., 0.0008. Also, the :func:`rasterization` function should be called + with `absgrad=True` as well so that the absolute gradients are computed. + + Args: + prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. + grow_grad2d (float): GSs with image plane gradient above this value will be + split/duplicated. Default is 0.0002. + grow_scale3d (float): GSs with 3d scale (normalized by scene_scale) below this + value will be duplicated. Above will be split. Default is 0.01. + grow_scale2d (float): GSs with 2d scale (normalized by image resolution) above + this value will be split. Default is 0.05. + prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this + value will be pruned. Default is 0.1. + prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above + this value will be pruned. Default is 0.15. + refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this + iteration. Default is 0. Set to a positive value to enable this feature. + refine_start_iter (int): Start refining GSs after this iteration. Default is 500. + refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. + reset_every (int): Reset opacities every this steps. Default is 3000. + refine_every (int): Refine GSs every this steps. Default is 100. + pause_refine_after_reset (int): Pause refining GSs until this number of steps after + reset, Default is 0 (no pause at all) and one might want to set this number to the + number of images in training set. + absgrad (bool): Use absolute gradients for GS splitting. Default is False. + revised_opacity (bool): Whether to use revised opacity heuristic from + arXiv:2404.06109 (experimental). Default is False. + verbose (bool): Whether to print verbose information. Default is False. + key_for_gradient (str): Which variable uses for densification strategy. + 3DGS uses "means2d" gradient and 2DGS uses a similar gradient which stores + in variable "gradient_2dgs". + + Examples: + + >>> from gsplat import ConvexSplattingStrategy, rasterization + >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... + >>> optimizers: Dict[str, torch.optim.Optimizer] = ... + >>> strategy = ConvexSplattingStrategy() + >>> strategy.check_sanity(params, optimizers) + >>> strategy_state = strategy.initialize_state() + >>> for step in range(1000): + ... render_image, render_alpha, info = rasterization(...) + ... strategy.step_pre_backward(params, optimizers, strategy_state, step, info) + ... loss = ... + ... loss.backward() + ... strategy.step_post_backward(params, optimizers, strategy_state, step, info) + + """ + + prune_opa: float = 0.02 + grow_grad_sigma: float = 0.000025 + mask_threshold: float = 0.01 + refine_scale2d_stop_iter: int = 0 + scaling_clone = 0.5 + reset_opacity_until = 9000 + refine_start_iter: int = 500 + refine_stop_iter: int = 9000 + reset_every: int = 3000 + refine_every: int = 200 + pause_refine_after_reset: int = 0 + absgrad: bool = False + revised_opacity: bool = False + verbose: bool = False + key_for_gradient: str = "sigma" + sigma_scaling_cloning: float = 0.88 + scaling_cloning: float = 0.63 + + def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: + """Initialize and return the running state for this strategy. + + The returned state should be passed to the `step_pre_backward()` and + `step_post_backward()` functions. + """ + # Postpone the initialization of the state to the first step so that we can + # put them on the correct device. + # - grad2d: running accum of the norm of the image plane gradients for each GS. + # - grad_sigma: running accum of the norm of the 3D scale gradients for each 3D convexes. + # - count: running accum of how many time each GS is visible. + # - radii: the radii of the GSs (normalized by the image resolution). + state = {"count": None, "scene_scale": scene_scale, "sigma": None} + #if self.refine_scale2d_stop_iter > 0: + state["radii"] = None + return state + + def check_sanity( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + ): + """Sanity check for the parameters and optimizers. + + Check if: + * `params` and `optimizers` have the same keys. + * Each optimizer has exactly one param_group, corresponding to each parameter. + * The following keys are present: {"convex_points", "opacities", "delta", "sigma", "mask"}. + + Raises: + AssertionError: If any of the above conditions is not met. + + .. note:: + It is not required but highly recommended for the user to call this function + after initializing the strategy to ensure the convention of the parameters + and optimizers is as expected. + """ + + super().check_sanity(params, optimizers) + # The following keys are required for this strategy. + for key in ["convex_points", "opacities", "delta", "sigma", "mask"]: + assert key in params, f"{key} is required in params but missing." + + def step_pre_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + ): + """Callback function to be executed before the `loss.backward()` call.""" + assert ( + self.key_for_gradient in info + ), "The Sigma of the Convexes is required but missing." + info[self.key_for_gradient].retain_grad() + + def step_post_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + packed: bool = False + ): + """Callback function to be executed after the `loss.backward()` call.""" + if step >= self.refine_stop_iter and step % 1000 != 0: + return + self._update_state(params, state, info, packed=packed) + + if ( + step > self.refine_start_iter + and step % self.refine_every == 0 + and step % self.reset_every >= self.pause_refine_after_reset + ): + if step < self.refine_stop_iter: + # grow CSs + n_split = self._grow_sigma_big(params, optimizers, state, step) + if self.verbose: + print( + f"Step {step}: {n_split} Convexes split. " + f"Now having {len(params['convex_points'])} Convexes." + ) + + # prune CSs + n_prune = self._prune_convexes(params, optimizers, state, step) + if self.verbose: + print( + f"Step {step}: {n_prune} Convexes pruned. " + f"Now having {len(params['convex_points'])} Convexes." + ) + + # reset running stats + state["count"].zero_() + state["sigma"].zero_() + #if self.refine_scale2d_stop_iter > 0: + state["radii"].zero_() + torch.cuda.empty_cache() + + if step % self.reset_every == 0 and step <= self.reset_opacity_until: + reset_opa( + params=params, + optimizers=optimizers, + state=state, + value=0.2, + ) + + def _update_state( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, + ): + for key in [ + "width", + "height", + "n_cameras", + "radii", + "gaussian_ids", + self.key_for_gradient, + ]: + assert key in info, f"{key} is required but missing." + + # normalize grads to [-1, 1] screen space + if self.absgrad: + grads = info[self.key_for_gradient].absgrad.clone() + else: + grads = info[self.key_for_gradient].grad.clone() + + # initialize state on the first run + n_gaussian = len(list(params.values())[0]) + + if state["sigma"] is None: + state["sigma"] = torch.zeros(n_gaussian, device=grads.device) + if state["count"] is None: + state["count"] = torch.zeros(n_gaussian, device=grads.device) + #if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: + if state["radii"] is None: + assert "radii" in info, "radii is required but missing." + state["radii"] = torch.zeros(n_gaussian, device=grads.device) + + # update the running state + if packed: + # grads is [nnz, 2] + gs_ids = info["gaussian_ids"] # [nnz] + radii = info["radii"] # [nnz] + else: + # grads is [C, N, 2] + sel = (info["radii"] > 0.0).all(dim=-1) # [C, N] + gs_ids = torch.where(sel)[1] # [nnz] + grads = grads[sel.permute(1, 0)] # [nnz] + radii = info["radii"][sel].max(dim=-1).values # [nnz] + + state["sigma"].index_add_(0, gs_ids, grads) + state["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32) + ) + #if self.refine_scale2d_stop_iter > 0: + # Should be ideally using scatter max + + state["radii"][gs_ids] = state["radii"][gs_ids] + # state["radii"][gs_ids] = torch.maximum( + # state["radii"][gs_ids], + # # normalize radii to [0, 1] screen space + # #radii / float(max(info["width"], info["height"])), + # radii, + #) + + @torch.no_grad() + def _grow_sigma_big( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + ) -> Tuple[int, int]: + count = state["count"] + # 3DCS + grads_sigma = state["sigma"] / count.clamp_min(1) + is_grad_high = grads_sigma >= self.grow_grad_sigma + + n_split = is_grad_high.sum().item() + + # is_small = ( + # state["radii"].max(dim=-1).values + # <= 0.3 * state["scene_scale"] + # ) + # print("before", is_grad_high.sum().item() ) + # is_dupli = is_grad_high & is_small + # n_dupli = is_dupli.sum().item() + # print("after", is_dupli.sum().item() ) + + #is_large = ~is_small + #is_split = is_grad_high & is_large + # if step < self.refine_scale2d_stop_iter: + # is_split |= state["sigma"] > self.grow_scale2d + # n_split = is_split.sum().item() + + if n_split > 0: + split(params=params, optimizers=optimizers, state=state, mask=is_grad_high, sigma_scaling_cloning=self.sigma_scaling_cloning, scaling_cloning=self.scaling_cloning) + + return n_split + + @torch.no_grad() + def _prune_convexes( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + ) -> int: + #is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa + is_prune = torch.logical_or((torch.sigmoid(params["mask"]) <= self.mask_threshold).squeeze(), (torch.sigmoid(params["opacities"]) < 0.03).squeeze()) + + #if step > self.reset_every: + is_too_big = state['radii'] >= 0.03 + # # The official code also implements sreen-size pruning but + # # it's actually not being used due to a bug: + # # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 + # # We implement it here for completeness but set `refine_scale2d_stop_iter` + # # to 0 by default to disable it. + # if step < self.refine_scale2d_stop_iter: + # is_too_big |= state["radii"] > self.prune_scale2d + + is_prune = is_prune | is_too_big + + n_prune = is_prune.sum().item() + if n_prune > 0: + remove(params=params, optimizers=optimizers, state=state, mask=is_prune) + + return n_prune diff --git a/gsplat/strategy/ops_convex.py b/gsplat/strategy/ops_convex.py new file mode 100644 index 000000000..e5c701814 --- /dev/null +++ b/gsplat/strategy/ops_convex.py @@ -0,0 +1,174 @@ +import numpy as np +from typing import Callable, Dict, List, Union + +import torch +from torch import Tensor + +@torch.no_grad() +def _update_param_with_optimizer( + param_fn: Callable[[str, Tensor], Tensor], + optimizer_fn: Callable[[str, Tensor], Tensor], + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + names: Union[List[str], None] = None, +): + """Update the parameters and the state in the optimizers with defined functions. + + Args: + param_fn: A function that takes the name of the parameter and the parameter itself, + and returns the new parameter. + optimizer_fn: A function that takes the key of the optimizer state and the state value, + and returns the new state value. + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + names: A list of key names to update. If None, update all. Default: None. + """ + if names is None: + # If names is not provided, update all parameters + names = list(params.keys()) + + for name in names: + param = params[name] + new_param = param_fn(name, param) + params[name] = new_param + if name not in optimizers: + assert not param.requires_grad, ( + f"Optimizer for {name} is not found, but the parameter is trainable." + f"Got requires_grad={param.requires_grad}" + ) + continue + optimizer = optimizers[name] + for i in range(len(optimizer.param_groups)): + param_state = optimizer.state[param] + del optimizer.state[param] + for key in param_state.keys(): + if key != "step": + v = param_state[key] + param_state[key] = optimizer_fn(key, v) + optimizer.param_groups[i]["params"] = [new_param] + optimizer.state[new_param] = param_state + +@torch.no_grad() +def split( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Tensor], + mask: Tensor, + sigma_scaling_cloning: float, + scaling_cloning: float +): + """Inplace split the Gaussian with the given mask. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + mask: A boolean mask to split the Gaussians. + """ + device = mask.device + sel = torch.where(mask)[0] + rest = torch.where(~mask)[0] + nb_dupli = 6 + + def param_fn(name: str, p: Tensor) -> Tensor: + repeats = [nb_dupli] + [1] * (p.dim() - 1) + if name == "convex_points": + new_convex_points_list = torch.empty(0, device=device) + selected_convex_points = p[sel, :, :] + centroids = selected_convex_points.mean(dim=1, keepdim=True) + for i in range(nb_dupli): + shift_point = selected_convex_points[:, i % p.shape[1], :] + shift_vector = (shift_point - centroids.squeeze(1)) * 1 + new_centroid = centroids.squeeze(1) + shift_vector + relative_positions = selected_convex_points - centroids + scaled_relative_positions = relative_positions * scaling_cloning + new_convex = new_centroid.unsqueeze(1) + scaled_relative_positions + new_convex_points_list = torch.cat([new_convex_points_list, new_convex], dim=0) + p_split = new_convex_points_list + elif name == "opacities": + new_opacities = torch.sigmoid(p[sel]) + p_split = torch.logit(new_opacities * 0.5).repeat(nb_dupli) # [2N] + elif name == "delta": + p_split = p[sel].repeat(nb_dupli, 1) * 1 + elif name == "sigma": + p_split = p[sel].repeat(nb_dupli, 1) * sigma_scaling_cloning + elif name == "sh0" or name == "shN": + p_split = p[sel, :, :].repeat(nb_dupli, 1, 1) + else: + p_split = p[sel].repeat(repeats) + p_new = torch.cat([p[rest], p_split]) + p_new = torch.nn.Parameter(p_new, requires_grad=p.requires_grad) + return p_new + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + v_split = torch.zeros((nb_dupli * len(sel), *v.shape[1:]), device=device) + return torch.cat([v[rest], v_split]) + + # update the parameters and the state in the optimizers + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + # update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor): + repeats = [nb_dupli] + [1] * (v.dim() - 1) + v_new = v[sel].repeat(repeats) + state[k] = torch.cat((v[rest], v_new)) + + +@torch.no_grad() +def remove( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Tensor], + mask: Tensor, +): + """Inplace remove the Gaussian with the given mask. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + mask: A boolean mask to remove the Gaussians. + """ + sel = torch.where(~mask)[0] + + def param_fn(name: str, p: Tensor) -> Tensor: + return torch.nn.Parameter(p[sel], requires_grad=p.requires_grad) + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + return v[sel] + + # update the parameters and the state in the optimizers + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + # update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v[sel] + + +@torch.no_grad() +def reset_opa( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Tensor], + value: float, +): + """Inplace reset the opacities to the given post-sigmoid value. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + value: The value to reset the opacities + """ + + def param_fn(name: str, p: Tensor) -> Tensor: + if name == "opacities": + opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item()) + return torch.nn.Parameter(opacities, requires_grad=p.requires_grad) + else: + raise ValueError(f"Unexpected parameter name: {name}") + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + return torch.zeros_like(v) + + # update the parameters and the state in the optimizers + _update_param_with_optimizer( + param_fn, optimizer_fn, params, optimizers, names=["opacities"] + ) diff --git a/gsplat/utils.py b/gsplat/utils.py index e56692958..31436abad 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -7,6 +7,124 @@ import torch.nn.functional as F from torch import Tensor +def fibonacci_sphere(x, y, z, radii, nb_points): + # Prepare storage for the generated points + points_per_convex = torch.zeros((x.shape[0], nb_points, 3), device=x.device) + + # Generate nb_points on a unit sphere using the Fibonacci lattice + for i in range(nb_points): + # z-coordinates, linearly spaced between 1 and -1, converted to tensor + z_coord = torch.tensor(1 - (2 * i / (nb_points - 1)), device=x.device) # Tensor + + # Calculate the radial distance in the xy-plane + radii_xy = torch.sqrt(1 - z_coord**2) # Tensor, radial distance in the xy-plane + + # Theta, spaced by the golden angle + theta = torch.pi * (3.0 - torch.sqrt(torch.tensor(5.0))) * i # Scalar, but used later in tensor ops + + # Generate unit vectors for each point on the sphere + x_unit = radii_xy * torch.cos(theta.clone().detach()) # Tensor + y_unit = radii_xy * torch.sin(theta.clone().detach()) # Tensor + z_unit = z_coord # Already a tensor + + # Stack the unit vector (shape: [3]) and scale it by radii (shape: [100, 1]) + unit_sphere_point = torch.stack([x_unit, y_unit, z_unit], dim=0) # Shape: [3] + + # Apply the scaling by radii and add the center coordinates + points_per_convex[:, i, :] = radii * unit_sphere_point + torch.stack([x, y, z], dim=1) + + return points_per_convex + +def save_ply_convex(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): + # Convert all tensors to numpy arrays in one go + print(f"Saving ply to {dir}") + numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} + + convex_points = numpy_data["convex_points"] + opacities = numpy_data["opacities"] + delta = numpy_data["delta"] + sigma = numpy_data["sigma"] + + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(convex_points.shape[0], -1) + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(convex_points.shape[0], -1) + + # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays + invalid_mask = ( + np.isnan(convex_points).any(axis=2) + | np.isinf(convex_points).any(axis=2) + | np.isnan(opacities).any(axis=0) + | np.isinf(opacities).any(axis=0) + | np.isnan(delta).any(axis=0) + | np.isinf(delta).any(axis=0) + | np.isnan(sigma).any(axis=0) + | np.isinf(sigma).any(axis=0) + | np.isnan(sh0).any(axis=1) + | np.isinf(sh0).any(axis=1) + | np.isnan(shN).any(axis=1) + | np.isinf(shN).any(axis=1) + ) + + # Filter out rows with NaNs or Infs from all data arrays + convex_points = convex_points[~invalid_mask] + opacities = opacities[~invalid_mask] + delta = delta[~invalid_mask] + sigma = sigma[~invalid_mask] + sh0 = sh0[~invalid_mask] + shN = shN[~invalid_mask] + + num_points = convex_points.shape[0] + + with open(dir, "wb") as f: + # Write PLY header + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(f"element vertex {num_points}\n".encode()) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + f.write(b"property float nx\n") + f.write(b"property float ny\n") + f.write(b"property float nz\n") + + if colors is not None: + for j in range(colors.shape[1]): + f.write(f"property float f_dc_{j}\n".encode()) + else: + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + f.write(f"property float {prefix}_{j}\n".encode()) + + f.write(b"property float opacity\n") + + for i in range(scales.shape[1]): + f.write(f"property float scale_{i}\n".encode()) + for i in range(quats.shape[1]): + f.write(f"property float rot_{i}\n".encode()) + + f.write(b"end_header\n") + + # Write vertex data + for i in range(num_points): + f.write(struct.pack("