Skip to content

Gate deep imports from torch.distributed#13673

Merged
DN6 merged 1 commit intohuggingface:mainfrom
hlky:torch-distributed-gate
May 8, 2026
Merged

Gate deep imports from torch.distributed#13673
DN6 merged 1 commit intohuggingface:mainfrom
hlky:torch-distributed-gate

Conversation

@hlky
Copy link
Copy Markdown
Contributor

@hlky hlky commented May 1, 2026

What does this PR do?

Some PyTorch builds, such as ROCm Windows, have non-functional torch.distributed. The module exists but any deeper imports such as from torch.distributed.fsdp import CPUOffload, ShardingStrategy will fail:

This affects various tests that indirectly import from diffusers.training_utils.

tests\pipelines\flux2\test_pipeline_flux2.py:17: in <module>
    from ..test_pipelines_common import (
.venv\Lib\site-packages\_pytest\assertion\rewrite.py:197: in exec_module
    exec(co, module.__dict__)
tests\pipelines\test_pipelines_common.py:61: in <module>
    from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
.venv\Lib\site-packages\_pytest\assertion\rewrite.py:197: in exec_module
    exec(co, module.__dict__)
tests\models\transformers\test_models_transformer_flux.py:27: in <module>
    from ..testing_utils import (
tests\models\testing_utils\__init__.py:37: in <module>
    from .training import TrainingTesterMixin
tests\models\testing_utils\training.py:22: in <module>
    from diffusers.training_utils import EMAModel
src\diffusers\training_utils.py:18: in <module>
    from torch.distributed.fsdp import CPUOffload, ShardingStrategy
.venv\Lib\site-packages\torch\distributed\fsdp\__init__.py:1: in <module>
    from ._flat_param import FlatParameter as FlatParameter
.venv\Lib\site-packages\torch\distributed\fsdp\_flat_param.py:31: in <module>
    from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
.venv\Lib\site-packages\torch\testing\_internal\distributed\fake_pg.py:4: in <module>
    from torch._C._distributed_c10d import FakeProcessGroup
E   ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package

Currently, the imports are guarded with getattr(torch, "distributed", None) is not None. I have found no evidence that PyTorch builds will completely omit distributed, it appears that the module should always exist and torch.distributed.is_available() will return False if it is non-functional.

This PR replaces the guard with torch.distributed.is_available().

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions Bot added the size/S PR with diff < 50 LOC label May 1, 2026
@hlky hlky force-pushed the torch-distributed-gate branch from 33b00d1 to 7afd122 Compare May 4, 2026 23:50
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels May 4, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6 DN6 merged commit 95c4339 into huggingface:main May 8, 2026
13 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants