FSDP2 Support#3733
Conversation
| raise RuntimeError( | ||
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.0." | ||
| ) |
There was a problem hiding this comment.
Minimal version of pytorch when fsdp2 appeared is 2.6
| with_score=self.score_function is not None, | ||
| with_score_name=self.score_name is not None, | ||
| with_global_step=global_step is not None, | ||
| with_ext=bool(self.ext), |
There was a problem hiding this comment.
Let's rename this param as as_folder
There was a problem hiding this comment.
Pull request overview
Adds FSDP2 integration to Ignite’s distributed model wrapping and checkpointing utilities, enabling auto_model(..., use_fsdp=True) to use fully_shard and extending checkpoint save/load to support FSDP2 (including DCP directory checkpoints).
Changes:
- Add
use_fsdpflag toignite.distributed.auto.auto_modelto apply FSDP2fully_shardin native distributed contexts and rejectsync_bn+use_fsdp. - Extend
ignite.handlers.checkpoint.Checkpointto recognize FSDP2 (FSDPModule) for save/load, and introduceDCPSaverfortorch.distributed.checkpointdirectory-based saves. - Add/extend tests for FSDP2 wrapping and checkpoint behavior (single-process and distributed).
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
ignite/distributed/auto.py |
Adds use_fsdp support via fully_shard, plus SyncBN incompatibility handling. |
ignite/handlers/checkpoint.py |
Adds FSDP2 save/load support, DCP directory load path, and new DCPSaver. |
tests/ignite/distributed/test_auto.py |
Adds pytest-distributed coverage for auto_model(..., use_fsdp=True) behavior. |
tests/ignite/handlers/test_checkpoint.py |
Adds distributed GPU coverage for FSDP2 checkpoint save/load. |
tests/ignite/test_fsdp_smoke.py |
Adds single-process regression/smoke tests around new FSDP2-related behavior. |
tests/ignite/test_fsdp_distributed.py |
Adds a standalone multiprocess script to exercise wrapping and checkpoint save in gloo. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from ignite.utils import _tree_apply2, _tree_map | ||
|
|
||
| __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] | ||
| __all__ = ["Checkpoint", "DiskSaver", "DCPSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] |
There was a problem hiding this comment.
DCPSaver is added to ignite.handlers.checkpoint.__all__, but it is not re-exported from ignite/handlers/__init__.py (which currently exports Checkpoint, DiskSaver, ModelCheckpoint, etc.). If DCPSaver is intended as part of the public API (so users can from ignite.handlers import DCPSaver), it should be imported and included in that package __all__ as well.
| import torch.distributed as dist | ||
| import torch.nn as nn | ||
|
|
||
| sys.path.insert(0, "/pfss/mlde/workspaces/mlde_wsp_MazaheriA/tk27ryru/ignite") |
There was a problem hiding this comment.
sys.path.insert(0, "/pfss/.../ignite") hard-codes a developer-specific absolute path, which will break for everyone else and in CI. Remove this and rely on normal imports (run from repo root / editable install), or compute a repo-relative path from __file__ if a path tweak is truly needed for standalone execution.
| sys.path.insert(0, "/pfss/mlde/workspaces/mlde_wsp_MazaheriA/tk27ryru/ignite") | |
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) | |
| if REPO_ROOT not in sys.path: | |
| sys.path.insert(0, REPO_ROOT) |
There was a problem hiding this comment.
@intelava let's remove all local paths like this one
| torch.cuda.set_device(rank) | ||
|
|
||
| try: | ||
| import ignite.distributed as idist |
There was a problem hiding this comment.
ignite.distributed as idist is imported but never used in this file. With Ruff/Pyflakes enabled, this will raise an unused-import error and fail linting.
| import ignite.distributed as idist |
| """ | ||
| Distributed smoke test for FSDP wrapping via auto_model. | ||
| Must be run as a standalone script (not via pytest) using: | ||
| python tests/ignite/test_fsdp_distributed.py | ||
| """ |
There was a problem hiding this comment.
This file is placed under tests/ but is explicitly “not via pytest” and uses if __name__ == "__main__" with manual process management. That means it won’t run in the normal test suite and can confuse contributors/CI expectations. Consider converting this into pytest-discoverable distributed tests (similar to tests/ignite/distributed/*) or moving it under a scripts//examples/ location with appropriate documentation.
| same device as the wrapped model to forward correctly. | ||
| """ | ||
| from ignite.distributed.auto import auto_model | ||
| import ignite.distributed as idist |
There was a problem hiding this comment.
import ignite.distributed as idist is unused, which will fail Ruff/Pyflakes (F401) if linting is enabled for tests/. Remove it or use it in an assertion (e.g., to document expected backend/device).
| import ignite.distributed as idist |
| if use_fsdp: | ||
| if not HAVE_FSDP2: | ||
| raise RuntimeError( | ||
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.0." | ||
| ) | ||
| if "mesh" not in kwargs: | ||
| kwargs["mesh"] = init_device_mesh(device.type, (idist.get_world_size(),)) | ||
| logger.info("Apply torch FSDP2 (fully_shard) on model") | ||
| model = fully_shard(model, **kwargs) |
There was a problem hiding this comment.
In the use_fsdp branch, DDP-specific kwargs like device_ids are no longer validated/blocked. Passing device_ids with use_fsdp=True will be forwarded into fully_shard(...) and raise a confusing TypeError. Consider explicitly rejecting device_ids (and any other known DDP-only kwargs) when use_fsdp=True to preserve the clearer error behavior.
| # In non-dist mode, FSDP is not instantiated so kwargs are irrelevant. | ||
| result = auto_model(model, use_fsdp=True) |
There was a problem hiding this comment.
test_use_fsdp_kwargs_passthrough_non_distributed claims to verify extra-kwargs passthrough/ignoring, but it never actually passes any extra kwargs. Add at least one representative FSDP kwarg (e.g., reshard_after_forward=False) to ensure the behavior is exercised.
| # In non-dist mode, FSDP is not instantiated so kwargs are irrelevant. | |
| result = auto_model(model, use_fsdp=True) | |
| # In non-dist mode, FSDP is not instantiated, so representative FSDP kwargs | |
| # such as ``reshard_after_forward`` should be accepted and ignored. | |
| result = auto_model(model, use_fsdp=True, reshard_after_forward=False) |
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "12399" | ||
| os.environ["RANK"] = str(rank) | ||
| os.environ["WORLD_SIZE"] = str(world_size) |
There was a problem hiding this comment.
This script binds to fixed ports (12399/12400). That’s prone to flakes when the port is already in use (parallel CI, local dev). Prefer selecting a free port dynamically (e.g., bind to port 0 via socket and propagate it) or accept MASTER_PORT from the environment/CLI.
| dist.init_process_group(backend, rank=rank, world_size=world_size) | ||
|
|
||
| try: | ||
| import ignite.distributed as idist |
There was a problem hiding this comment.
ignite.distributed as idist is imported but never used in this function. This will trigger an unused-import lint failure under Ruff/Pyflakes.
| import ignite.distributed as idist |
| state_dicts = {} | ||
| for k, obj in to_load.items(): | ||
| if isinstance(obj, FSDPModule): | ||
| state_dicts[k] = get_model_state_dict(obj) | ||
| else: | ||
| state_dicts[k] = obj.state_dict() | ||
| dcp.load(state_dicts, storage_reader=FileSystemReader(checkpoint_path)) | ||
| for k, obj in to_load.items(): | ||
| if isinstance(obj, FSDPModule): | ||
| set_model_state_dict(obj, state_dicts[k]) | ||
| else: | ||
| obj.load_state_dict(state_dicts[k]) |
There was a problem hiding this comment.
The DCP-directory load path does not unwrap DataParallel/DistributedDataParallel objects. However _setup_checkpoint() unwraps them when saving, so the directory checkpoint keys will be for the underlying module, while here the template state_dict is taken from the wrapper (with module. prefixes). This mismatch will make dcp.load(...) and/or load_state_dict fail for DP/DDP models. Apply the same unwrapping logic as the non-directory path (use obj.module for DP/DDP both when building state_dicts and when calling load_state_dict).
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_save_and_load_ddp_wrapped_model(self) -> None: | ||
| """Checkpoint must unwrap DataParallel and save the inner module's state.""" | ||
| from ignite.handlers import Checkpoint, DiskSaver | ||
|
|
||
| torch.manual_seed(99) | ||
| model = _SimpleLinear() | ||
|
|
||
| # Only wrap with DataParallel if multiple GPUs are present; else bare model. | ||
| if torch.cuda.device_count() > 1: | ||
| wrapped = nn.DataParallel(model) | ||
| else: | ||
| # Simulate DP wrapping using a stub that exposes .module. | ||
| class _FakeDP(nn.Module): | ||
| def __init__(self, m: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = m | ||
|
|
||
| def state_dict(self, **kw): # type: ignore[override] | ||
| return self.module.state_dict(**kw) | ||
|
|
||
| def load_state_dict(self, sd, **kw): # type: ignore[override] | ||
| return self.module.load_state_dict(sd, **kw) | ||
|
|
||
| wrapped = _FakeDP(model) | ||
|
|
There was a problem hiding this comment.
wrapped is created (DataParallel or FakeDP) but never used; the checkpoint is created with the plain model instead. This means the test isn’t actually exercising the “unwrap DataParallel” behavior described in the docstring. Either pass wrapped into Checkpoint({...}) / load_objects, or remove the unused wrapping setup.
| def test_save_and_load_ddp_wrapped_model(self) -> None: | |
| """Checkpoint must unwrap DataParallel and save the inner module's state.""" | |
| from ignite.handlers import Checkpoint, DiskSaver | |
| torch.manual_seed(99) | |
| model = _SimpleLinear() | |
| # Only wrap with DataParallel if multiple GPUs are present; else bare model. | |
| if torch.cuda.device_count() > 1: | |
| wrapped = nn.DataParallel(model) | |
| else: | |
| # Simulate DP wrapping using a stub that exposes .module. | |
| class _FakeDP(nn.Module): | |
| def __init__(self, m: nn.Module) -> None: | |
| super().__init__() | |
| self.module = m | |
| def state_dict(self, **kw): # type: ignore[override] | |
| return self.module.state_dict(**kw) | |
| def load_state_dict(self, sd, **kw): # type: ignore[override] | |
| return self.module.load_state_dict(sd, **kw) | |
| wrapped = _FakeDP(model) | |
| def test_save_and_load_plain_model_regression(self) -> None: | |
| """Checkpoint save/load should remain correct for a plain nn.Module.""" | |
| from ignite.handlers import Checkpoint, DiskSaver | |
| torch.manual_seed(99) | |
| model = _SimpleLinear() |
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) | ||
| if REPO_ROOT not in sys.path: | ||
| sys.path.insert(0, REPO_ROOT) |
There was a problem hiding this comment.
This module mutates sys.path at import time. Even if it’s intended as a standalone script, pytest will still import test_*.py modules during collection, so this can have unintended side effects on the rest of the test suite. Consider moving the sys.path manipulation under if __name__ == "__main__": (or renaming the file so pytest doesn’t collect it).
| if torch.cuda.is_available(): | ||
| torch.cuda.set_device(rank) | ||
|
|
||
| try: | ||
| from ignite.distributed.auto import auto_model | ||
| from torch.distributed._composable.fsdp import FSDPModule | ||
|
|
||
| device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") |
There was a problem hiding this comment.
torch.cuda.is_available() doesn’t guarantee there are at least world_size GPUs. With world_size=2 on a 1‑GPU machine this will call torch.cuda.set_device(1) and fail, and will also construct cuda:1 devices. Guard with torch.cuda.device_count() > rank (or force CPU when device_count < world_size).
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(rank) | |
| try: | |
| from ignite.distributed.auto import auto_model | |
| from torch.distributed._composable.fsdp import FSDPModule | |
| device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") | |
| cuda_device_count = torch.cuda.device_count() | |
| use_cuda = torch.cuda.is_available() and cuda_device_count >= world_size | |
| if use_cuda: | |
| torch.cuda.set_device(rank) | |
| try: | |
| from ignite.distributed.auto import auto_model | |
| from torch.distributed._composable.fsdp import FSDPModule | |
| device = torch.device(f"cuda:{rank}" if use_cuda else "cpu") |
| from ignite.handlers.checkpoint import HAVE_FSDP2 | ||
|
|
||
| assert HAVE_FSDP2 is True, ( | ||
| "HAVE_FSDP2 is False — torch.distributed._composable.fsdp may be unavailable in this environment" | ||
| ) | ||
|
|
||
| def test_fsdp2_symbols_importable_from_checkpoint_module(self) -> None: | ||
| """FSDPModule, get_model_state_dict and set_model_state_dict must be reachable.""" |
There was a problem hiding this comment.
These tests hard-assert HAVE_FSDP2 is True, but the project supports torch>=2.2 (pyproject.toml) where torch.distributed._composable.fsdp / FSDP2 may be absent. This will fail in supported environments; please change this to a conditional skip (e.g., pytest.skip/pytest.importorskip) when FSDP2 is unavailable, and only assert True when the symbol is present.
| from ignite.handlers.checkpoint import HAVE_FSDP2 | |
| assert HAVE_FSDP2 is True, ( | |
| "HAVE_FSDP2 is False — torch.distributed._composable.fsdp may be unavailable in this environment" | |
| ) | |
| def test_fsdp2_symbols_importable_from_checkpoint_module(self) -> None: | |
| """FSDPModule, get_model_state_dict and set_model_state_dict must be reachable.""" | |
| pytest.importorskip( | |
| "torch.distributed._composable.fsdp", | |
| reason="FSDP2 is unavailable in this supported torch environment", | |
| ) | |
| from ignite.handlers.checkpoint import HAVE_FSDP2 | |
| assert HAVE_FSDP2 is True | |
| def test_fsdp2_symbols_importable_from_checkpoint_module(self) -> None: | |
| """FSDPModule, get_model_state_dict and set_model_state_dict must be reachable.""" | |
| pytest.importorskip( | |
| "torch.distributed._composable.fsdp", | |
| reason="FSDP2 is unavailable in this supported torch environment", | |
| ) |
| def test_fsdp2_symbols_importable_from_checkpoint_module(self) -> None: | ||
| """FSDPModule, get_model_state_dict and set_model_state_dict must be reachable.""" | ||
| # checkpoint.py imports these at module level inside the try/except; | ||
| # confirm they are accessible from the installed torch build. | ||
| from torch.distributed._composable.fsdp import FSDPModule # noqa: F401 | ||
| from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict # noqa: F401 | ||
|
|
There was a problem hiding this comment.
This test imports FSDP2 symbols unconditionally. On supported installs without FSDP2, it will raise ImportError and fail the suite. Use pytest.importorskip (or guard with HAVE_FSDP2) so the test suite remains compatible with the repo’s stated torch version range.
| model = _SimpleLinear() | ||
| # In non-dist mode, FSDP is not instantiated, so representative FSDP kwargs | ||
| # such as ``reshard_after_forward`` should be accepted and ignored. | ||
| result = auto_model(model, use_fsdp=True, reshard_after_forward=False) | ||
| assert isinstance(result, nn.Module) |
There was a problem hiding this comment.
This assumes that extra FSDP-only kwargs are “silently ignored” in non-distributed mode, but auto_model(..., use_fsdp=True) can still wrap with DataParallel when multiple GPUs are visible, and DataParallel will error on unknown kwargs like reshard_after_forward. Either adjust auto_model to drop/partition kwargs when it falls back to DataParallel, or make this test robust by skipping/conditioning on torch.cuda.device_count() <= 1.
| Default, False. When enabled, ``kwargs`` are forwarded to ``fully_shard()``, allowing | ||
| control over ``reshard_after_forward``, ``mp_policy``, ``offload_policy``, etc. | ||
| Note: FSDP2 does not support ``auto_wrap_policy``; manually call ``fully_shard()`` on | ||
| submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.0. |
There was a problem hiding this comment.
The docstring says FSDP2 “Requires PyTorch >= 2.0”, but the runtime error below says “upgrade to PyTorch >= 2.6”, and the project dependency is torch>=2.2 (pyproject.toml). Please align the documented minimum requirement and the raised error message with the actual feature availability so users aren’t misled.
| submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.0. | |
| submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.2. |
| if use_fsdp: | ||
| if not HAVE_FSDP2: | ||
| raise RuntimeError( | ||
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6." | ||
| ) | ||
| ddp_only_kwargs = {"device_ids", "output_device", "find_unused_parameters", "gradient_as_bucket_view"} | ||
| bad_kwargs = ddp_only_kwargs & set(kwargs) | ||
| if bad_kwargs: | ||
| raise ValueError( | ||
| f"Argument(s) {bad_kwargs} are DDP-only and cannot be used with use_fsdp=True." | ||
| ) | ||
| if "mesh" not in kwargs: | ||
| kwargs["mesh"] = init_device_mesh(device.type, (idist.get_world_size(),)) | ||
| logger.info("Apply torch FSDP2 (fully_shard) on model") | ||
| model = fully_shard(model, **kwargs) |
There was a problem hiding this comment.
use_fsdp=True repurposes **kwargs to mean “fully_shard kwargs”, but in the non-distributed multi-GPU branch auto_model can still wrap with DataParallel and will forward the same **kwargs there, causing TypeError for FSDP-only args (e.g., reshard_after_forward, mp_policy). Consider splitting/clearing kwargs when use_fsdp=True and world_size==1 (or otherwise ensuring FSDP-only kwargs never reach the DataParallel path).
| if use_fsdp: | ||
| if not HAVE_FSDP2: | ||
| raise RuntimeError( | ||
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6." |
There was a problem hiding this comment.
The RuntimeError message hard-codes “upgrade to PyTorch >= 2.6”, but the repo’s declared dependency is torch>=2.2 and FSDP2 availability is detected by imports. It would be more accurate to mention the missing modules/features (e.g., torch.distributed._composable.fsdp) rather than a specific version, or ensure the version claim matches the actual minimum where FSDP2 is supported.
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6." | |
| "fully_shard (FSDP2) is not available because the required PyTorch modules/features " | |
| "could not be imported: torch.distributed._composable.fsdp.fully_shard and " | |
| "torch.distributed.device_mesh.init_device_mesh." |
| if not HAVE_FSDP2: | ||
| raise RuntimeError("DCPSaver requires PyTorch >= 2.0 with torch.distributed.checkpoint.") |
There was a problem hiding this comment.
DCPSaver/directory checkpoints raise “requires PyTorch >= 2.0”, while auto_model(use_fsdp=True) raises “upgrade to PyTorch >= 2.6” and the docs mention “>= 2.0”. Please unify these requirement/error messages (ideally based on actual feature detection) so users get consistent guidance about what torch version/build is needed.
Add FSDP2 support
Added FSDP2 support to ignite using the composable
fully_shardAPI.What's new
auto_model now has a use_fsdp parameter. When True in a distributed context it
applies fully_shard to the model instead of DDP. You can pass extra kwargs like
mp_policy and offload_policy directly. Combining use_fsdp=True with sync_bn=True
raises a ValueError right away since they don't work together.
The checkpoint handler now recognizes FSDP2 models and uses get_model_state_dict
and set_model_state_dict with full_state_dict=True so parameter gathering across
ranks is handled automatically.
Tests
Added test_fsdp_smoke.py with 23 single process tests covering the flag
behavior, checkpoint imports and edge cases. test_fsdp_distributed.py is a
standalone multi process script that spawns 2 workers with gloo and tests
wrapping, forward pass and checkpoint save/load end to end.
If you have any further questions or suggestions, I would be happy to help!