We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
torch.distributed
1 parent ebaa187 commit 95c4339Copy full SHA for 95c4339
1 file changed
src/diffusers/training_utils.py
@@ -14,10 +14,15 @@
14
import torch.nn.functional as F
15
16
17
-if getattr(torch, "distributed", None) is not None:
+if torch.distributed.is_available():
18
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
19
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
21
+else:
22
+ CPUOffload = None
23
+ ShardingStrategy = None
24
+ FSDP = None
25
+ transformer_auto_wrap_policy = None
26
27
from .models import UNet2DConditionModel
28
from .pipelines import DiffusionPipeline
0 commit comments