File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 9999 get_optimizer_state_dict ,
100100 set_optimizer_state_dict ,
101101)
102- from torch .distributed .fsdp import (
103- fully_shard ,
104- )
102+
103+ try :
104+ from torch .distributed .fsdp import (
105+ fully_shard ,
106+ )
107+ except ImportError :
108+ fully_shard = None # type: ignore[assignment]
105109from torch .distributed .optim import (
106110 ZeroRedundancyOptimizer ,
107111)
@@ -849,6 +853,15 @@ def single_model_finetune(
849853 if self .is_distributed :
850854 torch .cuda .set_device (LOCAL_RANK )
851855 if self .zero_stage >= 2 :
856+ if fully_shard is None :
857+ raise RuntimeError (
858+ "training.zero_stage>=2 requires FSDP2, which is only "
859+ "available in PyTorch >= 2.6 "
860+ "(``torch.distributed.fsdp.fully_shard``). "
861+ f"Current PyTorch is { torch .__version__ } . "
862+ "Please upgrade PyTorch, or set training.zero_stage "
863+ "to 0 or 1 to stay on the DDP / ZeRO-1 path."
864+ )
852865 # FSDP2 does NOT broadcast params (unlike DDP constructor).
853866 # Ensure all ranks share identical weights before sharding.
854867 for p in self .wrapper .parameters ():
You can’t perform that action at this time.
0 commit comments