diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 2d24a4119..6b9ae2604 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -2768,6 +2768,7 @@ def rasterize_to_pixels_2dgs( packed: bool = False, absgrad: bool = False, distloss: bool = False, + has_depth_channel: bool = False, ) -> Tuple[Tensor, Tensor]: """Rasterize Gaussians to pixels. @@ -2873,6 +2874,7 @@ def rasterize_to_pixels_2dgs( flatten_ids.contiguous(), absgrad, distloss, + has_depth_channel, ) if padded_channels > 0: @@ -2989,6 +2991,7 @@ def forward( flatten_ids: Tensor, absgrad: bool, distloss: bool, + has_depth_channel: bool, ) -> Tuple[Tensor, Tensor]: ( render_colors, @@ -3011,6 +3014,7 @@ def forward( tile_size, isect_offsets, flatten_ids, + has_depth_channel, ) ctx.save_for_backward( @@ -3034,6 +3038,7 @@ def forward( ctx.tile_size = tile_size ctx.absgrad = absgrad ctx.distloss = distloss + ctx.has_depth_channel = has_depth_channel # double to float render_alphas = render_alphas.float() @@ -3108,6 +3113,7 @@ def backward( v_render_distort.contiguous(), v_render_median.contiguous(), absgrad, + ctx.has_depth_channel, ) torch.cuda.synchronize() if absgrad: @@ -3136,4 +3142,5 @@ def backward( None, # flatten_ids None, # absgrad None, # distloss + None, # has_depth_channel ) diff --git a/gsplat/cuda/csrc/Rasterization.cpp b/gsplat/cuda/csrc/Rasterization.cpp index 5b402a674..085bbf080 100644 --- a/gsplat/cuda/csrc/Rasterization.cpp +++ b/gsplat/cuda/csrc/Rasterization.cpp @@ -330,7 +330,8 @@ rasterize_to_pixels_2dgs_fwd( int64_t tile_size, // intersections const at::Tensor &tile_offsets, // [..., tile_height, tile_width] - const at::Tensor &flatten_ids // [n_isects] + const at::Tensor &flatten_ids, // [n_isects] + bool has_depth_channel // Fix #863: last channel is depth ) { DEVICE_GUARD(means2d); CHECK_INPUT(means2d); @@ -394,6 +395,7 @@ rasterize_to_pixels_2dgs_fwd( tile_size, \ tile_offsets, \ flatten_ids, \ + has_depth_channel, \ renders, \ alphas, \ render_normals, \ @@ -462,7 +464,8 @@ rasterize_to_pixels_2dgs_bwd( const at::Tensor &v_render_distort, // [..., image_height, image_width, 1] const at::Tensor &v_render_median, // [..., image_height, image_width, 1] // options - bool absgrad + bool absgrad, + bool has_depth_channel // Fix #863: last channel is depth ) { DEVICE_GUARD(means2d); CHECK_INPUT(means2d); @@ -518,6 +521,7 @@ rasterize_to_pixels_2dgs_bwd( tile_size, \ tile_offsets, \ flatten_ids, \ + has_depth_channel, \ render_colors, \ render_alphas, \ last_ids, \ diff --git a/gsplat/cuda/csrc/Rasterization.h b/gsplat/cuda/csrc/Rasterization.h index 9397044af..d08cd46ff 100644 --- a/gsplat/cuda/csrc/Rasterization.h +++ b/gsplat/cuda/csrc/Rasterization.h @@ -136,6 +136,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel( // intersections const at::Tensor tile_offsets, // [..., tile_height, tile_width] const at::Tensor flatten_ids, // [n_isects] + bool has_depth_channel, // Fix #863: last channel is depth // outputs at::Tensor renders, // [..., image_height, image_width, channels] at::Tensor alphas, // [..., image_height, image_width, 1] @@ -163,6 +164,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel( // ray_crossions const at::Tensor tile_offsets, // [..., tile_height, tile_width] const at::Tensor flatten_ids, // [n_isects] + bool has_depth_channel, // Fix #863: last channel is depth // forward outputs const at::Tensor render_colors, // [..., image_height, image_width, CDIM] const at::Tensor render_alphas, // [..., image_height, image_width, 1] diff --git a/gsplat/cuda/csrc/RasterizeToPixels2DGSBwd.cu b/gsplat/cuda/csrc/RasterizeToPixels2DGSBwd.cu index 6a6451032..cf2020edd 100644 --- a/gsplat/cuda/csrc/RasterizeToPixels2DGSBwd.cu +++ b/gsplat/cuda/csrc/RasterizeToPixels2DGSBwd.cu @@ -41,6 +41,7 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel( const uint32_t N, // number of gaussians const uint32_t n_isects, // number of ray-primitive intersections. const bool packed, // whether the input tensors are packed + const bool has_depth_channel, // Fix #863: last color channel is depth // fwd inputs const vec2 *__restrict__ means2d, // Projected Gaussian means. [..., N, 2] if @@ -484,6 +485,17 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel( for (uint32_t k = 0; k < CDIM; ++k) { v_rgb_local[k] += fac * v_render_c[k]; } + // Fix #863: depth channel grad goes to v_w_M, not v_colors + if (has_depth_channel) { + float isect_d = + 1.0f / (w_M.x * px + w_M.y * py + w_M.z); + float d2 = isect_d * isect_d; + float v_id = fac * v_render_c[CDIM - 1]; + v_w_M_local.x += v_id * (-px * d2); + v_w_M_local.y += v_id * (-py * d2); + v_w_M_local.z += v_id * (-d2); + v_rgb_local[CDIM - 1] -= v_id; // undo center-depth grad + } // contribution from this pixel to alpha // we have d(alpha)/d(c_i) = c_i * G_i * T + [grad contribution @@ -494,6 +506,13 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel( v_alpha += (rgbs_batch[t * CDIM + k] * T - buffer[k] * ra) * v_render_c[k]; } + // Fix #863: fix depth channel's contribution to v_alpha + if (has_depth_channel) { + float isect_d = + 1.0f / (w_M.x * px + w_M.y * py + w_M.z); + v_alpha += (isect_d - rgbs_batch[t * CDIM + CDIM - 1]) + * T * v_render_c[CDIM - 1]; + } /* * d(normal_out) / d(rgb) and d(normal_out) / d(alpha) @@ -620,6 +639,13 @@ __global__ void rasterize_to_pixels_2dgs_bwd_kernel( for (uint32_t k = 0; k < CDIM; ++k) { buffer[k] += rgbs_batch[t * CDIM + k] * fac; } + // Fix #863: fix depth channel in buffer accumulator + if (has_depth_channel) { + float isect_d = + 1.0f / (w_M.x * px + w_M.y * py + w_M.z); + buffer[CDIM - 1] += + (isect_d - rgbs_batch[t * CDIM + CDIM - 1]) * fac; + } /** * Update the cumulative "later" gaussian contributions, used in derivatives of @@ -722,6 +748,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel( // ray_crossions const at::Tensor tile_offsets, // [..., tile_height, tile_width] const at::Tensor flatten_ids, // [n_isects] + const bool has_depth_channel, // Fix #863: last channel is depth // forward outputs const at::Tensor render_colors, // [..., image_height, image_width, CDIM] const at::Tensor render_alphas, // [..., image_height, image_width, 1] @@ -786,6 +813,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel( N, n_isects, packed, + has_depth_channel, reinterpret_cast(means2d.data_ptr()), ray_transforms.data_ptr(), colors.data_ptr(), @@ -842,6 +870,7 @@ void launch_rasterize_to_pixels_2dgs_bwd_kernel( const uint32_t tile_size, \ const at::Tensor tile_offsets, \ const at::Tensor flatten_ids, \ + bool has_depth_channel, \ const at::Tensor render_colors, \ const at::Tensor render_alphas, \ const at::Tensor last_ids, \ diff --git a/gsplat/cuda/csrc/RasterizeToPixels2DGSFwd.cu b/gsplat/cuda/csrc/RasterizeToPixels2DGSFwd.cu index c6a2e06c1..85da3f0b8 100644 --- a/gsplat/cuda/csrc/RasterizeToPixels2DGSFwd.cu +++ b/gsplat/cuda/csrc/RasterizeToPixels2DGSFwd.cu @@ -43,6 +43,7 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel( const uint32_t N, // number of gaussians const uint32_t n_isects, // number of ray-primitive intersections. const bool packed, // whether the input tensors are packed + const bool has_depth_channel, // Fix #863: last color channel is depth const vec2 *__restrict__ means2d, // Projected Gaussian means. [..., N, 2] if // packed is False, [nnz, 2] if packed is True. @@ -404,6 +405,13 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel( for (uint32_t k = 0; k < CDIM; ++k) { pix_out[k] += c_ptr[k] * vis; } + if (has_depth_channel) { + // Fix #863: replace packed Gaussian center depth + // with the Inria 2DGS ray-splat intersection depth. + const float isect_depth = s.x * w_M.x + s.y * w_M.y + w_M.z; + pix_out[CDIM - 1] += + (isect_depth - c_ptr[CDIM - 1]) * vis; + } const float *n_ptr = normals + g * 3; #pragma unroll @@ -413,7 +421,7 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel( if (render_distort != nullptr) { // the last channel of colors is depth - const float depth = c_ptr[CDIM - 1]; + const float depth = s.x * w_M.x + s.y * w_M.y + w_M.z; // in nerfacc, loss_bi_0 = weights * t_mids * // exclusive_sum(weights) const float distort_bi_0 = vis * depth * (1.0f - T); @@ -426,7 +434,7 @@ __global__ void rasterize_to_pixels_2dgs_fwd_kernel( // compute median depth if (T > 0.5) { - median_depth = c_ptr[CDIM - 1]; + median_depth = s.x * w_M.x + s.y * w_M.y + w_M.z; median_idx = batch_start + t; } @@ -482,6 +490,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel( // intersections const at::Tensor tile_offsets, // [..., tile_height, tile_width] const at::Tensor flatten_ids, // [n_isects] + const bool has_depth_channel, // Fix #863: last channel is depth // outputs at::Tensor renders, // [..., image_height, image_width, channels] at::Tensor alphas, // [..., image_height, image_width] @@ -529,6 +538,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel( N, n_isects, packed, + has_depth_channel, reinterpret_cast(means2d.data_ptr()), ray_transforms.data_ptr(), colors.data_ptr(), @@ -571,6 +581,7 @@ void launch_rasterize_to_pixels_2dgs_fwd_kernel( uint32_t tile_size, \ const at::Tensor tile_offsets, \ const at::Tensor flatten_ids, \ + bool has_depth_channel, \ at::Tensor renders, \ at::Tensor alphas, \ at::Tensor render_normals, \ diff --git a/gsplat/cuda/ext.cpp b/gsplat/cuda/ext.cpp index 7711a9d2f..662ebbba6 100644 --- a/gsplat/cuda/ext.cpp +++ b/gsplat/cuda/ext.cpp @@ -744,8 +744,8 @@ TORCH_LIBRARY(gsplat, m) { m.def("projection_2dgs_packed_fwd(Tensor means, Tensor quats, Tensor scales, Tensor viewmats, Tensor Ks, int image_width, int image_height, float near_plane, float far_plane, float radius_clip) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("projection_2dgs_packed_bwd(Tensor means, Tensor quats, Tensor scales, Tensor viewmats, Tensor Ks, int image_width, int image_height, Tensor batch_ids, Tensor camera_ids, Tensor gaussian_ids, Tensor ray_transforms, Tensor v_means2d, Tensor v_depths, Tensor v_ray_transforms, Tensor v_normals, bool viewmats_requires_grad, bool sparse_grad) -> (Tensor, Tensor, Tensor, Tensor)"); - m.def("rasterize_to_pixels_2dgs_fwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("rasterize_to_pixels_2dgs_bwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor densify, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, Tensor render_colors, Tensor render_alphas, Tensor last_ids, Tensor median_ids, Tensor v_render_colors, Tensor v_render_alphas, Tensor v_render_normals, Tensor v_render_distort, Tensor v_render_median, bool absgrad) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("rasterize_to_pixels_2dgs_fwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, bool has_depth_channel) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("rasterize_to_pixels_2dgs_bwd(Tensor means2d, Tensor ray_transforms, Tensor colors, Tensor opacities, Tensor normals, Tensor densify, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, Tensor render_colors, Tensor render_alphas, Tensor last_ids, Tensor median_ids, Tensor v_render_colors, Tensor v_render_alphas, Tensor v_render_normals, Tensor v_render_distort, Tensor v_render_median, bool absgrad, bool has_depth_channel) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("rasterize_to_indices_2dgs(int range_start, int range_end, Tensor transmittances, Tensor means2d, Tensor ray_transforms, Tensor opacities, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor)"); #endif diff --git a/gsplat/cuda/include/Ops.h b/gsplat/cuda/include/Ops.h index 27f02babc..9017f47e0 100644 --- a/gsplat/cuda/include/Ops.h +++ b/gsplat/cuda/include/Ops.h @@ -503,7 +503,8 @@ rasterize_to_pixels_2dgs_fwd( int64_t tile_size, // intersections const at::Tensor &tile_offsets, // [..., tile_height, tile_width] - const at::Tensor &flatten_ids // [n_isects] + const at::Tensor &flatten_ids, // [n_isects] + bool has_depth_channel // Fix #863: last channel is depth ); std::tuple< at::Tensor, @@ -542,7 +543,8 @@ rasterize_to_pixels_2dgs_bwd( const at::Tensor &v_render_distort, // [..., image_height, image_width, 1] const at::Tensor &v_render_median, // [..., image_height, image_width, 1] // options - bool absgrad + bool absgrad, + bool has_depth_channel // Fix #863: last channel is depth ); std::tuple rasterize_to_indices_2dgs( diff --git a/gsplat/rendering.py b/gsplat/rendering.py index fd008686e..b1b9211ae 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -2254,6 +2254,7 @@ def rasterization_2dgs( packed=packed, absgrad=absgrad, distloss=distloss, + has_depth_channel=render_mode_has_depth_channel(render_mode), ) render_normals_from_depth = None if render_mode_has_expected_depth(render_mode):