-
-
Notifications
You must be signed in to change notification settings - Fork 695
FSDP2 Support #3733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
FSDP2 Support #3733
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,14 @@ | |||||||||
| from ignite.distributed.comp_models import xla as idist_xla | ||||||||||
| from ignite.utils import setup_logger | ||||||||||
|
|
||||||||||
| try: | ||||||||||
| from torch.distributed._composable.fsdp import fully_shard | ||||||||||
| from torch.distributed.device_mesh import init_device_mesh | ||||||||||
|
|
||||||||||
| HAVE_FSDP2 = True | ||||||||||
| except ImportError: | ||||||||||
| HAVE_FSDP2 = False | ||||||||||
|
|
||||||||||
| __all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"] | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -141,24 +149,33 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader | _MpDeviceLo | |||||||||
| return dataloader | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Module: | ||||||||||
| def auto_model(model: nn.Module, sync_bn: bool = False, use_fsdp: bool = False, **kwargs: Any) -> nn.Module: | ||||||||||
| """Helper method to adapt provided model for non-distributed and distributed configurations (supporting | ||||||||||
| all available backends from :meth:`~ignite.distributed.utils.available_backends()`). | ||||||||||
|
|
||||||||||
| Internally, we perform to following: | ||||||||||
|
|
||||||||||
| - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device. | ||||||||||
| - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1. | ||||||||||
| - wrap the model with `torch FSDP2 fully_shard`_ instead of DDP if ``use_fsdp=True`` and native torch | ||||||||||
| distributed is used with world size larger than 1. | ||||||||||
| - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available. | ||||||||||
| - broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| model: model to adapt. | ||||||||||
| sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch | ||||||||||
| distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be | ||||||||||
| applied before calling ``amp.initialize``. | ||||||||||
| kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_ | ||||||||||
| if applicable. Please, make sure to use acceptable kwargs for given backend. | ||||||||||
| applied before calling ``amp.initialize``. Incompatible with ``use_fsdp=True``. | ||||||||||
| use_fsdp: if True, applies `torch FSDP2 fully_shard`_ to the model instead of wrapping with | ||||||||||
| ``DistributedDataParallel`` for native torch distributed backends (NCCL, GLOO, MPI). | ||||||||||
| Default, False. When enabled, ``kwargs`` are forwarded to ``fully_shard()``, allowing | ||||||||||
| control over ``reshard_after_forward``, ``mp_policy``, ``offload_policy``, etc. | ||||||||||
| Note: FSDP2 does not support ``auto_wrap_policy``; manually call ``fully_shard()`` on | ||||||||||
| submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.0. | ||||||||||
| kwargs: kwargs forwarded to the wrapping class: `torch DistributedDataParallel`_, | ||||||||||
| `torch FSDP2 fully_shard`_ (when ``use_fsdp=True``), or `torch DataParallel`_ | ||||||||||
| if applicable. Please, make sure to use acceptable kwargs for the given backend. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| torch.nn.Module | ||||||||||
|
|
@@ -179,8 +196,26 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod | |||||||||
| model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) | ||||||||||
| model = idist.auto_model(model) | ||||||||||
|
|
||||||||||
| To use FSDP2 with bf16 mixed precision: | ||||||||||
|
|
||||||||||
| .. code-block:: python | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
| import ignite.distributed as idist | ||||||||||
| from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy | ||||||||||
|
|
||||||||||
| bf16_policy = MixedPrecisionPolicy( | ||||||||||
| param_dtype=torch.bfloat16, | ||||||||||
| reduce_dtype=torch.bfloat16, | ||||||||||
| ) | ||||||||||
| # Optionally shard submodules first: | ||||||||||
| for layer in model.layers: | ||||||||||
| fully_shard(layer) | ||||||||||
| model = idist.auto_model(model, use_fsdp=True, mp_policy=bf16_policy) | ||||||||||
|
|
||||||||||
| .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel. | ||||||||||
| DistributedDataParallel.html | ||||||||||
| .. _torch FSDP2 fully_shard: https://pytorch.org/docs/stable/distributed.fsdp2.html | ||||||||||
| .. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html | ||||||||||
| .. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html# | ||||||||||
| torch.nn.SyncBatchNorm.convert_sync_batchnorm | ||||||||||
|
|
@@ -192,9 +227,13 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod | |||||||||
|
|
||||||||||
| .. versionchanged:: 0.4.3 | ||||||||||
| Added kwargs to ``idist.auto_model``. | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| logger = setup_logger(__name__ + ".auto_model") | ||||||||||
|
|
||||||||||
| if use_fsdp and sync_bn: | ||||||||||
| raise ValueError("Arguments use_fsdp and sync_bn are mutually exclusive. FSDP does not support SyncBatchNorm.") | ||||||||||
|
|
||||||||||
| # Put model's parameters to device if its parameters are not on the device | ||||||||||
| device = idist.device() | ||||||||||
| if not all([p.device == device for p in model.parameters()]): | ||||||||||
|
|
@@ -204,23 +243,39 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod | |||||||||
| if idist.get_world_size() > 1: | ||||||||||
| bnd = idist.backend() | ||||||||||
| if idist.has_native_dist_support and bnd in (idist_native.NCCL, idist_native.GLOO, idist_native.MPI): | ||||||||||
| if sync_bn: | ||||||||||
| logger.info("Convert batch norm to sync batch norm") | ||||||||||
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | ||||||||||
|
|
||||||||||
| if torch.cuda.is_available(): | ||||||||||
| if "device_ids" in kwargs: | ||||||||||
| raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}") | ||||||||||
|
|
||||||||||
| lrank = idist.get_local_rank() | ||||||||||
| logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}") | ||||||||||
| kwargs["device_ids"] = [ | ||||||||||
| lrank, | ||||||||||
| ] | ||||||||||
| if use_fsdp: | ||||||||||
| if not HAVE_FSDP2: | ||||||||||
| raise RuntimeError( | ||||||||||
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6." | ||||||||||
|
||||||||||
| "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6." | |
| "fully_shard (FSDP2) is not available because the required PyTorch modules/features " | |
| "could not be imported: torch.distributed._composable.fsdp.fully_shard and " | |
| "torch.distributed.device_mesh.init_device_mesh." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minimal version of pytorch when fsdp2 appeared is 2.6
Copilot
AI
Apr 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the use_fsdp branch, DDP-specific kwargs like device_ids are no longer validated/blocked. Passing device_ids with use_fsdp=True will be forwarded into fully_shard(...) and raise a confusing TypeError. Consider explicitly rejecting device_ids (and any other known DDP-only kwargs) when use_fsdp=True to preserve the clearer error behavior.
Copilot
AI
Apr 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_fsdp=True repurposes **kwargs to mean “fully_shard kwargs”, but in the non-distributed multi-GPU branch auto_model can still wrap with DataParallel and will forward the same **kwargs there, causing TypeError for FSDP-only args (e.g., reshard_after_forward, mp_policy). Consider splitting/clearing kwargs when use_fsdp=True and world_size==1 (or otherwise ensuring FSDP-only kwargs never reach the DataParallel path).
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,12 +19,22 @@ | |
| else: | ||
| HAVE_ZERO = False | ||
|
|
||
| try: | ||
| import torch.distributed.checkpoint as dcp | ||
| from torch.distributed._composable.fsdp import FSDPModule | ||
| from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter | ||
| from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict | ||
|
|
||
| HAVE_FSDP2 = True | ||
| except ImportError: | ||
| HAVE_FSDP2 = False | ||
|
|
||
| import ignite.distributed as idist | ||
| from ignite.base import Serializable | ||
| from ignite.engine import Engine, Events, EventEnum | ||
| from ignite.utils import _tree_apply2, _tree_map | ||
|
|
||
| __all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] | ||
| __all__ = ["Checkpoint", "DiskSaver", "DCPSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] | ||
|
||
|
|
||
|
|
||
| class CheckpointEvents(EventEnum): | ||
|
|
@@ -388,7 +398,7 @@ def __init__( | |
| if n_saved is not None and n_saved < 1: | ||
| raise ValueError(f"n_saved must be a positive integer or None, got {n_saved}") | ||
| self.n_saved = n_saved | ||
| self.ext = "pt" | ||
| self.ext = "" if isinstance(self.save_handler, DCPSaver) else "pt" | ||
| self.global_step_transform = global_step_transform | ||
| self.filename_pattern = filename_pattern | ||
| self._saved: list["Checkpoint.Item"] = [] | ||
|
|
@@ -403,6 +413,7 @@ def _get_filename_pattern(self, global_step: int | None) -> str: | |
| with_score=self.score_function is not None, | ||
| with_score_name=self.score_name is not None, | ||
| with_global_step=global_step is not None, | ||
| as_folder=not self.ext, | ||
| ) | ||
| else: | ||
| filename_pattern = self.filename_pattern | ||
|
|
@@ -531,6 +542,18 @@ def _setup_checkpoint(self) -> dict[str, Any]: | |
| def func(obj: Any, **kwargs: Any) -> dict: | ||
| if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | ||
| obj = obj.module | ||
| elif HAVE_FSDP2 and isinstance(obj, FSDPModule): | ||
| if isinstance(self.save_handler, DCPSaver): | ||
| # Each rank saves its own shard — no consolidation needed | ||
| return get_model_state_dict(obj) | ||
| else: | ||
| # Consolidate full state dict to rank 0 for single-file saving | ||
| state_dict = get_model_state_dict( | ||
| obj, options=StateDictOptions(full_state_dict=True, cpu_offload=True) | ||
| ) | ||
| if idist.get_rank() != self.save_on_rank: | ||
| return {} | ||
| return state_dict | ||
| elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer): | ||
| obj.consolidate_state_dict(to=self.save_on_rank) | ||
| if self.save_on_rank != idist.get_rank(): | ||
|
|
@@ -542,7 +565,11 @@ def func(obj: Any, **kwargs: Any) -> dict: | |
|
|
||
| @staticmethod | ||
| def setup_filename_pattern( | ||
| with_prefix: bool = True, with_score: bool = True, with_score_name: bool = True, with_global_step: bool = True | ||
| with_prefix: bool = True, | ||
| with_score: bool = True, | ||
| with_score_name: bool = True, | ||
| with_global_step: bool = True, | ||
| as_folder: bool = False, | ||
| ) -> str: | ||
| """Helper method to get the default filename pattern for a checkpoint. | ||
|
|
||
|
|
@@ -557,6 +584,8 @@ def setup_filename_pattern( | |
| with_global_step: If True, ``{global_step}`` is added to the | ||
| filename pattern: ``...{name}_{global_step}...``. | ||
| At least one of ``with_score`` and ``with_global_step`` should be True. | ||
| as_folder: If True, the ``.{ext}`` suffix is omitted from the pattern, producing a | ||
| bare name suitable for directory-based (DCP) checkpoints. Default, False. | ||
|
|
||
| Examples: | ||
| .. code-block:: python | ||
|
|
@@ -588,7 +617,8 @@ def setup_filename_pattern( | |
| if with_prefix: | ||
| filename_pattern = "{filename_prefix}_" + filename_pattern | ||
|
|
||
| filename_pattern += ".{ext}" | ||
| if not as_folder: | ||
| filename_pattern += ".{ext}" | ||
| return filename_pattern | ||
|
|
||
| @staticmethod | ||
|
|
@@ -649,6 +679,8 @@ def load_objects(to_load: Mapping, checkpoint: str | Mapping | Path, **kwargs: A | |
| Note: | ||
| If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or | ||
| `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``). | ||
| If ``to_load`` contains FSDP2-sharded objects (``FSDPModule``), ``set_model_state_dict`` | ||
| is used with ``full_state_dict=True`` so that all ranks correctly receive the sharded parameters. | ||
|
|
||
| .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ | ||
| torch.nn.parallel.DistributedDataParallel.html | ||
|
|
@@ -660,13 +692,40 @@ def load_objects(to_load: Mapping, checkpoint: str | Mapping | Path, **kwargs: A | |
| Checkpoint._check_objects(to_load, "load_state_dict") | ||
|
|
||
| if isinstance(checkpoint, (str, Path)): | ||
| checkpoint_obj = torch.load(checkpoint, weights_only=True) | ||
| checkpoint_path = Path(checkpoint) | ||
| if checkpoint_path.is_dir(): | ||
| # DCP directory checkpoint — all ranks load their own shard | ||
| if not HAVE_FSDP2: | ||
| raise RuntimeError( | ||
| "Loading DCP directory checkpoints requires PyTorch >= 2.0 with torch.distributed.checkpoint." | ||
| ) | ||
| state_dicts = {} | ||
| for k, obj in to_load.items(): | ||
| if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | ||
| obj = obj.module | ||
| if isinstance(obj, FSDPModule): | ||
| state_dicts[k] = get_model_state_dict(obj) | ||
| else: | ||
| state_dicts[k] = obj.state_dict() | ||
| dcp.load(state_dicts, storage_reader=FileSystemReader(checkpoint_path)) | ||
| for k, obj in to_load.items(): | ||
| if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | ||
| obj = obj.module | ||
| if isinstance(obj, FSDPModule): | ||
| set_model_state_dict(obj, state_dicts[k]) | ||
| else: | ||
| obj.load_state_dict(state_dicts[k]) | ||
|
Comment on lines
+702
to
+717
|
||
| return | ||
| checkpoint_obj = torch.load(checkpoint_path, weights_only=True) | ||
| else: | ||
| checkpoint_obj = checkpoint | ||
|
|
||
| def _load_object(obj: Any, chkpt_obj: Any) -> None: | ||
| if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | ||
| obj = obj.module | ||
| elif HAVE_FSDP2 and isinstance(obj, FSDPModule): | ||
| set_model_state_dict(obj, chkpt_obj, options=StateDictOptions(full_state_dict=True)) | ||
| return | ||
|
|
||
| if isinstance(obj, torch.nn.Module): | ||
| obj.load_state_dict(chkpt_obj, **kwargs) | ||
|
|
@@ -922,6 +981,55 @@ def remove(self, filename: str) -> None: | |
| path.unlink() | ||
|
|
||
|
|
||
| class DCPSaver(BaseSaveHandler): | ||
| """Handler that saves FSDP2 checkpoints using `torch.distributed.checkpoint`_ (DCP). | ||
|
|
||
| Unlike :class:`~ignite.handlers.DiskSaver`, every rank participates in saving so that | ||
| each rank writes only its own parameter shard. This avoids the memory spike of | ||
| consolidating the full model on rank 0, making it suitable for very large models. | ||
|
|
||
| Each checkpoint is saved as a subdirectory inside ``dirname``. The subdirectory name | ||
| is derived from the filename pattern of the parent :class:`~ignite.handlers.Checkpoint` | ||
| handler (same naming convention, just without the ``.pt`` extension). | ||
|
|
||
| Args: | ||
| dirname: base directory where checkpoint subdirectories will be created. | ||
| create_dir: if True, creates ``dirname`` if it does not exist. Default, True. | ||
| require_empty: if True, raises if ``dirname`` already contains checkpoint | ||
| subdirectories. Default, True. | ||
|
|
||
| .. _torch.distributed.checkpoint: https://pytorch.org/docs/stable/distributed.checkpoint.html | ||
| """ | ||
|
|
||
| def __init__(self, dirname: str | Path, create_dir: bool = True, require_empty: bool = True): | ||
| if not HAVE_FSDP2: | ||
| raise RuntimeError("DCPSaver requires PyTorch >= 2.0 with torch.distributed.checkpoint.") | ||
|
Comment on lines
+1005
to
+1006
|
||
| self.dirname = Path(dirname).expanduser() | ||
| if create_dir and not self.dirname.exists(): | ||
| self.dirname.mkdir(parents=True) | ||
| if not self.dirname.exists(): | ||
| raise ValueError(f"Directory path '{self.dirname}' is not found") | ||
| if require_empty: | ||
| subdirs = [p.name for p in self.dirname.iterdir() if p.is_dir()] | ||
| if subdirs: | ||
| raise ValueError( | ||
| f"Checkpoint directories {subdirs} are already present in '{dirname}'. " | ||
| "Pass require_empty=False to use this directory anyway." | ||
| ) | ||
|
|
||
| def __call__(self, checkpoint: Mapping, filename: str, metadata: Mapping | None = None) -> None: | ||
| path = self.dirname / filename | ||
| path.mkdir(exist_ok=True) | ||
| dcp.save(checkpoint, storage_writer=FileSystemWriter(path)) | ||
|
|
||
| def remove(self, filename: str) -> None: | ||
| import shutil | ||
|
|
||
| path = self.dirname / filename | ||
| if path.exists(): | ||
| shutil.rmtree(path) | ||
|
|
||
|
|
||
| class ModelCheckpoint(Checkpoint): | ||
| """ModelCheckpoint handler, inherits from :class:`~ignite.handlers.checkpoint.Checkpoint`, can be used | ||
| to periodically save objects to disk only. If needed to store checkpoints to | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring says FSDP2 “Requires PyTorch >= 2.0”, but the runtime error below says “upgrade to PyTorch >= 2.6”, and the project dependency is
torch>=2.2(pyproject.toml). Please align the documented minimum requirement and the raised error message with the actual feature availability so users aren’t misled.