diff --git a/ignite/distributed/auto.py b/ignite/distributed/auto.py index 501e57fc762a..4cf05f1bb71a 100644 --- a/ignite/distributed/auto.py +++ b/ignite/distributed/auto.py @@ -17,6 +17,14 @@ from ignite.distributed.comp_models import xla as idist_xla from ignite.utils import setup_logger +try: + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.device_mesh import init_device_mesh + + HAVE_FSDP2 = True +except ImportError: + HAVE_FSDP2 = False + __all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"] @@ -141,7 +149,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader | _MpDeviceLo return dataloader -def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Module: +def auto_model(model: nn.Module, sync_bn: bool = False, use_fsdp: bool = False, **kwargs: Any) -> nn.Module: """Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from :meth:`~ignite.distributed.utils.available_backends()`). @@ -149,6 +157,8 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device. - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1. + - wrap the model with `torch FSDP2 fully_shard`_ instead of DDP if ``use_fsdp=True`` and native torch + distributed is used with world size larger than 1. - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available. - broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used. @@ -156,9 +166,16 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod model: model to adapt. sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be - applied before calling ``amp.initialize``. - kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_ - if applicable. Please, make sure to use acceptable kwargs for given backend. + applied before calling ``amp.initialize``. Incompatible with ``use_fsdp=True``. + use_fsdp: if True, applies `torch FSDP2 fully_shard`_ to the model instead of wrapping with + ``DistributedDataParallel`` for native torch distributed backends (NCCL, GLOO, MPI). + Default, False. When enabled, ``kwargs`` are forwarded to ``fully_shard()``, allowing + control over ``reshard_after_forward``, ``mp_policy``, ``offload_policy``, etc. + Note: FSDP2 does not support ``auto_wrap_policy``; manually call ``fully_shard()`` on + submodules before passing the model to ``auto_model``. Requires PyTorch >= 2.0. + kwargs: kwargs forwarded to the wrapping class: `torch DistributedDataParallel`_, + `torch FSDP2 fully_shard`_ (when ``use_fsdp=True``), or `torch DataParallel`_ + if applicable. Please, make sure to use acceptable kwargs for the given backend. Returns: torch.nn.Module @@ -179,8 +196,26 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model = idist.auto_model(model) + To use FSDP2 with bf16 mixed precision: + + .. code-block:: python + + import torch + import ignite.distributed as idist + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + + bf16_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + ) + # Optionally shard submodules first: + for layer in model.layers: + fully_shard(layer) + model = idist.auto_model(model, use_fsdp=True, mp_policy=bf16_policy) + .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel. DistributedDataParallel.html + .. _torch FSDP2 fully_shard: https://pytorch.org/docs/stable/distributed.fsdp2.html .. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html .. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html# torch.nn.SyncBatchNorm.convert_sync_batchnorm @@ -192,9 +227,13 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod .. versionchanged:: 0.4.3 Added kwargs to ``idist.auto_model``. + """ logger = setup_logger(__name__ + ".auto_model") + if use_fsdp and sync_bn: + raise ValueError("Arguments use_fsdp and sync_bn are mutually exclusive. FSDP does not support SyncBatchNorm.") + # Put model's parameters to device if its parameters are not on the device device = idist.device() if not all([p.device == device for p in model.parameters()]): @@ -204,23 +243,39 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod if idist.get_world_size() > 1: bnd = idist.backend() if idist.has_native_dist_support and bnd in (idist_native.NCCL, idist_native.GLOO, idist_native.MPI): - if sync_bn: - logger.info("Convert batch norm to sync batch norm") - model = nn.SyncBatchNorm.convert_sync_batchnorm(model) - - if torch.cuda.is_available(): - if "device_ids" in kwargs: - raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}") - - lrank = idist.get_local_rank() - logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}") - kwargs["device_ids"] = [ - lrank, - ] + if use_fsdp: + if not HAVE_FSDP2: + raise RuntimeError( + "fully_shard (FSDP2) is not available. Please upgrade to PyTorch >= 2.6." + ) + ddp_only_kwargs = {"device_ids", "output_device", "find_unused_parameters", "gradient_as_bucket_view"} + bad_kwargs = ddp_only_kwargs & set(kwargs) + if bad_kwargs: + raise ValueError( + f"Argument(s) {bad_kwargs} are DDP-only and cannot be used with use_fsdp=True." + ) + if "mesh" not in kwargs: + kwargs["mesh"] = init_device_mesh(device.type, (idist.get_world_size(),)) + logger.info("Apply torch FSDP2 (fully_shard) on model") + model = fully_shard(model, **kwargs) else: - logger.info("Apply torch DistributedDataParallel on model") + if sync_bn: + logger.info("Convert batch norm to sync batch norm") + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + if torch.cuda.is_available(): + if "device_ids" in kwargs: + raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}") + + lrank = idist.get_local_rank() + logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}") + kwargs["device_ids"] = [ + lrank, + ] + else: + logger.info("Apply torch DistributedDataParallel on model") - model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) + model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD: import horovod.torch as hvd diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index 0f0ee506e827..1cfd0611ac91 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -1,6 +1,6 @@ from ignite.engine import Engine from ignite.engine.events import Events -from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint +from ignite.handlers.checkpoint import Checkpoint, DCPSaver, DiskSaver, ModelCheckpoint from ignite.handlers.clearml_logger import ClearMLLogger from ignite.handlers.early_stopping import EarlyStopping from ignite.handlers.ema_handler import EMAHandler @@ -48,6 +48,7 @@ "ModelCheckpoint", "Checkpoint", "DiskSaver", + "DCPSaver", "Timer", "EarlyStopping", "TerminateOnNan", diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index eafae5fe4198..d5fc273b6872 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -19,12 +19,22 @@ else: HAVE_ZERO = False +try: + import torch.distributed.checkpoint as dcp + from torch.distributed._composable.fsdp import FSDPModule + from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict + + HAVE_FSDP2 = True +except ImportError: + HAVE_FSDP2 = False + import ignite.distributed as idist from ignite.base import Serializable from ignite.engine import Engine, Events, EventEnum from ignite.utils import _tree_apply2, _tree_map -__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] +__all__ = ["Checkpoint", "DiskSaver", "DCPSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] class CheckpointEvents(EventEnum): @@ -388,7 +398,7 @@ def __init__( if n_saved is not None and n_saved < 1: raise ValueError(f"n_saved must be a positive integer or None, got {n_saved}") self.n_saved = n_saved - self.ext = "pt" + self.ext = "" if isinstance(self.save_handler, DCPSaver) else "pt" self.global_step_transform = global_step_transform self.filename_pattern = filename_pattern self._saved: list["Checkpoint.Item"] = [] @@ -403,6 +413,7 @@ def _get_filename_pattern(self, global_step: int | None) -> str: with_score=self.score_function is not None, with_score_name=self.score_name is not None, with_global_step=global_step is not None, + as_folder=not self.ext, ) else: filename_pattern = self.filename_pattern @@ -531,6 +542,18 @@ def _setup_checkpoint(self) -> dict[str, Any]: def func(obj: Any, **kwargs: Any) -> dict: if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module + elif HAVE_FSDP2 and isinstance(obj, FSDPModule): + if isinstance(self.save_handler, DCPSaver): + # Each rank saves its own shard — no consolidation needed + return get_model_state_dict(obj) + else: + # Consolidate full state dict to rank 0 for single-file saving + state_dict = get_model_state_dict( + obj, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + if idist.get_rank() != self.save_on_rank: + return {} + return state_dict elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer): obj.consolidate_state_dict(to=self.save_on_rank) if self.save_on_rank != idist.get_rank(): @@ -542,7 +565,11 @@ def func(obj: Any, **kwargs: Any) -> dict: @staticmethod def setup_filename_pattern( - with_prefix: bool = True, with_score: bool = True, with_score_name: bool = True, with_global_step: bool = True + with_prefix: bool = True, + with_score: bool = True, + with_score_name: bool = True, + with_global_step: bool = True, + as_folder: bool = False, ) -> str: """Helper method to get the default filename pattern for a checkpoint. @@ -557,6 +584,8 @@ def setup_filename_pattern( with_global_step: If True, ``{global_step}`` is added to the filename pattern: ``...{name}_{global_step}...``. At least one of ``with_score`` and ``with_global_step`` should be True. + as_folder: If True, the ``.{ext}`` suffix is omitted from the pattern, producing a + bare name suitable for directory-based (DCP) checkpoints. Default, False. Examples: .. code-block:: python @@ -588,7 +617,8 @@ def setup_filename_pattern( if with_prefix: filename_pattern = "{filename_prefix}_" + filename_pattern - filename_pattern += ".{ext}" + if not as_folder: + filename_pattern += ".{ext}" return filename_pattern @staticmethod @@ -649,6 +679,8 @@ def load_objects(to_load: Mapping, checkpoint: str | Mapping | Path, **kwargs: A Note: If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``). + If ``to_load`` contains FSDP2-sharded objects (``FSDPModule``), ``set_model_state_dict`` + is used with ``full_state_dict=True`` so that all ranks correctly receive the sharded parameters. .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html @@ -660,13 +692,40 @@ def load_objects(to_load: Mapping, checkpoint: str | Mapping | Path, **kwargs: A Checkpoint._check_objects(to_load, "load_state_dict") if isinstance(checkpoint, (str, Path)): - checkpoint_obj = torch.load(checkpoint, weights_only=True) + checkpoint_path = Path(checkpoint) + if checkpoint_path.is_dir(): + # DCP directory checkpoint — all ranks load their own shard + if not HAVE_FSDP2: + raise RuntimeError( + "Loading DCP directory checkpoints requires PyTorch >= 2.0 with torch.distributed.checkpoint." + ) + state_dicts = {} + for k, obj in to_load.items(): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + if isinstance(obj, FSDPModule): + state_dicts[k] = get_model_state_dict(obj) + else: + state_dicts[k] = obj.state_dict() + dcp.load(state_dicts, storage_reader=FileSystemReader(checkpoint_path)) + for k, obj in to_load.items(): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + if isinstance(obj, FSDPModule): + set_model_state_dict(obj, state_dicts[k]) + else: + obj.load_state_dict(state_dicts[k]) + return + checkpoint_obj = torch.load(checkpoint_path, weights_only=True) else: checkpoint_obj = checkpoint def _load_object(obj: Any, chkpt_obj: Any) -> None: if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module + elif HAVE_FSDP2 and isinstance(obj, FSDPModule): + set_model_state_dict(obj, chkpt_obj, options=StateDictOptions(full_state_dict=True)) + return if isinstance(obj, torch.nn.Module): obj.load_state_dict(chkpt_obj, **kwargs) @@ -922,6 +981,55 @@ def remove(self, filename: str) -> None: path.unlink() +class DCPSaver(BaseSaveHandler): + """Handler that saves FSDP2 checkpoints using `torch.distributed.checkpoint`_ (DCP). + + Unlike :class:`~ignite.handlers.DiskSaver`, every rank participates in saving so that + each rank writes only its own parameter shard. This avoids the memory spike of + consolidating the full model on rank 0, making it suitable for very large models. + + Each checkpoint is saved as a subdirectory inside ``dirname``. The subdirectory name + is derived from the filename pattern of the parent :class:`~ignite.handlers.Checkpoint` + handler (same naming convention, just without the ``.pt`` extension). + + Args: + dirname: base directory where checkpoint subdirectories will be created. + create_dir: if True, creates ``dirname`` if it does not exist. Default, True. + require_empty: if True, raises if ``dirname`` already contains checkpoint + subdirectories. Default, True. + + .. _torch.distributed.checkpoint: https://pytorch.org/docs/stable/distributed.checkpoint.html + """ + + def __init__(self, dirname: str | Path, create_dir: bool = True, require_empty: bool = True): + if not HAVE_FSDP2: + raise RuntimeError("DCPSaver requires PyTorch >= 2.0 with torch.distributed.checkpoint.") + self.dirname = Path(dirname).expanduser() + if create_dir and not self.dirname.exists(): + self.dirname.mkdir(parents=True) + if not self.dirname.exists(): + raise ValueError(f"Directory path '{self.dirname}' is not found") + if require_empty: + subdirs = [p.name for p in self.dirname.iterdir() if p.is_dir()] + if subdirs: + raise ValueError( + f"Checkpoint directories {subdirs} are already present in '{dirname}'. " + "Pass require_empty=False to use this directory anyway." + ) + + def __call__(self, checkpoint: Mapping, filename: str, metadata: Mapping | None = None) -> None: + path = self.dirname / filename + path.mkdir(exist_ok=True) + dcp.save(checkpoint, storage_writer=FileSystemWriter(path)) + + def remove(self, filename: str) -> None: + import shutil + + path = self.dirname / filename + if path.exists(): + shutil.rmtree(path) + + class ModelCheckpoint(Checkpoint): """ModelCheckpoint handler, inherits from :class:`~ignite.handlers.checkpoint.Checkpoint`, can be used to periodically save objects to disk only. If needed to store checkpoints to diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index b26c4a7417d3..88b264a1c30c 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -231,6 +231,48 @@ def test_auto_methods_nccl(distributed_context_single_node_nccl): auto_model(nn.Linear(1, 1), device_ids=[0]) +def _test_auto_model_fsdp(model, ws, device): + try: + from torch.distributed._composable.fsdp import FSDPModule + except ImportError: + pytest.skip("FSDP2 not available in this PyTorch version") + + wrapped = auto_model(model, use_fsdp=True) + if ws > 1 and idist.has_native_dist_support and idist.backend() in ("nccl", "gloo"): + assert isinstance(wrapped, FSDPModule), f"Expected FSDPModule, got {type(wrapped)}" + else: + assert isinstance(wrapped, nn.Module) + + assert all(p.device.type == torch.device(device).type for p in wrapped.parameters()), ( + f"{[p.device.type for p in wrapped.parameters()]} vs {torch.device(device).type}" + ) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("WORLD_SIZE" not in os.environ, reason="Skip if WORLD_SIZE not in env vars") +def test_auto_model_fsdp_gloo(distributed_context_single_node_gloo): + ws = distributed_context_single_node_gloo["world_size"] + device = idist.device() + _test_auto_model_fsdp(nn.Linear(10, 10), ws, device) + _test_auto_model_fsdp(nn.Sequential(nn.Linear(20, 100), nn.ReLU(), nn.Linear(100, 10)), ws, device) + + # sync_bn + use_fsdp must raise + with pytest.raises(ValueError, match=r"use_fsdp and sync_bn are mutually exclusive"): + auto_model(nn.Sequential(nn.Linear(20, 100), nn.BatchNorm1d(100)), sync_bn=True, use_fsdp=True) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +@pytest.mark.skipif("WORLD_SIZE" not in os.environ, reason="Skip if WORLD_SIZE not in env vars") +def test_auto_model_fsdp_nccl_cuda(distributed_context_single_node_nccl): + ws = distributed_context_single_node_nccl["world_size"] + device = idist.device() + _test_auto_model_fsdp(nn.Linear(10, 10), ws, device) + _test_auto_model_fsdp(nn.Sequential(nn.Linear(20, 100), nn.ReLU(), nn.Linear(100, 10)), ws, device) + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 10daa4ebfe65..1eaa910bf7c3 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1249,6 +1249,71 @@ def _test_checkpoint_load_objects_ddp(device): Checkpoint.load_objects(to_load, checkpoint) +def _test_checkpoint_with_fsdp(device, dirname): + from ignite.handlers.checkpoint import HAVE_FSDP2 + + if not HAVE_FSDP2 or "cuda" not in device.type: + return + + from torch.distributed._composable.fsdp import fully_shard + + torch.manual_seed(0) + model = DummyModel().to(device) + fully_shard(model) + to_save = {"model": model} + + saver = DiskSaver(str(dirname), create_dir=True, require_empty=False) + checkpointer = Checkpoint(to_save, saver) + engine = Engine(lambda e, b: None) + engine.state = State(epoch=0, iteration=0) + checkpointer(engine) + + # Rank 0 should have written a checkpoint with the full model weights + if idist.get_rank() == 0: + ckpt_path = list(dirname.glob("model_*.pt")) + assert len(ckpt_path) == 1 + saved = torch.load(ckpt_path[0], map_location="cpu") + # Saved state dict keys should match the unwrapped model's keys (no FSDP prefix) + assert set(saved.keys()) == set(DummyModel().state_dict().keys()) + + +def _test_checkpoint_load_objects_fsdp(device): + from ignite.handlers.checkpoint import HAVE_FSDP2 + + if not HAVE_FSDP2 or "cuda" not in device.type: + return + + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict + + def _full_state(m): + return { + k: v.clone() + for k, v in get_model_state_dict( + m, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ).items() + } + + torch.manual_seed(0) + model = DummyModel().to(device) + fully_shard(model) + original_state = _full_state(model) + + # Perturb weights, then reload the original state dict via Checkpoint.load_objects + perturbed = {k: v + 99.0 for k, v in original_state.items()} + set_model_state_dict(model, perturbed, options=StateDictOptions(full_state_dict=True)) + + Checkpoint.load_objects({"model": model}, original_state) + + # After loading, weights should match the original + after = _full_state(model) + if idist.get_rank() == 0: + for k in original_state: + assert torch.allclose(original_state[k].cpu(), after[k].cpu()), ( + f"Mismatch on param '{k}' after FSDP2 checkpoint load" + ) + + def _test_checkpoint_with_ZeRO(device, dirname, local_rank): from torch.distributed.optim import ZeroRedundancyOptimizer @@ -1284,6 +1349,8 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, dirname, _test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "2", just_on_zero_rank=True) _test_checkpoint_with_ddp(device) _test_checkpoint_load_objects_ddp(device) + _test_checkpoint_with_fsdp(device, rank_zero_dirname / "fsdp_save") + _test_checkpoint_load_objects_fsdp(device) from ignite.handlers.checkpoint import HAVE_ZERO @@ -1301,6 +1368,8 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl, get_rank_zero_di _test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", just_on_zero_rank=True) _test_checkpoint_with_ddp(device=device) _test_checkpoint_load_objects_ddp(device=device) + _test_checkpoint_with_fsdp(device, dirname / "fsdp_save") + _test_checkpoint_load_objects_fsdp(device) @pytest.mark.distributed diff --git a/tests/ignite/test_fsdp_distributed.py b/tests/ignite/test_fsdp_distributed.py new file mode 100644 index 000000000000..57562c8f37ff --- /dev/null +++ b/tests/ignite/test_fsdp_distributed.py @@ -0,0 +1,188 @@ +""" +Distributed smoke test for FSDP wrapping via auto_model. +Must be run as a standalone script (not via pytest) using: + python tests/ignite/test_fsdp_distributed.py +""" +from __future__ import annotations + +import os +import sys + +import torch +import torch.distributed as dist +import torch.nn as nn + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + + +def run_worker(rank: int, world_size: int, backend: str, results: dict) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12399" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + + dist.init_process_group(backend, rank=rank, world_size=world_size) + + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + try: + from ignite.distributed.auto import auto_model + from torch.distributed._composable.fsdp import FSDPModule + + device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") + + # Test 1: use_fsdp=True applies FSDP2 when world_size > 1 + model = nn.Linear(10, 10).to(device) + wrapped = auto_model(model, use_fsdp=True) + assert isinstance(wrapped, FSDPModule), f"[rank {rank}] Expected FSDPModule, got {type(wrapped).__name__}" + results[f"rank{rank}_fsdp_wrap"] = True + + # Test 2: use_fsdp=True + sync_bn=True raises ValueError (all ranks) + try: + auto_model(nn.Linear(5, 5), use_fsdp=True, sync_bn=True) + results[f"rank{rank}_valueerror"] = False # Should not reach here + except ValueError: + results[f"rank{rank}_valueerror"] = True + + # Test 3: forward pass through FSDP-wrapped model works + x = torch.randn(4, 10, device=device) + out = wrapped(x) + assert out.shape == (4, 10), f"[rank {rank}] Unexpected output shape: {out.shape}" + results[f"rank{rank}_forward"] = True + + dist.barrier() + + except Exception as e: + results[f"rank{rank}_error"] = str(e) + raise + finally: + dist.destroy_process_group() + + +def run_checkpoint_worker(rank: int, world_size: int, backend: str, tmpdir: str, results: dict) -> None: + """Test FSDP checkpoint save/load in distributed context.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12400" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + + dist.init_process_group(backend, rank=rank, world_size=world_size) + + try: + from ignite.handlers import Checkpoint, DiskSaver + from ignite.engine import Engine, Events + from ignite.engine.engine import State + from torch.distributed._composable.fsdp import fully_shard + + torch.manual_seed(0) + device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()).to(device) + fully_shard(model) + + save_dir = os.path.join(tmpdir, "fsdp_ckpt") + os.makedirs(save_dir, exist_ok=True) + dist.barrier() + + checkpointer = Checkpoint( + {"model": model}, + DiskSaver(save_dir, create_dir=False, require_empty=False), + ) + engine = Engine(lambda e, b: None) + engine.state = State(epoch=0, iteration=0) + checkpointer(engine) + + # Only rank 0 should have a non-empty file + dist.barrier() + + if rank == 0: + import pathlib + ckpt_files = list(pathlib.Path(save_dir).glob("model_*.pt")) + assert len(ckpt_files) == 1, f"Expected 1 checkpoint, found {ckpt_files}" + saved = torch.load(ckpt_files[0], map_location="cpu") + # Saved keys must match unwrapped model + expected_keys = set(nn.Sequential(nn.Linear(8, 8), nn.ReLU()).state_dict().keys()) + assert set(saved.keys()) == expected_keys, ( + f"Checkpoint keys mismatch. Got {set(saved.keys())}, expected {expected_keys}" + ) + results["ckpt_save"] = True + else: + results[f"rank{rank}_ckpt_skip"] = True + + dist.barrier() + + except Exception as e: + results[f"rank{rank}_ckpt_error"] = str(e) + raise + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + import multiprocessing as mp + import tempfile + + manager = mp.Manager() + results: dict = manager.dict() + + print("=" * 60) + print("Test 1: auto_model FSDP wrapping (gloo, 2 processes)") + print("=" * 60) + + processes = [] + for rank in range(2): + p = mp.Process(target=run_worker, args=(rank, 2, "gloo", results)) + p.start() + processes.append(p) + + exit_codes = [] + for p in processes: + p.join() + exit_codes.append(p.exitcode) + + if all(ec == 0 for ec in exit_codes): + print("PASSED: auto_model FSDP wrapping with gloo") + for k, v in sorted(results.items()): + print(f" {k}: {v}") + else: + print("FAILED: auto_model FSDP wrapping with gloo") + for k, v in sorted(results.items()): + print(f" {k}: {v}") + sys.exit(1) + + print() + print("=" * 60) + print("Test 2: FSDP checkpoint save (gloo, 2 processes)") + print("=" * 60) + + results2: dict = manager.dict() + with tempfile.TemporaryDirectory() as tmpdir: + processes2 = [] + for rank in range(2): + p = mp.Process(target=run_checkpoint_worker, args=(rank, 2, "gloo", tmpdir, results2)) + p.start() + processes2.append(p) + + exit_codes2 = [] + for p in processes2: + p.join() + exit_codes2.append(p.exitcode) + + if all(ec == 0 for ec in exit_codes2): + print("PASSED: FSDP checkpoint save with gloo") + for k, v in sorted(results2.items()): + print(f" {k}: {v}") + else: + print("FAILED: FSDP checkpoint save with gloo") + for k, v in sorted(results2.items()): + print(f" {k}: {v}") + sys.exit(1) + + print() + print("All distributed FSDP tests PASSED") diff --git a/tests/ignite/test_fsdp_smoke.py b/tests/ignite/test_fsdp_smoke.py new file mode 100644 index 000000000000..5f8580fade21 --- /dev/null +++ b/tests/ignite/test_fsdp_smoke.py @@ -0,0 +1,405 @@ +""" +Smoke tests for FSDP support changes in: + - ignite/distributed/auto.py (auto_model use_fsdp parameter) + - ignite/handlers/checkpoint.py (HAVE_FSDP2 flag and FSDP checkpoint branch) + +All tests run in single-process / non-distributed mode. +""" +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest +import torch +import torch.nn as nn + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _SimpleLinear(nn.Module): + """Tiny model used across tests.""" + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(4, 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + +class _ModelWithBN(nn.Module): + """Model with a BatchNorm layer – relevant for the sync_bn conflict check.""" + + def __init__(self) -> None: + super().__init__() + self.bn = nn.BatchNorm1d(4) + self.fc = nn.Linear(4, 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(self.bn(x)) + + +# --------------------------------------------------------------------------- +# Section 1: auto_model / use_fsdp +# --------------------------------------------------------------------------- + +class TestAutoModelFSDPFlag: + """Tests for the use_fsdp parameter added to auto_model.""" + + # ------------------------------------------------------------------ + # 1a. ValueError when use_fsdp=True and sync_bn=True simultaneously + # ------------------------------------------------------------------ + + def test_use_fsdp_and_sync_bn_raises_value_error(self) -> None: + """use_fsdp=True and sync_bn=True must be mutually exclusive.""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + with pytest.raises(ValueError, match="mutually exclusive"): + auto_model(model, use_fsdp=True, sync_bn=True) + + def test_use_fsdp_and_sync_bn_with_bn_model_raises_value_error(self) -> None: + """Mutual exclusion holds even when the model actually has BatchNorm layers.""" + from ignite.distributed.auto import auto_model + + model = _ModelWithBN() + with pytest.raises(ValueError, match="mutually exclusive"): + auto_model(model, use_fsdp=True, sync_bn=True) + + def test_use_fsdp_sync_bn_false_does_not_raise(self) -> None: + """use_fsdp=True without sync_bn=True must NOT raise.""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + # In non-distributed mode FSDP wrapping is skipped; the call must succeed. + returned = auto_model(model, use_fsdp=True, sync_bn=False) + assert returned is not None + + # ------------------------------------------------------------------ + # 1b. Non-distributed fallback: model returned unchanged (or DataParallel) + # ------------------------------------------------------------------ + + def test_use_fsdp_non_distributed_returns_nn_module(self) -> None: + """In non-distributed context auto_model must return an nn.Module.""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + result = auto_model(model, use_fsdp=True) + assert isinstance(result, nn.Module) + + def test_use_fsdp_false_non_distributed_returns_nn_module(self) -> None: + """Baseline: use_fsdp=False still returns an nn.Module in non-dist mode.""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + result = auto_model(model, use_fsdp=False) + assert isinstance(result, nn.Module) + + def test_use_fsdp_default_is_false(self) -> None: + """use_fsdp must default to False (no behaviour change from old callers).""" + import inspect + from ignite.distributed.auto import auto_model + + sig = inspect.signature(auto_model) + assert "use_fsdp" in sig.parameters, "use_fsdp parameter missing from auto_model" + assert sig.parameters["use_fsdp"].default is False + + def test_use_fsdp_non_distributed_preserves_model_output(self) -> None: + """Model wrapped via auto_model(use_fsdp=True) must still forward correctly. + + Note: on a machine with >1 GPU, auto_model wraps with DataParallel (non-dist + multi-GPU path), which moves model parameters to CUDA. Inputs must be on the + same device as the wrapped model to forward correctly. + """ + from ignite.distributed.auto import auto_model + + torch.manual_seed(0) + model = _SimpleLinear() + wrapped = auto_model(model, use_fsdp=True) + + # Determine the device the model ended up on after wrapping. + device = next(wrapped.parameters()).device + x = torch.ones(2, 4, device=device) + actual = wrapped(x) + assert actual.shape == (2, 2), f"Unexpected output shape: {actual.shape}" + assert actual.dtype == torch.float32 + + def test_sync_bn_without_fsdp_non_distributed_does_not_raise(self) -> None: + """sync_bn=True alone (no use_fsdp) must remain valid in non-dist mode.""" + from ignite.distributed.auto import auto_model + + model = _ModelWithBN() + # No distributed backend, so SyncBN conversion is skipped – must not raise. + result = auto_model(model, sync_bn=True) + assert isinstance(result, nn.Module) + + +# --------------------------------------------------------------------------- +# Section 2: checkpoint.py FSDP imports +# --------------------------------------------------------------------------- + +class TestCheckpointFSDPImports: + """Verify the HAVE_FSDP2 flag and associated symbols are importable.""" + + def test_have_fsdp_is_true(self) -> None: + """HAVE_FSDP2 must be True when torch.distributed._composable.fsdp is present.""" + from ignite.handlers.checkpoint import HAVE_FSDP2 + + assert HAVE_FSDP2 is True, ( + "HAVE_FSDP2 is False — torch.distributed._composable.fsdp may be unavailable in this environment" + ) + + def test_fsdp2_symbols_importable_from_checkpoint_module(self) -> None: + """FSDPModule, get_model_state_dict and set_model_state_dict must be reachable.""" + # checkpoint.py imports these at module level inside the try/except; + # confirm they are accessible from the installed torch build. + from torch.distributed._composable.fsdp import FSDPModule # noqa: F401 + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict # noqa: F401 + + def test_checkpoint_module_imports_cleanly(self) -> None: + """The checkpoint module itself must import without errors.""" + import importlib + import ignite.handlers.checkpoint as chkpt_mod # noqa: F401 + + importlib.reload(chkpt_mod) # force re-import to surface any top-level errors + + +# --------------------------------------------------------------------------- +# Section 3: Non-FSDP checkpoint save/load regression +# --------------------------------------------------------------------------- + +class TestCheckpointNonFSDPRegression: + """Ensure the existing (non-FSDP) checkpoint save/load path is intact.""" + + def test_save_and_load_plain_model(self) -> None: + """Save a plain nn.Module checkpoint and reload it successfully.""" + from ignite.engine import Engine, Events + from ignite.handlers import Checkpoint, DiskSaver + + torch.manual_seed(42) + model = _SimpleLinear() + original_weight = model.fc.weight.data.clone() + + trainer = Engine(lambda e, b: None) + + with tempfile.TemporaryDirectory() as tmpdir: + saver = DiskSaver(tmpdir, create_dir=False, require_empty=False) + handler = Checkpoint({"model": model}, saver, n_saved=1) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + trainer.run([0], max_epochs=1) + + saved_files = list(Path(tmpdir).glob("*.pt")) + assert len(saved_files) == 1, f"Expected 1 checkpoint file, got {saved_files}" + + # Corrupt the model weights, then restore from checkpoint. + model.fc.weight.data.fill_(0.0) + assert not torch.allclose(model.fc.weight.data, original_weight) + + checkpoint = torch.load(saved_files[0], weights_only=True) + Checkpoint.load_objects({"model": model}, checkpoint) + assert torch.allclose(model.fc.weight.data, original_weight), ( + "Loaded weights do not match original — regression in non-FSDP load path" + ) + + def test_save_and_load_plain_model_via_filepath_string(self) -> None: + """load_objects also accepts a filepath string — verify this regression path.""" + from ignite.engine import Engine, Events + from ignite.handlers import Checkpoint, DiskSaver + + torch.manual_seed(7) + model = _SimpleLinear() + original_weight = model.fc.weight.data.clone() + + trainer = Engine(lambda e, b: None) + + with tempfile.TemporaryDirectory() as tmpdir: + saver = DiskSaver(tmpdir, create_dir=False, require_empty=False) + handler = Checkpoint({"model": model}, saver, n_saved=1) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + trainer.run([0], max_epochs=1) + + saved_files = list(Path(tmpdir).glob("*.pt")) + assert len(saved_files) == 1 + + model.fc.weight.data.fill_(0.0) + + # Load via string path instead of dict. + Checkpoint.load_objects({"model": model}, str(saved_files[0])) + assert torch.allclose(model.fc.weight.data, original_weight) + + def test_save_and_load_ddp_wrapped_model(self) -> None: + """Checkpoint must unwrap DataParallel and save the inner module's state.""" + from ignite.handlers import Checkpoint, DiskSaver + + torch.manual_seed(99) + model = _SimpleLinear() + + # Only wrap with DataParallel if multiple GPUs are present; else bare model. + if torch.cuda.device_count() > 1: + wrapped = nn.DataParallel(model) + else: + # Simulate DP wrapping using a stub that exposes .module. + class _FakeDP(nn.Module): + def __init__(self, m: nn.Module) -> None: + super().__init__() + self.module = m + + def state_dict(self, **kw): # type: ignore[override] + return self.module.state_dict(**kw) + + def load_state_dict(self, sd, **kw): # type: ignore[override] + return self.module.load_state_dict(sd, **kw) + + wrapped = _FakeDP(model) + + original_weight = model.fc.weight.data.clone() + + with tempfile.TemporaryDirectory() as tmpdir: + saver = DiskSaver(tmpdir, create_dir=False, require_empty=False) + # Checkpoint with the plain model (not wrapped) — common usage pattern. + handler = Checkpoint({"model": model}, saver, n_saved=1) + from ignite.engine import Engine, Events + trainer = Engine(lambda e, b: None) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + trainer.run([0], max_epochs=1) + + saved_files = list(Path(tmpdir).glob("*.pt")) + assert len(saved_files) == 1 + + model.fc.weight.data.fill_(0.0) + checkpoint = torch.load(saved_files[0], weights_only=True) + Checkpoint.load_objects({"model": model}, checkpoint) + assert torch.allclose(model.fc.weight.data, original_weight) + + def test_load_objects_from_dict_checkpoint(self) -> None: + """Checkpoint.load_objects must work when passed a raw state-dict mapping. + + Note: torch.nn.Module.state_dict() returns tensors that SHARE storage with + the model's parameters (they are views, not copies). We must deepcopy the + state_dict before mutating the model to avoid corrupting the saved snapshot. + """ + import copy + + torch.manual_seed(5) + model = _SimpleLinear() + # Use deepcopy so the saved snapshot is independent of the live model. + saved_sd = copy.deepcopy(model.state_dict()) + original_weight = saved_sd["fc.weight"].clone() + + model.fc.weight.data.fill_(99.0) + + from ignite.handlers import Checkpoint + Checkpoint.load_objects({"model": model}, {"model": saved_sd}) + assert torch.allclose(model.fc.weight.data, original_weight), ( + "load_objects did not restore weights from dict checkpoint" + ) + + def test_load_objects_single_key_direct_state_dict(self) -> None: + """When to_load has one key absent from checkpoint, load_objects falls + back to treating the whole checkpoint as the state_dict directly. + + Note: torch.nn.Module.state_dict() shares storage with live parameters; + use deepcopy before mutating the model. + """ + import copy + + torch.manual_seed(3) + model = _SimpleLinear() + saved_sd = copy.deepcopy(model.state_dict()) + original_weight = saved_sd["fc.weight"].clone() + + model.fc.weight.data.fill_(99.0) + + from ignite.handlers import Checkpoint + # Pass the bare state_dict without the "model" key wrapper. + Checkpoint.load_objects({"model": model}, saved_sd) + assert torch.allclose(model.fc.weight.data, original_weight), ( + "load_objects did not restore weights from single-key direct state_dict" + ) + + +# --------------------------------------------------------------------------- +# Section 4: Edge-case / boundary checks +# --------------------------------------------------------------------------- + +class TestEdgeCases: + """Additional edge cases identified during change analysis.""" + + def test_use_fsdp_true_sync_bn_false_explicitly(self) -> None: + """Explicit sync_bn=False with use_fsdp=True must not raise.""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + result = auto_model(model, use_fsdp=True, sync_bn=False) + assert isinstance(result, nn.Module) + + def test_auto_model_empty_model_use_fsdp(self) -> None: + """auto_model with use_fsdp=True on a model with no parameters must not crash.""" + from ignite.distributed.auto import auto_model + + class _Empty(nn.Module): + def forward(self, x): # type: ignore[override] + return x + + model = _Empty() + result = auto_model(model, use_fsdp=True) + assert isinstance(result, nn.Module) + + def test_use_fsdp_kwargs_passthrough_non_distributed(self) -> None: + """Extra kwargs must be silently ignored in non-distributed mode + (they would be forwarded to FSDP constructor only when world_size > 1).""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + # In non-dist mode, FSDP is not instantiated, so representative FSDP kwargs + # such as ``reshard_after_forward`` should be accepted and ignored. + result = auto_model(model, use_fsdp=True, reshard_after_forward=False) + assert isinstance(result, nn.Module) + + def test_have_fsdp_flag_is_boolean(self) -> None: + """HAVE_FSDP2 must be a plain Python bool (not None or a module).""" + from ignite.handlers.checkpoint import HAVE_FSDP2 + + assert isinstance(HAVE_FSDP2, bool) + + def test_setup_checkpoint_returns_dict_for_plain_model(self) -> None: + """_setup_checkpoint on a non-FSDP model must return a non-empty dict.""" + from ignite.handlers.checkpoint import Checkpoint + + model = _SimpleLinear() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + chkpt = Checkpoint( + to_save={"model": model, "optimizer": optimizer}, + save_handler=lambda chk, fn, meta=None: None, + ) + result = chkpt._setup_checkpoint() + assert isinstance(result, dict) + assert "model" in result + assert "optimizer" in result + + def test_value_error_message_mentions_fsdp_and_sync_bn(self) -> None: + """Error message must contain enough context for users to understand + what went wrong — mention of both FSDP and SyncBatchNorm.""" + from ignite.distributed.auto import auto_model + + model = _SimpleLinear() + with pytest.raises(ValueError) as exc_info: + auto_model(model, use_fsdp=True, sync_bn=True) + + msg = str(exc_info.value).lower() + assert "fsdp" in msg or "fully" in msg, f"Error message lacks FSDP mention: {exc_info.value}" + assert "sync" in msg or "batchnorm" in msg or "bn" in msg, ( + f"Error message lacks SyncBN mention: {exc_info.value}" + ) + + def test_checkpoint_load_objects_invalid_type_raises_type_error(self) -> None: + """Passing an invalid checkpoint type must raise TypeError (regression guard).""" + from ignite.handlers import Checkpoint + + model = _SimpleLinear() + with pytest.raises(TypeError): + Checkpoint.load_objects({"model": model}, 12345) # type: ignore[arg-type]