diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 55e0b433..5efbe061 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -13,14 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Callable, override + +import torch +from torch import distributed as dist +from torch import optim from torch.optim.optimizer import ParamsT + +if TYPE_CHECKING: + from typing import overload + from emerging_optimizers import mixin as opt_mixin -from emerging_optimizers import registry +from emerging_optimizers import registry, utils +from emerging_optimizers.scalar_optimizers import update_functions +from emerging_optimizers.soap import soap, soap_utils, tp_utils from emerging_optimizers.soap.soap import SOAP +from emerging_optimizers.utils import FP32MatmulPrecT, get_pg_rank, get_pg_size -__all__ = ["REKLS"] +__all__ = ["REKLS", "TpRekls"] @registry.register_optimizer("rekls") @@ -56,3 +68,218 @@ def __init__( use_eigh=True, use_kl_shampoo=True, ) + + +@registry.register_optimizer("tp_rekls") +class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): + """Tensor-parallel variant of :class:`REKLS`. + + Reimplemented from scratch (not inheriting from :class:`~emerging_optimizers.soap.soap.SOAP`) so the + tensor-parallel bookkeeping stays isolated. Eigenbases are not stored in optimizer state; they are + recomputed via :func:`~emerging_optimizers.soap.soap_utils.get_eigenbasis_eigh` from the kronecker + factors. Each step calls eigh twice — once on the pre-update L, R for the + :func:`~emerging_optimizers.soap.soap.update_kronecker_factors_kl_shampoo` correction, and once on + the post-update L, R for the gradient projection. + + State per parameter (one entry per rank): + - ``step`` + - ``exp_avg``, ``exp_avg_sq``: full-size tensors duplicated across ``tp_group`` ranks. ``exp_avg`` + is rotated through the basis change between steps (project back via the pre-update eigenbasis, + then forward via the post-update eigenbasis), matching SOAP's + :func:`~emerging_optimizers.soap.soap.update_eigenbasis_and_exp_avgs`. ``exp_avg_sq`` is not + rotated, matching SOAP's eigh path. + - ``L``, ``R``: kronecker factor matrices, sharded along dimension 0 across ``tp_group``. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. + betas: Inner Adam betas ``(b1, b2)``. + shampoo_beta: Beta for the kronecker factor moving average. + eps: Inner Adam epsilon. + weight_decay: Weight decay coefficient. + weight_decay_method: See :class:`~emerging_optimizers.mixin.WeightDecayMixin`. + tp_group: Process group across which parameters and gradients are sharded. + fp32_matmul_prec: Precision for the optimizer-state GEMM operations. + + Note: + Sharding is configured per-parameter-group via ``partition_dim`` (an int in ``{0, 1}``, + or ``None`` for replicated parameters). Mixed-layout models should use one group per + distinct ``partition_dim``:: + + optimizer = TpRekls([ + {"params": column_parallel_params, "partition_dim": 0}, + {"params": row_parallel_params, "partition_dim": 1}, + {"params": replicated_params, "partition_dim": None}, + ], lr=1e-3, tp_group=tp_group) + + Groups without ``partition_dim`` use the default (``None`` → replicated, plain non-TP REKLS + step on each rank, no collectives, full-size ``L``/``R``). + """ + + def __init__( + self, + params: ParamsT, + lr: float, + betas: tuple[float, float] = (0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + *, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + tp_group: dist.ProcessGroup, + fp32_matmul_prec: FP32MatmulPrecT = "high", + ) -> None: + self.tp_group = tp_group + self.tp_size = get_pg_size(tp_group) + self.tp_rank = get_pg_rank(tp_group) + + self.weight_decay_method = weight_decay_method + self.fp32_matmul_prec = fp32_matmul_prec + + defaults = { + "lr": lr, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + "partition_dim": None, + } + super().__init__(params, defaults) + + @staticmethod + def _validate_partition_dim(partition_dim: int | None) -> int | None: + if partition_dim is not None and partition_dim not in (0, 1): + raise ValueError(f"partition_dim must be 0, 1, or None, got {partition_dim}") + return partition_dim + + @torch.no_grad() # type: ignore[misc] + def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: + partition_dim = self._validate_partition_dim(group["partition_dim"]) + for p in group["params"]: + if skip_non_grad_params and p.grad is None: + continue + if p.dim() != 2: + raise TypeError("TpRekls is only supported for 2D tensors") + state = self.state[p] + if len(state) == 0: + m, n = p.shape + # Get full size of m, n if the parameter is tensor-parallel. + if partition_dim == 0: + m *= self.tp_size + elif partition_dim == 1: + n *= self.tp_size + + # Both dimensions must be divisible by tp_size for the L/R shards (each sharded + # along dim 0) to gather back to the full square shape via torch.cat. + if partition_dim is not None and (m % self.tp_size or n % self.tp_size): + raise ValueError( + f"TpRekls requires both dimensions to be divisible by tp_size={self.tp_size}; " + f"got full shape ({m}, {n}) for a parameter with partition_dim={partition_dim}." + ) + + state["step"] = 0 + state["exp_avg"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) + state["exp_avg_sq"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) + # Match init_kronecker_factors in soap.py: default dtype (typically float32). + # L, R are sharded along dim 0 only when the param is tensor-parallel. + shard = self.tp_size if partition_dim is not None else 1 + state["L"] = torch.zeros((m // shard, m), device=p.device) + state["R"] = torch.zeros((n // shard, n), device=p.device) + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: None = None) -> None: + assert closure is None, "No support for closure" + for group in self.param_groups: + self._init_group(group) + + for group in self.param_groups: + partition_dim = self._validate_partition_dim(group["partition_dim"]) + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + + local_grad = p.grad.to(torch.float32) + state = self.state[p] + curr_iter_1_based = state["step"] + 1 + + # Apply weight decay before the gather so l2 mode propagates into full_grad. + self._apply_weight_decay_inplace(p, local_grad, group["lr"], group["weight_decay"]) + + if partition_dim is None: + # Replicated parameter: no all-gather, state is already full-size. + full_grad = local_grad + kronecker_factor_list = [state["L"], state["R"]] + else: + full_grad, kronecker_factor_list = tp_utils.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[state["L"], state["R"]], + grad=local_grad, + partition_dim=partition_dim, + tp_group=self.tp_group, + ) + + # Apply shampoo beta bias correction. + shampoo_beta = group["shampoo_beta"] + shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based) + + # KL-Shampoo correction needs the eigenbasis of the *pre-update* L, R; recompute it + # via eigh since we do not persist eigenbases across steps. + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + pre_eigenbasis_list = soap_utils.get_eigenbasis_eigh(kronecker_factor_list) + soap.update_kronecker_factors_kl_shampoo( + kronecker_factor_list, + full_grad, + shampoo_beta=shampoo_beta, + eigenbasis_list=pre_eigenbasis_list, + eps=group["eps"], + ) + + # Persist the updated local shard back into state — only needed for the TP path, + # since the replicated path updated state["L"], state["R"] in place via the alias. + if partition_dim is not None: + state["L"].copy_(kronecker_factor_list[0].chunk(self.tp_size, dim=0)[self.tp_rank]) + state["R"].copy_(kronecker_factor_list[1].chunk(self.tp_size, dim=0)[self.tp_rank]) + + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis, + # and recompute the post-update eigenbasis via eigh. + eigenbasis_list, state["exp_avg"], state["exp_avg_sq"] = soap.update_eigenbasis_and_exp_avgs( + kronecker_factor_list=kronecker_factor_list, + eigenbasis_list=pre_eigenbasis_list, + exp_avg_sq=state["exp_avg_sq"], + exp_avg=state["exp_avg"], + use_eigh=True, + ) + + full_grad_projected = soap.precondition(full_grad, eigenbasis_list, dims=[[0], [0]]) + + # No matmul inside adam update. Put it under fp32_matmul_precision for code simplicity. + full_adam_update = update_functions.calculate_laprop_update( + full_grad_projected, + state["exp_avg"], + state["exp_avg_sq"], + True, # correct_bias + group["betas"], + curr_iter_1_based, + group["eps"], + ) + + full_precond_update = soap.precondition(full_adam_update, eigenbasis_list, dims=[[0], [1]]) + + if partition_dim is None: + p.add_(full_precond_update, alpha=-group["lr"]) + else: + local_precond_update = full_precond_update.chunk(self.tp_size, dim=partition_dim)[self.tp_rank] + p.add_(local_precond_update, alpha=-group["lr"]) + + state["step"] += 1 + + return None diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 015a7358..6ffa47a8 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -380,7 +380,24 @@ def update_kronecker_factors_kl_shampoo( ) -> None: """Updates the kronecker factor matrices in place using KL-Shampoo correction. - Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378 + Implements the Kullback–Leibler minimization update from https://arxiv.org/pdf/2509.03378. + + For a gradient :math:`G \\in \\mathbb{R}^{m \\times n}`, current kronecker factors + :math:`L_t \\in \\mathbb{R}^{m \\times m}`, :math:`R_t \\in \\mathbb{R}^{n \\times n}`, and their + orthonormal eigenbases :math:`Q_L, Q_R`, the approximate eigenvalues in the current eigenbasis are + + .. math:: + \\Lambda_L = \\mathrm{diag}(Q_L^{\\top} L_t Q_L), \\quad + \\Lambda_R = \\mathrm{diag}(Q_R^{\\top} R_t Q_R) + + and the EMA update with momentum :math:`\\beta` (= ``shampoo_beta``) and exponent :math:`p` + (= ``eigval_exp``, default ``-1``) is + + .. math:: + L_{t+1} = \\beta\\, L_t + \\frac{1-\\beta}{n}\\, G\\, Q_R\\, \\mathrm{diag}(\\Lambda_R^{p})\\, Q_R^{\\top} G^{\\top} \\\\ + R_{t+1} = \\beta\\, R_t + \\frac{1-\\beta}{m}\\, G^{\\top}\\, Q_L\\, \\mathrm{diag}(\\Lambda_L^{p})\\, Q_L^{\\top} G + + Eigenvalues are clamped to ``eps`` from below before exponentiation for numerical stability. Args: kronecker_factor_list: List of preconditioner matrices (L and R) to update. @@ -388,7 +405,7 @@ def update_kronecker_factors_kl_shampoo( shampoo_beta: Momentum coefficient for updating preconditioners. eigenbasis_list: List of orthonormal eigenbases of the kronecker factor matrices eps: Small offset for numerical stability. - eigenval_exp: Exponent of the eigenvalues. + eigval_exp: Exponent applied to the (clamped) eigenvalues. Defaults to ``-1.0``. """ if grad.dim() != 2: raise TypeError("KL-Shampoo mathematical correction is only supported for 2D tensors") diff --git a/emerging_optimizers/soap/tp_utils.py b/emerging_optimizers/soap/tp_utils.py new file mode 100644 index 00000000..32840ce3 --- /dev/null +++ b/emerging_optimizers/soap/tp_utils.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import distributed as dist + +from emerging_optimizers.utils import get_pg_size + + +@torch.no_grad() # type: ignore[misc] +def all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list: list[torch.Tensor], + grad: torch.Tensor, + partition_dim: int, + tp_group: dist.ProcessGroup, +) -> tuple[torch.Tensor, list[torch.Tensor]]: + """All-gathers a sharded gradient and its kronecker factors across the tensor parallel group. + + This is a simple implementation to support tensor parallel. It assumes grad is sharded among tensor parallel domain + with partition_dim indicating the dimension it was sharded. To save memory, kronecker factors are also sharded + but always along dimension 0 to make gather operation easy. + + Gradient and kronecker factors are both all-gathered to all tensor parallel ranks and returned so the caller + can pass them to ``update_kronecker_factors`` (and downstream eigenbasis computations) without further + communication. + + Args: + kronecker_factor_list: List of preconditioner matrices (L and R), each sharded along dimension 0 across + the tensor parallel group. + grad: Local shard of the gradient tensor, sharded along ``partition_dim`` across the tensor parallel group. + partition_dim: Dimension along which ``grad`` is sharded across the tensor parallel group. + tp_group: Tensor parallel process group used to all-gather ``grad`` and ``kronecker_factor_list``. + + Returns: + full_grad: Full (un-sharded) gradient tensor on every rank. + full_kronecker_factor_list: List of full (un-sharded) kronecker factor matrices ``[L, R]`` on every rank. + """ + tp_size = get_pg_size(tp_group) + + grad_shards = [torch.empty_like(grad) for _ in range(tp_size)] + dist.all_gather(grad_shards, grad, group=tp_group) + full_grad = torch.cat(grad_shards, dim=partition_dim) + + full_kronecker_factor_list: list[torch.Tensor] = [] + for kronecker_factor in kronecker_factor_list: + factor_shards = [torch.empty_like(kronecker_factor) for _ in range(tp_size)] + dist.all_gather(factor_shards, kronecker_factor, group=tp_group) + full_kronecker_factor_list.append(torch.cat(factor_shards, dim=0)) + + return full_grad, full_kronecker_factor_list diff --git a/emerging_optimizers/utils/__init__.py b/emerging_optimizers/utils/__init__.py index aa5528af..68d97c56 100644 --- a/emerging_optimizers/utils/__init__.py +++ b/emerging_optimizers/utils/__init__.py @@ -21,7 +21,7 @@ from .sinkhorn_mapper import * -__all__ = ["fp32_matmul_precision", "FP32MatmulPrecT", "SinkhornMapper"] +__all__ = ["fp32_matmul_precision", "FP32MatmulPrecT", "SinkhornMapper", "get_pg_size", "get_pg_rank"] FP32MatmulPrecT = Literal["highest", "high", "medium"] @@ -39,3 +39,17 @@ def fp32_matmul_precision(precision: FP32MatmulPrecT = "highest") -> Generator[N yield finally: torch.set_float32_matmul_precision(prev_val) + + +def get_pg_size(group: torch.distributed.ProcessGroup | None = None) -> int: + """Get world size for a distributed group with fallback""" + if not torch.distributed.is_initialized() or group is None: + return 1 + return group.size() + + +def get_pg_rank(group: torch.distributed.ProcessGroup | None = None) -> int: + """Get rank for a distributed group with fallback""" + if not torch.distributed.is_initialized() or group is None: + return 0 + return group.rank() diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index 30c09b2b..44042f3c 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -17,12 +17,21 @@ mkdir -p test-results/tests/ error=0 for n in 8 4; do - torchrun --nproc_per_node=$n --no-python coverage run -p \ - tests/test_distributed_muon_utils_cpu.py \ - --xml_output_file="test-results/tests/test_distributed_muon_utils_cpu_n${n}.xml" \ - -v -2 || error=1 + for test in tests/test_distributed_*_cpu.py; do + report_base=$(basename "$test" .py) + torchrun --nproc_per_node=$n --no-python coverage run -p \ + "$test" \ + --xml_output_file="test-results/tests/${report_base}_n${n}.xml" \ + -v -2 || error=1 + done done +# Single-process sanity check for TpRekls — exercises the size-1 tp_group path. +torchrun --nproc_per_node=1 --no-python coverage run -p \ + tests/test_distributed_rekls_cpu.py \ + --xml_output_file="test-results/tests/test_distributed_rekls_cpu_n1.xml" \ + -v -2 || error=1 + for test in "tests/test_scalar_optimizers.py" "tests/test_procrustes_step.py"; do report_name="test-results/${test}.xml" coverage run -p --source=emerging_optimizers $test --device=cpu -v -2 --xml_output_file="$report_name" || error=1 diff --git a/tests/test_distributed_rekls_cpu.py b/tests/test_distributed_rekls_cpu.py new file mode 100644 index 00000000..096a3dd5 --- /dev/null +++ b/tests/test_distributed_rekls_cpu.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized + +from emerging_optimizers.soap.rekls import REKLS, TpRekls +from emerging_optimizers.utils import get_pg_rank, get_pg_size + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +def tearDownModule() -> None: + torch.distributed.destroy_process_group() + + +class TpReklsCpuTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.tp_group = torch.distributed.group.WORLD + self.world_size = get_pg_size(self.tp_group) + self.rank = get_pg_rank(self.tp_group) + + def test_5steps_matches_non_distributed_rekls(self): + """Multiple TpRekls steps over multiple params (mixed partition_dim) must produce + bit-identical updates to non-distributed REKLS for every param at every step. + """ + # ``partition_dim=None`` exercises the replicated fallback path: param has no partition_dim + # attribute, no all-gather, full-size L/R, full update applied directly. + params_config = [ + {"shape": (16, 32), "partition_dim": 0}, + {"shape": (32, 16), "partition_dim": 1}, + {"shape": (96, 200), "partition_dim": 0}, + {"shape": (96, 200), "partition_dim": 1}, + {"shape": (24, 40), "partition_dim": None}, + ] + for cfg in params_config: + if cfg["partition_dim"] is None: + continue + m, n = cfg["shape"] + assert m % self.world_size == 0 and n % self.world_size == 0, ( + f"shape {cfg['shape']} must be divisible by world size {self.world_size}" + ) + + # Initial param data — all-reduce so every rank starts from the same tensors. + full_params_data = [] + for cfg in params_config: + d = torch.randn(cfg["shape"]) + torch.distributed.all_reduce(d, group=self.tp_group) + full_params_data.append(d) + + ref_params = [torch.nn.Parameter(d.clone()) for d in full_params_data] + ref_optimizer = REKLS(ref_params, lr=1e-3) + + tp_params = [] + for cfg, d in zip(params_config, full_params_data): + pd = cfg["partition_dim"] + if pd is None: + local_data = d.clone() + else: + local_data = d.chunk(self.world_size, dim=pd)[self.rank].contiguous() + tp_params.append(torch.nn.Parameter(local_data)) + + # One param group per distinct partition_dim — TpRekls reads partition_dim from group. + tp_param_groups: list[dict] = [] + for pd in (0, 1, None): + members = [tp_p for tp_p, cfg in zip(tp_params, params_config) if cfg["partition_dim"] == pd] + if members: + tp_param_groups.append({"params": members, "partition_dim": pd}) + tp_optimizer = TpRekls(tp_param_groups, lr=1e-3, tp_group=self.tp_group) + + num_steps = 5 + for _ in range(num_steps): + full_grads = [] + for cfg in params_config: + g = torch.randn(cfg["shape"]) + torch.distributed.all_reduce(g, group=self.tp_group) + full_grads.append(g) + + for ref_p, full_g in zip(ref_params, full_grads): + ref_p.grad = full_g.clone() + for tp_p, cfg, full_g in zip(tp_params, params_config, full_grads): + pd = cfg["partition_dim"] + if pd is None: + tp_p.grad = full_g.clone() + else: + tp_p.grad = full_g.chunk(self.world_size, dim=pd)[self.rank].contiguous() + + ref_optimizer.step() + tp_optimizer.step() + + for ref_p, tp_p, cfg in zip(ref_params, tp_params, params_config): + pd = cfg["partition_dim"] + if pd is None: + ref_local = ref_p.detach() + else: + ref_local = ref_p.detach().chunk(self.world_size, dim=pd)[self.rank] + torch.testing.assert_close( + tp_p.detach(), + ref_local, + atol=0, + rtol=0, + ) + + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="gloo") + torch.set_float32_matmul_precision("highest") + + rank = get_pg_rank(torch.distributed.group.WORLD) + + for i, arg in enumerate(sys.argv): + if arg.startswith("--xml_output_file="): + base, ext = os.path.splitext(arg) + sys.argv[i] = f"{base}_rank{rank}{ext}" + break + + absltest.main() diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py new file mode 100644 index 00000000..229a88de --- /dev/null +++ b/tests/test_distributed_soap_utils_cpu.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized + +from emerging_optimizers.soap import soap, tp_utils +from emerging_optimizers.utils import get_pg_rank, get_pg_size + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +def tearDownModule() -> None: + torch.distributed.destroy_process_group() + + +class AllGatherGradAndKroneckerFactorsTpCpuTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.tp_group = torch.distributed.group.WORLD + self.world_size = get_pg_size(self.tp_group) + self.rank = get_pg_rank(self.tp_group) + + @parameterized.product( + shape=((16, 32), (32, 16), (96, 200)), + partition_dim=(0, 1), + ) + def test_matches_non_distributed(self, shape, partition_dim): + m, n = shape + # Grad is sharded along partition_dim; both kronecker factors are sharded along dim 0, + # so m (rows of L) and n (rows of R) must each be divisible by the group size. + assert m % self.world_size == 0 and n % self.world_size == 0, "shape must be divisible by world size" + full_grad = torch.randint(-5, 5, shape) + full_l = torch.randint(-5, 5, (m, m)) + full_r = torch.randint(-5, 5, (n, n)) + # All-reduce ensures that every rank starts from the same tensors. + torch.distributed.all_reduce(full_grad, group=self.tp_group) + torch.distributed.all_reduce(full_l, group=self.tp_group) + torch.distributed.all_reduce(full_r, group=self.tp_group) + + local_grad = full_grad.chunk(self.world_size, dim=partition_dim)[self.rank].contiguous() + local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() + local_r = full_r.chunk(self.world_size, dim=0)[self.rank].contiguous() + + gathered_grad, gathered_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[local_l, local_r], + grad=local_grad, + partition_dim=partition_dim, + tp_group=self.tp_group, + ) + + torch.testing.assert_close(gathered_grad, full_grad, atol=0, rtol=0) + torch.testing.assert_close(gathered_factors[0], full_l, atol=0, rtol=0) + torch.testing.assert_close(gathered_factors[1], full_r, atol=0, rtol=0) + + @parameterized.product( + shape=((16, 32), (32, 16), (96, 200)), + partition_dim=(0, 1), + ) + def test_updated_factors_match_non_distributed(self, shape, partition_dim): + m, n = shape + assert m % self.world_size == 0 and n % self.world_size == 0, "shape must be divisible by world size" + # lerp_ requires a floating-point dtype, but values stay integer-valued so both paths + # produce bit-identical results. + full_grad = torch.randint(-5, 5, shape, dtype=torch.float32) + full_l = torch.randint(-5, 5, (m, m), dtype=torch.float32) + full_r = torch.randint(-5, 5, (n, n), dtype=torch.float32) + torch.distributed.all_reduce(full_grad, group=self.tp_group) + torch.distributed.all_reduce(full_l, group=self.tp_group) + torch.distributed.all_reduce(full_r, group=self.tp_group) + + shampoo_beta = 0.95 + + ref_l = full_l.clone() + ref_r = full_r.clone() + soap.update_kronecker_factors([ref_l, ref_r], full_grad, shampoo_beta=shampoo_beta) + + local_grad = full_grad.chunk(self.world_size, dim=partition_dim)[self.rank].contiguous() + local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() + local_r = full_r.chunk(self.world_size, dim=0)[self.rank].contiguous() + + gathered_grad, gathered_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[local_l, local_r], + grad=local_grad, + partition_dim=partition_dim, + tp_group=self.tp_group, + ) + soap.update_kronecker_factors(gathered_factors, gathered_grad, shampoo_beta=shampoo_beta) + + torch.testing.assert_close(gathered_factors[0], ref_l, atol=0, rtol=0) + torch.testing.assert_close(gathered_factors[1], ref_r, atol=0, rtol=0) + + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="gloo") + torch.set_float32_matmul_precision("highest") + + rank = get_pg_rank(torch.distributed.group.WORLD) + + for i, arg in enumerate(sys.argv): + if arg.startswith("--xml_output_file="): + base, ext = os.path.splitext(arg) + + # Attach rank to the output file name + sys.argv[i] = f"{base}_rank{rank}{ext}" + break + + absltest.main()