Skip to content

Commit 267153e

Browse files
authored
Merge pull request #110 from Hendrik-code/development_robert
Speed up nnUnet inference
2 parents 3906165 + a4d1e6d commit 267153e

18 files changed

Lines changed: 986 additions & 104 deletions

TPTBox/core/bids_files.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def enumerate_subjects(self, sort: bool = False, shuffle: bool = False) -> list[
486486
return s
487487
return self.subjects.items() # type: ignore
488488

489-
def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container]]:
489+
def iter_subjects(self, sort: bool = False, shuffle: bool = False) -> list[tuple[str, Subject_Container]]:
490490
"""Iterate over all subjects (alias for :meth:`enumerate_subjects` without shuffle).
491491
492492
Args:
@@ -498,6 +498,10 @@ def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container
498498
"""
499499
if sort:
500500
return sorted(self.subjects.items())
501+
if shuffle:
502+
s = list(self.subjects.items())
503+
random.shuffle(s)
504+
return s
501505
return self.subjects.items() # type: ignore
502506

503507
def __len__(self):

TPTBox/core/nii_wrapper.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,33 +1180,36 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN
11801180
if mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0):
11811181
log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose)
11821182
return self if inplace else self.copy()
1183-
11841183
m1 = mapping if mapping.orientation == self.orientation else mapping.make_empty_POI().reorient(self.orientation)
11851184
if m1.assert_affine(self,raise_error=False,origin_tolerance=0.00001,error_tolerance=0.00001,shape_tolerance=0):
11861185
log.print(f"resample_from_to only need reorientation; {self.orientation}",verbose=verbose)
11871186
ret = self.reorient(mapping.orientation,inplace=inplace)
11881187
ret.affine = mapping.affine #remove floating point error
11891188
return ret
1190-
if self.orientation == mapping.orientation and np.allclose(self.zoom , mapping.zoom, atol=1e-6):
1191-
shift = (np.array(self.origin) - np.array(m1.origin)) / np.array(m1.zoom)
1192-
if np.allclose(shift, np.round(shift), atol=1e-6):
1193-
s = self.reorient(mapping.orientation,inplace=inplace) # noqa: PLW0642
1194-
shift = (np.array(self.origin) - np.array(mapping.origin)) / np.array(mapping.zoom)
1195-
shift = np.round(shift).astype(int)
1196-
dst_shape = np.array(mapping.shape)
1197-
src_shape = np.array(s.shape)
1198-
# padding before = how much dst starts before src
1199-
pad_before = shift
1200-
# padding after = remaining dst size after src
1201-
pad_after = dst_shape-shift-src_shape
1202-
pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after))
1203-
ret = s.apply_pad(pad, mode=mode,inplace=inplace,verbose=verbose)
1204-
1189+
if np.allclose(self.zoom, m1.zoom, atol=1e-6):
1190+
s = self.reorient(mapping.orientation, inplace=inplace)
1191+
# Compute voxel offset directly from the affines after both
1192+
# images are in the same orientation. This is robust to axis
1193+
# permutations and flips.
1194+
voxel_offset = np.linalg.inv(mapping.affine) @ s.affine @ np.array([0, 0, 0, 1])
1195+
shift = np.round(voxel_offset[:3]).astype(int)
1196+
1197+
dst_shape = np.array(mapping.shape)
1198+
src_shape = np.array(s.shape)
1199+
# padding before = how much dst starts before src
1200+
pad_before = shift
1201+
# padding after = remaining dst size after src
1202+
pad_after = dst_shape - shift - src_shape
1203+
pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after))
1204+
try:
1205+
ret = s.apply_pad(pad,mode=mode,inplace=inplace,verbose=verbose)
12051206
valid = ret.assert_affine(mapping,raise_error=False,origin_tolerance=0.0001,error_tolerance=0.0001,shape_tolerance=0)
12061207
if valid:
12071208
log.print(f"resample_from_to only needs padding/cropping {pad}",verbose=verbose)
1208-
ret.affine = mapping.affine #remove floating point error
1209+
ret.affine = mapping.affine # remove floating point error
12091210
return ret
1211+
except ValueError as e:
1212+
log.warning("Padding failed.",e,verbose=verbose)
12101213

12111214

12121215
assert mapping is not None
@@ -2505,7 +2508,7 @@ def to_stl(
25052508
try:
25062509
verts, faces, normals, values = marching_cubes(seg_arr, gradient_direction="ascent", step_size=1)
25072510
except RuntimeError as e:
2508-
raise RuntimeError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None
2511+
raise IndexError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None
25092512
# Remove padding offset (since we padded by 1 voxel)
25102513
verts -= 1
25112514
# Apply bounding box offset (still voxel space)
@@ -2696,6 +2699,20 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_la
26962699
if keep_label:
26972700
seg_arr = seg_arr * self.get_seg_array()
26982701
return self.set_array(seg_arr,inplace=inplace)
2702+
def ravel(self,order:Literal["K", "A", "C", "F"] | None="C")->np.ndarray:
2703+
"""Return a contiguous flattened array.
2704+
2705+
A 1-D array, containing the elements of the input, is returned. A copy is made only if needed.
2706+
2707+
As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)
2708+
2709+
Args:
2710+
order (Literal["K", "A", "C", "F"] | None, optional): The elements of a are read using this index order. ‘C’ means to index the elements in row-major, C-style order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to index the elements in column-major, Fortran-style order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. ‘A’ means to read the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise. ‘K’ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, ‘C’ index order is used. Defaults to "C".
2711+
2712+
Returns:
2713+
np.ndarray
2714+
"""
2715+
return self.get_array().ravel(order=order)
26992716
def extract_label_(self, label: int | Enum | Sequence[int] | Sequence[Enum], keep_label=False) -> Self:
27002717
"""In-place variant of `extract_label`."""
27012718
return self.extract_label(label,keep_label,inplace=True)

TPTBox/core/np_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def np_count_nonzero(arr: np.ndarray) -> int:
160160
return np.count_nonzero(arr)
161161

162162

163-
def np_unique(arr: np.ndarray) -> list[int]:
163+
def old_np_unique(arr: np.ndarray) -> list[int]:
164164
"""Returns each existing label in the array (including zero!).
165165
166166
Uses cc3d statistics for unsigned-integer arrays for speed, and falls back
@@ -181,9 +181,52 @@ def np_unique(arr: np.ndarray) -> list[int]:
181181
return list(np.unique(arr))
182182

183183

184+
def np_unique(arr: np.ndarray) -> list[int]:
185+
"""Returns each existing label in the array (including zero!).
186+
187+
Uses cc3d statistics for unsigned-integer arrays for speed, and falls back
188+
to ``numpy.unique`` for other dtypes.
189+
190+
Args:
191+
arr (np.ndarray): Input label array.
192+
193+
Returns:
194+
list[int]: Sorted list of every distinct label value present in ``arr``,
195+
including 0 (background).
196+
"""
197+
if np.issubdtype(arr.dtype, np.unsignedinteger):
198+
# bincount is O(max_val) but ~5-10x faster than np.unique for dense label arrays
199+
max_val = int(arr.max())
200+
if max_val < 2**20: # ~1M labels threshold — bincount stays fast
201+
counts = np.bincount(arr.ravel())
202+
return list(np.where(counts > 0)[0])
203+
# For sparse label spaces fall back to np.unique
204+
return old_np_unique(arr)
205+
206+
184207
def np_unique_withoutzero(arr: UINTARRAY) -> list[int]:
185208
"""Returns each existing non-zero label in the array (excluding background zero).
186209
210+
Args:
211+
arr (UINTARRAY): Input unsigned-integer label array.
212+
213+
Returns:
214+
list[int]: Sorted list of every distinct label value present in ``arr``,
215+
excluding 0 (background).
216+
"""
217+
if np.issubdtype(arr.dtype, np.unsignedinteger):
218+
max_val = int(arr.max())
219+
if max_val == 0:
220+
return []
221+
if max_val < 2**20:
222+
counts = np.bincount(arr.ravel())
223+
return list(np.where(counts[1:] > 0)[0] + 1)
224+
return [i for i in np.unique(arr) if i != 0]
225+
226+
227+
def old_np_unique_withoutzero(arr: UINTARRAY) -> list[int]:
228+
"""Returns each existing non-zero label in the array (excluding background zero).
229+
187230
Args:
188231
arr (UINTARRAY): Input unsigned-integer label array.
189232

TPTBox/registration/_deformable/multilabel_segmentation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__( # noqa: C901
5454
poi_target_cms: POI | None = None,
5555
max_history=100,
5656
change_after_point_reg=lambda x, y, z, w: (x, y, z, w),
57+
tether_distance=1,
5758
**args,
5859
):
5960
"""Initialize a multi-stage registration pipeline from an atlas to a target image.
@@ -90,7 +91,7 @@ def __init__( # noqa: C901
9091
"be": ("BSplineBending", {"stride": 1}),
9192
"seg": "MSE",
9293
"Dice": "Dice",
93-
"Tether": Tether_Seg(delta=5),
94+
"Tether": Tether_Seg(delta=tether_distance),
9495
}
9596

9697
assert target_seg.seg, target_seg.seg
@@ -187,7 +188,7 @@ def __init__( # noqa: C901
187188
poi_cms = poi_cms.resample_from_to(atlas_seg_)
188189

189190
self.reg_point = Point_Registration(poi_target, poi_cms, verbose=False)
190-
atlas_reg = self.reg_point.transform_nii(atlas_seg_)
191+
atlas_reg = self.reg_point.transform_nii(atlas_seg_, c_val=0)
191192

192193
if not atlas_reg.is_segmentation_in_border():
193194
print("point registration ok")
@@ -204,7 +205,7 @@ def __init__( # noqa: C901
204205
target_img = target_img.apply_pad(resize_param) if target_img is not None else None
205206

206207
self.reg_point = Point_Registration(poi_target.resample_from_to(target_seg), poi_cms.resample_from_to(atlas_seg))
207-
atlas_reg = self.reg_point.transform_nii(atlas_seg)
208+
atlas_reg = self.reg_point.transform_nii(atlas_seg, c_val=0)
208209
atlas_img_reg = self.reg_point.transform_nii(atlas_img) if atlas_img is not None else None
209210

210211
if crop:

TPTBox/registration/_ridged_intensity/affine_deepali.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def forward(
9292
target: torch.Tensor, # shape: (B, C, X, Y, Z)
9393
mask: torch.Tensor | None = None, # noqa: ARG002
9494
) -> torch.Tensor:
95-
w = max(target.shape[2:])
95+
w = min(target.shape[2:])
9696
com_fixed = center_of_mass_cc(target) # (B, C, 3)
9797
com_warped = center_of_mass_cc(source) # (B, C, 3)
9898

9999
l_com = torch.norm(com_fixed - com_warped, dim=-1) / w # (B, C)
100100

101101
# Zero out channels with small displacement (<10) or NaNs
102-
l_com = torch.where(l_com < self.delta, torch.zeros_like(l_com), l_com)
102+
l_com = torch.where(l_com * w < self.delta, torch.zeros_like(l_com), l_com)
103103
l_com = torch.nan_to_num(l_com, nan=0.0)
104104

105105
return l_com.mean() # type: ignore

TPTBox/segmentation/VibeSeg/inference_nnunet.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
out_base = Path(__file__).parent.parent / "nnUNet/"
1717
_model_path_ = out_base / "nnUNet_results"
1818

19+
# Opt-in cache of loaded predictors (enable via cache_model=True). Keyed by model identity plus
20+
# the device/runtime settings that affect the loaded predictor, so repeated inference (e.g. a loop
21+
# over many files with the same model) reuses the in-memory model instead of reloading weights from
22+
# disk and re-uploading them to the GPU on every call.
23+
_model_cache: dict = {}
24+
1925

2026
def get_ds_info(idx: int, _model_path: str | Path | None = None, exit_one_fail: bool = True, logger=logger) -> dict:
2127
"""Load and return the ``dataset.json`` for the model with the given dataset index.
@@ -87,13 +93,16 @@ def run_inference_on_file(
8793
ddevice: Literal["cpu", "cuda", "mps"] = "cuda",
8894
model_path=None,
8995
step_size: float = 0.5,
90-
memory_base: int = 5000, # Base memory in MB, default is 5GB
91-
memory_factor: int = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB
92-
memory_max: int = 160000, # in MB, default is 160GB
96+
memory_base: float | None = None, # Base memory in MB, default is 5GB
97+
memory_factor: float | None = None, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB
98+
memory_max: int = 990000, # in MB, default is 990GB (so it is most likely ignored and replaced by Max Memory of the GPU)
9399
wait_till_gpu_percent_is_free: float = 0.1,
100+
tile_batch_size: int = 1,
94101
verbose: bool = True,
95102
auto_download: bool = False,
103+
cache_model: bool = False,
96104
_key_ResEnc: str = "__nnUNet*ResEnc",
105+
fail_on_missing_memory=False,
97106
logger=logger,
98107
) -> tuple[Image_Reference, np.ndarray | None]:
99108
"""Load a VibeSeg model and run inference on the supplied NIfTI images.
@@ -135,7 +144,18 @@ def run_inference_on_file(
135144
memory_max: Hard cap on assumed GPU memory in MB.
136145
wait_till_gpu_percent_is_free: Minimum free GPU fraction to require
137146
before starting inference.
147+
tile_batch_size: Number of sliding-window tiles to run per network
148+
forward pass. ``1`` (default) keeps the original per-tile behaviour;
149+
larger values batch tiles to better saturate the GPU at the cost of
150+
higher peak memory.
138151
verbose: Print progress information.
152+
cache_model: If ``True``, keep the loaded predictor in a process-wide
153+
cache and reuse it on subsequent calls with identical model and
154+
device/runtime settings. Avoids reloading weights from disk and
155+
re-uploading them to the GPU when segmenting many files in a loop, at
156+
the cost of holding the model in GPU memory between calls. The GPU
157+
cache is also left warm (no ``empty_cache``) so the allocator can
158+
reuse buffers across images.
139159
140160
Returns:
141161
A tuple ``(seg_nii, softmax_logits)`` where ``seg_nii`` is the
@@ -196,20 +216,44 @@ def run_inference_on_file(
196216
if "labels" in ds_info2:
197217
ds_info["labels_mapping"] = ds_info2["labels"]
198218

199-
nnunet = load_inf_model(
200-
nnunet_path,
201-
allow_non_final=True,
202-
use_folds=tuple(folds) if len(folds) != 5 else None,
203-
gpu=gpu,
204-
ddevice=ddevice,
205-
step_size=step_size,
206-
memory_base=memory_base,
207-
memory_factor=memory_factor,
208-
memory_max=memory_max,
209-
wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free,
219+
if memory_base is None:
220+
memory_base = float(ds_info.get("memory_base", 5000))
221+
if memory_factor is None:
222+
memory_factor = float(ds_info.get("memory_factor", 160))
223+
224+
use_folds_arg = tuple(folds) if len(folds) != 5 else None
225+
# Include every setting that changes the loaded predictor so a cache hit is always equivalent
226+
# to a fresh load; differing settings simply miss the cache and reload.
227+
cache_key = (
228+
str(nnunet_path),
229+
use_folds_arg,
230+
ddevice,
231+
gpu,
232+
step_size,
233+
memory_base,
234+
memory_factor,
235+
memory_max,
236+
wait_till_gpu_percent_is_free,
237+
tile_batch_size,
210238
)
211-
212-
# _unets[idx] = nnunet
239+
nnunet = _model_cache.get(cache_key) if cache_model else None
240+
if nnunet is None:
241+
nnunet = load_inf_model(
242+
nnunet_path,
243+
allow_non_final=True,
244+
use_folds=use_folds_arg,
245+
gpu=gpu,
246+
ddevice=ddevice,
247+
step_size=step_size,
248+
memory_base=memory_base,
249+
memory_factor=memory_factor,
250+
memory_max=memory_max,
251+
wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free,
252+
tile_batch_size=tile_batch_size,
253+
fail_on_missing_memory=fail_on_missing_memory,
254+
)
255+
if cache_model:
256+
_model_cache[cache_key] = nnunet
213257
if "orientation" in ds_info:
214258
orientation = ds_info["orientation"]
215259

@@ -315,9 +359,11 @@ def to_int(a: str, k: None | int = None):
315359
seg_nii.map_labels_(mapping)
316360
if out_file is not None and (not Path(out_file).exists() or override):
317361
seg_nii.set_dtype("smallest_uint").save(out_file)
318-
del nnunet
319-
320-
torch.cuda.empty_cache()
362+
if not cache_model:
363+
# When caching we keep the predictor alive (it stays referenced by _model_cache, so del
364+
# would not free it anyway) and leave the CUDA cache warm so the next image reuses buffers.
365+
del nnunet
366+
torch.cuda.empty_cache()
321367
return seg_nii, softmax_logits
322368

323369

TPTBox/segmentation/VibeSeg/vibeseg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@
8484
72: "bone_other",
8585
}
8686

87+
defaults = {
88+
100: {"memory_base": 5500, "memory_factor": 25},
89+
}
90+
8791

8892
def run_vibeseg(
8993
i: Image_Reference,
@@ -113,6 +117,10 @@ def run_vibeseg(
113117
Returns:
114118
Segmentation ``NII`` saved at *out_seg*.
115119
"""
120+
if dataset_id in defaults:
121+
for k, v in defaults[dataset_id].items():
122+
if k not in args:
123+
args[k] = v
116124
return run_inference_on_file(
117125
dataset_id,
118126
[to_nii(i)] if not isinstance(i, (list, tuple)) else [to_nii(j) for j in i],

0 commit comments

Comments
 (0)