Skip to content

Commit f77697b

Browse files
authored
Simplify patch rendering (#380)
* Make filter_intersections_outside_volume False by default * Simplify patch rendering
1 parent 37c96ec commit f77697b

4 files changed

Lines changed: 14 additions & 22 deletions

File tree

diffdrr/drr.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,12 @@ def render(
194194
**kwargs,
195195
)
196196
else:
197-
n_points = target.shape[1] // self.n_patches
198197
partials = []
199-
for idx in range(self.n_patches):
200-
partial = self.renderer(
201-
density,
202-
source,
203-
target[:, idx * n_points : (idx + 1) * n_points],
204-
img[..., idx * n_points : (idx + 1) * n_points],
205-
**kwargs,
206-
)
198+
for t, i in zip(
199+
target.chunk(self.n_patches, dim=1),
200+
img.chunk(self.n_patches, dim=-1),
201+
):
202+
partial = self.renderer(density, source, t, i, **kwargs)
207203
partials.append(partial)
208204
img = torch.cat(partials, dim=-1)
209205

diffdrr/renderers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
mode: str = "nearest", # Interpolation mode for grid_sample
1717
stop_gradients_through_grid_sample: bool = False, # Apply torch.no_grad when calling grid_sample
18-
filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections
18+
filter_intersections_outside_volume: bool = False, # Use alphamin/max to filter the intersections
1919
reducefn: str = "sum", # Function for combining samples along each ray
2020
eps: float = 1e-8, # Small constant to avoid div by zero errors
2121
):
@@ -51,7 +51,7 @@ def forward(
5151

5252
# Calculate the midpoint of every pair of adjacent intersections
5353
# These midpoints lie exclusively in a single voxel
54-
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
54+
alphamid = (alphas[..., :-1] + alphas[..., 1:]) / 2
5555

5656
# Get the XYZ coordinate of each midpoint (normalized to [-1, +1]^3)
5757
xyzs = _get_xyzs(alphamid, source, target, dims, self.eps)

notebooks/api/00_drr.ipynb

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,12 @@
313313
" **kwargs,\n",
314314
" )\n",
315315
" else:\n",
316-
" n_points = target.shape[1] // self.n_patches\n",
317316
" partials = []\n",
318-
" for idx in range(self.n_patches):\n",
319-
" partial = self.renderer(\n",
320-
" density,\n",
321-
" source,\n",
322-
" target[:, idx * n_points : (idx + 1) * n_points],\n",
323-
" img[..., idx * n_points : (idx + 1) * n_points],\n",
324-
" **kwargs,\n",
325-
" )\n",
317+
" for t, i in zip(\n",
318+
" target.chunk(self.n_patches, dim=1),\n",
319+
" img.chunk(self.n_patches, dim=-1),\n",
320+
" ):\n",
321+
" partial = self.renderer(density, source, t, i, **kwargs)\n",
326322
" partials.append(partial)\n",
327323
" img = torch.cat(partials, dim=-1)\n",
328324
"\n",

notebooks/api/01_renderers.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
" self,\n",
118118
" mode: str = \"nearest\", # Interpolation mode for grid_sample\n",
119119
" stop_gradients_through_grid_sample: bool = False, # Apply torch.no_grad when calling grid_sample\n",
120-
" filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections\n",
120+
" filter_intersections_outside_volume: bool = False, # Use alphamin/max to filter the intersections\n",
121121
" reducefn: str = \"sum\", # Function for combining samples along each ray\n",
122122
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
123123
" ):\n",
@@ -153,7 +153,7 @@
153153
"\n",
154154
" # Calculate the midpoint of every pair of adjacent intersections\n",
155155
" # These midpoints lie exclusively in a single voxel\n",
156-
" alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2\n",
156+
" alphamid = (alphas[..., :-1] + alphas[..., 1:]) / 2\n",
157157
"\n",
158158
" # Get the XYZ coordinate of each midpoint (normalized to [-1, +1]^3)\n",
159159
" xyzs = _get_xyzs(alphamid, source, target, dims, self.eps)\n",

0 commit comments

Comments
 (0)