diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 8559e615e..9f2cf7633 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1,5 +1,5 @@ import math -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed @@ -283,13 +283,17 @@ def rasterization( assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor: - view_list = list( - map( - lambda x: x.split(int(x.shape[0] / C), dim=0), - world_view.split([C * N_i for N_i in N_world], dim=0), - ) - ) - return torch.stack([torch.cat(l, dim=0) for l in zip(*view_list)], dim=0) + world_view = world_view.contiguous() + blocks = world_view.split([C * N_i for N_i in N_world], dim=0) + reshaped_blocks = [ + block.view(C, N_i, *block.shape[1:]) for block, N_i in zip(blocks, N_world) + ] + return torch.cat(reshaped_blocks, dim=1) + + def recover_world_view(output_tensor: torch.Tensor, N_world: list) -> torch.Tensor: + blocks = output_tensor.split(N_world, dim=1) + reshaped_blocks = [block.reshape(-1, *block.shape[2:]) for block in blocks] + return torch.cat(reshaped_blocks, dim=0) if sh_degree is None: # treat colors as post-activation values, should be in shape [..., N, D] or [..., C, N, D] @@ -322,9 +326,6 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso colors.dim() == num_batch_dims + 3 ), "Distributed mode only supports per-Gaussian colors." - if absgrad: - assert not distributed, "AbsGrad is not supported in distributed mode." - if ( radial_coeffs is not None or tangential_coeffs is not None @@ -540,6 +541,31 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso output_splits=collected_splits, ) + # all to all communication in backward pass + def register_absgrad_hook( + means2d: Tensor, + means2d_redistributed: Tensor, + splits: List[Union[int, Tensor]], + output_splits: List[Union[int, Tensor]], + ): + def absgrad_hook(grad: Tensor): + (means2d.absgrad,) = all_to_all_tensor_list( + world_size, + [means2d_redistributed.absgrad], + splits, + output_splits=output_splits, + ) + return grad + + return means2d.register_hook(absgrad_hook) + + means2d = means2d.contiguous() + if absgrad and meta["means2d"].requires_grad: + handle = register_absgrad_hook( + meta["means2d"], means2d, collected_splits, cnts + ) + meta.update({"absgrad_hook": handle}) + # before sending the data, we should turn the camera_ids from global to local. # i.e. the camera_ids produced by the projection stage are over all cameras world-wide, # so we need to turn them into camera_ids that are local to each rank. @@ -577,11 +603,13 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso # all to all communication across all ranks. After this step, each rank # would have all the necessary GSs to render its own images. + splits = [C_i * N for C_i in C_world] + output_splits = [C * N_i for N_i in N_world] (radii,) = all_to_all_tensor_list( world_size, [radii.flatten(0, 1)], - splits=[C_i * N for C_i in C_world], - output_splits=[C * N_i for N_i in N_world], + splits=splits, + output_splits=output_splits, ) radii = reshape_view(C, radii, N_world) @@ -594,8 +622,8 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso opacities.flatten(0, 1), colors.flatten(0, 1), ], - splits=[C_i * N for C_i in C_world], - output_splits=[C * N_i for N_i in N_world], + splits=splits, + output_splits=output_splits, ) means2d = reshape_view(C, means2d, N_world) depths = reshape_view(C, depths, N_world) @@ -603,6 +631,34 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso opacities = reshape_view(C, opacities, N_world) colors = reshape_view(C, colors, N_world) + # all to all communication in backward pass + def register_absgrad_hook( + means2d: Tensor, + means2d_redistributed: Tensor, + splits: List[Union[int, Tensor]], + output_splits: List[Union[int, Tensor]], + ): + def absgrad_hook(grad: Tensor): + absgrad_recovered = recover_world_view( + means2d_redistributed.absgrad, N_world + ) + (absgrad,) = all_to_all_tensor_list( + world_size, + [absgrad_recovered], + splits, + output_splits=output_splits, + ) + means2d.absgrad = absgrad.reshape(means2d.shape) + return grad + + return means2d.register_hook(absgrad_hook) + + if absgrad and meta["means2d"].requires_grad: + handle = register_absgrad_hook( + meta["means2d"], means2d, output_splits, splits + ) + meta.update({"absgrad_hook": handle}) + # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: colors = torch.cat((colors, depths[..., None]), dim=-1) diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 9075278a9..381207c75 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -159,6 +159,8 @@ def step_post_backward( packed: bool = False, ): """Callback function to be executed after the `loss.backward()` call.""" + if "absgrad_hook" in info: + info["absgrad_hook"].remove() if step >= self.refine_stop_iter: return