Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2768,6 +2768,7 @@ def rasterize_to_pixels_2dgs(
packed: bool = False,
absgrad: bool = False,
distloss: bool = False,
has_depth_channel: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Rasterize Gaussians to pixels.

Expand Down Expand Up @@ -2873,6 +2874,7 @@ def rasterize_to_pixels_2dgs(
flatten_ids.contiguous(),
absgrad,
distloss,
has_depth_channel,
)

if padded_channels > 0:
Expand Down Expand Up @@ -2989,6 +2991,7 @@ def forward(
flatten_ids: Tensor,
absgrad: bool,
distloss: bool,
has_depth_channel: bool,
) -> Tuple[Tensor, Tensor]:
(
render_colors,
Expand All @@ -3011,6 +3014,7 @@ def forward(
tile_size,
isect_offsets,
flatten_ids,
has_depth_channel,
)

ctx.save_for_backward(
Expand All @@ -3034,6 +3038,7 @@ def forward(
ctx.tile_size = tile_size
ctx.absgrad = absgrad
ctx.distloss = distloss
ctx.has_depth_channel = has_depth_channel

# double to float
render_alphas = render_alphas.float()
Expand Down Expand Up @@ -3108,6 +3113,7 @@ def backward(
v_render_distort.contiguous(),
v_render_median.contiguous(),
absgrad,
ctx.has_depth_channel,
)
torch.cuda.synchronize()
if absgrad:
Expand Down Expand Up @@ -3136,4 +3142,5 @@ def backward(
None, # flatten_ids
None, # absgrad
None, # distloss
None, # has_depth_channel
)
8 changes: 6 additions & 2 deletions gsplat/cuda/csrc/Rasterization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ rasterize_to_pixels_2dgs_fwd(
int64_t tile_size,
// intersections
const at::Tensor &tile_offsets, // [..., tile_height, tile_width]
const at::Tensor &flatten_ids // [n_isects]
const at::Tensor &flatten_ids, // [n_isects]
bool has_depth_channel // Fix #863: last channel is depth
) {
DEVICE_GUARD(means2d);
CHECK_INPUT(means2d);
Expand Down Expand Up @@ -394,6 +395,7 @@ rasterize_to_pixels_2dgs_fwd(
tile_size, \
tile_offsets, \
flatten_ids, \
has_depth_channel, \
renders, \
alphas, \
render_normals, \
Expand Down Expand Up @@ -462,7 +464,8 @@ rasterize_to_pixels_2dgs_bwd(
const at::Tensor &v_render_distort, // [..., image_height, image_width, 1]
const at::Tensor &v_render_median, // [..., image_height, image_width, 1]
// options
bool absgrad
bool absgrad,
bool has_depth_channel // Fix #863: last channel is depth
) {
DEVICE_GUARD(means2d);
CHECK_INPUT(means2d);
Expand Down Expand Up @@ -518,6 +521,7 @@ rasterize_to_pixels_2dgs_bwd(
tile_size, \
tile_offsets, \
flatten_ids, \
has_depth_channel, \
render_colors, \
render_alphas, \
last_ids, \
Expand Down
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/Rasterization.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel(
// intersections
const at::Tensor tile_offsets, // [..., tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
bool has_depth_channel, // Fix #863: last channel is depth
// outputs
at::Tensor renders, // [..., image_height, image_width, channels]
at::Tensor alphas, // [..., image_height, image_width, 1]
Expand Down Expand Up @@ -163,6 +164,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel(
// ray_crossions
const at::Tensor tile_offsets, // [..., tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
bool has_depth_channel, // Fix #863: last channel is depth
// forward outputs
const at::Tensor render_colors, // [..., image_height, image_width, CDIM]
const at::Tensor render_alphas, // [..., image_height, image_width, 1]
Expand Down
29 changes: 29 additions & 0 deletions gsplat/cuda/csrc/RasterizeToPixels2DGSBwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel(
const uint32_t N, // number of gaussians
const uint32_t n_isects, // number of ray-primitive intersections.
const bool packed, // whether the input tensors are packed
const bool has_depth_channel, // Fix #863: last color channel is depth
// fwd inputs
const vec2
*__restrict__ means2d, // Projected Gaussian means. [..., N, 2] if
Expand Down Expand Up @@ -484,6 +485,17 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel(
for (uint32_t k = 0; k < CDIM; ++k) {
v_rgb_local[k] += fac * v_render_c[k];
}
// Fix #863: depth channel grad goes to v_w_M, not v_colors
if (has_depth_channel) {
float isect_d =
1.0f / (w_M.x * px + w_M.y * py + w_M.z);
float d2 = isect_d * isect_d;
float v_id = fac * v_render_c[CDIM - 1];
v_w_M_local.x += v_id * (-px * d2);
v_w_M_local.y += v_id * (-py * d2);
v_w_M_local.z += v_id * (-d2);
v_rgb_local[CDIM - 1] -= v_id; // undo center-depth grad
}

// contribution from this pixel to alpha
// we have d(alpha)/d(c_i) = c_i * G_i * T + [grad contribution
Expand All @@ -494,6 +506,13 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel(
v_alpha += (rgbs_batch[t * CDIM + k] * T - buffer[k] * ra) *
v_render_c[k];
}
// Fix #863: fix depth channel's contribution to v_alpha
if (has_depth_channel) {
float isect_d =
1.0f / (w_M.x * px + w_M.y * py + w_M.z);
v_alpha += (isect_d - rgbs_batch[t * CDIM + CDIM - 1])
* T * v_render_c[CDIM - 1];
}

/*
* d(normal_out) / d(rgb) and d(normal_out) / d(alpha)
Expand Down Expand Up @@ -620,6 +639,13 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel(
for (uint32_t k = 0; k < CDIM; ++k) {
buffer[k] += rgbs_batch[t * CDIM + k] * fac;
}
// Fix #863: fix depth channel in buffer accumulator
if (has_depth_channel) {
float isect_d =
1.0f / (w_M.x * px + w_M.y * py + w_M.z);
buffer[CDIM - 1] +=
(isect_d - rgbs_batch[t * CDIM + CDIM - 1]) * fac;
}

/**
* Update the cumulative "later" gaussian contributions, used in derivatives of
Expand Down Expand Up @@ -722,6 +748,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel(
// ray_crossions
const at::Tensor tile_offsets, // [..., tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
const bool has_depth_channel, // Fix #863: last channel is depth
// forward outputs
const at::Tensor render_colors, // [..., image_height, image_width, CDIM]
const at::Tensor render_alphas, // [..., image_height, image_width, 1]
Expand Down Expand Up @@ -786,6 +813,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel(
N,
n_isects,
packed,
has_depth_channel,
reinterpret_cast<vec2 *>(means2d.data_ptr<float>()),
ray_transforms.data_ptr<float>(),
colors.data_ptr<float>(),
Expand Down Expand Up @@ -842,6 +870,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel(
const uint32_t tile_size, \
const at::Tensor tile_offsets, \
const at::Tensor flatten_ids, \
bool has_depth_channel, \
const at::Tensor render_colors, \
const at::Tensor render_alphas, \
const at::Tensor last_ids, \
Expand Down
15 changes: 13 additions & 2 deletions gsplat/cuda/csrc/RasterizeToPixels2DGSFwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel(
const uint32_t N, // number of gaussians
const uint32_t n_isects, // number of ray-primitive intersections.
const bool packed, // whether the input tensors are packed
const bool has_depth_channel, // Fix #863: last color channel is depth
const vec2
*__restrict__ means2d, // Projected Gaussian means. [..., N, 2] if
// packed is False, [nnz, 2] if packed is True.
Expand Down Expand Up @@ -404,6 +405,13 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel(
for (uint32_t k = 0; k < CDIM; ++k) {
pix_out[k] += c_ptr[k] * vis;
}
if (has_depth_channel) {
// Fix #863: replace packed Gaussian center depth
// with the Inria 2DGS ray-splat intersection depth.
const float isect_depth = s.x * w_M.x + s.y * w_M.y + w_M.z;
pix_out[CDIM - 1] +=
(isect_depth - c_ptr[CDIM - 1]) * vis;
}

const float *n_ptr = normals + g * 3;
#pragma unroll
Expand All @@ -413,7 +421,7 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel(

if (render_distort != nullptr) {
// the last channel of colors is depth
const float depth = c_ptr[CDIM - 1];
const float depth = s.x * w_M.x + s.y * w_M.y + w_M.z;
// in nerfacc, loss_bi_0 = weights * t_mids *
// exclusive_sum(weights)
const float distort_bi_0 = vis * depth * (1.0f - T);
Expand All @@ -426,7 +434,7 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel(

// compute median depth
if (T > 0.5) {
median_depth = c_ptr[CDIM - 1];
median_depth = s.x * w_M.x + s.y * w_M.y + w_M.z;
median_idx = batch_start + t;
}

Expand Down Expand Up @@ -482,6 +490,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel(
// intersections
const at::Tensor tile_offsets, // [..., tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
const bool has_depth_channel, // Fix #863: last channel is depth
// outputs
at::Tensor renders, // [..., image_height, image_width, channels]
at::Tensor alphas, // [..., image_height, image_width]
Expand Down Expand Up @@ -529,6 +538,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel(
N,
n_isects,
packed,
has_depth_channel,
reinterpret_cast<vec2 *>(means2d.data_ptr<float>()),
ray_transforms.data_ptr<float>(),
colors.data_ptr<float>(),
Expand Down Expand Up @@ -571,6 +581,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel(
uint32_t tile_size, \
const at::Tensor tile_offsets, \
const at::Tensor flatten_ids, \
bool has_depth_channel, \
at::Tensor renders, \
at::Tensor alphas, \
at::Tensor render_normals, \
Expand Down
4 changes: 2 additions & 2 deletions gsplat/cuda/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ TORCH_LIBRARY(gsplat, m) {
m.def("projection_2dgs_packed_fwd(Tensor means, Tensor quats, Tensor scales, Tensor viewmats, Tensor Ks, int image_width, int image_height, float near_plane, float far_plane, float radius_clip) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("projection_2dgs_packed_bwd(Tensor means, Tensor quats, Tensor scales, Tensor viewmats, Tensor Ks, int image_width, int image_height, Tensor batch_ids, Tensor camera_ids, Tensor gaussian_ids, Tensor ray_transforms, Tensor v_means2d, Tensor v_depths, Tensor v_ray_transforms, Tensor v_normals, bool viewmats_requires_grad, bool sparse_grad) -> (Tensor, Tensor, Tensor, Tensor)");

m.def("rasterize_to_pixels_2dgs_fwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("rasterize_to_pixels_2dgs_bwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor densify, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, Tensor render_colors, Tensor render_alphas, Tensor last_ids, Tensor median_ids, Tensor v_render_colors, Tensor v_render_alphas, Tensor v_render_normals, Tensor v_render_distort, Tensor v_render_median, bool absgrad) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("rasterize_to_pixels_2dgs_fwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, bool has_depth_channel) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("rasterize_to_pixels_2dgs_bwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor densify, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, Tensor render_colors, Tensor render_alphas, Tensor last_ids, Tensor median_ids, Tensor v_render_colors, Tensor v_render_alphas, Tensor v_render_normals, Tensor v_render_distort, Tensor v_render_median, bool absgrad, bool has_depth_channel) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("rasterize_to_indices_2dgs(int range_start, int range_end, Tensor transmittances, Tensor means2d, Tensor ray_transforms, Tensor opacities, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor)");
#endif

Expand Down
6 changes: 4 additions & 2 deletions gsplat/cuda/include/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ rasterize_to_pixels_2dgs_fwd(
int64_t tile_size,
// intersections
const at::Tensor &tile_offsets, // [..., tile_height, tile_width]
const at::Tensor &flatten_ids // [n_isects]
const at::Tensor &flatten_ids, // [n_isects]
bool has_depth_channel // Fix #863: last channel is depth
);
std::tuple<
at::Tensor,
Expand Down Expand Up @@ -542,7 +543,8 @@ rasterize_to_pixels_2dgs_bwd(
const at::Tensor &v_render_distort, // [..., image_height, image_width, 1]
const at::Tensor &v_render_median, // [..., image_height, image_width, 1]
// options
bool absgrad
bool absgrad,
bool has_depth_channel // Fix #863: last channel is depth
);

std::tuple<at::Tensor, at::Tensor> rasterize_to_indices_2dgs(
Expand Down
1 change: 1 addition & 0 deletions gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,7 @@ def rasterization_2dgs(
packed=packed,
absgrad=absgrad,
distloss=distloss,
has_depth_channel=render_mode_has_depth_channel(render_mode),
)
render_normals_from_depth = None
if render_mode_has_expected_depth(render_mode):
Expand Down