Skip to content

Commit 95c4339

Browse files
authored
Gate deep imports from torch.distributed (#13673)
1 parent ebaa187 commit 95c4339

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/diffusers/training_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414
import torch.nn.functional as F
1515

1616

17-
if getattr(torch, "distributed", None) is not None:
17+
if torch.distributed.is_available():
1818
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
1919
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2020
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
21+
else:
22+
CPUOffload = None
23+
ShardingStrategy = None
24+
FSDP = None
25+
transformer_auto_wrap_policy = None
2126

2227
from .models import UNet2DConditionModel
2328
from .pipelines import DiffusionPipeline

0 commit comments

Comments
 (0)