Skip to content

Commit f81d7d0

Browse files
authored
Merge pull request #111 from Hendrik-code/optimization
Optimization
2 parents 267153e + a47f632 commit f81d7d0

22 files changed

Lines changed: 834 additions & 218 deletions

TPTBox/core/bids_files.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,15 +1750,15 @@ def filter(
17501750
"""
17511751
if self._flatten:
17521752
assert isinstance(self.candidates, list)
1753-
for bids_file in self.candidates.copy():
1754-
if not bids_file.do_filter(key, filter_fun, required=required):
1755-
self.candidates.remove(bids_file)
1753+
# list comprehension is O(n); the old copy()+list.remove() loop was O(n^2)
1754+
self.candidates = [f for f in self.candidates if f.do_filter(key, filter_fun, required=required)]
17561755
else:
17571756
assert isinstance(self.candidates, dict)
1758-
for sequences, bids_files in self.candidates.copy().items():
1759-
# print(sequences, list(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files))
1760-
if not any(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files):
1761-
self.candidates.pop(sequences)
1757+
self.candidates = {
1758+
seq: bids_files
1759+
for seq, bids_files in self.candidates.items()
1760+
if any(f.do_filter(key, filter_fun, required=required) for f in bids_files)
1761+
}
17621762

17631763
def filter_format(self, filter_fun: list[str] | str | typing.Callable[[str | object], bool]) -> None:
17641764
"""Keep only files whose format label satisfies *filter_fun*.
@@ -1807,15 +1807,15 @@ def filter_non_existence(
18071807
"""
18081808
if self._flatten:
18091809
assert isinstance(self.candidates, list)
1810-
for bids_file in self.candidates.copy():
1811-
if bids_file.do_filter(key, filter_fun, required=required):
1812-
self.candidates.remove(bids_file)
1810+
# list comprehension is O(n); the old copy()+list.remove() loop was O(n^2)
1811+
self.candidates = [f for f in self.candidates if not f.do_filter(key, filter_fun, required=required)]
18131812
else:
18141813
assert isinstance(self.candidates, dict)
1815-
for sequences, bids_files in self.candidates.copy().items():
1816-
# print(sequences, list(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files))
1817-
if any(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files):
1818-
self.candidates.pop(sequences)
1814+
self.candidates = {
1815+
seq: bids_files
1816+
for seq, bids_files in self.candidates.items()
1817+
if not any(f.do_filter(key, filter_fun, required=required) for f in bids_files)
1818+
}
18191819

18201820
def filter_dixon_only_inphase(self) -> None:
18211821
"""Remove Dixon files that are fat, water, out-of-phase, or difference images.

TPTBox/core/nii_poi_abstract.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,15 @@ def global_to_local(self, x: COORDINATE) -> tuple:
482482
a = self.rotation.T @ (np.array(x) - self.origin) / np.array(self.zoom)
483483
return tuple(round(float(v), 7) for v in a)
484484

485+
def global_to_local_arr(self, coords: np.ndarray) -> np.ndarray:
486+
"""Vectorized :meth:`global_to_local` for an ``(N, 3)`` array of world coordinates.
487+
488+
Equivalent to applying ``global_to_local`` to each row but in a single batched
489+
inverse-affine matmul.
490+
"""
491+
a = (np.asarray(coords, dtype=float) - np.asarray(self.origin)) @ np.asarray(self.rotation) / np.asarray(self.zoom)
492+
return np.round(a, 7)
493+
485494
def local_to_global(self, x: COORDINATE) -> tuple:
486495
"""Convert voxel (local) coordinates to world (RAS/LPS) coordinates.
487496

TPTBox/core/nii_wrapper.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
np_filter_connected_components,
4444
np_get_connected_components_center_of_mass,
4545
np_is_empty,
46+
np_isin,
4647
np_map_labels,
4748
np_map_labels_based_on_majority_label_mask_overlap,
4849
np_point_coordinates,
@@ -2092,10 +2093,11 @@ def truncate_labels_beyond_reference_(
20922093
flip = self.orientation[axis_] != axis # Check orientation for flipping
20932094
# Get the array data
20942095
np_array = self.get_array()
2095-
np_array_cond = self.extract_label(idx).get_seg_array()
2096+
# both masks come directly from np_array via np_isin (avoids two extract_label round-trips)
2097+
np_array_cond = np_isin(np_array, idx)
20962098

20972099
# Find the lowest point (smallest index) along the axis where `not_above` exists
2098-
threshold = np.where(self.extract_label(not_beyond).get_seg_array() == 1)
2100+
threshold = np.where(np_isin(np_array, not_beyond))
20992101
if len(threshold[axis_]) == 0:
21002102
return self if inplace else self.copy()
21012103
flip_up = flip
@@ -2115,7 +2117,7 @@ def truncate_labels_beyond_reference_(
21152117
mask = np.broadcast_to(mask, self.shape)
21162118

21172119
# Replace values of `idx` with `fill` in the masked region
2118-
np_array = np.where((np_array_cond == 1) & mask, fill, np_array)
2120+
np_array = np.where(np_array_cond & mask, fill, np_array)
21192121

21202122
# Update the NIfTI object with the modified array
21212123
return self.set_array(np_array, inplace=inplace)
@@ -2253,7 +2255,8 @@ def map_labels(self, label_map:LABEL_MAP , verbose:logging=True, inplace=False)
22532255
If inplace is True, returns the current NIfTI image object with mapped labels. Otherwise, returns a new NIfTI image object with mapped labels.
22542256
"""
22552257
data_orig = self.get_seg_array()
2256-
labels_before = [v for v in np_unique(data_orig) if v > 0]
2258+
# the before/after np_unique scans are only used for the verbose log line; skip them otherwise
2259+
labels_before = [v for v in np_unique(data_orig) if v > 0] if verbose else None
22572260
# enforce keys to be str to support both str and int
22582261
label_map_ = {
22592262
(v_name2idx[k] if k in v_name2idx else int(k)): (
@@ -2263,15 +2266,16 @@ def map_labels(self, label_map:LABEL_MAP , verbose:logging=True, inplace=False)
22632266
}
22642267
log.print("label_map_ =", label_map_, verbose=verbose)
22652268
data = np_map_labels(data_orig, label_map_)
2266-
labels_after = [v for v in np_unique(data) if v > 0]
2267-
log.print(
2268-
"N =",
2269-
len(label_map_),
2270-
"labels reassigned, before labels: ",
2271-
labels_before,
2272-
" after: ",
2273-
labels_after,verbose=verbose
2274-
)
2269+
if verbose:
2270+
labels_after = [v for v in np_unique(data) if v > 0]
2271+
log.print(
2272+
"N =",
2273+
len(label_map_),
2274+
"labels reassigned, before labels: ",
2275+
labels_before,
2276+
" after: ",
2277+
labels_after,verbose=verbose
2278+
)
22752279
nii = data.astype(np.uint16), self.affine, self.header
22762280
if inplace:
22772281
self.nii = nii
@@ -2685,19 +2689,22 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_la
26852689
seg_arr = self.get_seg_array()
26862690

26872691
if isinstance(label, Sequence):
2688-
label_int:list[int] = [idx.value if isinstance(idx,Enum) else idx for idx in label]
2689-
assert 0 not in label_int, 'Zero label does not make sense. This is the background'
2690-
seg_arr = np_extract_label(seg_arr, label_int, to_label=1, inplace=True)
2692+
labels:int|list[int] = [idx.value if isinstance(idx,Enum) else idx for idx in label]
2693+
assert 0 not in labels, 'Zero label does not make sense. This is the background'
26912694
else:
26922695
if isinstance(label,Enum):
26932696
label = label.value
26942697
if isinstance(label,str):
26952698
label = int(label)
26962699

26972700
assert label != 0, 'Zero label does not make sense. This is the background'
2698-
seg_arr = np_extract_label(seg_arr, label, to_label=1, inplace=True)
2701+
labels = label
26992702
if keep_label:
2700-
seg_arr = seg_arr * self.get_seg_array()
2703+
# keep the original label values where in `labels`, zero everywhere else.
2704+
# single get_seg_array() copy + one np_isin mask (faster than extract + a second copy/multiply)
2705+
seg_arr[~np_isin(seg_arr, labels)] = 0
2706+
else:
2707+
seg_arr = np_extract_label(seg_arr, labels, to_label=1, inplace=True)
27012708
return self.set_array(seg_arr,inplace=inplace)
27022709
def ravel(self,order:Literal["K", "A", "C", "F"] | None="C")->np.ndarray:
27032710
"""Return a contiguous flattened array.
@@ -2719,15 +2726,17 @@ def extract_label_(self, label: int | Enum | Sequence[int] | Sequence[Enum], kee
27192726
def remove_labels(self,label:int|Enum|Sequence[int]|Sequence[Enum], inplace=False, verbose:logging=True, removed_to_label=0) -> Self:
27202727
"""If this NII is a segmentation you can single out one label."""
27212728
assert label != 0, 'Zero label does not make sens. This is the background'
2722-
seg_arr = self.get_seg_array()
27232729
if not isinstance(label,Sequence):
27242730
label = [label] # type: ignore
2731+
flat: list[int] = []
27252732
for l in label:
27262733
if isinstance(l, list):
2727-
for g in l:
2728-
seg_arr[seg_arr == g] = removed_to_label
2734+
flat.extend(g.value if isinstance(g, Enum) else g for g in l)
27292735
else:
2730-
seg_arr[seg_arr == l] = removed_to_label
2736+
flat.append(l.value if isinstance(l, Enum) else l)
2737+
# one np_map_labels gather is constant-time in the number of labels (a per-label
2738+
# `seg_arr == l` loop costs one full pass per label).
2739+
seg_arr = np_map_labels(self.get_seg_array(), dict.fromkeys(flat, removed_to_label))
27312740
return self.set_array(seg_arr,inplace=inplace, verbose=verbose)
27322741
def remove_labels_(self, label: int | Enum | Sequence[int] | Sequence[Enum], removed_to_label=0, verbose: logging = True) -> Self:
27332742
"""In-place variant of `remove_labels`."""

TPTBox/core/np_utils.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,40 @@
3636
INTARRAY = Union[UINTARRAY, NDArray[INT]]
3737

3838

39+
def np_isin(arr: np.ndarray, labels, invert: bool = False) -> np.ndarray:
40+
"""Fast ``np.isin`` for non-negative integer label arrays via a boolean lookup table.
41+
42+
For unsigned-integer segmentation masks this is ~3-6x faster than ``np.isin`` when testing
43+
membership in more than one label, because it replaces the general algorithm with a single
44+
``lut[arr]`` gather. Falls back to ``np.isin`` for non-unsigned dtypes, negative labels, or
45+
very large label ranges; uses ``arr == label`` for the single-label case.
46+
47+
Args:
48+
arr (np.ndarray): Input array.
49+
labels: A label or iterable of labels to test membership against.
50+
invert (bool, optional): If True, return the complement (equivalent to
51+
``np.isin(arr, labels, invert=True)``). Defaults to False.
52+
53+
Returns:
54+
np.ndarray: Boolean mask, same shape as ``arr``.
55+
"""
56+
if not isinstance(labels, (list, tuple, np.ndarray)):
57+
labels = [labels]
58+
if len(labels) == 0:
59+
return np.ones(arr.shape, dtype=bool) if invert else np.zeros(arr.shape, dtype=bool)
60+
if len(labels) == 1:
61+
res = arr == labels[0]
62+
return ~res if invert else res
63+
if np.issubdtype(arr.dtype, np.unsignedinteger) and min(int(x) for x in labels) >= 0:
64+
m = max(int(arr.max()), int(max(labels))) + 1
65+
if m < 2**20: # keep the lookup table small (same threshold as np_unique's bincount path)
66+
lut = np.zeros(m, dtype=bool)
67+
lut[np.asarray(labels)] = True
68+
res = lut[arr]
69+
return ~res if invert else res
70+
return np.isin(arr, labels, invert=invert)
71+
72+
3973
def np_extract_label(
4074
arr: np.ndarray,
4175
label: int | list[int],
@@ -69,7 +103,7 @@ def np_extract_label(
69103

70104
if isinstance(label, list):
71105
assert 0 not in label, "label 0 is not supported in list mode"
72-
arr_msk = np.isin(arr, label)
106+
arr_msk = np_isin(arr, label)
73107
arr[arr_msk] = to_label
74108
arr[~arr_msk] = 0
75109
return arr
@@ -125,10 +159,12 @@ def np_volume(arr: UINTARRAY, include_zero: bool = False) -> dict[int, int]:
125159
Returns:
126160
dict[int, int]: Mapping from label value to number of voxels with that label.
127161
"""
162+
# np.bincount wins decisively when there are many labels (e.g. connected-component maps);
163+
# cc3d statistics is faster for the few-label case typical of anatomical segmentations.
164+
counts = np.bincount(arr.ravel()) if int(arr.max()) > 256 else cc3dstatistics(arr, use_crop=not include_zero)["voxel_counts"]
128165
if include_zero:
129-
return {idx: i for idx, i in dict(enumerate(cc3dstatistics(arr, use_crop=False)["voxel_counts"])).items() if i > 0}
130-
else:
131-
return {idx: i for idx, i in dict(enumerate(cc3dstatistics(arr)["voxel_counts"])).items() if i > 0 and idx != 0}
166+
return {idx: i for idx, i in enumerate(counts) if i > 0}
167+
return {idx: i for idx, i in enumerate(counts) if i > 0 and idx != 0}
132168

133169

134170
def np_is_empty(arr: UINTARRAY | INTARRAY) -> bool:
@@ -253,8 +289,8 @@ def np_center_of_mass(arr: UINTARRAY) -> dict[int, COORDINATE]:
253289
"""
254290
stats = cc3dstatistics(arr, use_crop=False)
255291
# Does not use the other calls for speed reasons
256-
unique = [idx for idx, i in enumerate(stats["voxel_counts"]) if i > 0 and idx != 0]
257-
return {idx: v for idx, v in enumerate(stats["centroids"]) if idx in unique}
292+
vc = stats["voxel_counts"]
293+
return {idx: v for idx, v in enumerate(stats["centroids"]) if idx != 0 and vc[idx] > 0}
258294

259295

260296
def np_bounding_boxes(arr: UINTARRAY) -> dict[int, tuple[slice, slice, slice]]:
@@ -270,8 +306,8 @@ def np_bounding_boxes(arr: UINTARRAY) -> dict[int, tuple[slice, slice, slice]]:
270306
"""
271307
stats = cc3dstatistics(arr)
272308
# Does not use the other calls for speed reasons
273-
unique = [idx for idx, i in enumerate(stats["voxel_counts"]) if i > 0 and idx != 0]
274-
return {idx: v for idx, v in enumerate(stats["bounding_boxes"]) if idx in unique}
309+
vc = stats["voxel_counts"]
310+
return {idx: v for idx, v in enumerate(stats["bounding_boxes"]) if idx != 0 and vc[idx] > 0}
275311

276312

277313
def np_contacts(arr: UINTARRAY, connectivity: int) -> dict[tuple[int, int], int]:
@@ -383,14 +419,14 @@ def np_erode_msk_euclid(arr: np.ndarray, n_pixel: int = 3, use_crop=True, labels
383419
if use_crop:
384420
arr_bin = arr.copy()
385421
if labels is not None:
386-
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
422+
arr_bin[np_isin(arr_bin, labels, invert=True)] = 0
387423
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
388424
arrc = arr[crop]
389425
else:
390426
arrc = arr
391427
if labels is not None:
392428
arrc = arrc.copy()
393-
arrc[np.isin(arrc, labels, invert=True)] = 0
429+
arrc[np_isin(arrc, labels, invert=True)] = 0
394430

395431
if mask is not None:
396432
mask = mask.copy()
@@ -426,17 +462,18 @@ def np_dilate_msk_euclid(arr: np.ndarray, n_pixel: int = 3, use_crop=True, label
426462
427463
Assigns each newly covered voxel to the nearest existing label.
428464
"""
465+
arr_bin = arr.copy()
466+
if labels is not None:
467+
arr_bin[np_isin(arr_bin, labels, invert=True)] = 0
468+
429469
if use_crop:
430-
arr_bin = arr.copy()
431-
if labels is not None:
432-
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
433470
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
434471
arrc = arr[crop]
435472
else:
436473
arrc = arr
437474
if labels is not None:
438475
arrc = arrc.copy()
439-
arrc[np.isin(arr_bin, labels, invert=True)] = 0
476+
arrc[np_isin(arr_bin, labels, invert=True)] = 0
440477
if mask is not None:
441478
mask[mask != 0] = 1
442479
if use_crop:
@@ -500,7 +537,7 @@ def np_dilate_msk(
500537
if use_crop:
501538
# try:
502539
arr_bin = arr.copy()
503-
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
540+
arr_bin[np_isin(arr_bin, labels, invert=True)] = 0
504541
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
505542
arrc = arr[crop]
506543
else:
@@ -521,8 +558,7 @@ def np_dilate_msk(
521558
out = arrc
522559
for _ in range(n_pixel):
523560
for i in labels:
524-
data = out.copy()
525-
data[i != data] = 0
561+
data = out == i # boolean mask; _binary_dilation casts to bool anyway, so this is exact and avoids a full copy
526562
if use_crop:
527563
lcrop = np_bbox_binary(data, px_dist=2 + n_pixel, raise_error=False)
528564
data = data[lcrop]
@@ -575,7 +611,7 @@ def np_erode_msk(
575611
labels: list[int] = _to_labels(arr, label_ref)
576612

577613
if use_crop:
578-
crop = np_bbox_binary(np.isin(arr, labels, invert=False), px_dist=1 + n_pixel, raise_error=False)
614+
crop = np_bbox_binary(np_isin(arr, labels, invert=False), px_dist=1 + n_pixel, raise_error=False)
579615
arrc = arr[crop]
580616
else:
581617
arrc = arr
@@ -703,9 +739,16 @@ def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray =
703739
assert len(px_dist) == n, f"dimension mismatch, got img shape {shp} and px_dist {px_dist}"
704740

705741
bbox: list[float] = []
706-
for ax in itertools.combinations(reversed(range(n)), n - 1):
707-
nonzero = np.any(a=img, axis=ax)
708-
bbox.extend(np.where(nonzero)[0][[0, -1]]) # type: ignore
742+
if n == 3:
743+
# 2 full passes instead of 3: two axis extents come from a shared 2D projection (cheap),
744+
# only the third axis needs a second full reduction.
745+
p = np.any(img, axis=2)
746+
for nonzero in (np.any(p, axis=1), np.any(p, axis=0), np.any(img, axis=(0, 1))):
747+
bbox.extend(np.where(nonzero)[0][[0, -1]]) # type: ignore
748+
else:
749+
for ax in itertools.combinations(reversed(range(n)), n - 1):
750+
nonzero = np.any(a=img, axis=ax)
751+
bbox.extend(np.where(nonzero)[0][[0, -1]]) # type: ignore
709752
out: tuple[slice, ...] = tuple(
710753
slice(
711754
max(bbox[i] - px_dist[i // 2], 0),
@@ -867,7 +910,7 @@ def np_connected_components(
867910
labels: Sequence[int] = _to_labels(arr, label_ref)
868911
if include_zero:
869912
arr[arr == 0] = arr.max() + 1
870-
arr[np.isin(arr, labels, invert=True)] = 0
913+
arr[np_isin(arr, labels, invert=True)] = 0
871914
cc_map, n = _connected_components(arr, connectivity=connectivity, return_N=True)
872915
return cc_map, n
873916

@@ -952,7 +995,7 @@ def np_filter_connected_components(
952995

953996
arr2 = arr.copy()
954997
labels: Sequence[int] = _to_labels(arr, label_ref)
955-
arr2[np.isin(arr2, labels, invert=True)] = 0 # type:ignore
998+
arr2[np_isin(arr2, labels, invert=True)] = 0 # type:ignore
956999

9571000
labels_out, n = _connected_components(arr2, connectivity=connectivity, return_N=True)
9581001
largest_k_components_org = largest_k_components

0 commit comments

Comments
 (0)