Skip to content

fix(image): lazy-import torchvision in arniqa and dists modules#3373

Open
xodn348 wants to merge 1 commit intoLightning-AI:masterfrom
xodn348:fix/arniqa-lazy-torchvision-import
Open

fix(image): lazy-import torchvision in arniqa and dists modules#3373
xodn348 wants to merge 1 commit intoLightning-AI:masterfrom
xodn348:fix/arniqa-lazy-torchvision-import

Conversation

@xodn348
Copy link
Copy Markdown

@xodn348 xodn348 commented May 5, 2026

Summary

arniqa.py and dists.py imported symbols from torchvision at module load
time inside an if _TORCHVISION_AVAILABLE: / else: guard. Because that
block executes unconditionally when torchvision is installed, it triggered
torchvision's own __init__ chain the moment torchmetrics.functional.image
was imported. In environments where torchvision's initialisation races with
itself (e.g. Kaggle's GPU runtime, or any build where
torchvision._meta_registrations tries to import torchvision.extension
before torchvision itself is fully loaded), the eager import raised an
AttributeError: partially initialized module 'torchvision' has no attribute 'extension', breaking from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity and every other top-level import of
torchmetrics.

The fix moves the from torchvision import ... / from torchvision.models import ... statements inside the class constructors and methods that actually
use them, matching the lazy-import pattern that lpips.py already uses
(import torchvision inside _get_tv_model_features). The if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError guard already runs first, so
behaviour is unchanged when torchvision is absent.

Issue

Fixes #3314

Local verification

$ cd /tmp/torchmetrics
$ ruff check src/torchmetrics/functional/image/arniqa.py \
             src/torchmetrics/functional/image/dists.py
All checks passed!

$ python3 -m pytest tests/unittests/image/test_arniqa.py::test_error_on_wrong_init \
                    tests/unittests/image/test_arniqa.py::test_error_on_wrong_input_shape \
                    tests/unittests/image/test_arniqa.py::test_error_on_wrong_normalize_value \
                    tests/unittests/image/test_arniqa.py::test_check_for_backprop -v
============================= test session starts ==============================
platform linux -- Python 3.11.15, pytest-9.0.3, pluggy-1.6.0
plugins: rerunfailures-16.1, plus-0.8.1, xdist-3.8.0, doctestplus-1.7.1
collected 4 items

tests/unittests/image/test_arniqa.py::test_error_on_wrong_init PASSED   [ 25%]
tests/unittests/image/test_arniqa.py::test_error_on_wrong_input_shape PASSED   [ 50%]
tests/unittests/image/test_arniqa.py::test_error_on_wrong_normalize_value PASSED   [ 75%]
tests/unittests/image/test_arniqa.py::test_check_for_backprop PASSED    [100%]

======================== 4 passed, 8 warnings in 4.39s =========================

$ python3 -c "
import sys
for mod in list(sys.modules.keys()):
    if 'torchmetrics' in mod or 'torchvision' in mod:
        del sys.modules[mod]

import builtins; orig = builtins.__import__
eager = []
def tracing(name, g=None, l=None, fl=(), lv=0):
    if 'torchvision' in name and not name.startswith('torchvision'):
        import traceback
        for fr in reversed(traceback.extract_stack()[:-2]):
            if 'torchmetrics' in (fr.filename or '') and 'arniqa' in fr.filename:
                eager.append(name)
                break
            if 'torchmetrics' in (fr.filename or '') and 'dists' in fr.filename:
                eager.append(name)
                break
    return orig(name, g, l, fl, lv)
builtins.__import__ = tracing

import torchmetrics.functional.image
builtins.__import__ = orig
print('eager torchvision imports from arniqa/dists:', eager)
assert not eager, 'still importing torchvision eagerly!'
print('PASS')
"
eager torchvision imports from arniqa/dists: []
PASS
=== LOCAL_TEST_PASSED ===

Risk

The change only affects import time: symbols that were previously imported once
at module load are now imported on first use inside their respective
constructor/method. Python caches modules in sys.modules, so the runtime
cost per call is negligible (a dict lookup). The lpips.py module already
uses this pattern successfully, confirming it is safe within this code base.
Users who do not have torchvision installed will continue to see the existing
ModuleNotFoundError on first instantiation.


📚 Documentation preview 📚: https://torchmetrics--3373.org.readthedocs.build/en/3373/

…rcular import

Module-level 'from torchvision import ...' blocks inside the
'if _TORCHVISION_AVAILABLE:' guard in arniqa.py and dists.py ran at
import time, which triggered torchvision's own __init__ chain.  In
environments where torchvision is only partially initialised (e.g.
Kaggle, or any scenario with a circular-import race in torchvision's
_meta_registrations), importing torchmetrics raised an AttributeError
before the user code ran.

Move the torchvision symbols into the functions/methods that actually
use them, matching the pattern already used by lpips.py.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Failed to import from torchmetrics.image.lpip

1 participant