Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader | _MpDeviceLo
Examples:
.. code-block:: python

import ignite.distribted as idist
import ignite.distributed as idist

train_loader = idist.auto_dataloader(
train_dataset,
Expand All @@ -76,9 +76,9 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader | _MpDeviceLo
if "batch_size" in kwargs and kwargs["batch_size"] >= world_size:
kwargs["batch_size"] //= world_size

nproc = idist.get_nproc_per_node()
if "num_workers" in kwargs and kwargs["num_workers"] >= nproc:
kwargs["num_workers"] = (kwargs["num_workers"] + nproc - 1) // nproc
nprocs = idist.get_nprocs_per_node()
if "num_workers" in kwargs and kwargs["num_workers"] >= nprocs:
kwargs["num_workers"] = (kwargs["num_workers"] + nprocs - 1) // nprocs

if "batch_sampler" not in kwargs:
if isinstance(dataset, IterableDataset):
Expand Down Expand Up @@ -155,7 +155,7 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod
Args:
model: model to adapt.
sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch
distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be
distributed only. Default, False. Note, if using Nvidia/APex, batchnorm conversion should be
applied before calling ``amp.initialize``.
kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_
if applicable. Please, make sure to use acceptable kwargs for given backend.
Expand All @@ -166,23 +166,24 @@ def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Mod
Examples:
.. code-block:: python

import ignite.distribted as idist
import ignite.distributed as idist

model = idist.auto_model(model)

In addition with NVidia/Apex, it can be used in the following way:
In addition with Nvidia/APex, it can be used in the following way:

.. code-block:: python

import ignite.distribted as idist
import ignite.distributed as idist

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)

.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
.. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
.. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
.. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.
SyncBatchNorm.html#
torch.nn.SyncBatchNorm.convert_sync_batchnorm

.. versionchanged:: 0.4.2
Expand Down Expand Up @@ -242,7 +243,7 @@ def auto_optim(optimizer: Optimizer, **kwargs: Any) -> Optimizer:
Internally, this method is no-op for non-distributed and torch native distributed configuration.

For XLA distributed configuration, we create a new class that inherits from provided optimizer.
The goal is to override the `step()` method with specific `xm.optimizer_step`_ implementation.
The goal is to override the ``step()`` method with specific `xm.optimizer_step`_ implementation.

For Horovod distributed configuration, optimizer is wrapped with Horovod Distributed Optimizer and
its state is broadcasted from rank 0 to all other processes.
Expand Down Expand Up @@ -285,7 +286,7 @@ def auto_optim(optimizer: Optimizer, **kwargs: Any) -> Optimizer:


class DistributedProxySampler(DistributedSampler):
"""Distributed sampler proxy to adapt user's sampler for distributed data parallelism configuration.
"""Distributed sampler proxy to adapt user's sampler for distributed data paralellism configuration.

Code is based on https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407

Expand Down Expand Up @@ -339,7 +340,7 @@ class _MpDeviceLoader:
# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
self._loader = loader
# pyrefly: ignore [read-only]
# pyrely: ignore [read-only]
self._device = device
self._parallel_loader_kwargs = kwargs

Expand All @@ -356,4 +357,4 @@ def __init__(self, optimizer: Optimizer) -> None:
self.wrapped_optimizer = optimizer

def step(self, closure: Any = None) -> Any:
xm.optimizer_step(self.wrapped_optimizer, barrier=True)
xm.optimizer_step(self.wrapped_optimizer, barrier=True)