Skip to content

Commit 4ca04df

Browse files
committed
fix backgrounds shape, fix colors for 2dgs in packed mode
1 parent 73fad53 commit 4ca04df

2 files changed

Lines changed: 35 additions & 17 deletions

File tree

gsplat/cuda/_wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ def rasterize_to_pixels(
595595
assert colors.shape == image_dims + (N, channels), colors.shape
596596
assert opacities.shape == image_dims + (N,), opacities.shape
597597
if backgrounds is not None:
598-
assert backgrounds.shape == image_dims + (channels,), backgrounds.shape
598+
assert backgrounds.shape == (image_dims or (1,)) + (
599+
channels,
600+
), backgrounds.shape
599601
backgrounds = backgrounds.contiguous()
600602
if masks is not None:
601603
assert masks.shape == isect_offsets.shape, masks.shape
@@ -2281,7 +2283,9 @@ def rasterize_to_pixels_2dgs(
22812283
assert colors.shape[:-2] == image_dims, colors.shape
22822284
assert opacities.shape == image_dims + (N,), opacities.shape
22832285
if backgrounds is not None:
2284-
assert backgrounds.shape == image_dims + (channels,), backgrounds.shape
2286+
assert backgrounds.shape == (image_dims or (1,)) + (
2287+
channels,
2288+
), backgrounds.shape
22852289
backgrounds = backgrounds.contiguous()
22862290

22872291
# Pad the channels to the nearest supported number if necessary

gsplat/rendering.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,23 +1512,37 @@ def rasterization_2dgs(
15121512
# if packed:
15131513
# colors = colors.view(B, C, N, -1)[batch_ids, camera_ids, gaussian_ids, :]
15141514
if sh_degree is not None: # SH coefficients
1515-
camtoworlds = torch.inverse(viewmats)
1515+
# Colors are SH coefficients, with shape [..., N, K, 3] or [..., C, N, K, 3]
1516+
campos = torch.inverse(viewmats)[..., :3, 3] # [..., C, 3]
15161517
if packed:
1517-
dirs = means[..., gaussian_ids, :] - camtoworlds[..., camera_ids, :3, 3]
1518-
else:
1519-
dirs = means[..., None, :, :] - camtoworlds[..., None, :3, 3]
1520-
1521-
if colors.dim() == num_batch_dims + 3:
1522-
# Turn [..., N, K, 3] into [..., C, N, K, 3]
1523-
shs = torch.broadcast_to(
1524-
colors[..., None, :, :, :], batch_dims + (C, N, -1, 3)
1525-
) # [..., C, N, K, 3]
1518+
dirs = (
1519+
means.view(B, N, 3)[batch_ids, gaussian_ids]
1520+
- campos.view(B, C, 3)[batch_ids, camera_ids]
1521+
) # [nnz, 3]
1522+
masks = (radii > 0).all(dim=-1) # [nnz]
1523+
if colors.dim() == num_batch_dims + 3:
1524+
# Turn [..., N, K, 3] into [nnz, 3]
1525+
shs = colors.view(B, N, -1, 3)[batch_ids, gaussian_ids] # [nnz, K, 3]
1526+
else:
1527+
# Turn [..., C, N, K, 3] into [nnz, 3]
1528+
shs = colors.view(B, C, N, -1, 3)[
1529+
batch_ids, camera_ids, gaussian_ids
1530+
] # [nnz, K, 3]
1531+
colors = spherical_harmonics(sh_degree, dirs, shs, masks=masks) # [nnz, 3]
15261532
else:
1527-
# colors is already [..., C, N, K, 3]
1528-
shs = colors
1529-
colors = spherical_harmonics(
1530-
sh_degree, dirs, shs, masks=(radii > 0).all(dim=-1)
1531-
) # [nnz, D] or [..., C, N, 3]
1533+
dirs = means[..., None, :, :] - campos[..., None, :] # [..., C, N, 3]
1534+
masks = (radii > 0).all(dim=-1) # [..., C, N]
1535+
if colors.dim() == num_batch_dims + 3:
1536+
# Turn [..., N, K, 3] into [..., C, N, K, 3]
1537+
shs = torch.broadcast_to(
1538+
colors[..., None, :, :, :], batch_dims + (C, N, -1, 3)
1539+
)
1540+
else:
1541+
# colors is already [..., C, N, K, 3]
1542+
shs = colors
1543+
colors = spherical_harmonics(
1544+
sh_degree, dirs, shs, masks=masks
1545+
) # [..., C, N, 3]
15321546
# make it apple-to-apple with Inria's CUDA Backend.
15331547
colors = torch.clamp_min(colors + 0.5, 0.0)
15341548

0 commit comments

Comments
 (0)