diff --git a/CHANGELOG.md b/CHANGELOG.md index a8976a91e..2dc54b5d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `OverrideDecay`, a late-stage decay override usable on both `ComposableScheduler` and `SequentialScheduler` via an `override_decay` field. When `current >= override_decay.start`, the main schedule is interrupted mid-flight and the LR decays from the value the main schedule would have produced at `start` to a target LR over `duration` (linear or cosine). `SequentialScheduler` additionally warns that `t_max` is ignored once the override becomes active. - `OLMO_RICH_LOGGING` can now explicitly enable *or* disable rich console logging (`0`/`false`/`no`/`off` disables it); previously setting it to any value only force-enabled rich logging. - `init_distributed()` now bootstraps a minimal single-process environment (`RANK=0`, `WORLD_SIZE=1`, `MASTER_ADDR`/`MASTER_PORT`) when launch env vars are absent, so scripts can be run directly (without `torchrun`) for single-process debugging. +- Added `MultiGroupDistributedDataParallel` (`olmo_core.nn.parallel`), a data-parallel wrapper that accumulates gradients into flat bucket views and supports per-parameter process groups (`param_process_group_fn`), overlapped bucketed all-reduce (finalized via `finalize_grad_reduce()`), and optional fp32 gradient accumulation/reduction. ### Fixed diff --git a/docs/source/nn/index.rst b/docs/source/nn/index.rst index d98ac56b2..4408f1b13 100644 --- a/docs/source/nn/index.rst +++ b/docs/source/nn/index.rst @@ -15,5 +15,6 @@ layer_norm lm_head moe + parallel rope transformer diff --git a/docs/source/nn/parallel.rst b/docs/source/nn/parallel.rst new file mode 100644 index 000000000..834e1da9e --- /dev/null +++ b/docs/source/nn/parallel.rst @@ -0,0 +1,6 @@ +``nn.parallel`` +=============== + +.. automodule:: olmo_core.nn.parallel + :members: + :member-order: bysource diff --git a/src/olmo_core/nn/parallel/__init__.py b/src/olmo_core/nn/parallel/__init__.py new file mode 100644 index 000000000..8764ea150 --- /dev/null +++ b/src/olmo_core/nn/parallel/__init__.py @@ -0,0 +1,7 @@ +""" +Data-parallel wrappers. +""" + +from .distributed import MultiGroupDistributedDataParallel + +__all__ = ["MultiGroupDistributedDataParallel"] diff --git a/src/olmo_core/nn/parallel/distributed.py b/src/olmo_core/nn/parallel/distributed.py new file mode 100644 index 000000000..8f58192f5 --- /dev/null +++ b/src/olmo_core/nn/parallel/distributed.py @@ -0,0 +1,715 @@ +# mypy: allow-untyped-defs + +import logging +import os +from collections import OrderedDict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +from torch.nn.modules import Module + +if dist.is_available(): + from torch.distributed.distributed_c10d import ReduceOp, _get_default_group + from torch.distributed.utils import _verify_param_shape_across_processes +else: + _get_default_group = None + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +__all__ = ["MultiGroupDistributedDataParallel"] + +logger = logging.getLogger(__name__) + + +@dataclass +class _GradBucket: + process_group: Any + storage_dtype: torch.dtype + comm_dtype: torch.dtype + params: list[torch.nn.Parameter] + ranges: list[tuple[int, int]] + numel: int + flat_storage: torch.Tensor + flat_comm: Optional[torch.Tensor] + + +class MultiGroupDistributedDataParallel(Module): + """ + A data-parallel module wrapper that all-reduces gradients across one or more process groups. + + Unlike :class:`torch.nn.parallel.DistributedDataParallel`, gradients are accumulated into + flat, pre-allocated **bucket views** (each parameter's ``.grad`` is a view into a contiguous + bucket), and different parameters may be assigned to different process groups via + :data:`param_process_group_fn` — useful when subsets of the model are replicated over + different device meshes. Bucketed all-reduces are launched (overlapped with the backward + pass) as each bucket fills, and :meth:`finalize_grad_reduce` must be called after + ``loss.backward()`` to wait for them to complete. + + Optionally, gradients can be accumulated and/or reduced in fp32 for numerical stability with + low-precision parameters. + + Usage:: + + ddp = MultiGroupDistributedDataParallel(model) + out = ddp(inputs) + out.loss.backward() + ddp.finalize_grad_reduce() # required: waits for the overlapped all-reduces + optimizer.step() + + :param module: The module to wrap. All parameters must be materialized (no meta/uninitialized + params) and the module must not have buffers when using :data:`param_process_group_fn`. + :param init_sync: If ``True``, broadcast parameters from rank 0 within each process group at + construction so all ranks start from identical weights. + :param process_group: The default process group. Defaults to the global group. + :param bucket_cap_mb: Soft cap (in MiB) on the size of each gradient bucket. + :param param_process_group_fn: Optional ``(name, param) -> process_group`` mapping assigning + parameters to process groups. Defaults to a single (default) process group for all params. + :param accumulate_grads_in_fp32: Accumulate gradients into an fp32 buffer (requires + :data:`reduce_grads_in_fp32`). + :param reduce_grads_in_fp32: All-reduce gradients in fp32 regardless of parameter dtype. + """ + + def __init__( + self, + module, + dim=0, + init_sync=True, + process_group=None, + bucket_cap_mb=None, + param_process_group_fn=None, + accumulate_grads_in_fp32=False, + reduce_grads_in_fp32=False, + ): + super().__init__() + + if process_group is None: + if _get_default_group is None: + self.process_group = None + else: + self.process_group = _get_default_group() + else: + self.process_group = process_group + + if hasattr(module, "_ddp_params_and_buffers_to_ignore"): + self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) + else: + self.parameters_to_ignore = set() + + self._module_parameters = [ + p for n, p in module.named_parameters() if n not in self.parameters_to_ignore + ] + + self._param_to_name = {p: n for n, p in module.named_parameters()} + + # this is the order to launch grad reduce + self._reversed_module_parameters = list(reversed(self._module_parameters)) + + if not any(p.requires_grad for p in self._module_parameters): + raise RuntimeError( + "DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.", + ) + + is_multi_device_module = len({p.device for p in self._module_parameters}) > 1 + if is_multi_device_module: + raise NotImplementedError( + "DistributedDataParallel with parameters on multiple devices is not supported yet." + ) + distinct_device_types = { + p.device.type for p in self._module_parameters if p.device is not None + } + + self.device_type = next(iter(distinct_device_types)) + + self.dim = dim + self.module = module + self.device = next(iter(self._module_parameters)).device + self.require_backward_grad_sync = True + self.require_forward_param_sync = True + self.overlap_grad_reduce = True + + # Multi-process-group support. + if param_process_group_fn is None: + # default to single process group + def param_process_group_fn(_name, _param): + return self.process_group + + self._param_process_group_fn = param_process_group_fn + + self._accumulate_grads_in_fp32 = accumulate_grads_in_fp32 + self._reduce_grads_in_fp32 = reduce_grads_in_fp32 + + if self._accumulate_grads_in_fp32 and not self._reduce_grads_in_fp32: + raise ValueError("accumulate_grads_in_fp32 requires reduce_grads_in_fp32 to be True") + + # Check that a module does not have Uninitialized parameters + for param in self._module_parameters: + if isinstance(param, torch.nn.parameter.UninitializedParameter): + raise RuntimeError( + "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. " + "Run a dummy forward pass to correctly initialize the modules", + ) + # disable meta device parameters + if param.device.type == "meta": + raise RuntimeError( + "Modules with meta device parameters can't be used with `DistributedDataParallel`. " + "Please initialize all parameters before wrapping with DistributedDataParallel.", + ) + # used for intra-node param sync and inter-node sync as well + self.broadcast_bucket_size = int(250 * 1024 * 1024) + + # reduction bucket size + if bucket_cap_mb is None: + # default case (bucket cap is 250 MiB) + bucket_cap_mb = 250 + + self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + + # Whether to perform input tensor CPU to GPU copies on a side-stream + self.use_side_stream_for_tensor_copies = ( + os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" + ) + + # build param <-> process group mapping + self.process_group_to_params: Dict[Any, List[torch.nn.Parameter]] = {} + self.param_to_process_group: Dict[torch.nn.Parameter, Any] = {} + + named_buffers = list(self.module.named_buffers()) + if len(named_buffers) > 0: + raise NotImplementedError( + "DDP with param_process_group_fn does not support buffers yet." + ) + + for name, param in self.module.named_parameters(): + if name in self.parameters_to_ignore: + continue + pg = self._param_process_group_fn(name, param) + + if pg is None: + pg = self.process_group + + self.param_to_process_group[param] = pg + if pg not in self.process_group_to_params: + self.process_group_to_params[pg] = [] + self.process_group_to_params[pg].append(param) + + if init_sync: + self.init_sync() + + self._comm_hooks: list[tuple[object, object]] = [] + + self._grad_buckets: list[_GradBucket] = [] + self._param_to_bucket_idx: Dict[torch.nn.Parameter, int] = {} + self._param_to_bucket_view: Dict[torch.nn.Parameter, torch.Tensor] = {} + self._bucket_ready_count: list[int] = [] + self._grad_reduce_hooks: list[tuple[Any, int]] = [] + self._grad_views_need_rebind = False + self._warned_grad_view_rebind = False + self._forwards_since_finalize = 0 + + self._build_grad_buckets() + self._bind_bucket_views(zero_buffers=True, reason="initialization") + + self._fp32_acc_hooks = [] + if self._accumulate_grads_in_fp32: + for p in module.parameters(): + if not p.requires_grad: + continue + self._fp32_acc_hooks.append( + p.register_post_accumulate_grad_hook(self._fp32_post_grad_acc_hook) + ) + + # Register the AccumulateGrad post hooks that drive this wrapper's + # own bucket readiness/all-reduce path. + self._accum_grad_hooks: list[RemovableHandle] = [] + + self._param_grad_ready: OrderedDict[torch.nn.Parameter, bool] = OrderedDict() + self._next_reduce_bucket_idx = 0 + + # the hook that controls gradient allreduce + self._register_accum_grad_hook() + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # nn.Module logic first + except AttributeError: + return getattr(self.module, name) + + def __getitem__(self, key: int) -> Any: + return self.module.__getitem__(key) # type: ignore[operator] + + def _get_storage_dtype_for_param(self, param: torch.nn.Parameter) -> torch.dtype: + if self._accumulate_grads_in_fp32: + return torch.float32 + return param.dtype + + def _get_comm_dtype_for_storage(self, storage_dtype: torch.dtype) -> torch.dtype: + if self._reduce_grads_in_fp32: + return torch.float32 + return storage_dtype + + def _get_param_grad_buffer(self, param: torch.nn.Parameter) -> Optional[torch.Tensor]: + if self._accumulate_grads_in_fp32: + return getattr(param, "_main_grad_fp32", None) + return param.grad + + def _set_param_grad_buffer( + self, param: torch.nn.Parameter, buffer: Optional[torch.Tensor] + ) -> None: + if self._accumulate_grads_in_fp32: + param._main_grad_fp32 = buffer # type: ignore[attr-defined] + else: + param.grad = buffer + + def _is_expected_grad_view(self, current: torch.Tensor, expected: torch.Tensor) -> bool: + return ( + current.dtype == expected.dtype + and current.device == expected.device + and current.shape == expected.shape + and current.stride() == expected.stride() + and current.data_ptr() == expected.data_ptr() + ) + + def _build_grad_buckets(self) -> None: + params_in_reduce_order = [p for p in self._reversed_module_parameters if p.requires_grad] + + current_params: list[torch.nn.Parameter] = [] + current_ranges: list[tuple[int, int]] = [] + current_numel = 0 + current_bucket_bytes = 0 + current_process_group = None + current_storage_dtype: Optional[torch.dtype] = None + current_comm_dtype: Optional[torch.dtype] = None + + def flush_current_bucket() -> None: + nonlocal current_params + nonlocal current_ranges + nonlocal current_numel + nonlocal current_bucket_bytes + nonlocal current_process_group + nonlocal current_storage_dtype + nonlocal current_comm_dtype + + if not current_params: + return + + assert current_storage_dtype is not None + assert current_comm_dtype is not None + assert current_process_group is not None + + flat_storage = torch.zeros( + current_numel, device=self.device, dtype=current_storage_dtype + ) + flat_comm = None + if current_comm_dtype != current_storage_dtype: + flat_comm = torch.empty(current_numel, device=self.device, dtype=current_comm_dtype) + + bucket_idx = len(self._grad_buckets) + self._grad_buckets.append( + _GradBucket( + process_group=current_process_group, + storage_dtype=current_storage_dtype, + comm_dtype=current_comm_dtype, + params=list(current_params), + ranges=list(current_ranges), + numel=current_numel, + flat_storage=flat_storage, + flat_comm=flat_comm, + ) + ) + + for param, (start, end) in zip(current_params, current_ranges): + self._param_to_bucket_idx[param] = bucket_idx + self._param_to_bucket_view[param] = flat_storage[start:end].view_as(param) + + current_params = [] + current_ranges = [] + current_numel = 0 + current_bucket_bytes = 0 + current_process_group = None + current_storage_dtype = None + current_comm_dtype = None + + for param in params_in_reduce_order: + process_group = self.param_to_process_group[param] + storage_dtype = self._get_storage_dtype_for_param(param) + comm_dtype = self._get_comm_dtype_for_storage(storage_dtype) + param_bytes = param.numel() * torch.empty((), dtype=comm_dtype).element_size() + + should_flush = current_params and ( + process_group is not current_process_group + or storage_dtype != current_storage_dtype + or comm_dtype != current_comm_dtype + or (current_bucket_bytes + param_bytes > self.bucket_bytes_cap) + ) + if should_flush: + flush_current_bucket() + + start = current_numel + end = start + param.numel() + current_params.append(param) + current_ranges.append((start, end)) + current_numel = end + current_bucket_bytes += param_bytes + current_process_group = process_group + current_storage_dtype = storage_dtype + current_comm_dtype = comm_dtype + + flush_current_bucket() + self._bucket_ready_count = [0 for _ in self._grad_buckets] + + def _bind_bucket_views(self, *, zero_buffers: bool, reason: str) -> None: + if zero_buffers: + for bucket in self._grad_buckets: + bucket.flat_storage.zero_() + + rebound_count = 0 + none_count = 0 + for param, expected_view in self._param_to_bucket_view.items(): + current = self._get_param_grad_buffer(param) + if current is None: + none_count += 1 + elif not self._is_expected_grad_view(current, expected_view): + raise RuntimeError( + "Detected an external gradient tensor replacement that breaks bucket views. " + "This mode requires grads to remain bucket views." + ) + + if current is not expected_view: + self._set_param_grad_buffer(param, expected_view) + rebound_count += 1 + + # In fp32-accum mode, .grad must stay ephemeral and be consumed by the post-acc hook. + if self._accumulate_grads_in_fp32: + param.grad = None + + self._grad_views_need_rebind = False + if rebound_count > 0 and reason != "initialization": + if not self._warned_grad_view_rebind: + logger.warning( + f"Rebound {rebound_count} gradient bucket view(s) because buffers were detached " + f"(reason={reason}, none={none_count})." + ) + self._warned_grad_view_rebind = True + + def _ensure_grad_views_bound(self, *, allow_none_rebind: bool, where: str) -> None: + if self._grad_views_need_rebind: + if self._forwards_since_finalize > 0: + raise RuntimeError( + f"Gradient buckets were marked for rebind in {where} after the training step started. " + "This usually indicates set_to_none=True between micro-batches." + ) + self._bind_bucket_views(zero_buffers=True, reason=f"flagged@{where}") + return + + has_none_binding = False + for param, expected_view in self._param_to_bucket_view.items(): + current = self._get_param_grad_buffer(param) + if current is None: + has_none_binding = True + continue + if not self._is_expected_grad_view(current, expected_view): + raise RuntimeError( + f"Gradient view integrity check failed in {where}: gradient tensor was replaced " + "with a non-bucket-view tensor." + ) + + if has_none_binding: + if not allow_none_rebind: + raise RuntimeError( + f"Found None gradient buffers in {where}. This usually means set_to_none=True ran " + "at an unexpected time. Rebind before backward by running a forward pass." + ) + if self._forwards_since_finalize > 0 and not self._grad_views_need_rebind: + raise RuntimeError( + f"Found detached gradient buffers in {where} after this training step already started. " + "Rebinding now would lose accumulated grads. This usually indicates an unexpected " + "set_to_none=True between micro-batches." + ) + self._bind_bucket_views(zero_buffers=True, reason=f"none@{where}") + + def _fp32_post_grad_acc_hook(self, param: torch.Tensor): + g = param.grad + if g is None: + return + + expected_view = self._param_to_bucket_view[param] + main_grad = getattr(param, "_main_grad_fp32", None) + if main_grad is None: + raise RuntimeError( + "FP32 grad bucket view is missing during backward. " + "Likely caused by set_to_none=True after forward began." + ) + if not self._is_expected_grad_view(main_grad, expected_view): + raise RuntimeError( + "FP32 grad buffer is not the expected bucket view. " + "External grad buffer replacement is not supported in bucket-view mode." + ) + + main_grad.add_(g) + param.grad = None + + def _launch_bucket_all_reduce(self, bucket_idx: int) -> None: + if self._comm_hooks: + raise NotImplementedError("Comm hooks are not implemented in bucket-view mode.") + + bucket = self._grad_buckets[bucket_idx] + world_size = bucket.process_group.size() + + if bucket.storage_dtype == bucket.comm_dtype: + tensor_for_reduce = bucket.flat_storage + tensor_for_reduce.div_(world_size) + else: + assert bucket.flat_comm is not None + bucket.flat_comm.copy_(bucket.flat_storage) + bucket.flat_comm.div_(world_size) + tensor_for_reduce = bucket.flat_comm + + handle = torch.distributed.all_reduce( + tensor_for_reduce, op=ReduceOp.SUM, group=bucket.process_group, async_op=True + ) + self._grad_reduce_hooks.append((handle, bucket_idx)) + + def _maybe_kick_start_all_reduce(self): + while self._next_reduce_bucket_idx < len(self._grad_buckets): + bucket = self._grad_buckets[self._next_reduce_bucket_idx] + if self._bucket_ready_count[self._next_reduce_bucket_idx] < len(bucket.params): + break + + self._launch_bucket_all_reduce(self._next_reduce_bucket_idx) + self._next_reduce_bucket_idx += 1 + + def _register_accum_grad_hook(self): + def notify_grad_ready( + param, + ): + if not self.require_backward_grad_sync: + return + + if self._param_grad_ready[param]: + return + + self._param_grad_ready[param] = True + bucket_idx = self._param_to_bucket_idx[param] + self._bucket_ready_count[bucket_idx] += 1 + + # do this in backward + if self.overlap_grad_reduce: + self._maybe_kick_start_all_reduce() + + # otherwise, leave the all-reduce to finalize_grad_reduce + + for index, param in enumerate(self._module_parameters): + if not param.requires_grad: + continue + + # set up param order + self._param_grad_ready[param] = False + + # NOTE: in order to ensure param grads reduce always happen in the same order, + # instead of launching all-reduce in accumulate_grad_hook + # it only notifies the grad is ready + # and the actual all-reduce is kicked off in _maybe_kick_start_all_reduce + # based on what grads are ready + self._accum_grad_hooks.append( + param.register_post_accumulate_grad_hook(notify_grad_ready) + ) + + def finalize_grad_reduce(self): + # Grad buffers should already be bound before backward starts. If they are detached here, + # we cannot safely rebind without risking silent corruption. + self._ensure_grad_views_bound(allow_none_rebind=False, where="finalize_grad_reduce") + + # in some cases (eg, imbalance moe routing), some params may not have grads, and their + # post_accumulate_grad_hook is never called, so their grad_ready is never set to True. + if self._next_reduce_bucket_idx < len(self._grad_buckets): + for param in self._param_grad_ready.keys(): + if not self._param_grad_ready[param]: + self._param_grad_ready[param] = True + bucket_idx = self._param_to_bucket_idx[param] + self._bucket_ready_count[bucket_idx] += 1 + + # Keep missing grads explicitly zero in the bucket view. + self._param_to_bucket_view[param].zero_() + + self._maybe_kick_start_all_reduce() + + # now all grad reduce should have been launched + assert self._next_reduce_bucket_idx == len(self._grad_buckets), ( + f"Not all bucket all-reduce operations were launched: " + f"{self._next_reduce_bucket_idx} vs {len(self._grad_buckets)}" + ) + + for idx, (handle, bucket_idx) in enumerate(self._grad_reduce_hooks): + handle.wait() + bucket = self._grad_buckets[bucket_idx] + if bucket.flat_comm is not None: + bucket.flat_storage.copy_(bucket.flat_comm) + self._grad_reduce_hooks = [] + + self._next_reduce_bucket_idx = 0 + + # mark all grads as not ready + for key in self._param_grad_ready.keys(): + self._param_grad_ready[key] = False + for bucket_idx in range(len(self._bucket_ready_count)): + self._bucket_ready_count[bucket_idx] = 0 + self._forwards_since_finalize = 0 + self._sync_module_logical_grads_from_anchor() + + def init_sync(self): + for process_group, parameters in self.process_group_to_params.items(): + # Verify model equivalence. + _verify_param_shape_across_processes(process_group, parameters) + + for param in parameters: + dist.broadcast( + param.data, + src=dist.get_global_rank(process_group, 0), + group=process_group, + async_op=False, + ) + + def __getstate__(self): + # TODO: review if this works with multi-process-group DDP + raise NotImplementedError("DDP serialization is not implemented.") + + def __setstate__(self, state): + # TODO: review if this works with multi-process-group DDP + raise NotImplementedError("DDP serialization is not implemented.") + # If serializable, then the process group should be the default one + + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient synchronizations across DDP processes. + + Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) + >>> with ddp.no_sync(): + >>> for input in inputs: + >>> ddp(input).backward() # no synchronization, accumulate grads + >>> ddp(another_input).backward() # synchronize grads + + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. + """ + old_require_backward_grad_sync = self.require_backward_grad_sync + self.require_backward_grad_sync = False + try: + yield + finally: + self.require_backward_grad_sync = old_require_backward_grad_sync + + def _pre_forward(self, *inputs, **kwargs): + self._ensure_grad_views_bound(allow_none_rebind=True, where="forward") + self._forwards_since_finalize += 1 + return inputs, kwargs + + def _post_forward(self, output): + return output + + def forward(self, *inputs, **kwargs): + with torch.autograd.profiler.record_function("MultiGroupDistributedDataParallel.forward"): + inputs, kwargs = self._pre_forward(*inputs, **kwargs) + output = self.module(*inputs, **kwargs) + + return self._post_forward(output) + + def train(self, mode=True): + super().train(mode) + return self + + def _zero_module_logical_grads(self, set_to_none: bool) -> None: + zero_logical_grads = getattr(self.module, "zero_fp8_weight_store_grads", None) + if zero_logical_grads is None: + zero_logical_grads = getattr(self.module, "zero_mxfp8_expert_weight_grads", None) + if zero_logical_grads is not None: + zero_logical_grads(set_to_none=set_to_none) + + def _sync_module_logical_grads_from_anchor(self) -> None: + sync_logical_grads = getattr( + self.module, + "sync_fp8_weight_store_grads_from_anchor", + None, + ) + if sync_logical_grads is None: + sync_logical_grads = getattr( + self.module, + "sync_mxfp8_expert_weight_grads_from_anchor", + None, + ) + if sync_logical_grads is not None: + sync_logical_grads() + + def zero_grad(self, set_to_none: bool = True): + if not set_to_none: + # Fast path for bucket-view mode: zero bucket storage once and keep view bindings. + if self._grad_views_need_rebind: + self._bind_bucket_views(zero_buffers=True, reason="zero_grad") + else: + for bucket in self._grad_buckets: + bucket.flat_storage.zero_() + # In fp32-accum mode, .grad remains ephemeral. + if self._accumulate_grads_in_fp32: + for param in self._module_parameters: + if param.requires_grad: + param.grad = None + self._forwards_since_finalize = 0 + self._zero_module_logical_grads(set_to_none=False) + return + + super().zero_grad(set_to_none=True) + if self._accumulate_grads_in_fp32: + for param in self._module_parameters: + if not param.requires_grad: + continue + param._main_grad_fp32 = None # type: ignore[attr-defined] + + self._forwards_since_finalize = 0 + self._grad_views_need_rebind = True + self._zero_module_logical_grads(set_to_none=True) + + def set_main_grads_to_none(self): + if hasattr(self.module, "set_main_grads_to_none"): + self.module.set_main_grads_to_none() + else: + for param in self._module_parameters: + if not param.requires_grad: + continue + if hasattr(param, "_main_grad_fp32"): + param._main_grad_fp32 = None # type: ignore[attr-defined] + set_logical_main_grads_to_none = getattr( + self.module, + "set_fp8_weight_store_main_grads_to_none", + None, + ) + if set_logical_main_grads_to_none is None: + set_logical_main_grads_to_none = getattr( + self.module, + "set_mxfp8_expert_weight_main_grads_to_none", + None, + ) + if set_logical_main_grads_to_none is not None: + set_logical_main_grads_to_none() + self._grad_views_need_rebind = True + self._forwards_since_finalize = 0 + + def register_comm_hook(self, state: object, hook: Callable): + raise NotImplementedError + + @property + def _distributed_rank(self): + return dist.get_rank(self.process_group) diff --git a/src/test/nn/parallel/__init__.py b/src/test/nn/parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/test/nn/parallel/distributed_test.py b/src/test/nn/parallel/distributed_test.py new file mode 100644 index 000000000..50e19188d --- /dev/null +++ b/src/test/nn/parallel/distributed_test.py @@ -0,0 +1,115 @@ +import copy + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +from olmo_core.nn.parallel import MultiGroupDistributedDataParallel +from olmo_core.testing import BACKENDS, run_distributed_test +from olmo_core.utils import seed_all + + +class SimpleModel(nn.Module): + def __init__(self, d_in: int, d_hidden: int, d_out: int): + super().__init__() + self.fc1 = nn.Linear(d_in, d_hidden) + self.fc2 = nn.Linear(d_hidden, d_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(torch.relu(self.fc1(x))) + + +def _device_for_backend() -> torch.device: + if dist.get_backend() == "nccl": + device = torch.device(f"cuda:{dist.get_rank()}") + torch.cuda.set_device(device) + return device + return torch.device("cpu") + + +def _reference_grads(model: nn.Module, world_size: int): + """Manually all-reduce-average each parameter's local grad as the DDP reference.""" + grads = [] + for p in model.parameters(): + assert p.grad is not None + g = p.grad.detach().clone() + dist.all_reduce(g, op=dist.ReduceOp.SUM) + g /= world_size + grads.append(g) + return grads + + +def _run_grad_parity(d_in: int, d_hidden: int, d_out: int): + device = _device_for_backend() + rank, world_size = dist.get_rank(), dist.get_world_size() + + # Identical init across ranks, so init_sync isn't needed. + seed_all(0) + model = SimpleModel(d_in, d_hidden, d_out).to(device) + reference = copy.deepcopy(model) + ddp = MultiGroupDistributedDataParallel(model, init_sync=False) + + # Distinct per-rank batch (data parallelism). + torch.manual_seed(100 + rank) + x = torch.randn(4, d_in, device=device) + y = torch.randn(4, d_out, device=device) + + ((ddp(x) - y) ** 2).mean().backward() + ddp.finalize_grad_reduce() + + ((reference(x) - y) ** 2).mean().backward() + expected = _reference_grads(reference, world_size) + + for (name, p), g_ref in zip(ddp.module.named_parameters(), expected): + assert p.grad is not None, f"missing grad for {name}" + torch.testing.assert_close(p.grad, g_ref, rtol=1e-5, atol=1e-6) + + +def _run_no_sync_accumulation(d_in: int, d_hidden: int, d_out: int): + device = _device_for_backend() + rank, world_size = dist.get_rank(), dist.get_world_size() + + seed_all(0) + model = SimpleModel(d_in, d_hidden, d_out).to(device) + reference = copy.deepcopy(model) + ddp = MultiGroupDistributedDataParallel(model, init_sync=False) + + torch.manual_seed(100 + rank) + xa = torch.randn(4, d_in, device=device) + ya = torch.randn(4, d_out, device=device) + xb = torch.randn(4, d_in, device=device) + yb = torch.randn(4, d_out, device=device) + + # First micro-batch accumulates without syncing; the second (synced) triggers the reduce. + with ddp.no_sync(): + ((ddp(xa) - ya) ** 2).mean().backward() + ((ddp(xb) - yb) ** 2).mean().backward() + ddp.finalize_grad_reduce() + + # Reference: accumulate both micro-batch grads locally, then all-reduce-average. + ((reference(xa) - ya) ** 2).mean().backward() + ((reference(xb) - yb) ** 2).mean().backward() + expected = _reference_grads(reference, world_size) + + for (name, p), g_ref in zip(ddp.module.named_parameters(), expected): + assert p.grad is not None, f"missing grad for {name}" + torch.testing.assert_close(p.grad, g_ref, rtol=1e-5, atol=1e-6) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_grad_parity(backend): + run_distributed_test( + _run_grad_parity, + backend=backend, + func_args=(16, 32, 8), + ) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_no_sync_accumulation(backend): + run_distributed_test( + _run_no_sync_accumulation, + backend=backend, + func_args=(16, 32, 8), + )