Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,8 @@ def rasterize_to_pixels(
masks: Optional[Tensor] = None, # [..., tile_height, tile_width]
packed: bool = False,
absgrad: bool = False,
) -> Tuple[Tensor, Tensor]:
return_median: bool = False,
) -> Tuple[Tensor, ...]:
"""Rasterizes Gaussians to pixels.

Args:
Expand All @@ -857,12 +858,20 @@ def rasterize_to_pixels(
masks: Optional tile mask to skip rendering GS to masked tiles. [..., tile_height, tile_width]. Default: None.
packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False.
absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False.
return_median: If True, additionally returns per-pixel median depth, defined
as the value of the last channel of ``colors`` for the Gaussian whose
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the value of the last channel of colors " is weird to read. Maybe just say something like "expect the last channel to be depth"

contribution causes the accumulated transmittance to drop across 0.5.
If the threshold is never crossed the last contributing Gaussian's value
is returned; if no Gaussian hits the pixel the value is 0. This output
is non-differentiable. Default: False.

Returns:
A tuple:

- **Rendered colors**. [..., image_height, image_width, channels]
- **Rendered alphas**. [..., image_height, image_width, 1]
- **Rendered median depth** (only when ``return_median=True``).
[..., image_height, image_width, 1]
"""

image_dims = means2d.shape[:-2]
Expand Down Expand Up @@ -941,7 +950,7 @@ def rasterize_to_pixels(
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"

render_colors, render_alphas = _RasterizeToPixels.apply(
render_colors, render_alphas, render_median, _median_ids = _RasterizeToPixels.apply(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are not returning median ids we might just remove it entirely?

means2d.contiguous(),
conics.contiguous(),
colors.contiguous(),
Expand All @@ -958,6 +967,8 @@ def rasterize_to_pixels(

if padded_channels > 0:
render_colors = render_colors[..., :-padded_channels]
if return_median:
return render_colors, render_alphas, render_median
return render_colors, render_alphas


Expand Down Expand Up @@ -1724,10 +1735,14 @@ def forward(
isect_offsets: Tensor, # [..., tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
absgrad: bool,
) -> Tuple[Tensor, Tensor]:
render_colors, render_alphas, last_ids = _make_lazy_cuda_func(
"rasterize_to_pixels_3dgs_fwd"
)(
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
(
render_colors,
render_alphas,
last_ids,
render_median,
median_ids,
) = _make_lazy_cuda_func("rasterize_to_pixels_3dgs_fwd")(
means2d,
conics,
colors,
Expand Down Expand Up @@ -1758,15 +1773,20 @@ def forward(
ctx.tile_size = tile_size
ctx.absgrad = absgrad

# Median depth outputs are forward-only (no backward).
ctx.mark_non_differentiable(render_median, median_ids)

# double to float
render_alphas = render_alphas.float()
return render_colors, render_alphas
return render_colors, render_alphas, render_median, median_ids

@staticmethod
def backward(
ctx,
v_render_colors: Tensor, # [..., H, W, 3]
v_render_alphas: Tensor, # [..., H, W, 1]
v_render_median: Tensor, # [..., H, W, 1], unused (non-differentiable)
v_median_ids: Tensor, # [..., H, W], unused (non-differentiable)
):
(
means2d,
Expand Down
17 changes: 14 additions & 3 deletions gsplat/cuda/csrc/Rasterization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ namespace gsplat {
// 3DGS
////////////////////////////////////////////////////

std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd(
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
rasterize_to_pixels_3dgs_fwd(
// Gaussian parameters
const at::Tensor &means2d, // [..., N, 2] or [nnz, 2]
const at::Tensor &conics, // [..., N, 3] or [nnz, 3]
Expand Down Expand Up @@ -85,6 +86,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd(
last_ids_dims.append({image_height, image_width});
at::Tensor last_ids = at::empty(last_ids_dims, opt.dtype(at::kInt));

at::DimVector median_ids_dims(image_dims);
median_ids_dims.append({image_height, image_width});
at::Tensor median_ids = at::empty(median_ids_dims, opt.dtype(at::kInt));

at::DimVector render_median_dims(image_dims);
render_median_dims.append({image_height, image_width, 1});
at::Tensor render_median = at::empty(render_median_dims, opt);

#define __LAUNCH_KERNEL__(N) \
case N: \
launch_rasterize_to_pixels_3dgs_fwd_kernel<N>( \
Expand All @@ -101,7 +110,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd(
flatten_ids, \
renders, \
alphas, \
last_ids \
last_ids, \
render_median, \
median_ids \
Comment on lines +114 to +115
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make them optional tensors -- to make sure we dont introduce any perf regression when median depth is not needed.

); \
break;

Expand All @@ -115,7 +126,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd(
}
#undef __LAUNCH_KERNEL__

return std::make_tuple(renders, alphas, last_ids);
return std::make_tuple(renders, alphas, last_ids, render_median, median_ids);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand Down
8 changes: 5 additions & 3 deletions gsplat/cuda/csrc/Rasterization.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel(
const at::Tensor tile_offsets, // [..., tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
// outputs
at::Tensor renders, // [..., image_height, image_width, channels]
at::Tensor alphas, // [..., image_height, image_width]
at::Tensor last_ids // [..., image_height, image_width]
at::Tensor renders, // [..., image_height, image_width, channels]
at::Tensor alphas, // [..., image_height, image_width]
at::Tensor last_ids, // [..., image_height, image_width]
at::Tensor render_median, // [..., image_height, image_width, 1]
at::Tensor median_ids // [..., image_height, image_width]
);

template <uint32_t CDIM>
Expand Down
53 changes: 47 additions & 6 deletions gsplat/cuda/csrc/RasterizeToPixels3DGSFwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel(
scalar_t
*__restrict__ render_colors, // [I, image_height, image_width, CDIM]
scalar_t *__restrict__ render_alphas, // [I, image_height, image_width, 1]
int32_t *__restrict__ last_ids // [I, image_height, image_width]
int32_t *__restrict__ last_ids, // [I, image_height, image_width]
scalar_t
*__restrict__ render_median, // [I, image_height, image_width, 1]
// depth of the Gaussian that causes
// transmittance to cross 0.5, or the
// last-hit depth as a fallback.
int32_t *__restrict__ median_ids // [I, image_height, image_width]
Comment on lines +63 to +68
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to optional: we want to allow them to be null ptr when not needed. and in the function we skip compute median depth and write out so we dont waste compute if they are not needed

) {
// each thread draws one pixel, but also timeshares caching gaussians in a
// shared tile
Expand All @@ -75,6 +81,8 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel(
render_colors += image_id * image_height * image_width * CDIM;
render_alphas += image_id * image_height * image_width;
last_ids += image_id * image_height * image_width;
render_median += image_id * image_height * image_width;
median_ids += image_id * image_height * image_width;
if (backgrounds != nullptr) {
backgrounds += image_id * CDIM;
}
Expand Down Expand Up @@ -131,6 +139,14 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel(
// index of most recent gaussian to write to this thread's pixel
uint32_t cur_idx = 0;

// median depth tracking: record the depth of the Gaussian whose
// contribution drops T across 0.5. If T never crosses 0.5 fall back to the
// last hit depth (0.0 if no hits at all).
float median_depth = 0.0f;
uint32_t median_idx = 0u;
float last_hit_depth = 0.0f;
bool median_found = false;

// collect and process batches of gaussians
// each thread loads one gaussian at a time before rasterizing its
// designated pixel
Expand Down Expand Up @@ -190,6 +206,17 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel(
}
cur_idx = batch_start + t;

// Track last hit depth (depth is stored as the last channel of
// colors, consistent with the existing D/ED plumbing and 2DGS).
const float dep = c_ptr[CDIM - 1];
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This silently assume render mode is in ED or D. Should assert this to protect in the rendering.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for the best performance, I think you can push this link and next into the if (!median_found && T > 0.5f && next_T <= 0.5f) { condition?

last_hit_depth = dep;
// Median depth: first Gaussian that causes T to drop below 0.5.
if (!median_found && T > 0.5f && next_T <= 0.5f) {
median_depth = dep;
median_idx = batch_start + t;
median_found = true;
}

T = next_T;
}
}
Expand All @@ -209,6 +236,14 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel(
}
// index in bin of last gaussian in this pixel
last_ids[pix_id] = static_cast<int32_t>(cur_idx);

// Median depth: if transmittance never crossed 0.5 fall back to the
// last contributing Gaussian's depth (0.0 if there were no hits).
if (!median_found) {
median_depth = last_hit_depth;
}
render_median[pix_id] = median_depth;
median_ids[pix_id] = static_cast<int32_t>(median_idx);
Comment on lines +240 to +246
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip write if median is not required (render_median is a null ptr)

}
}

Expand All @@ -229,9 +264,11 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel(
const at::Tensor tile_offsets, // [..., tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
// outputs
at::Tensor renders, // [..., image_height, image_width, channels]
at::Tensor alphas, // [..., image_height, image_width]
at::Tensor last_ids // [..., image_height, image_width]
at::Tensor renders, // [..., image_height, image_width, channels]
at::Tensor alphas, // [..., image_height, image_width]
at::Tensor last_ids, // [..., image_height, image_width]
at::Tensor render_median, // [..., image_height, image_width, 1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent on the naming convension in gsplat, suggested naming is median_depths

at::Tensor median_ids // [..., image_height, image_width]
) {
bool packed = means2d.dim() == 2;

Expand Down Expand Up @@ -286,7 +323,9 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel(
flatten_ids.data_ptr<int32_t>(),
renders.data_ptr<float>(),
alphas.data_ptr<float>(),
last_ids.data_ptr<int32_t>()
last_ids.data_ptr<int32_t>(),
render_median.data_ptr<float>(),
median_ids.data_ptr<int32_t>()
Comment on lines -289 to +328
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update for optional

);
}

Expand All @@ -308,7 +347,9 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel(
const at::Tensor flatten_ids, \
at::Tensor renders, \
at::Tensor alphas, \
at::Tensor last_ids \
at::Tensor last_ids, \
at::Tensor render_median, \
at::Tensor median_ids \
);

GSPLAT_FOR_EACH(__INS__, GSPLAT_NUM_CHANNELS)
Expand Down
2 changes: 1 addition & 1 deletion gsplat/cuda/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ TORCH_LIBRARY(gsplat, m) {
m.def("projection_ewa_3dgs_packed_fwd(Tensor means, Tensor? covars, Tensor? quats, Tensor? scales, Tensor? opacities, Tensor viewmats, Tensor Ks, int image_width, int image_height, float eps2d, float near_plane, float far_plane, float radius_clip, bool calc_compensations, int camera_model) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("projection_ewa_3dgs_packed_bwd(Tensor means, Tensor? covars, Tensor? quats, Tensor? scales, Tensor viewmats, Tensor Ks, int image_width, int image_height, float eps2d, int camera_model, Tensor batch_ids, Tensor camera_ids, Tensor gaussian_ids, Tensor conics, Tensor? compensations, Tensor v_means2d, Tensor v_depths, Tensor v_conics, Tensor? v_compensations, bool viewmats_requires_grad, bool sparse_grad) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");

m.def("rasterize_to_pixels_3dgs_fwd(Tensor means2d, Tensor conics, Tensor colors, Tensor opacities, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor, Tensor)");
m.def("rasterize_to_pixels_3dgs_fwd(Tensor means2d, Tensor conics, Tensor colors, Tensor opacities, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("rasterize_to_pixels_3dgs_bwd(Tensor means2d, Tensor conics, Tensor colors, Tensor opacities, Tensor? backgrounds, Tensor? masks, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids, Tensor render_alphas, Tensor last_ids, Tensor v_render_colors, Tensor v_render_alphas, bool absgrad) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("rasterize_to_indices_3dgs(int range_start, int range_end, Tensor transmittances, Tensor means2d, Tensor conics, Tensor opacities, int image_width, int image_height, int tile_size, Tensor tile_offsets, Tensor flatten_ids) -> (Tensor, Tensor)");
#endif
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/include/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ std::tuple<at::Tensor, at::Tensor> quat_scale_to_covar_preci_bwd(
);

// Rasterize 3D Gaussian to pixels
std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd(
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
rasterize_to_pixels_3dgs_fwd(
// Gaussian parameters
const at::Tensor &means2d, // [..., N, 2] or [nnz, 2]
const at::Tensor &conics, // [..., N, 3] or [nnz, 3]
Expand Down
Loading
Loading