From 253507b79b67a79f50aa20e3610ce983f8eb31fb Mon Sep 17 00:00:00 2001 From: xodn348 Date: Tue, 5 May 2026 15:14:40 +0000 Subject: [PATCH] fix(image): lazy-import torchvision in arniqa and dists to prevent circular import Module-level 'from torchvision import ...' blocks inside the 'if _TORCHVISION_AVAILABLE:' guard in arniqa.py and dists.py ran at import time, which triggered torchvision's own __init__ chain. In environments where torchvision is only partially initialised (e.g. Kaggle, or any scenario with a circular-import race in torchvision's _meta_registrations), importing torchmetrics raised an AttributeError before the user code ran. Move the torchvision symbols into the functions/methods that actually use them, matching the pattern already used by lpips.py. --- src/torchmetrics/functional/image/arniqa.py | 8 ++++---- src/torchmetrics/functional/image/dists.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/image/arniqa.py b/src/torchmetrics/functional/image/arniqa.py index b5fbf556a1c..c4e084a0555 100644 --- a/src/torchmetrics/functional/image/arniqa.py +++ b/src/torchmetrics/functional/image/arniqa.py @@ -27,10 +27,6 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCHVISION_AVAILABLE -if _TORCHVISION_AVAILABLE: - from torchvision import transforms - from torchvision.models import resnet50 - _AVAILABLE_REGRESSOR_DATASETS = { "kadid10k": (1, 5), "koniq10k": (1, 100), @@ -65,6 +61,8 @@ def __init__(self, regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k") -> N " Either install as `pip install torchmetrics[image]` or `pip install torchvision`." ) + from torchvision.models import resnet50 + valid_regressor_datasets = _AVAILABLE_REGRESSOR_DATASETS.keys() if regressor_dataset not in valid_regressor_datasets: raise ValueError( @@ -116,6 +114,8 @@ def _preprocess_input(self, img: Tensor, normalize: bool = False) -> tuple[Tenso Obtains the half-scale version of the input image and applies normalization if needed. """ + from torchvision import transforms + h, w = img.shape[-2:] img_ds = transforms.Resize((h // 2, w // 2))(img) # get the half-scale version of the image if normalize: diff --git a/src/torchmetrics/functional/image/dists.py b/src/torchmetrics/functional/image/dists.py index f3a81e022ef..abd803a3c7d 100644 --- a/src/torchmetrics/functional/image/dists.py +++ b/src/torchmetrics/functional/image/dists.py @@ -47,8 +47,6 @@ if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["deep_image_structure_and_texture_similarity"] -else: - from torchvision.models import VGG16_Weights, vgg16 _PATH_WEIGHT_DISTS = Path(__file__).resolve().parent / "dists_models" / "weights.pt" @@ -91,6 +89,8 @@ def __init__(self, load_weights: bool = True) -> None: "DISTS requires torchvision to be installed. Please install it with `pip install torchvision`." ) + from torchvision.models import VGG16_Weights, vgg16 + vgg_pretrained_features = vgg16(weights=VGG16_Weights.DEFAULT).features self.stage1 = torch.nn.Sequential() self.stage2 = torch.nn.Sequential()