-
Notifications
You must be signed in to change notification settings - Fork 828
Added median depth rasterization for 3dgs [fwd only] #930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -840,7 +840,8 @@ def rasterize_to_pixels( | |
| masks: Optional[Tensor] = None, # [..., tile_height, tile_width] | ||
| packed: bool = False, | ||
| absgrad: bool = False, | ||
| ) -> Tuple[Tensor, Tensor]: | ||
| return_median: bool = False, | ||
| ) -> Tuple[Tensor, ...]: | ||
| """Rasterizes Gaussians to pixels. | ||
|
|
||
| Args: | ||
|
|
@@ -857,12 +858,20 @@ def rasterize_to_pixels( | |
| masks: Optional tile mask to skip rendering GS to masked tiles. [..., tile_height, tile_width]. Default: None. | ||
| packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. | ||
| absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. | ||
| return_median: If True, additionally returns per-pixel median depth, defined | ||
| as the value of the last channel of ``colors`` for the Gaussian whose | ||
| contribution causes the accumulated transmittance to drop across 0.5. | ||
| If the threshold is never crossed the last contributing Gaussian's value | ||
| is returned; if no Gaussian hits the pixel the value is 0. This output | ||
| is non-differentiable. Default: False. | ||
|
|
||
| Returns: | ||
| A tuple: | ||
|
|
||
| - **Rendered colors**. [..., image_height, image_width, channels] | ||
| - **Rendered alphas**. [..., image_height, image_width, 1] | ||
| - **Rendered median depth** (only when ``return_median=True``). | ||
| [..., image_height, image_width, 1] | ||
| """ | ||
|
|
||
| image_dims = means2d.shape[:-2] | ||
|
|
@@ -941,7 +950,7 @@ def rasterize_to_pixels( | |
| tile_width * tile_size >= image_width | ||
| ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" | ||
|
|
||
| render_colors, render_alphas = _RasterizeToPixels.apply( | ||
| render_colors, render_alphas, render_median, _median_ids = _RasterizeToPixels.apply( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we are not returning |
||
| means2d.contiguous(), | ||
| conics.contiguous(), | ||
| colors.contiguous(), | ||
|
|
@@ -958,6 +967,8 @@ def rasterize_to_pixels( | |
|
|
||
| if padded_channels > 0: | ||
| render_colors = render_colors[..., :-padded_channels] | ||
| if return_median: | ||
| return render_colors, render_alphas, render_median | ||
| return render_colors, render_alphas | ||
|
|
||
|
|
||
|
|
@@ -1724,10 +1735,14 @@ def forward( | |
| isect_offsets: Tensor, # [..., tile_height, tile_width] | ||
| flatten_ids: Tensor, # [n_isects] | ||
| absgrad: bool, | ||
| ) -> Tuple[Tensor, Tensor]: | ||
| render_colors, render_alphas, last_ids = _make_lazy_cuda_func( | ||
| "rasterize_to_pixels_3dgs_fwd" | ||
| )( | ||
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | ||
| ( | ||
| render_colors, | ||
| render_alphas, | ||
| last_ids, | ||
| render_median, | ||
| median_ids, | ||
| ) = _make_lazy_cuda_func("rasterize_to_pixels_3dgs_fwd")( | ||
| means2d, | ||
| conics, | ||
| colors, | ||
|
|
@@ -1758,15 +1773,20 @@ def forward( | |
| ctx.tile_size = tile_size | ||
| ctx.absgrad = absgrad | ||
|
|
||
| # Median depth outputs are forward-only (no backward). | ||
| ctx.mark_non_differentiable(render_median, median_ids) | ||
|
|
||
| # double to float | ||
| render_alphas = render_alphas.float() | ||
| return render_colors, render_alphas | ||
| return render_colors, render_alphas, render_median, median_ids | ||
|
|
||
| @staticmethod | ||
| def backward( | ||
| ctx, | ||
| v_render_colors: Tensor, # [..., H, W, 3] | ||
| v_render_alphas: Tensor, # [..., H, W, 1] | ||
| v_render_median: Tensor, # [..., H, W, 1], unused (non-differentiable) | ||
| v_median_ids: Tensor, # [..., H, W], unused (non-differentiable) | ||
| ): | ||
| ( | ||
| means2d, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,7 +39,8 @@ namespace gsplat { | |
| // 3DGS | ||
| //////////////////////////////////////////////////// | ||
|
|
||
| std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd( | ||
| std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> | ||
| rasterize_to_pixels_3dgs_fwd( | ||
| // Gaussian parameters | ||
| const at::Tensor &means2d, // [..., N, 2] or [nnz, 2] | ||
| const at::Tensor &conics, // [..., N, 3] or [nnz, 3] | ||
|
|
@@ -85,6 +86,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd( | |
| last_ids_dims.append({image_height, image_width}); | ||
| at::Tensor last_ids = at::empty(last_ids_dims, opt.dtype(at::kInt)); | ||
|
|
||
| at::DimVector median_ids_dims(image_dims); | ||
| median_ids_dims.append({image_height, image_width}); | ||
| at::Tensor median_ids = at::empty(median_ids_dims, opt.dtype(at::kInt)); | ||
|
|
||
| at::DimVector render_median_dims(image_dims); | ||
| render_median_dims.append({image_height, image_width, 1}); | ||
| at::Tensor render_median = at::empty(render_median_dims, opt); | ||
|
|
||
| #define __LAUNCH_KERNEL__(N) \ | ||
| case N: \ | ||
| launch_rasterize_to_pixels_3dgs_fwd_kernel<N>( \ | ||
|
|
@@ -101,7 +110,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd( | |
| flatten_ids, \ | ||
| renders, \ | ||
| alphas, \ | ||
| last_ids \ | ||
| last_ids, \ | ||
| render_median, \ | ||
| median_ids \ | ||
|
Comment on lines
+114
to
+115
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should make them optional tensors -- to make sure we dont introduce any perf regression when median depth is not needed. |
||
| ); \ | ||
| break; | ||
|
|
||
|
|
@@ -115,7 +126,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_3dgs_fwd( | |
| } | ||
| #undef __LAUNCH_KERNEL__ | ||
|
|
||
| return std::make_tuple(renders, alphas, last_ids); | ||
| return std::make_tuple(renders, alphas, last_ids, render_median, median_ids); | ||
| } | ||
|
|
||
| std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,7 +59,13 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel( | |
| scalar_t | ||
| *__restrict__ render_colors, // [I, image_height, image_width, CDIM] | ||
| scalar_t *__restrict__ render_alphas, // [I, image_height, image_width, 1] | ||
| int32_t *__restrict__ last_ids // [I, image_height, image_width] | ||
| int32_t *__restrict__ last_ids, // [I, image_height, image_width] | ||
| scalar_t | ||
| *__restrict__ render_median, // [I, image_height, image_width, 1] | ||
| // depth of the Gaussian that causes | ||
| // transmittance to cross 0.5, or the | ||
| // last-hit depth as a fallback. | ||
| int32_t *__restrict__ median_ids // [I, image_height, image_width] | ||
|
Comment on lines
+63
to
+68
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related to optional: we want to allow them to be null ptr when not needed. and in the function we skip compute median depth and write out so we dont waste compute if they are not needed |
||
| ) { | ||
| // each thread draws one pixel, but also timeshares caching gaussians in a | ||
| // shared tile | ||
|
|
@@ -75,6 +81,8 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel( | |
| render_colors += image_id * image_height * image_width * CDIM; | ||
| render_alphas += image_id * image_height * image_width; | ||
| last_ids += image_id * image_height * image_width; | ||
| render_median += image_id * image_height * image_width; | ||
| median_ids += image_id * image_height * image_width; | ||
| if (backgrounds != nullptr) { | ||
| backgrounds += image_id * CDIM; | ||
| } | ||
|
|
@@ -131,6 +139,14 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel( | |
| // index of most recent gaussian to write to this thread's pixel | ||
| uint32_t cur_idx = 0; | ||
|
|
||
| // median depth tracking: record the depth of the Gaussian whose | ||
| // contribution drops T across 0.5. If T never crosses 0.5 fall back to the | ||
| // last hit depth (0.0 if no hits at all). | ||
| float median_depth = 0.0f; | ||
| uint32_t median_idx = 0u; | ||
| float last_hit_depth = 0.0f; | ||
| bool median_found = false; | ||
|
|
||
| // collect and process batches of gaussians | ||
| // each thread loads one gaussian at a time before rasterizing its | ||
| // designated pixel | ||
|
|
@@ -190,6 +206,17 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel( | |
| } | ||
| cur_idx = batch_start + t; | ||
|
|
||
| // Track last hit depth (depth is stored as the last channel of | ||
| // colors, consistent with the existing D/ED plumbing and 2DGS). | ||
| const float dep = c_ptr[CDIM - 1]; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This silently assume render mode is in ED or D. Should assert this to protect in the rendering.py
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also for the best performance, I think you can push this link and next into the |
||
| last_hit_depth = dep; | ||
| // Median depth: first Gaussian that causes T to drop below 0.5. | ||
| if (!median_found && T > 0.5f && next_T <= 0.5f) { | ||
| median_depth = dep; | ||
| median_idx = batch_start + t; | ||
| median_found = true; | ||
| } | ||
|
|
||
| T = next_T; | ||
| } | ||
| } | ||
|
|
@@ -209,6 +236,14 @@ __global__ void rasterize_to_pixels_3dgs_fwd_kernel( | |
| } | ||
| // index in bin of last gaussian in this pixel | ||
| last_ids[pix_id] = static_cast<int32_t>(cur_idx); | ||
|
|
||
| // Median depth: if transmittance never crossed 0.5 fall back to the | ||
| // last contributing Gaussian's depth (0.0 if there were no hits). | ||
| if (!median_found) { | ||
| median_depth = last_hit_depth; | ||
| } | ||
| render_median[pix_id] = median_depth; | ||
| median_ids[pix_id] = static_cast<int32_t>(median_idx); | ||
|
Comment on lines
+240
to
+246
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skip write if median is not required ( |
||
| } | ||
| } | ||
|
|
||
|
|
@@ -229,9 +264,11 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel( | |
| const at::Tensor tile_offsets, // [..., tile_height, tile_width] | ||
| const at::Tensor flatten_ids, // [n_isects] | ||
| // outputs | ||
| at::Tensor renders, // [..., image_height, image_width, channels] | ||
| at::Tensor alphas, // [..., image_height, image_width] | ||
| at::Tensor last_ids // [..., image_height, image_width] | ||
| at::Tensor renders, // [..., image_height, image_width, channels] | ||
| at::Tensor alphas, // [..., image_height, image_width] | ||
| at::Tensor last_ids, // [..., image_height, image_width] | ||
| at::Tensor render_median, // [..., image_height, image_width, 1] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be consistent on the naming convension in gsplat, suggested naming is |
||
| at::Tensor median_ids // [..., image_height, image_width] | ||
| ) { | ||
| bool packed = means2d.dim() == 2; | ||
|
|
||
|
|
@@ -286,7 +323,9 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel( | |
| flatten_ids.data_ptr<int32_t>(), | ||
| renders.data_ptr<float>(), | ||
| alphas.data_ptr<float>(), | ||
| last_ids.data_ptr<int32_t>() | ||
| last_ids.data_ptr<int32_t>(), | ||
| render_median.data_ptr<float>(), | ||
| median_ids.data_ptr<int32_t>() | ||
|
Comment on lines
-289
to
+328
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update for optional |
||
| ); | ||
| } | ||
|
|
||
|
|
@@ -308,7 +347,9 @@ void launch_rasterize_to_pixels_3dgs_fwd_kernel( | |
| const at::Tensor flatten_ids, \ | ||
| at::Tensor renders, \ | ||
| at::Tensor alphas, \ | ||
| at::Tensor last_ids \ | ||
| at::Tensor last_ids, \ | ||
| at::Tensor render_median, \ | ||
| at::Tensor median_ids \ | ||
| ); | ||
|
|
||
| GSPLAT_FOR_EACH(__INS__, GSPLAT_NUM_CHANNELS) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"the value of the last channel of
colors" is weird to read. Maybe just say something like "expect the last channel to be depth"