Skip to content

Commit 6187529

Browse files
virginiafdezVirginia Fernandezpre-commit-ci[bot]ericspodcoderabbitai[bot]
authored
8627 perceptual loss errors out after hitting the maximum number of downloads (#8652)
Fixes #8627. Moves the perceptual loss code to MONAI repository https://github.com/Project-MONAI/perceptual-models and the checkpoints to Huggingface. ### Description This PR changes and simplifies the torch.hub loading process and gets the models from Huggingface lirbary. A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [N/A] New tests added to cover the changes. - [N/A] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk> Signed-off-by: Yun Liu <yunl@nvidia.com> Signed-off-by: Rafael Garcia-Dias <rafaelagd@gmail.com> Co-authored-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Yun Liu <yunl@nvidia.com> Co-authored-by: Rafael Garcia-Dias <rafaelagd@gmail.com>
1 parent a6b672a commit 6187529

3 files changed

Lines changed: 69 additions & 23 deletions

File tree

monai/losses/perceptual.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@
1919
from monai.utils import optional_import
2020
from monai.utils.enums import StrEnum
2121

22+
# Valid model name to download from the repository
23+
HF_MONAI_MODELS = frozenset(
24+
("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50")
25+
)
26+
2227
LPIPS, _ = optional_import("lpips", name="LPIPS")
2328
torchvision, _ = optional_import("torchvision")
2429

2530

26-
class PercetualNetworkType(StrEnum):
31+
class PerceptualNetworkType(StrEnum):
32+
"""Types of neural networks that are supported by perceptual loss."""
33+
2734
alex = "alex"
2835
vgg = "vgg"
2936
squeeze = "squeeze"
@@ -81,7 +88,7 @@ class PerceptualLoss(nn.Module):
8188
def __init__(
8289
self,
8390
spatial_dims: int,
84-
network_type: str = PercetualNetworkType.alex,
91+
network_type: str = PerceptualNetworkType.alex,
8592
is_fake_3d: bool = True,
8693
fake_3d_ratio: float = 0.5,
8794
cache_dir: str | None = None,
@@ -95,18 +102,26 @@ def __init__(
95102
if spatial_dims not in [2, 3]:
96103
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")
97104

98-
if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type:
99-
raise ValueError(
100-
"MedicalNet networks are only compatible with ``spatial_dims=3``."
101-
"Argument is_fake_3d must be set to False."
102-
)
103-
104-
if channel_wise and "medicalnet_" not in network_type:
105+
network_type = network_type.lower()
106+
107+
# Strict validation for MedicalNet
108+
if "medicalnet_" in network_type:
109+
if spatial_dims == 2 or is_fake_3d:
110+
raise ValueError(
111+
"MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False."
112+
)
113+
if not channel_wise:
114+
warnings.warn(
115+
"MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2
116+
)
117+
118+
# Channel-wise only for MedicalNet
119+
elif channel_wise:
105120
raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")
106121

107-
if network_type.lower() not in list(PercetualNetworkType):
122+
if network_type.lower() not in list(PerceptualNetworkType):
108123
raise ValueError(
109-
f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PercetualNetworkType)}"
124+
f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PerceptualNetworkType)}"
110125
)
111126
if cache_dir:
112127
torch.hub.set_dir(cache_dir)
@@ -117,12 +132,16 @@ def __init__(
117132

118133
self.spatial_dims = spatial_dims
119134
self.perceptual_function: nn.Module
135+
136+
# If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
120137
if spatial_dims == 3 and is_fake_3d is False:
121138
self.perceptual_function = MedicalNetPerceptualSimilarity(
122-
net=network_type, verbose=False, channel_wise=channel_wise
139+
net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir
123140
)
124141
elif "radimagenet_" in network_type:
125-
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
142+
self.perceptual_function = RadImageNetPerceptualSimilarity(
143+
net=network_type, verbose=False, cache_dir=cache_dir
144+
)
126145
elif network_type == "resnet50":
127146
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
128147
net=network_type,
@@ -131,7 +150,9 @@ def __init__(
131150
pretrained_state_dict_key=pretrained_state_dict_key,
132151
)
133152
else:
153+
# VGG, AlexNet and SqueezeNet are independently handled by LPIPS.
134154
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
155+
135156
self.is_fake_3d = is_fake_3d
136157
self.fake_3d_ratio = fake_3d_ratio
137158
self.channel_wise = channel_wise
@@ -203,22 +224,31 @@ class MedicalNetPerceptualSimilarity(nn.Module):
203224
"""
204225
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
205226
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
206-
"Warvito/MedicalNet-models".
227+
"Project-MONAI/perceptual-models".
207228
208229
Args:
209230
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
210231
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
211232
verbose: if false, mute messages from torch Hub load function.
212233
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
213-
Defaults to ``False``.
234+
Defaults to ``False``.
235+
cache_dir: path to cache directory to save the pretrained network weights.
214236
"""
215237

216238
def __init__(
217-
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
239+
self,
240+
net: str = "medicalnet_resnet10_23datasets",
241+
verbose: bool = False,
242+
channel_wise: bool = False,
243+
cache_dir: str | None = None,
218244
) -> None:
219245
super().__init__()
220-
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
221-
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True)
246+
if net not in HF_MONAI_MODELS:
247+
raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.")
248+
249+
self.model = torch.hub.load(
250+
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True
251+
)
222252
self.eval()
223253

224254
self.channel_wise = channel_wise
@@ -267,7 +297,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
267297
for i in range(input.shape[1]):
268298
l_idx = i * feats_per_ch
269299
r_idx = (i + 1) * feats_per_ch
270-
results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
300+
results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1)
271301
else:
272302
results = feats_diff.sum(dim=1, keepdim=True)
273303

@@ -296,17 +326,22 @@ class RadImageNetPerceptualSimilarity(nn.Module):
296326
"""
297327
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
298328
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
299-
uses torch Hub to download the networks from "Warvito/radimagenet-models".
329+
uses torch Hub to download the networks from "Project-MONAI/perceptual-models".
300330
301331
Args:
302332
net: {``"radimagenet_resnet50"``}
303333
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
304334
verbose: if false, mute messages from torch Hub load function.
335+
cache_dir: path to cache directory to save the pretrained network weights.
305336
"""
306337

307-
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
338+
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None:
308339
super().__init__()
309-
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True)
340+
if net not in HF_MONAI_MODELS:
341+
raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.")
342+
self.model = torch.hub.load(
343+
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True
344+
)
310345
self.eval()
311346

312347
for param in self.parameters():

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import sys
1818
import warnings
19+
from typing import Any, cast
1920

2021
from packaging import version
2122
from setuptools import find_packages, setup
@@ -146,6 +147,6 @@ def get_cmds():
146147
cmdclass=get_cmds(),
147148
packages=find_packages(exclude=("docs", "examples", "tests", "tests.*")),
148149
zip_safe=False,
149-
package_data={"monai": ["py.typed", *jit_extension_source]}, # type: ignore[arg-type]
150+
package_data=cast(Any, {"monai": ["py.typed", *jit_extension_source]}),
150151
ext_modules=get_extensions(),
151152
)

tests/losses/test_perceptual_loss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ def test_medicalnet_on_2d_data(self, network_type):
116116
with self.assertRaises(ValueError):
117117
PerceptualLoss(spatial_dims=2, network_type=network_type)
118118

119+
@parameterized.expand(["squeeze", "alex", "vgg", "radimagenet_resnet50", "resnet50"])
120+
def test_channel_wise_with_non_medicalnet(self, network_type):
121+
with self.assertRaises(ValueError):
122+
PerceptualLoss(spatial_dims=2, network_type=network_type, channel_wise=True)
123+
124+
@parameterized.expand(["squeeze", "alex", "vgg", "radimagenet_resnet50", "resnet50"])
125+
def test_non_medicalnet_3d_without_fake_3d(self, network_type):
126+
with self.assertRaises(ValueError):
127+
PerceptualLoss(spatial_dims=3, network_type=network_type, is_fake_3d=False)
128+
119129

120130
if __name__ == "__main__":
121131
unittest.main()

0 commit comments

Comments
 (0)