Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/torchmetrics/functional/image/arniqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@

from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.models import resnet50

_AVAILABLE_REGRESSOR_DATASETS = {
"kadid10k": (1, 5),
"koniq10k": (1, 100),
Expand Down Expand Up @@ -65,6 +61,8 @@ def __init__(self, regressor_dataset: _TYPE_REGRESSOR_DATASET = "koniq10k") -> N
" Either install as `pip install torchmetrics[image]` or `pip install torchvision`."
)

from torchvision.models import resnet50

valid_regressor_datasets = _AVAILABLE_REGRESSOR_DATASETS.keys()
if regressor_dataset not in valid_regressor_datasets:
raise ValueError(
Expand Down Expand Up @@ -116,6 +114,8 @@ def _preprocess_input(self, img: Tensor, normalize: bool = False) -> tuple[Tenso
Obtains the half-scale version of the input image and applies normalization if needed.

"""
from torchvision import transforms

h, w = img.shape[-2:]
img_ds = transforms.Resize((h // 2, w // 2))(img) # get the half-scale version of the image
if normalize:
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@

if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["deep_image_structure_and_texture_similarity"]
else:
from torchvision.models import VGG16_Weights, vgg16

_PATH_WEIGHT_DISTS = Path(__file__).resolve().parent / "dists_models" / "weights.pt"

Expand Down Expand Up @@ -91,6 +89,8 @@ def __init__(self, load_weights: bool = True) -> None:
"DISTS requires torchvision to be installed. Please install it with `pip install torchvision`."
)

from torchvision.models import VGG16_Weights, vgg16

vgg_pretrained_features = vgg16(weights=VGG16_Weights.DEFAULT).features
self.stage1 = torch.nn.Sequential()
self.stage2 = torch.nn.Sequential()
Expand Down
Loading