Skip to content

Commit a2e3c54

Browse files
committed
Merge branch 'development_robert' of github.com:Hendrik-code/TPTBox into development_robert
2 parents 8f6b780 + b17456f commit a2e3c54

7 files changed

Lines changed: 302 additions & 74 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import time
2+
3+
import elasticdeform
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
7+
from TPTBox import NII
8+
9+
10+
def deformed_nii(
11+
nii_dic: dict[str, NII],
12+
sigma: float | None = None,
13+
points=None,
14+
deform_factor=1.0,
15+
deform_padding=10,
16+
normalize=True,
17+
joint_normalize=False,
18+
) -> dict[str, NII]:
19+
"""
20+
Deform a dictionary of NII objects using random grid deformation. Requires elasticdeform. 'pip install elasticdeform'
21+
22+
IMPORTANT: Normalize your image data to 0,1. The .seg property of NII shows if this is a segmentation. (NII is form our TPTBox and is a wrapper for nibable)
23+
24+
This function takes a dictionary of NII objects and applies random grid deformation to each object
25+
using specified deformation parameters or, if not provided, random parameters generated based on
26+
the `deform_factor`. The deformed objects are returned as a dictionary.
27+
28+
Args:
29+
arr_dic (dict[str, NII]): A dictionary containing NII objects to be deformed.
30+
sigma (float, optional): The standard deviation of the deformation field. If not provided,
31+
it will be generated based on the `deform_factor`.
32+
points (int, optional): The number of control points for the deformation grid. If not provided,
33+
it will be generated based on the `deform_factor`.
34+
deform_factor (float, optional): A factor used to determine the deformation parameters if
35+
`sigma` and `points` are not specified. Larger values result in stronger deformations.
36+
deform_padding (int, optional): The padding added to the deformed objects to avoid edge artifacts.
37+
verbose (bool, optional): If True, enable verbose logging. Default is True.
38+
39+
Returns:
40+
dict[str, NII]: A dictionary where keys correspond to the input dictionary keys, and values
41+
correspond to the deformed NII objects.
42+
43+
Example:
44+
# Deform a dictionary of NII objects using default deformation parameters
45+
deformed_data = deformed_NII(arr_dic)
46+
47+
# Deform a dictionary of NII objects with specific deformation parameters
48+
sigma = 1.0
49+
points = 20
50+
deformed_data = deformed_NII(arr_dic, sigma=sigma, points=points)
51+
"""
52+
if sigma is None or points is None:
53+
sigma, points = get_random_deform_parameter(deform_factor=deform_factor)
54+
55+
print("deformation parameter sigma = ", round(sigma, 4), "; n_points = ", points)
56+
t = time.time()
57+
values = list(nii_dic.values())
58+
# Deform
59+
if joint_normalize:
60+
max_v = max([img.max() for img in nii_dic.values() if not img.seg])
61+
nii_dic = {k: img if img.seg else img.set_dtype(np.float32) / max_v for k, img in nii_dic.items()}
62+
elif normalize:
63+
nii_dic = {k: img if img.seg else img.set_dtype(np.float32).normalize() for k, img in nii_dic.items()}
64+
else:
65+
nii_dic = {k: img if img.seg else img.set_dtype(np.float32) for k, img in nii_dic.items()}
66+
assert sigma is not None
67+
p = deform_padding
68+
out: list[NDArray] = elasticdeform.deform_random_grid(
69+
[pad(v.get_array(), p=p) for v in values],
70+
sigma=sigma, # type: ignore
71+
points=points,
72+
order=[0 if v.seg else 3 for v in values], # type: ignore
73+
)
74+
out2: dict[str, NII] = {}
75+
for (k, nii), arr in zip(nii_dic.items(), out, strict=True):
76+
out2[k] = nii.set_array(arr[p:-p, p:-p, p:-p])
77+
print("Deformation took", round(time.time() - t, 1), "Seconds")
78+
return out2
79+
80+
81+
def pad(arr, p=10):
82+
return np.pad(arr, p, mode="reflect")
83+
84+
85+
def get_random_deform_parameter(deform_factor: float = 1):
86+
"""
87+
Generate random deformation parameters for use in 3D deformation.
88+
89+
This function generates random values for the deformation parameters, including 'sigma' and 'points',
90+
based on the specified deformation factor. These parameters are used for 3D deformation operations.
91+
92+
Args:
93+
deform_factor (float, optional): A factor to control the strength of deformation. Default is 1.
94+
95+
Returns:
96+
tuple[float, int]: A tuple containing the generated 'sigma' (float) and 'points' (int) parameters.
97+
98+
Example:
99+
# Generate random deformation parameters with a deformation factor of 1
100+
sigma, points = get_random_deform_parameter()
101+
102+
# Generate random deformation parameters with a deformation factor of 2
103+
sigma, points = get_random_deform_parameter(deform_factor=2)
104+
"""
105+
sigma = 2 + np.random.uniform() * 2.5 # 1,5 - 4.5
106+
min_points = 3
107+
max_points = 17
108+
if sigma < 2:
109+
max_points = 17
110+
elif sigma < 1.7:
111+
max_points = 16
112+
elif sigma < 2.1:
113+
max_points = 15
114+
elif sigma < 2.3:
115+
max_points = 14
116+
elif sigma < 2.5:
117+
max_points = 13
118+
elif sigma < 2.6:
119+
max_points = 12
120+
elif sigma < 2.7:
121+
max_points = 11
122+
elif sigma < 2.8:
123+
max_points = 10
124+
elif sigma < 3:
125+
max_points = 9
126+
elif sigma < 3.5:
127+
max_points = 8
128+
elif sigma < 4.0:
129+
max_points = 7
130+
elif sigma < 4.3:
131+
max_points = 6
132+
else:
133+
max_points = 5
134+
points = np.random.randint(max_points - min_points + 1) + min_points
135+
# Stronger
136+
sigma *= deform_factor
137+
# points *= deform_factor
138+
points = max(round(points), 1)
139+
return (sigma, points)

TPTBox/core/nii_wrapper.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import traceback
55
import warnings
66
import zlib
7+
from collections import deque
78
from collections.abc import Sequence
89
from enum import Enum
910
from math import ceil, floor
@@ -937,6 +938,34 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN
937938
if isinstance(mapping,Has_Grid) and mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0):
938939
log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose)
939940
return self if inplace else self.copy()
941+
942+
#m1 = mapping.make_empty_POI().reorient(self.orientation)
943+
#if m1.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0):
944+
# log.print(f"resample_from_to only need reorientation; {self.orientation}",verbose=verbose)
945+
# return self.reorient(mapping.orientation,inplace=inplace)
946+
#if self.orientation == mapping.orientation and self.zoom == mapping.zoom:
947+
# shift = (np.array(self.origin) - np.array(m1.origin)) / np.array(m1.zoom)
948+
# if np.allclose(shift, np.round(shift), atol=1e-6):
949+
# self = self.reorient(mapping.orientation,inplace=inplace) # noqa: PLW0642
950+
# shift = (np.array(self.origin) - np.array(mapping.origin)) / np.array(mapping.zoom)
951+
# shift = np.round(shift).astype(int)
952+
# src_shape = np.array(mapping.shape)
953+
# dst_shape = np.array(self.shape)
954+
# # padding before = how much dst starts before src
955+
# pad_before = np.maximum(-shift, 0)
956+
#
957+
# # where src ends inside dst
958+
# src_end_in_dst = shift + src_shape
959+
# # padding after = remaining dst size after src
960+
# pad_after = np.maximum(dst_shape - src_end_in_dst, 0)
961+
# pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after))
962+
# ret = self.apply_pad(pad, mode=mode)
963+
#
964+
# log.print(f"resample_from_to only needs padding/cropping {pad}, ",verbose=verbose,)
965+
# ret.assert_affine(mapping,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0)
966+
# return ret
967+
968+
940969
assert mapping is not None
941970
log.print(f"resample_from_to: {self} to {mapping}",verbose=verbose)
942971
if order is None:
@@ -1729,7 +1758,7 @@ def truncate_labels_beyond_reference(
17291758
):
17301759
return self.truncate_labels_beyond_reference_(idx,not_beyond,fill,axis,inclusion,inplace=inplace)
17311760

1732-
def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|str|None=None,max_depth=None):
1761+
def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|str|None=None,max_depth=None, _do_crop=True):
17331762
"""
17341763
Expands labels from self_mask into regions of reference_mask == 1 via breadth-first diffusion.
17351764
@@ -1742,8 +1771,14 @@ def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|s
17421771
ndarray: Updated label mask.
17431772
"""
17441773
self.assert_affine(reference_mask)
1745-
self_mask = self.compute_surface_mask().get_seg_array().copy()
1746-
self_mask_org = self.get_seg_array().copy()
1774+
if _do_crop:
1775+
crop = reference_mask.compute_crop(0,5)
1776+
s = self.apply_crop(crop)
1777+
reference_mask = reference_mask.apply_crop(crop)
1778+
else:
1779+
s = self
1780+
self_mask = s.compute_surface_mask().get_seg_array().copy()
1781+
self_mask_org = s.get_seg_array().copy()
17471782
ref_mask = np.clip(reference_mask.get_seg_array(), 0, 1)
17481783
ref_mask[self_mask_org != 0] = 0
17491784
searched = np.clip(self_mask,0,1).astype(np.uint8)
@@ -1763,13 +1798,14 @@ def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|s
17631798
else:
17641799
raise NotImplementedError(axis)
17651800

1766-
search = []
1801+
search = deque()
17671802
coords = np.where(self_mask != 0)
17681803
def _add_idx(x,y,z,v,d):
17691804
for x1,y1,z1 in kernel:
17701805
a = x+x1
17711806
b = y+y1
17721807
c = z+z1
1808+
17731809
if a < 0 or b < 0 or c < 0:
17741810
continue
17751811
if a >= self_mask.shape[0] or b >= self_mask.shape[1] or c >= self_mask.shape[2]:
@@ -1782,28 +1818,37 @@ def _add_idx(x,y,z,v,d):
17821818
def _infect(a,b,c,v,d):
17831819
if d-1 == max_depth:
17841820
return
1785-
if searched[a,b,c] != 0:
1821+
if searched[x,y,z] != 0:
17861822
return
1787-
if ref_mask[a,b,c] == 0:
1823+
if ref_mask[x,y,z] == 0:
17881824
return
17891825
#print(a,b,c)
17901826
searched[a,b,c] = 1
17911827
self_mask[a,b,c] = v
1792-
_add_idx(x,y,z,v,d)
1828+
_add_idx(a,b,c,v,d)
17931829

17941830
from tqdm import tqdm
17951831
for x,y,z in tqdm(zip(coords[0],coords[1],coords[2]),total=len(coords[0]),disable=not verbose,desc="Collecting Surface"):
17961832
_add_idx(x,y,z,self_mask[x,y,z],0)
17971833
while len(search) != 0:
17981834
search2 = search
1799-
search = []
1800-
for x,y,z,v,d in tqdm(search2,disable=not verbose,desc="infect"):
1835+
search = deque()
1836+
for _ in tqdm(range(len(search2)),disable=not verbose,desc="infect"):
1837+
x,y,z,v,d = search2.popleft()
18011838
_infect(x,y,z,v,d+1)
18021839
self_mask[self_mask == 0] = self_mask_org[self_mask == 0]
1840+
if _do_crop:
1841+
if inplace:
1842+
self[crop] = self_mask
1843+
return self
1844+
else:
1845+
arr = self.get_array()
1846+
arr[crop] = self_mask
1847+
self_mask = arr
18031848
return self.set_array(self_mask,inplace=inplace)
18041849

1805-
def infect_(self: NII, reference_mask: NII,verbose=True,axis:int|str|None=None):
1806-
return self.infect(reference_mask, inplace=True,verbose=verbose,axis=axis)
1850+
def infect_(self: NII, reference_mask: NII,verbose=True,axis:int|str|None=None,_do_crop=True):
1851+
return self.infect(reference_mask, inplace=True,verbose=verbose,axis=axis,_do_crop=_do_crop)
18071852

18081853
def map_labels(self, label_map:LABEL_MAP , verbose:logging=True, inplace=False):
18091854
"""

TPTBox/core/np_utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -598,16 +598,8 @@ def np_calc_crop_around_centerpoint(
598598

599599
cutout_coords_slices = tuple([slice(cutout_coords[i], cutout_coords[i + 1]) for i in range(0, n_dim * 2, 2)])
600600
arr_cut = arr[cutout_coords_slices]
601-
arr_cut = np.pad(
602-
arr_cut,
603-
tuple(padding),
604-
)
605-
return (
606-
arr_cut,
607-
cutout_coords_slices,
608-
tuple(padding),
609-
# tuple([slice(padding[i][0], padding[i][1]) for i in range(n_dim)]),
610-
)
601+
arr_cut = np.pad(arr_cut, tuple(padding))
602+
return (arr_cut, cutout_coords_slices, tuple(padding))
611603

612604

613605
def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray = 0, raise_error=True) -> tuple[slice, ...]:

0 commit comments

Comments
 (0)