diff --git a/src/torchmetrics/functional/image/arniqa.py b/src/torchmetrics/functional/image/arniqa.py index b5fbf556a1c..c4e084a0555 100644 --- a/src/torchmetrics/functional/image/arniqa.py +++ b/src/torchmetrics/functional/image/arniqa.py @@ -27,10 +27,6 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCHVISION_AVAILABLE -if _TORCHVISION_AVAILABLE: - from torchvision import transforms - from torchvision.models import resnet50 - _AVAILABLE_REGRESSOR_DATASETS = { "kadid10k": (1, 5), "koniq10k": (1, 100), @@ -65,6 +61,8 @@ def __init__(self, regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k") -> N " Either install as `pip install torchmetrics[image]` or `pip install torchvision`." ) + from torchvision.models import resnet50 + valid_regressor_datasets = _AVAILABLE_REGRESSOR_DATASETS.keys() if regressor_dataset not in valid_regressor_datasets: raise ValueError( @@ -116,6 +114,8 @@ def _preprocess_input(self, img: Tensor, normalize: bool = False) -> tuple[Tenso Obtains the half-scale version of the input image and applies normalization if needed. """ + from torchvision import transforms + h, w = img.shape[-2:] img_ds = transforms.Resize((h // 2, w // 2))(img) # get the half-scale version of the image if normalize: diff --git a/src/torchmetrics/functional/image/dists.py b/src/torchmetrics/functional/image/dists.py index f3a81e022ef..abd803a3c7d 100644 --- a/src/torchmetrics/functional/image/dists.py +++ b/src/torchmetrics/functional/image/dists.py @@ -47,8 +47,6 @@ if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["deep_image_structure_and_texture_similarity"] -else: - from torchvision.models import VGG16_Weights, vgg16 _PATH_WEIGHT_DISTS = Path(__file__).resolve().parent / "dists_models" / "weights.pt" @@ -91,6 +89,8 @@ def __init__(self, load_weights: bool = True) -> None: "DISTS requires torchvision to be installed. Please install it with `pip install torchvision`." ) + from torchvision.models import VGG16_Weights, vgg16 + vgg_pretrained_features = vgg16(weights=VGG16_Weights.DEFAULT).features self.stage1 = torch.nn.Sequential() self.stage2 = torch.nn.Sequential()