99import h5py
1010import numpy as np
1111import torch
12+ from .utils import load_file
1213from torchio import LabelMap , ScalarImage , Subject
1314from torchio .transforms .preprocessing import ToCanonical
1415from torchvision .transforms .functional import center_crop , gaussian_blur
1516
1617from diffdrr .data import read
18+ from diffdrr .detector import parse_intrinsic_matrix
1719from diffdrr .pose import RigidTransform
18- from diffdrr .utils import parse_intrinsic_matrix
19-
20- from .utils import load_file
2120
2221# %% ../notebooks/00_deepfluoro.ipynb 6
2322class DeepFluoroDataset (torch .utils .data .Dataset ):
@@ -64,8 +63,8 @@ def __init__(
6463 self .flip_z = RigidTransform (
6564 torch .tensor (
6665 [
67- [0 , 1 , 0 , 0 ],
6866 [- 1 , 0 , 0 , 0 ],
67+ [0 , - 1 , 0 , 0 ],
6968 [0 , 0 , - 1 , 0 ],
7069 [0 , 0 , 0 , 1 ],
7170 ]
@@ -94,12 +93,10 @@ def __getitem__(self, idx):
9493 pose = self .projections [f"{ idx :03d} /gt-poses/cam-to-pelvis-vol" ][:]
9594 pose = RigidTransform (torch .from_numpy (pose ))
9695 pose = (
97- self .rot_180
98- .compose (self .flip_z )
96+ self .flip_z
9997 .compose (self .world2camera .inverse ())
10098 .compose (pose )
10199 .compose (self .anatomical2world )
102- .compose (self .rot_180 )
103100 )
104101 if self .rot_180_for_up (idx ):
105102 img = torch .rot90 (img , k = 2 )
@@ -121,11 +118,11 @@ def parse_volume(subject, bone_attenuation_multiplier, labels):
121118 # Get all parts of the volume
122119 volume = subject ["vol/pixels" ][:]
123120 volume = np .swapaxes (volume , 0 , 2 ).copy ()
124- volume = torch .from_numpy (volume ).unsqueeze (0 ).flip (1 ).flip (2 )
121+ volume = torch .from_numpy (volume ).unsqueeze (0 ) # .flip(1).flip(2)
125122
126123 mask = subject ["vol-seg/image/pixels" ][:]
127124 mask = np .swapaxes (mask , 0 , 2 ).copy ()
128- mask = torch .from_numpy (mask ).unsqueeze (0 ).flip (1 ).flip (2 )
125+ mask = torch .from_numpy (mask ).unsqueeze (0 ) # .flip(1).flip(2)
129126
130127 affine = np .eye (4 )
131128 affine [:3 , :3 ] = subject ["vol/dir-mat" ][:]
@@ -150,10 +147,10 @@ def parse_volume(subject, bone_attenuation_multiplier, labels):
150147 anatomical2world = RigidTransform (
151148 torch .tensor (
152149 [
153- [1.0 , 0.0 , 0. 0 , - isocenter [0 ]],
154- [0.0 , 1.0 , 0. 0 , - isocenter [1 ]],
155- [0.0 , 0.0 , 1.0 , - isocenter [2 ]],
156- [0.0 , 0.0 , 0.0 , 1.0 ],
150+ [1 , 0 , 0 , - isocenter [0 ]],
151+ [0 , 1 , 0 , - isocenter [1 ]],
152+ [0 , 0 , 1 , - isocenter [2 ]],
153+ [0 , 0 , 0 , 1 ],
157154 ],
158155 dtype = torch .float32 ,
159156 )
@@ -169,24 +166,22 @@ def parse_volume(subject, bone_attenuation_multiplier, labels):
169166 label_def = defns ,
170167 fiducials = fiducials ,
171168 )
172- reorient = RigidTransform (torch .diag (torch .tensor ([- 1.0 , - 1.0 , 1.0 , 1.0 ])))
173- subject .fiducials = reorient (subject .fiducials )
174169
175170 return subject , anatomical2world
176171
177172
178173def parse_proj_params (f ):
179174 proj_params = f ["proj-params" ]
180175 extrinsic = torch .from_numpy (proj_params ["extrinsic" ][:])
181- world2camera = RigidTransform (extrinsic )
176+ camera2world = RigidTransform (extrinsic )
182177 intrinsic = torch .from_numpy (proj_params ["intrinsic" ][:])
183178 num_cols = proj_params ["num-cols" ][()]
184179 num_rows = proj_params ["num-rows" ][()]
185180 proj_col_spacing = float (proj_params ["pixel-col-spacing" ][()])
186181 proj_row_spacing = float (proj_params ["pixel-row-spacing" ][()])
187182 return (
188183 intrinsic ,
189- world2camera ,
184+ camera2world ,
190185 num_cols ,
191186 num_rows ,
192187 proj_col_spacing ,
@@ -200,7 +195,7 @@ def load(id_number, bone_attenuation_multiplier, labels):
200195 # Load dataset parameters
201196 (
202197 intrinsic ,
203- world2camera ,
198+ camera2world ,
204199 num_cols ,
205200 num_rows ,
206201 proj_col_spacing ,
@@ -227,13 +222,15 @@ def load(id_number, bone_attenuation_multiplier, labels):
227222 ][id_number - 1 ]
228223 subject = f [subject_id ]
229224 projections = subject ["projections" ]
230- subject , anatomical2world = parse_volume (subject , bone_attenuation_multiplier , labels )
225+ subject , anatomical2world = parse_volume (
226+ subject , bone_attenuation_multiplier , labels
227+ )
231228
232229 return (
233230 subject ,
234231 projections ,
235232 anatomical2world ,
236- world2camera ,
233+ camera2world ,
237234 focal_len ,
238235 int (num_rows ),
239236 int (num_cols ),
0 commit comments