Skip to content

Commit 8a5aca5

Browse files
committed
Simplify patch rendering
1 parent c2bb01a commit 8a5aca5

2 files changed

Lines changed: 10 additions & 18 deletions

File tree

diffdrr/drr.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,12 @@ def render(
194194
**kwargs,
195195
)
196196
else:
197-
n_points = target.shape[1] // self.n_patches
198197
partials = []
199-
for idx in range(self.n_patches):
200-
partial = self.renderer(
201-
density,
202-
source,
203-
target[:, idx * n_points : (idx + 1) * n_points],
204-
img[..., idx * n_points : (idx + 1) * n_points],
205-
**kwargs,
206-
)
198+
for t, i in zip(
199+
target.chunk(self.n_patches, dim=1),
200+
img.chunk(self.n_patches, dim=-1),
201+
):
202+
partial = self.renderer(density, source, t, i, **kwargs)
207203
partials.append(partial)
208204
img = torch.cat(partials, dim=-1)
209205

notebooks/api/00_drr.ipynb

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,12 @@
313313
" **kwargs,\n",
314314
" )\n",
315315
" else:\n",
316-
" n_points = target.shape[1] // self.n_patches\n",
317316
" partials = []\n",
318-
" for idx in range(self.n_patches):\n",
319-
" partial = self.renderer(\n",
320-
" density,\n",
321-
" source,\n",
322-
" target[:, idx * n_points : (idx + 1) * n_points],\n",
323-
" img[..., idx * n_points : (idx + 1) * n_points],\n",
324-
" **kwargs,\n",
325-
" )\n",
317+
" for t, i in zip(\n",
318+
" target.chunk(self.n_patches, dim=1),\n",
319+
" img.chunk(self.n_patches, dim=-1),\n",
320+
" ):\n",
321+
" partial = self.renderer(density, source, t, i, **kwargs)\n",
326322
" partials.append(partial)\n",
327323
" img = torch.cat(partials, dim=-1)\n",
328324
"\n",

0 commit comments

Comments
 (0)