Skip to content

Commit 0c67edc

Browse files
committed
Implement grid-splatting reconstruction
1 parent 43c1c2c commit 0c67edc

2 files changed

Lines changed: 319 additions & 0 deletions

File tree

diffdrr/reconstruction.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/09_reconstruction.ipynb.
2+
3+
# %% auto #0
4+
__all__ = ['grid_splat', 'splat_points_nearest', 'splat_points_trilinear']
5+
6+
# %% ../notebooks/api/09_reconstruction.ipynb #d1a04243-4893-4467-af3f-9a7492e301bb
7+
import torch
8+
9+
10+
def grid_splat(
11+
points, # 3D coordinates in voxel space [0, D-1] x [0, H-1] x [0, W-1]
12+
values, # Intensity values for each point - (N,)
13+
size: tuple[int, int, int], # Dimensions of output volume - (D, H, W)
14+
mode: str = "bilinear", # "nearest" or "bilinear" (trilinear in 3D)
15+
):
16+
"""Splat 3D points into a volume (inverse of grid_sample)."""
17+
if mode == "nearest":
18+
return splat_points_nearest(points, values, size)
19+
elif mode == "bilinear":
20+
return splat_points_trilinear(points, values, size)
21+
else:
22+
raise ValueError(f"Unsupported mode: {mode}. Use 'nearest' or 'bilinear'")
23+
24+
25+
def splat_points_nearest(points, values, size):
26+
"""Splat 3D points using nearest-neighbor interpolation."""
27+
device = points.device
28+
D, H, W = size
29+
30+
# Initialize output
31+
volume = torch.zeros(D, H, W, device=device)
32+
counts = torch.zeros(D, H, W, device=device)
33+
34+
# Round to nearest voxel
35+
i = torch.round(points[:, 0]).long()
36+
j = torch.round(points[:, 1]).long()
37+
k = torch.round(points[:, 2]).long()
38+
39+
# Filter valid points (inside volume)
40+
valid = (i >= 0) & (i < D) & (j >= 0) & (j < H) & (k >= 0) & (k < W)
41+
42+
if not valid.any():
43+
return volume
44+
45+
# Apply filter
46+
i, j, k = i[valid], j[valid], k[valid]
47+
vals = values[valid]
48+
49+
# Accumulate values
50+
volume.index_put_((i, j, k), vals, accumulate=True)
51+
counts.index_put_((i, j, k), torch.ones_like(vals), accumulate=True)
52+
53+
# Normalize by count
54+
volume = volume / (counts + 1e-8)
55+
56+
return volume
57+
58+
59+
def splat_points_trilinear(points, values, size):
60+
"""Splat 3D points using trilinear interpolation."""
61+
device = points.device
62+
D, H, W = size
63+
64+
# Initialize output
65+
volume = torch.zeros(D, H, W, device=device)
66+
weights = torch.zeros(D, H, W, device=device)
67+
68+
# Get integer voxel indices (floor)
69+
i0 = torch.floor(points[:, 0]).long()
70+
j0 = torch.floor(points[:, 1]).long()
71+
k0 = torch.floor(points[:, 2]).long()
72+
73+
i1 = i0 + 1
74+
j1 = j0 + 1
75+
k1 = k0 + 1
76+
77+
# Compute fractional parts for interpolation
78+
fi = points[:, 0] - i0.float()
79+
fj = points[:, 1] - j0.float()
80+
fk = points[:, 2] - k0.float()
81+
82+
# Filter valid points (inside volume)
83+
valid = (i0 >= 0) & (i1 < D) & (j0 >= 0) & (j1 < H) & (k0 >= 0) & (k1 < W)
84+
85+
if not valid.any():
86+
return volume
87+
88+
# Apply filter
89+
i0, i1 = i0[valid], i1[valid]
90+
j0, j1 = j0[valid], j1[valid]
91+
k0, k1 = k0[valid], k1[valid]
92+
fi, fj, fk = fi[valid], fj[valid], fk[valid]
93+
vals = values[valid]
94+
95+
# Compute 8 corner weights (trilinear interpolation weights)
96+
w000 = (1 - fi) * (1 - fj) * (1 - fk)
97+
w001 = (1 - fi) * (1 - fj) * fk
98+
w010 = (1 - fi) * fj * (1 - fk)
99+
w011 = (1 - fi) * fj * fk
100+
w100 = fi * (1 - fj) * (1 - fk)
101+
w101 = fi * (1 - fj) * fk
102+
w110 = fi * fj * (1 - fk)
103+
w111 = fi * fj * fk
104+
105+
# Splat to 8 neighboring voxels
106+
volume.index_put_((i0, j0, k0), vals * w000, accumulate=True)
107+
volume.index_put_((i0, j0, k1), vals * w001, accumulate=True)
108+
volume.index_put_((i0, j1, k0), vals * w010, accumulate=True)
109+
volume.index_put_((i0, j1, k1), vals * w011, accumulate=True)
110+
volume.index_put_((i1, j0, k0), vals * w100, accumulate=True)
111+
volume.index_put_((i1, j0, k1), vals * w101, accumulate=True)
112+
volume.index_put_((i1, j1, k0), vals * w110, accumulate=True)
113+
volume.index_put_((i1, j1, k1), vals * w111, accumulate=True)
114+
115+
# Accumulate weights for normalization
116+
weights.index_put_((i0, j0, k0), w000, accumulate=True)
117+
weights.index_put_((i0, j0, k1), w001, accumulate=True)
118+
weights.index_put_((i0, j1, k0), w010, accumulate=True)
119+
weights.index_put_((i0, j1, k1), w011, accumulate=True)
120+
weights.index_put_((i1, j0, k0), w100, accumulate=True)
121+
weights.index_put_((i1, j0, k1), w101, accumulate=True)
122+
weights.index_put_((i1, j1, k0), w110, accumulate=True)
123+
weights.index_put_((i1, j1, k1), w111, accumulate=True)
124+
125+
# Normalize by total weight at each voxel
126+
volume = volume / (weights + 1e-8)
127+
128+
return volume
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "raw",
5+
"id": "635db9c7-4e76-4c79-a9b9-e7909cf710e8",
6+
"metadata": {},
7+
"source": [
8+
"---\n",
9+
"title: reconstruction\n",
10+
"subtitle: Differentiable backprojection methods\n",
11+
"skip_exec: true\n",
12+
"---"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"id": "c5c3c45a-bb72-4e36-8655-758ed18aba65",
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"#| default_exp reconstruction"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"id": "8d1e3c62-8219-4bed-8739-31eef36d9238",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"#| hide\n",
33+
"from nbdev.showdoc import *"
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"id": "d1a04243-4893-4467-af3f-9a7492e301bb",
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"#| export\n",
44+
"import torch\n",
45+
"\n",
46+
"\n",
47+
"def grid_splat(\n",
48+
" points, # 3D coordinates in voxel space [0, D-1] x [0, H-1] x [0, W-1]\n",
49+
" values, # Intensity values for each point - (N,)\n",
50+
" size: tuple[int, int, int], # Dimensions of output volume - (D, H, W)\n",
51+
" mode: str = \"bilinear\", # \"nearest\" or \"bilinear\" (trilinear in 3D)\n",
52+
"):\n",
53+
" \"\"\"Splat 3D points into a volume (inverse of grid_sample).\"\"\"\n",
54+
" if mode == \"nearest\":\n",
55+
" return splat_points_nearest(points, values, size)\n",
56+
" elif mode == \"bilinear\":\n",
57+
" return splat_points_trilinear(points, values, size)\n",
58+
" else:\n",
59+
" raise ValueError(f\"Unsupported mode: {mode}. Use 'nearest' or 'bilinear'\")\n",
60+
"\n",
61+
"\n",
62+
"def splat_points_nearest(points, values, size):\n",
63+
" \"\"\"Splat 3D points using nearest-neighbor interpolation.\"\"\"\n",
64+
" device = points.device\n",
65+
" D, H, W = size\n",
66+
"\n",
67+
" # Initialize output\n",
68+
" volume = torch.zeros(D, H, W, device=device)\n",
69+
" counts = torch.zeros(D, H, W, device=device)\n",
70+
"\n",
71+
" # Round to nearest voxel\n",
72+
" i = torch.round(points[:, 0]).long()\n",
73+
" j = torch.round(points[:, 1]).long()\n",
74+
" k = torch.round(points[:, 2]).long()\n",
75+
"\n",
76+
" # Filter valid points (inside volume)\n",
77+
" valid = (i >= 0) & (i < D) & (j >= 0) & (j < H) & (k >= 0) & (k < W)\n",
78+
"\n",
79+
" if not valid.any():\n",
80+
" return volume\n",
81+
"\n",
82+
" # Apply filter\n",
83+
" i, j, k = i[valid], j[valid], k[valid]\n",
84+
" vals = values[valid]\n",
85+
"\n",
86+
" # Accumulate values\n",
87+
" volume.index_put_((i, j, k), vals, accumulate=True)\n",
88+
" counts.index_put_((i, j, k), torch.ones_like(vals), accumulate=True)\n",
89+
"\n",
90+
" # Normalize by count\n",
91+
" volume = volume / (counts + 1e-8)\n",
92+
"\n",
93+
" return volume\n",
94+
"\n",
95+
"\n",
96+
"def splat_points_trilinear(points, values, size):\n",
97+
" \"\"\"Splat 3D points using trilinear interpolation.\"\"\"\n",
98+
" device = points.device\n",
99+
" D, H, W = size\n",
100+
"\n",
101+
" # Initialize output\n",
102+
" volume = torch.zeros(D, H, W, device=device)\n",
103+
" weights = torch.zeros(D, H, W, device=device)\n",
104+
"\n",
105+
" # Get integer voxel indices (floor)\n",
106+
" i0 = torch.floor(points[:, 0]).long()\n",
107+
" j0 = torch.floor(points[:, 1]).long()\n",
108+
" k0 = torch.floor(points[:, 2]).long()\n",
109+
"\n",
110+
" i1 = i0 + 1\n",
111+
" j1 = j0 + 1\n",
112+
" k1 = k0 + 1\n",
113+
"\n",
114+
" # Compute fractional parts for interpolation\n",
115+
" fi = points[:, 0] - i0.float()\n",
116+
" fj = points[:, 1] - j0.float()\n",
117+
" fk = points[:, 2] - k0.float()\n",
118+
"\n",
119+
" # Filter valid points (inside volume)\n",
120+
" valid = (i0 >= 0) & (i1 < D) & (j0 >= 0) & (j1 < H) & (k0 >= 0) & (k1 < W)\n",
121+
"\n",
122+
" if not valid.any():\n",
123+
" return volume\n",
124+
"\n",
125+
" # Apply filter\n",
126+
" i0, i1 = i0[valid], i1[valid]\n",
127+
" j0, j1 = j0[valid], j1[valid]\n",
128+
" k0, k1 = k0[valid], k1[valid]\n",
129+
" fi, fj, fk = fi[valid], fj[valid], fk[valid]\n",
130+
" vals = values[valid]\n",
131+
"\n",
132+
" # Compute 8 corner weights (trilinear interpolation weights)\n",
133+
" w000 = (1 - fi) * (1 - fj) * (1 - fk)\n",
134+
" w001 = (1 - fi) * (1 - fj) * fk\n",
135+
" w010 = (1 - fi) * fj * (1 - fk)\n",
136+
" w011 = (1 - fi) * fj * fk\n",
137+
" w100 = fi * (1 - fj) * (1 - fk)\n",
138+
" w101 = fi * (1 - fj) * fk\n",
139+
" w110 = fi * fj * (1 - fk)\n",
140+
" w111 = fi * fj * fk\n",
141+
"\n",
142+
" # Splat to 8 neighboring voxels\n",
143+
" volume.index_put_((i0, j0, k0), vals * w000, accumulate=True)\n",
144+
" volume.index_put_((i0, j0, k1), vals * w001, accumulate=True)\n",
145+
" volume.index_put_((i0, j1, k0), vals * w010, accumulate=True)\n",
146+
" volume.index_put_((i0, j1, k1), vals * w011, accumulate=True)\n",
147+
" volume.index_put_((i1, j0, k0), vals * w100, accumulate=True)\n",
148+
" volume.index_put_((i1, j0, k1), vals * w101, accumulate=True)\n",
149+
" volume.index_put_((i1, j1, k0), vals * w110, accumulate=True)\n",
150+
" volume.index_put_((i1, j1, k1), vals * w111, accumulate=True)\n",
151+
"\n",
152+
" # Accumulate weights for normalization\n",
153+
" weights.index_put_((i0, j0, k0), w000, accumulate=True)\n",
154+
" weights.index_put_((i0, j0, k1), w001, accumulate=True)\n",
155+
" weights.index_put_((i0, j1, k0), w010, accumulate=True)\n",
156+
" weights.index_put_((i0, j1, k1), w011, accumulate=True)\n",
157+
" weights.index_put_((i1, j0, k0), w100, accumulate=True)\n",
158+
" weights.index_put_((i1, j0, k1), w101, accumulate=True)\n",
159+
" weights.index_put_((i1, j1, k0), w110, accumulate=True)\n",
160+
" weights.index_put_((i1, j1, k1), w111, accumulate=True)\n",
161+
"\n",
162+
" # Normalize by total weight at each voxel\n",
163+
" volume = volume / (weights + 1e-8)\n",
164+
"\n",
165+
" return volume"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"id": "d1701f75-7fa5-485e-add9-54746a5ee47b",
172+
"metadata": {},
173+
"outputs": [],
174+
"source": [
175+
"#| hide\n",
176+
"import nbdev\n",
177+
"\n",
178+
"nbdev.nbdev_export()"
179+
]
180+
}
181+
],
182+
"metadata": {
183+
"kernelspec": {
184+
"display_name": "python3",
185+
"language": "python",
186+
"name": "python3"
187+
}
188+
},
189+
"nbformat": 4,
190+
"nbformat_minor": 5
191+
}

0 commit comments

Comments
 (0)