Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
from ignite.distributed.comp_models import xla as idist_xla
from ignite.utils import setup_logger

try:
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh

HAVE_FSDP2 = True
except ImportError:
HAVE_FSDP2 = False

__all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"]


Expand Down Expand Up @@ -141,24 +149,33 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader | _MpDeviceLo
return dataloader


def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Module:
def auto_model(model: nn.Module, sync_bn: bool = False, use_fsdp: bool = False, **kwargs: Any) -> nn.Module:
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).

Internally, we perform to following:

- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
- wrap the model with `torch FSDP2 fully_shard`_ instead of DDP if ``use_fsdp=True`` and native torch
distributed is used with world size larger than 1.
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
- broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.

Args:
model: model to adapt.
sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch
distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be
applied before calling ``amp.initialize``.
kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_
if applicable. Please, make sure to use acceptable kwargs for given backend.
applied before calling ``amp.initialize``. Incompatible with ``use_fsdp=True``.
use_fsdp: if True, applies `torch FSDP2 fully_shard`_ to the model instead of wrapping with
``DistributedDataParallel`` for native torch distributed backends (NCCL, GLOO, MPI).
Default, False. When enabled, ``kwargs`` are forwarded to ``fully_shard()``, allowing
control over ``reshard_after_forward``, ``mp_policy``, ``offload_policy``, etc.
Note: FSDP2 does not support ``auto_wrap_policy``; manually call ``fully_shard()`` on
submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.0.
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says FSDP2 “Requires PyTorch >= 2.0”, but the runtime error below says “upgrade to PyTorch >= 2.6”, and the project dependency is torch>=2.2 (pyproject.toml). Please align the documented minimum requirement and the raised error message with the actual feature availability so users aren’t misled.

Suggested change
submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.0.
submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.2.

Copilot uses AI. Check for mistakes.
kwargs: kwargs forwarded to the wrapping class: `torch DistributedDataParallel`_,
`torch FSDP2 fully_shard`_ (when ``use_fsdp=True``), or `torch DataParallel`_
if applicable. Please, make sure to use acceptable kwargs for the given backend.

Returns:
torch.nn.Module
Expand All @@ -179,8 +196,26 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)

To use FSDP2 with bf16 mixed precision:

.. code-block:: python

import torch
import ignite.distributed as idist
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

bf16_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
)
# Optionally shard submodules first:
for layer in model.layers:
fully_shard(layer)
model = idist.auto_model(model, use_fsdp=True, mp_policy=bf16_policy)

.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
.. _torch FSDP2 fully_shard: https://pytorch.org/docs/stable/distributed.fsdp2.html
.. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
.. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
torch.nn.SyncBatchNorm.convert_sync_batchnorm
Expand All @@ -192,9 +227,13 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod

.. versionchanged:: 0.4.3
Added kwargs to ``idist.auto_model``.

"""
logger = setup_logger(__name__ + ".auto_model")

if use_fsdp and sync_bn:
raise ValueError("Arguments use_fsdp and sync_bn are mutually exclusive. FSDP does not support SyncBatchNorm.")

# Put model's parameters to device if its parameters are not on the device
device = idist.device()
if not all([p.device == device for p in model.parameters()]):
Expand All @@ -204,23 +243,39 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod
if idist.get_world_size() > 1:
bnd = idist.backend()
if idist.has_native_dist_support and bnd in (idist_native.NCCL, idist_native.GLOO, idist_native.MPI):
if sync_bn:
logger.info("Convert batch norm to sync batch norm")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

if torch.cuda.is_available():
if "device_ids" in kwargs:
raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}")

lrank = idist.get_local_rank()
logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}")
kwargs["device_ids"] = [
lrank,
]
if use_fsdp:
if not HAVE_FSDP2:
raise RuntimeError(
"fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6."
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RuntimeError message hard-codes “upgrade to PyTorch >= 2.6”, but the repo’s declared dependency is torch>=2.2 and FSDP2 availability is detected by imports. It would be more accurate to mention the missing modules/features (e.g., torch.distributed._composable.fsdp) rather than a specific version, or ensure the version claim matches the actual minimum where FSDP2 is supported.

Suggested change
"fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6."
"fully_shard (FSDP2) is not available because the required PyTorch modules/features "
"could not be imported: torch.distributed._composable.fsdp.fully_shard and "
"torch.distributed.device_mesh.init_device_mesh."

Copilot uses AI. Check for mistakes.
)
ddp_only_kwargs = {"device_ids", "output_device", "find_unused_parameters", "gradient_as_bucket_view"}
bad_kwargs = ddp_only_kwargs & set(kwargs)
if bad_kwargs:
raise ValueError(
f"Argument(s) {bad_kwargs} are DDP-only and cannot be used with use_fsdp=True."
)
Comment on lines +248 to +256
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minimal version of pytorch when fsdp2 appeared is 2.6

if "mesh" not in kwargs:
kwargs["mesh"] = init_device_mesh(device.type, (idist.get_world_size(),))
logger.info("Apply torch FSDP2 (fully_shard) on model")
model = fully_shard(model, **kwargs)
Comment on lines +246 to +260
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the use_fsdp branch, DDP-specific kwargs like device_ids are no longer validated/blocked. Passing device_ids with use_fsdp=True will be forwarded into fully_shard(...) and raise a confusing TypeError. Consider explicitly rejecting device_ids (and any other known DDP-only kwargs) when use_fsdp=True to preserve the clearer error behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +246 to +260
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_fsdp=True repurposes **kwargs to mean “fully_shard kwargs”, but in the non-distributed multi-GPU branch auto_model can still wrap with DataParallel and will forward the same **kwargs there, causing TypeError for FSDP-only args (e.g., reshard_after_forward, mp_policy). Consider splitting/clearing kwargs when use_fsdp=True and world_size==1 (or otherwise ensuring FSDP-only kwargs never reach the DataParallel path).

Copilot uses AI. Check for mistakes.
else:
logger.info("Apply torch DistributedDataParallel on model")
if sync_bn:
logger.info("Convert batch norm to sync batch norm")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

if torch.cuda.is_available():
if "device_ids" in kwargs:
raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}")

lrank = idist.get_local_rank()
logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}")
kwargs["device_ids"] = [
lrank,
]
else:
logger.info("Apply torch DistributedDataParallel on model")

model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
import horovod.torch as hvd

Expand Down
3 changes: 2 additions & 1 deletion ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ignite.engine import Engine
from ignite.engine.events import Events
from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
from ignite.handlers.checkpoint import Checkpoint, DCPSaver, DiskSaver, ModelCheckpoint
from ignite.handlers.clearml_logger import ClearMLLogger
from ignite.handlers.early_stopping import EarlyStopping
from ignite.handlers.ema_handler import EMAHandler
Expand Down Expand Up @@ -48,6 +48,7 @@
"ModelCheckpoint",
"Checkpoint",
"DiskSaver",
"DCPSaver",
"Timer",
"EarlyStopping",
"TerminateOnNan",
Expand Down
118 changes: 113 additions & 5 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,22 @@
else:
HAVE_ZERO = False

try:
import torch.distributed.checkpoint as dcp
from torch.distributed._composable.fsdp import FSDPModule
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict

HAVE_FSDP2 = True
except ImportError:
HAVE_FSDP2 = False

import ignite.distributed as idist
from ignite.base import Serializable
from ignite.engine import Engine, Events, EventEnum
from ignite.utils import _tree_apply2, _tree_map

__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
__all__ = ["Checkpoint", "DiskSaver", "DCPSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DCPSaver is added to ignite.handlers.checkpoint.__all__, but it is not re-exported from ignite/handlers/__init__.py (which currently exports Checkpoint, DiskSaver, ModelCheckpoint, etc.). If DCPSaver is intended as part of the public API (so users can from ignite.handlers import DCPSaver), it should be imported and included in that package __all__ as well.

Copilot uses AI. Check for mistakes.


class CheckpointEvents(EventEnum):
Expand Down Expand Up @@ -388,7 +398,7 @@ def __init__(
if n_saved is not None and n_saved < 1:
raise ValueError(f"n_saved must be a positive integer or None, got {n_saved}")
self.n_saved = n_saved
self.ext = "pt"
self.ext = "" if isinstance(self.save_handler, DCPSaver) else "pt"
self.global_step_transform = global_step_transform
self.filename_pattern = filename_pattern
self._saved: list["Checkpoint.Item"] = []
Expand All @@ -403,6 +413,7 @@ def _get_filename_pattern(self, global_step: int | None) -> str:
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
as_folder=not self.ext,
)
else:
filename_pattern = self.filename_pattern
Expand Down Expand Up @@ -531,6 +542,18 @@ def _setup_checkpoint(self) -> dict[str, Any]:
def func(obj: Any, **kwargs: Any) -> dict:
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
elif HAVE_FSDP2 and isinstance(obj, FSDPModule):
if isinstance(self.save_handler, DCPSaver):
# Each rank saves its own shard — no consolidation needed
return get_model_state_dict(obj)
else:
# Consolidate full state dict to rank 0 for single-file saving
state_dict = get_model_state_dict(
obj, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
)
if idist.get_rank() != self.save_on_rank:
return {}
return state_dict
elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer):
obj.consolidate_state_dict(to=self.save_on_rank)
if self.save_on_rank != idist.get_rank():
Expand All @@ -542,7 +565,11 @@ def func(obj: Any, **kwargs: Any) -> dict:

@staticmethod
def setup_filename_pattern(
with_prefix: bool = True, with_score: bool = True, with_score_name: bool = True, with_global_step: bool = True
with_prefix: bool = True,
with_score: bool = True,
with_score_name: bool = True,
with_global_step: bool = True,
as_folder: bool = False,
) -> str:
"""Helper method to get the default filename pattern for a checkpoint.

Expand All @@ -557,6 +584,8 @@ def setup_filename_pattern(
with_global_step: If True, ``{global_step}`` is added to the
filename pattern: ``...{name}_{global_step}...``.
At least one of ``with_score`` and ``with_global_step`` should be True.
as_folder: If True, the ``.{ext}`` suffix is omitted from the pattern, producing a
bare name suitable for directory-based (DCP) checkpoints. Default, False.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -588,7 +617,8 @@ def setup_filename_pattern(
if with_prefix:
filename_pattern = "{filename_prefix}_" + filename_pattern

filename_pattern += ".{ext}"
if not as_folder:
filename_pattern += ".{ext}"
return filename_pattern

@staticmethod
Expand Down Expand Up @@ -649,6 +679,8 @@ def load_objects(to_load: Mapping, checkpoint: str | Mapping | Path, **kwargs: A
Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
If ``to_load`` contains FSDP2-sharded objects (``FSDPModule``), ``set_model_state_dict``
is used with ``full_state_dict=True`` so that all ranks correctly receive the sharded parameters.

.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
Expand All @@ -660,13 +692,40 @@ def load_objects(to_load: Mapping, checkpoint: str | Mapping | Path, **kwargs: A
Checkpoint._check_objects(to_load, "load_state_dict")

if isinstance(checkpoint, (str, Path)):
checkpoint_obj = torch.load(checkpoint, weights_only=True)
checkpoint_path = Path(checkpoint)
if checkpoint_path.is_dir():
# DCP directory checkpoint — all ranks load their own shard
if not HAVE_FSDP2:
raise RuntimeError(
"Loading DCP directory checkpoints requires PyTorch >= 2.0 with torch.distributed.checkpoint."
)
state_dicts = {}
for k, obj in to_load.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, FSDPModule):
state_dicts[k] = get_model_state_dict(obj)
else:
state_dicts[k] = obj.state_dict()
dcp.load(state_dicts, storage_reader=FileSystemReader(checkpoint_path))
for k, obj in to_load.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, FSDPModule):
set_model_state_dict(obj, state_dicts[k])
else:
obj.load_state_dict(state_dicts[k])
Comment on lines +702 to +717
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DCP-directory load path does not unwrap DataParallel/DistributedDataParallel objects. However _setup_checkpoint() unwraps them when saving, so the directory checkpoint keys will be for the underlying module, while here the template state_dict is taken from the wrapper (with module. prefixes). This mismatch will make dcp.load(...) and/or load_state_dict fail for DP/DDP models. Apply the same unwrapping logic as the non-directory path (use obj.module for DP/DDP both when building state_dicts and when calling load_state_dict).

Copilot uses AI. Check for mistakes.
return
checkpoint_obj = torch.load(checkpoint_path, weights_only=True)
else:
checkpoint_obj = checkpoint

def _load_object(obj: Any, chkpt_obj: Any) -> None:
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
elif HAVE_FSDP2 and isinstance(obj, FSDPModule):
set_model_state_dict(obj, chkpt_obj, options=StateDictOptions(full_state_dict=True))
return

if isinstance(obj, torch.nn.Module):
obj.load_state_dict(chkpt_obj, **kwargs)
Expand Down Expand Up @@ -922,6 +981,55 @@ def remove(self, filename: str) -> None:
path.unlink()


class DCPSaver(BaseSaveHandler):
"""Handler that saves FSDP2 checkpoints using `torch.distributed.checkpoint`_ (DCP).

Unlike :class:`~ignite.handlers.DiskSaver`, every rank participates in saving so that
each rank writes only its own parameter shard. This avoids the memory spike of
consolidating the full model on rank 0, making it suitable for very large models.

Each checkpoint is saved as a subdirectory inside ``dirname``. The subdirectory name
is derived from the filename pattern of the parent :class:`~ignite.handlers.Checkpoint`
handler (same naming convention, just without the ``.pt`` extension).

Args:
dirname: base directory where checkpoint subdirectories will be created.
create_dir: if True, creates ``dirname`` if it does not exist. Default, True.
require_empty: if True, raises if ``dirname`` already contains checkpoint
subdirectories. Default, True.

.. _torch.distributed.checkpoint: https://pytorch.org/docs/stable/distributed.checkpoint.html
"""

def __init__(self, dirname: str | Path, create_dir: bool = True, require_empty: bool = True):
if not HAVE_FSDP2:
raise RuntimeError("DCPSaver requires PyTorch >= 2.0 with torch.distributed.checkpoint.")
Comment on lines +1005 to +1006
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DCPSaver/directory checkpoints raise “requires PyTorch >= 2.0”, while auto_model(use_fsdp=True) raises “upgrade to PyTorch >= 2.6” and the docs mention “>= 2.0”. Please unify these requirement/error messages (ideally based on actual feature detection) so users get consistent guidance about what torch version/build is needed.

Copilot uses AI. Check for mistakes.
self.dirname = Path(dirname).expanduser()
if create_dir and not self.dirname.exists():
self.dirname.mkdir(parents=True)
if not self.dirname.exists():
raise ValueError(f"Directory path '{self.dirname}' is not found")
if require_empty:
subdirs = [p.name for p in self.dirname.iterdir() if p.is_dir()]
if subdirs:
raise ValueError(
f"Checkpoint directories {subdirs} are already present in '{dirname}'. "
"Pass require_empty=False to use this directory anyway."
)

def __call__(self, checkpoint: Mapping, filename: str, metadata: Mapping | None = None) -> None:
path = self.dirname / filename
path.mkdir(exist_ok=True)
dcp.save(checkpoint, storage_writer=FileSystemWriter(path))

def remove(self, filename: str) -> None:
import shutil

path = self.dirname / filename
if path.exists():
shutil.rmtree(path)


class ModelCheckpoint(Checkpoint):
"""ModelCheckpoint handler, inherits from :class:`~ignite.handlers.checkpoint.Checkpoint`, can be used
to periodically save objects to disk only. If needed to store checkpoints to
Expand Down
Loading
Loading