From 75515f557b2364ed71cc4cc38cdb33fcd666d05d Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 23 Apr 2026 22:51:10 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20fsdp=20unavailable=20in=20older=20ve?= =?UTF-8?q?rsion=20of=20pytorch=20(=E2=89=A62.5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepmd/pt/train/training.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f2b11d50e2..4147ae739c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -103,9 +103,13 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) -from torch.distributed.fsdp import ( - fully_shard, -) + +try: + from torch.distributed.fsdp import ( + fully_shard, + ) +except ImportError: + fully_shard = None # type: ignore[assignment] from torch.distributed.optim import ( ZeroRedundancyOptimizer, ) @@ -853,6 +857,15 @@ def single_model_finetune( if self.is_distributed: torch.cuda.set_device(LOCAL_RANK) if self.zero_stage >= 2: + if fully_shard is None: + raise RuntimeError( + "training.zero_stage>=2 requires FSDP2, which is only " + "available in PyTorch >= 2.6 " + "(``torch.distributed.fsdp.fully_shard``). " + f"Current PyTorch is {torch.__version__}. " + "Please upgrade PyTorch, or set training.zero_stage " + "to 0 or 1 to stay on the DDP / ZeRO-1 path." + ) # FSDP2 does NOT broadcast params (unlike DDP constructor). # Ensure all ranks share identical weights before sharding. for p in self.wrapper.parameters(): From 6aded53a9864564c5b71fa0030124e82a7f544d0 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 25 Apr 2026 10:31:03 +0800 Subject: [PATCH 2/2] fixup --- deepmd/utils/argcheck.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d69cc3dc89..401aa6c922 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3948,6 +3948,7 @@ def training_args( "but reduces optimizer memory to 1/N per GPU. " "2: FSDP2 stage-2, shards optimizer states and gradients; same communication " "volume as stage-1 but further reduces gradient memory to 1/N per GPU. " + "Stages 2 and 3 require FSDP2, which is available in PyTorch >= 2.6. " "Note: FSDP2 introduces DTensor dispatch overhead that can slow down " "models with many small layers; use torch.compile to mitigate. " "3: FSDP2 stage-3, shards parameters as well; maximum memory savings but "