Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 71 additions & 15 deletions gsplat/rendering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.distributed
Expand Down Expand Up @@ -283,13 +283,17 @@ def rasterization(
assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode

def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor:
view_list = list(
map(
lambda x: x.split(int(x.shape[0] / C), dim=0),
world_view.split([C * N_i for N_i in N_world], dim=0),
)
)
return torch.stack([torch.cat(l, dim=0) for l in zip(*view_list)], dim=0)
world_view = world_view.contiguous()
blocks = world_view.split([C * N_i for N_i in N_world], dim=0)
reshaped_blocks = [
block.view(C, N_i, *block.shape[1:]) for block, N_i in zip(blocks, N_world)
]
return torch.cat(reshaped_blocks, dim=1)

def recover_world_view(output_tensor: torch.Tensor, N_world: list) -> torch.Tensor:
blocks = output_tensor.split(N_world, dim=1)
reshaped_blocks = [block.reshape(-1, *block.shape[2:]) for block in blocks]
return torch.cat(reshaped_blocks, dim=0)

if sh_degree is None:
# treat colors as post-activation values, should be in shape [..., N, D] or [..., C, N, D]
Expand Down Expand Up @@ -322,9 +326,6 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
colors.dim() == num_batch_dims + 3
), "Distributed mode only supports per-Gaussian colors."

if absgrad:
assert not distributed, "AbsGrad is not supported in distributed mode."

if (
radial_coeffs is not None
or tangential_coeffs is not None
Expand Down Expand Up @@ -540,6 +541,31 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
output_splits=collected_splits,
)

# all to all communication in backward pass
def register_absgrad_hook(
means2d: Tensor,
means2d_redistributed: Tensor,
splits: List[Union[int, Tensor]],
output_splits: List[Union[int, Tensor]],
):
def absgrad_hook(grad: Tensor):
(means2d.absgrad,) = all_to_all_tensor_list(
world_size,
[means2d_redistributed.absgrad],
splits,
output_splits=output_splits,
)
return grad

return means2d.register_hook(absgrad_hook)

means2d = means2d.contiguous()
if absgrad and meta["means2d"].requires_grad:
handle = register_absgrad_hook(
meta["means2d"], means2d, collected_splits, cnts
)
meta.update({"absgrad_hook": handle})

# before sending the data, we should turn the camera_ids from global to local.
# i.e. the camera_ids produced by the projection stage are over all cameras world-wide,
# so we need to turn them into camera_ids that are local to each rank.
Expand Down Expand Up @@ -577,11 +603,13 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso

# all to all communication across all ranks. After this step, each rank
# would have all the necessary GSs to render its own images.
splits = [C_i * N for C_i in C_world]
output_splits = [C * N_i for N_i in N_world]
(radii,) = all_to_all_tensor_list(
world_size,
[radii.flatten(0, 1)],
splits=[C_i * N for C_i in C_world],
output_splits=[C * N_i for N_i in N_world],
splits=splits,
output_splits=output_splits,
)
radii = reshape_view(C, radii, N_world)

Expand All @@ -594,15 +622,43 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
opacities.flatten(0, 1),
colors.flatten(0, 1),
],
splits=[C_i * N for C_i in C_world],
output_splits=[C * N_i for N_i in N_world],
splits=splits,
output_splits=output_splits,
)
means2d = reshape_view(C, means2d, N_world)
depths = reshape_view(C, depths, N_world)
conics = reshape_view(C, conics, N_world)
opacities = reshape_view(C, opacities, N_world)
colors = reshape_view(C, colors, N_world)

# all to all communication in backward pass
def register_absgrad_hook(
means2d: Tensor,
means2d_redistributed: Tensor,
splits: List[Union[int, Tensor]],
output_splits: List[Union[int, Tensor]],
):
def absgrad_hook(grad: Tensor):
absgrad_recovered = recover_world_view(
means2d_redistributed.absgrad, N_world
)
(absgrad,) = all_to_all_tensor_list(
world_size,
[absgrad_recovered],
splits,
output_splits=output_splits,
)
means2d.absgrad = absgrad.reshape(means2d.shape)
return grad

return means2d.register_hook(absgrad_hook)

if absgrad and meta["means2d"].requires_grad:
handle = register_absgrad_hook(
meta["means2d"], means2d, output_splits, splits
)
meta.update({"absgrad_hook": handle})

# Rasterize to pixels
if render_mode in ["RGB+D", "RGB+ED"]:
colors = torch.cat((colors, depths[..., None]), dim=-1)
Expand Down
2 changes: 2 additions & 0 deletions gsplat/strategy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def step_post_backward(
packed: bool = False,
):
"""Callback function to be executed after the `loss.backward()` call."""
if "absgrad_hook" in info:
info["absgrad_hook"].remove()
if step >= self.refine_stop_iter:
return

Expand Down