Skip to content

Commit 9e97fe5

Browse files
authored
Upgrade vertebra pipeline (#1222)
Signed-off-by: Andres <diazandr3s@gmail.com> Signed-off-by: Andres <diazandr3s@gmail.com>
1 parent ed6e4a6 commit 9e97fe5

6 files changed

Lines changed: 48 additions & 32 deletions

File tree

docs/images/localization_spine.png

90 KB
Loading

docs/images/vertebra-pipeline.png

143 KB
Loading

sample-apps/radiology/lib/configs/localization_spine.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import lib.infers
1717
import lib.trainers
18-
from monai.networks.nets import UNet
18+
from monai.networks.nets import SegResNet
1919

2020
from monailabel.interfaces.config import TaskConfig
2121
from monailabel.interfaces.tasks.infer_v2 import InferTask
@@ -66,22 +66,22 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
6666
# Download PreTrained Model
6767
if strtobool(self.conf.get("use_pretrained_model", "true")):
6868
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
69-
url = f"{url}/radiology_segmentation_unet_localization_spine.pt"
69+
url = f"{url}/radiology_segmentation_segresnet_localization_spine.pt"
7070
download_file(url, self.path[0])
7171

7272
self.target_spacing = (1.3, 1.3, 1.3) # target space for image
7373
# Setting ROI size should consider max width, height and depth of the images
7474
self.roi_size = (96, 96, 96) # sliding window size for train and infer
7575

7676
# Network
77-
self.network = UNet(
77+
self.network = SegResNet(
7878
spatial_dims=3,
7979
in_channels=1,
8080
out_channels=len(self.labels) + 1, # labels plus background,
81-
channels=(16, 32, 64, 128),
82-
strides=(2, 2, 2),
83-
num_res_units=2,
84-
dropout=0.2,
81+
init_filters=32,
82+
blocks_down=(1, 2, 2, 4),
83+
blocks_up=(1, 1, 1),
84+
dropout_prob=0.2,
8585
)
8686

8787
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:

sample-apps/radiology/lib/configs/localization_vertebra.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import lib.infers
1717
import lib.trainers
18-
from monai.networks.nets import UNet
18+
from monai.networks.nets import SegResNet
1919

2020
from monailabel.interfaces.config import TaskConfig
2121
from monailabel.interfaces.tasks.infer_v2 import InferTask
@@ -66,22 +66,22 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
6666
# Download PreTrained Model
6767
if strtobool(self.conf.get("use_pretrained_model", "true")):
6868
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
69-
url = f"{url}/radiology_segmentation_unet_localization_vertebra.pt"
69+
url = f"{url}/radiology_segmentation_segresnet_localization_vertebra.pt"
7070
download_file(url, self.path[0])
7171

7272
self.target_spacing = (1.3, 1.3, 1.3) # target space for image
7373
# Setting ROI size - This is for the image padding
7474
self.roi_size = (96, 96, 96)
7575

7676
# Network
77-
self.network = UNet(
77+
self.network = SegResNet(
7878
spatial_dims=3,
7979
in_channels=1,
8080
out_channels=len(self.labels) + 1, # labels plus background,
81-
channels=(16, 32, 64, 128),
82-
strides=(2, 2, 2),
83-
num_res_units=2,
84-
dropout=0.2,
81+
init_filters=32,
82+
blocks_down=(1, 2, 2, 4),
83+
blocks_up=(1, 1, 1),
84+
dropout_prob=0.2,
8585
)
8686

8787
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:

sample-apps/radiology/lib/configs/segmentation_vertebra.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import lib.infers
1717
import lib.trainers
18-
from monai.networks.nets import UNet
18+
from monai.networks.nets import SegResNet
1919

2020
from monailabel.interfaces.config import TaskConfig
2121
from monailabel.interfaces.tasks.infer_v2 import InferTask
@@ -66,21 +66,21 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
6666
# Download PreTrained Model
6767
if strtobool(self.conf.get("use_pretrained_model", "true")):
6868
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
69-
url = f"{url}/radiology_segmentation_unet_vertebra.pt"
69+
url = f"{url}/radiology_segmentation_segresnet_vertebra.pt"
7070
download_file(url, self.path[0])
7171

7272
self.target_spacing = (1.0, 1.0, 1.0) # target space for image
7373
self.roi_size = (128, 128, 96)
7474

7575
# Network
76-
self.network = UNet(
76+
self.network = SegResNet(
7777
spatial_dims=3,
7878
in_channels=2,
7979
out_channels=2,
80-
channels=(16, 32, 64, 128, 256),
81-
strides=(2, 2, 2, 2),
82-
num_res_units=2,
83-
dropout=0.2,
80+
init_filters=32,
81+
blocks_down=(1, 2, 2, 4),
82+
blocks_up=(1, 1, 1),
83+
dropout_prob=0.2,
8484
)
8585

8686
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:

sample-apps/radiology/lib/infers/localization_vertebra.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Any, Callable, Sequence, Tuple
12+
from typing import Any, Callable, Sequence, Tuple, Union
1313

14-
from lib.transforms.transforms import VertebraLocalizationSegmentation
14+
import torch
15+
from lib.transforms.transforms import CacheObjectd, VertebraLocalizationSegmentation
1516
from monai.inferers import Inferer, SlidingWindowInferer
1617
from monai.transforms import (
1718
Activationsd,
1819
AsDiscreted,
20+
CropForegroundd,
1921
EnsureChannelFirstd,
2022
EnsureTyped,
2123
GaussianSmoothd,
@@ -64,32 +66,46 @@ def pre_transforms(self, data=None) -> Sequence[Callable]:
6466
LoadImaged(keys="image", reader="ITKReader"),
6567
EnsureTyped(keys="image", device=data.get("device") if data else None),
6668
EnsureChannelFirstd(keys="image"),
69+
CacheObjectd(keys="image"),
70+
Spacingd(keys="image", pixdim=self.target_spacing, allow_missing_keys=True),
71+
ScaleIntensityRanged(keys="image", a_min=-1000, a_max=1900, b_min=0.0, b_max=1.0, clip=True),
72+
GaussianSmoothd(keys="image", sigma=0.4),
73+
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
6774
]
6875
else:
69-
t = []
70-
71-
t.extend(
72-
[
73-
Spacingd(keys="image", pixdim=self.target_spacing),
76+
t = [
77+
EnsureChannelFirstd(keys="label"),
78+
CacheObjectd(keys="image"),
79+
Spacingd(keys=("image", "label"), pixdim=self.target_spacing),
7480
ScaleIntensityRanged(keys="image", a_min=-1000, a_max=1900, b_min=0.0, b_max=1.0, clip=True),
7581
GaussianSmoothd(keys="image", sigma=0.4),
7682
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
83+
CropForegroundd(keys=("image", "label"), source_key="label", margin=10),
7784
]
78-
)
85+
7986
return t
8087

8188
def inferer(self, data=None) -> Inferer:
8289
return SlidingWindowInferer(
83-
roi_size=self.roi_size, sw_batch_size=2, overlap=0.4, padding_mode="replicate", mode="gaussian"
90+
roi_size=self.roi_size,
91+
sw_batch_size=2,
92+
overlap=0.4,
93+
padding_mode="replicate",
94+
mode="gaussian",
95+
device=torch.device("cpu"), # Otherwise a rather big GPU (>45GB) is needed
8496
)
8597

98+
def inverse_transforms(self, data=None) -> Union[None, Sequence[Callable]]:
99+
return [] # Self-determine from the list of pre-transforms provided
100+
86101
def post_transforms(self, data=None) -> Sequence[Callable]:
87102
return [
88-
EnsureTyped(keys="pred", device=data.get("device") if data else None),
103+
# Otherwise a rather big GPU (>45GB) is needed
104+
EnsureTyped(keys="pred", device=torch.device("cpu")),
89105
Activationsd(keys="pred", softmax=True),
90106
AsDiscreted(keys="pred", argmax=True),
91107
KeepLargestConnectedComponentd(keys="pred"),
92-
Restored(keys="pred", ref_image="image"),
108+
Restored(keys="pred", ref_image="image_cached"),
93109
VertebraLocalizationSegmentation(keys="pred", result="result"),
94110
]
95111

0 commit comments

Comments
 (0)