Skip to content

Commit f99ebda

Browse files
authored
Add deepedit transforms (#2810)
* Add deepedit transforms Signed-off-by: Andres <diazandr3s@gmail.com> * Run unittests - autofix Signed-off-by: Andres <diazandr3s@gmail.com> * Update transform Signed-off-by: Andres <diazandr3s@gmail.com>
1 parent f981ad0 commit f99ebda

4 files changed

Lines changed: 275 additions & 0 deletions

File tree

monai/apps/deepedit/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.

monai/apps/deepedit/transforms.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import json
2+
import logging
3+
from typing import Dict, Hashable, Mapping, Tuple
4+
5+
import numpy as np
6+
7+
from monai.config import KeysCollection
8+
from monai.transforms.transform import MapTransform, Randomizable, Transform
9+
10+
logger = logging.getLogger(__name__)
11+
12+
from monai.utils import optional_import
13+
14+
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
15+
16+
17+
class DiscardAddGuidanced(MapTransform):
18+
def __init__(
19+
self,
20+
keys: KeysCollection,
21+
probability: float = 1.0,
22+
allow_missing_keys: bool = False,
23+
):
24+
"""
25+
Discard positive and negative points randomly or Add the two channels for inference time
26+
27+
:param probability: Discard probability; For inference it will be always 1.0
28+
"""
29+
super().__init__(keys, allow_missing_keys)
30+
self.probability = probability
31+
32+
def _apply(self, image):
33+
if self.probability >= 1.0 or np.random.choice([True, False], p=[self.probability, 1 - self.probability]):
34+
signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)
35+
if image.shape[0] == 3:
36+
image[1] = signal
37+
image[2] = signal
38+
else:
39+
image = np.concatenate((image, signal, signal), axis=0)
40+
return image
41+
42+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
43+
d: Dict = dict(data)
44+
for key in self.key_iterator(d):
45+
if key == "image":
46+
d[key] = self._apply(d[key])
47+
else:
48+
print("This transform only applies to the image")
49+
return d
50+
51+
52+
class ResizeGuidanceCustomd(Transform):
53+
"""
54+
Resize the guidance based on cropped vs resized image.
55+
"""
56+
57+
def __init__(
58+
self,
59+
guidance: str,
60+
ref_image: str,
61+
) -> None:
62+
self.guidance = guidance
63+
self.ref_image = ref_image
64+
65+
def __call__(self, data):
66+
d = dict(data)
67+
current_shape = d[self.ref_image].shape[1:]
68+
69+
factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4])
70+
pos_clicks, neg_clicks = d["foreground"], d["background"]
71+
72+
pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else []
73+
neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else []
74+
75+
d[self.guidance] = [pos, neg]
76+
return d
77+
78+
79+
class ClickRatioAddRandomGuidanced(Randomizable, Transform):
80+
"""
81+
Add random guidance based on discrepancies that were found between label and prediction.
82+
Args:
83+
guidance: key to guidance source, shape (2, N, # of dim)
84+
discrepancy: key that represents discrepancies found between label and prediction, shape (2, C, D, H, W) or (2, C, H, W)
85+
probability: key that represents click/interaction probability, shape (1)
86+
fn_fp_click_ratio: ratio of clicks between FN and FP
87+
"""
88+
89+
def __init__(
90+
self,
91+
guidance: str = "guidance",
92+
discrepancy: str = "discrepancy",
93+
probability: str = "probability",
94+
fn_fp_click_ratio: Tuple[float, float] = (1.0, 1.0),
95+
):
96+
self.guidance = guidance
97+
self.discrepancy = discrepancy
98+
self.probability = probability
99+
self.fn_fp_click_ratio = fn_fp_click_ratio
100+
self._will_interact = None
101+
102+
def randomize(self, data=None):
103+
probability = data[self.probability]
104+
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])
105+
106+
def find_guidance(self, discrepancy):
107+
distance = distance_transform_cdt(discrepancy).flatten()
108+
probability = np.exp(distance) - 1.0
109+
idx = np.where(discrepancy.flatten() > 0)[0]
110+
111+
if np.sum(discrepancy > 0) > 0:
112+
seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))
113+
dst = distance[seed]
114+
115+
g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0]
116+
g[0] = dst[0]
117+
return g
118+
return None
119+
120+
def add_guidance(self, discrepancy, will_interact):
121+
if not will_interact:
122+
return None, None
123+
124+
pos_discr = discrepancy[0]
125+
neg_discr = discrepancy[1]
126+
127+
can_be_positive = np.sum(pos_discr) > 0
128+
can_be_negative = np.sum(neg_discr) > 0
129+
130+
pos_prob = self.fn_fp_click_ratio[0] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1])
131+
neg_prob = self.fn_fp_click_ratio[1] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1])
132+
133+
correct_pos = self.R.choice([True, False], p=[pos_prob, neg_prob])
134+
135+
if can_be_positive and not can_be_negative:
136+
return self.find_guidance(pos_discr), None
137+
138+
if not can_be_positive and can_be_negative:
139+
return None, self.find_guidance(neg_discr)
140+
141+
if correct_pos and can_be_positive:
142+
return self.find_guidance(pos_discr), None
143+
144+
if not correct_pos and can_be_negative:
145+
return None, self.find_guidance(neg_discr)
146+
return None, None
147+
148+
def _apply(self, guidance, discrepancy):
149+
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
150+
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
151+
pos, neg = self.add_guidance(discrepancy, self._will_interact)
152+
if pos:
153+
guidance[0].append(pos)
154+
guidance[1].append([-1] * len(pos))
155+
if neg:
156+
guidance[0].append([-1] * len(neg))
157+
guidance[1].append(neg)
158+
159+
return json.dumps(np.asarray(guidance).astype(int).tolist())
160+
161+
def __call__(self, data):
162+
d = dict(data)
163+
guidance = d[self.guidance]
164+
discrepancy = d[self.discrepancy]
165+
self.randomize(data)
166+
d[self.guidance] = self._apply(guidance, discrepancy)
167+
return d

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def run_testsuit():
3737
"test_csv_iterable_dataset",
3838
"test_dataset",
3939
"test_dataset_summary",
40+
"test_deepedit_transforms",
4041
"test_deepgrow_dataset",
4142
"test_deepgrow_interaction",
4243
"test_deepgrow_transforms",

tests/test_deepedit_transforms.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
from parameterized import parameterized
16+
17+
from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd
18+
19+
IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]])
20+
LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]])
21+
22+
DATA_1 = {
23+
"image": IMAGE,
24+
"label": LABEL,
25+
"image_meta_dict": {"dim": IMAGE.shape},
26+
"label_meta_dict": {},
27+
"foreground": [0, 0, 0],
28+
"background": [0, 0, 0],
29+
}
30+
31+
DISCARD_ADD_GUIDANCE_TEST_CASE = [
32+
{"image": IMAGE, "label": LABEL},
33+
DATA_1,
34+
(3, 1, 5, 5),
35+
]
36+
37+
DATA_2 = {
38+
"image": IMAGE,
39+
"label": LABEL,
40+
"guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
41+
"discrepancy": np.array(
42+
[
43+
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
44+
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
45+
]
46+
),
47+
"probability": 1.0,
48+
}
49+
50+
CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [
51+
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"},
52+
DATA_2,
53+
"[[[1, 0, 2, 2], [-1, -1, -1, -1]], [[-1, -1, -1, -1], [1, 0, 2, 1]]]",
54+
]
55+
56+
DATA_3 = {
57+
"image": np.arange(1000).reshape((1, 5, 10, 20)),
58+
"image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]},
59+
"guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]],
60+
"foreground": [[10, 14, 6], [10, 14, 8]],
61+
"background": [[10, 16, 8]],
62+
}
63+
64+
RESIZE_GUIDANCE_TEST_CASE_1 = [
65+
{"ref_image": "image", "guidance": "guidance"},
66+
DATA_3,
67+
[[[0, 0, 0], [0, 0, 1]], [[0, 0, 1]]],
68+
]
69+
70+
71+
class TestDiscardAddGuidanced(unittest.TestCase):
72+
@parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE])
73+
def test_correct_results(self, arguments, input_data, expected_result):
74+
add_fn = DiscardAddGuidanced(arguments)
75+
result = add_fn(input_data)
76+
self.assertEqual(result["image"].shape, expected_result)
77+
78+
79+
class TestClickRatioAddRandomGuidanced(unittest.TestCase):
80+
@parameterized.expand([CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1])
81+
def test_correct_results(self, arguments, input_data, expected_result):
82+
seed = 0
83+
add_fn = ClickRatioAddRandomGuidanced(**arguments)
84+
add_fn.set_random_state(seed)
85+
result = add_fn(input_data)
86+
self.assertEqual(result[arguments["guidance"]], expected_result)
87+
88+
89+
class TestResizeGuidanced(unittest.TestCase):
90+
@parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1])
91+
def test_correct_results(self, arguments, input_data, expected_result):
92+
result = ResizeGuidanceCustomd(**arguments)(input_data)
93+
self.assertEqual(result[arguments["guidance"]], expected_result)
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main()

0 commit comments

Comments
 (0)