From 428839701bbaf00fa689c249a1fda0cca18260d4 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 5 May 2026 19:53:57 -0700 Subject: [PATCH 01/16] add grad and kronecker factors gather function Signed-off-by: Hao Wu --- emerging_optimizers/soap/tp_soap.py | 62 ++++++++++++++++ emerging_optimizers/utils/__init__.py | 16 ++++- tests/test_distributed_soap_utils_cpu.py | 90 ++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 emerging_optimizers/soap/tp_soap.py create mode 100644 tests/test_distributed_soap_utils_cpu.py diff --git a/emerging_optimizers/soap/tp_soap.py b/emerging_optimizers/soap/tp_soap.py new file mode 100644 index 00000000..c168bff1 --- /dev/null +++ b/emerging_optimizers/soap/tp_soap.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import distributed as dist + +from emerging_optimizers.utils import get_pg_size + + +@torch.no_grad() # type: ignore[misc] +def all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list: list[torch.Tensor], + grad: torch.Tensor, + partition_dim: int, + tp_group: dist.ProcessGroup, +) -> tuple[torch.Tensor, list[torch.Tensor]]: + """All-gathers a sharded gradient and its kronecker factors across the tensor parallel group. + + This is a simple implementation to support tensor parallel. It assumes grad is sharded among tensor parallel domain + with partition_dim indicating the dimension it was sharded. To save memory, kronecker factors are also sharded + but always long dimension 0 to make gather operation easy. + + Gradient and kronecker factors are both all-gathered to all tensor parallel ranks and returned so the caller + can pass them to ``update_kronecker_factors`` (and downstream eigenbasis computations) without further + communication. + + Args: + kronecker_factor_list: List of preconditioner matrices (L and R), each sharded along dimension 0 across + the tensor parallel group. + grad: Local shard of the gradient tensor, sharded along ``partition_dim`` across the tensor parallel group. + partition_dim: Dimension along which ``grad`` is sharded across the tensor parallel group. + tp_group: Tensor parallel process group used to all-gather ``grad`` and ``kronecker_factor_list``. + + Returns: + full_grad: Full (un-sharded) gradient tensor on every rank. + full_kronecker_factor_list: List of full (un-sharded) kronecker factor matrices ``[L, R]`` on every rank. + """ + tp_size = get_pg_size(tp_group) + + grad_shards = [torch.empty_like(grad) for _ in range(tp_size)] + dist.all_gather(grad_shards, grad, group=tp_group) + full_grad = torch.cat(grad_shards, dim=partition_dim) + + full_kronecker_factor_list: list[torch.Tensor] = [] + for kronecker_factor in kronecker_factor_list: + factor_shards = [torch.empty_like(kronecker_factor) for _ in range(tp_size)] + dist.all_gather(factor_shards, kronecker_factor, group=tp_group) + full_kronecker_factor_list.append(torch.cat(factor_shards, dim=0)) + + return full_grad, full_kronecker_factor_list diff --git a/emerging_optimizers/utils/__init__.py b/emerging_optimizers/utils/__init__.py index aa5528af..68d97c56 100644 --- a/emerging_optimizers/utils/__init__.py +++ b/emerging_optimizers/utils/__init__.py @@ -21,7 +21,7 @@ from .sinkhorn_mapper import * -__all__ = ["fp32_matmul_precision", "FP32MatmulPrecT", "SinkhornMapper"] +__all__ = ["fp32_matmul_precision", "FP32MatmulPrecT", "SinkhornMapper", "get_pg_size", "get_pg_rank"] FP32MatmulPrecT = Literal["highest", "high", "medium"] @@ -39,3 +39,17 @@ def fp32_matmul_precision(precision: FP32MatmulPrecT = "highest") -> Generator[N yield finally: torch.set_float32_matmul_precision(prev_val) + + +def get_pg_size(group: torch.distributed.ProcessGroup | None = None) -> int: + """Get world size for a distributed group with fallback""" + if not torch.distributed.is_initialized() or group is None: + return 1 + return group.size() + + +def get_pg_rank(group: torch.distributed.ProcessGroup | None = None) -> int: + """Get rank for a distributed group with fallback""" + if not torch.distributed.is_initialized() or group is None: + return 0 + return group.rank() diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py new file mode 100644 index 00000000..7681adf2 --- /dev/null +++ b/tests/test_distributed_soap_utils_cpu.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized + +from emerging_optimizers.soap import tp_soap + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +class AllGatherGradAndKroneckerFactorsTpCpuTest(parameterized.TestCase): + @parameterized.product( + shape=((16, 32), (32, 16), (96, 200)), + partition_dim=(0, 1), + ) + def test_matches_non_distributed(self, shape, partition_dim): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + m, n = shape + # Grad is sharded along partition_dim; both kronecker factors are sharded along dim 0, + # so m (rows of L) and n (rows of R) must each be divisible by the group size. + assert m % world_size == 0 and n % world_size == 0, "shape must be divisible by world size" + full_grad = torch.randint(-5, 5, shape) + full_l = torch.randint(-5, 5, (m, m)) + full_r = torch.randint(-5, 5, (n, n)) + # All-reduce ensures that every rank starts from the same tensors. + torch.distributed.all_reduce(full_grad) + torch.distributed.all_reduce(full_l) + torch.distributed.all_reduce(full_r) + + local_grad = full_grad.chunk(world_size, dim=partition_dim)[rank].contiguous() + local_l = full_l.chunk(world_size, dim=0)[rank].contiguous() + local_r = full_r.chunk(world_size, dim=0)[rank].contiguous() + + gathered_grad, gathered_factors = tp_soap.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[local_l, local_r], + grad=local_grad, + partition_dim=partition_dim, + tp_group=torch.distributed.group.WORLD, + ) + + torch.testing.assert_close(gathered_grad, full_grad, atol=0, rtol=0) + torch.testing.assert_close(gathered_factors[0], full_l, atol=0, rtol=0) + torch.testing.assert_close(gathered_factors[1], full_r, atol=0, rtol=0) + + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="gloo") + torch.set_float32_matmul_precision("highest") + + rank = torch.distributed.get_rank() + + for i, arg in enumerate(sys.argv): + if arg.startswith("--xml_output_file="): + base, ext = os.path.splitext(arg) + + # Attach rank to the output file name + sys.argv[i] = f"{base}_rank{rank}{ext}" + break + + absltest.main() + + torch.distributed.destroy_process_group() From 66ff8d813170aa40e34c6f7f1a9986fe05b0113b Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 5 May 2026 20:16:10 -0700 Subject: [PATCH 02/16] update tests Signed-off-by: Hao Wu --- tests/test_distributed_soap_utils_cpu.py | 58 ++++++++++++++++++++---- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py index 7681adf2..ac0e5dd3 100644 --- a/tests/test_distributed_soap_utils_cpu.py +++ b/tests/test_distributed_soap_utils_cpu.py @@ -19,7 +19,8 @@ from absl import flags, logging from absl.testing import absltest, parameterized -from emerging_optimizers.soap import tp_soap +from emerging_optimizers.soap import soap, tp_soap +from emerging_optimizers.utils import get_pg_rank, get_pg_size flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") @@ -36,17 +37,21 @@ def setUpModule() -> None: class AllGatherGradAndKroneckerFactorsTpCpuTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.tp_group = torch.distributed.group.WORLD + self.world_size = get_pg_size(self.tp_group) + self.rank = get_pg_rank(self.tp_group) + @parameterized.product( shape=((16, 32), (32, 16), (96, 200)), partition_dim=(0, 1), ) def test_matches_non_distributed(self, shape, partition_dim): - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() m, n = shape # Grad is sharded along partition_dim; both kronecker factors are sharded along dim 0, # so m (rows of L) and n (rows of R) must each be divisible by the group size. - assert m % world_size == 0 and n % world_size == 0, "shape must be divisible by world size" + assert m % self.world_size == 0 and n % self.world_size == 0, "shape must be divisible by world size" full_grad = torch.randint(-5, 5, shape) full_l = torch.randint(-5, 5, (m, m)) full_r = torch.randint(-5, 5, (n, n)) @@ -55,21 +60,58 @@ def test_matches_non_distributed(self, shape, partition_dim): torch.distributed.all_reduce(full_l) torch.distributed.all_reduce(full_r) - local_grad = full_grad.chunk(world_size, dim=partition_dim)[rank].contiguous() - local_l = full_l.chunk(world_size, dim=0)[rank].contiguous() - local_r = full_r.chunk(world_size, dim=0)[rank].contiguous() + local_grad = full_grad.chunk(self.world_size, dim=partition_dim)[self.rank].contiguous() + local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() + local_r = full_r.chunk(self.world_size, dim=0)[self.rank].contiguous() gathered_grad, gathered_factors = tp_soap.all_gather_grad_and_kronecker_factors_tp( kronecker_factor_list=[local_l, local_r], grad=local_grad, partition_dim=partition_dim, - tp_group=torch.distributed.group.WORLD, + tp_group=self.tp_group, ) torch.testing.assert_close(gathered_grad, full_grad, atol=0, rtol=0) torch.testing.assert_close(gathered_factors[0], full_l, atol=0, rtol=0) torch.testing.assert_close(gathered_factors[1], full_r, atol=0, rtol=0) + @parameterized.product( + shape=((16, 32), (32, 16), (96, 200)), + partition_dim=(0, 1), + ) + def test_updated_factors_match_non_distributed(self, shape, partition_dim): + m, n = shape + assert m % self.world_size == 0 and n % self.world_size == 0, "shape must be divisible by world size" + # lerp_ requires a floating-point dtype, but values stay integer-valued so both paths + # produce bit-identical results. + full_grad = torch.randint(-5, 5, shape, dtype=torch.float32) + full_l = torch.randint(-5, 5, (m, m), dtype=torch.float32) + full_r = torch.randint(-5, 5, (n, n), dtype=torch.float32) + torch.distributed.all_reduce(full_grad) + torch.distributed.all_reduce(full_l) + torch.distributed.all_reduce(full_r) + + shampoo_beta = 0.95 + + ref_l = full_l.clone() + ref_r = full_r.clone() + soap.update_kronecker_factors([ref_l, ref_r], full_grad, shampoo_beta=shampoo_beta) + + local_grad = full_grad.chunk(self.world_size, dim=partition_dim)[self.rank].contiguous() + local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() + local_r = full_r.chunk(self.world_size, dim=0)[self.rank].contiguous() + + gathered_grad, gathered_factors = tp_soap.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[local_l, local_r], + grad=local_grad, + partition_dim=partition_dim, + tp_group=self.tp_group, + ) + soap.update_kronecker_factors(gathered_factors, gathered_grad, shampoo_beta=shampoo_beta) + + torch.testing.assert_close(gathered_factors[0], ref_l, atol=0, rtol=0) + torch.testing.assert_close(gathered_factors[1], ref_r, atol=0, rtol=0) + if __name__ == "__main__": torch.distributed.init_process_group(backend="gloo") From cfd2cc59fcfe91af4e57ecbca3743983862e2ffc Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 10:45:24 -0700 Subject: [PATCH 03/16] add tp rekls Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 210 +++++++++++++++++- .../soap/{tp_soap.py => tp_utils.py} | 0 tests/test_distributed_soap_utils_cpu.py | 6 +- 3 files changed, 211 insertions(+), 5 deletions(-) rename emerging_optimizers/soap/{tp_soap.py => tp_utils.py} (100%) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 55e0b433..1eff6ba3 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -13,14 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Callable, override + +import torch +from torch import distributed as dist +from torch import optim from torch.optim.optimizer import ParamsT + +if TYPE_CHECKING: + from typing import overload + from emerging_optimizers import mixin as opt_mixin -from emerging_optimizers import registry +from emerging_optimizers import registry, utils +from emerging_optimizers.scalar_optimizers import update_functions +from emerging_optimizers.soap import soap, soap_utils, tp_utils from emerging_optimizers.soap.soap import SOAP +from emerging_optimizers.utils import FP32MatmulPrecT, get_pg_rank, get_pg_size -__all__ = ["REKLS"] +__all__ = ["REKLS", "TpRekls"] @registry.register_optimizer("rekls") @@ -56,3 +68,197 @@ def __init__( use_eigh=True, use_kl_shampoo=True, ) + + +@registry.register_optimizer("tp_rekls") +class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): + """Tensor-parallel variant of :class:`REKLS`. + + Reimplemented from scratch (not inheriting from :class:`~emerging_optimizers.soap.soap.SOAP`) so the + tensor-parallel bookkeeping stays isolated. Eigenbases are not stored in optimizer state; they are + recomputed via :func:`~emerging_optimizers.soap.soap_utils.get_eigenbasis_eigh` from the kronecker + factors. Each step calls eigh twice — once on the pre-update L, R for the + :func:`~emerging_optimizers.soap.soap.update_kronecker_factors_kl_shampoo` correction, and once on + the post-update L, R for the gradient projection. + + State per parameter (one entry per rank): + - ``step`` + - ``exp_avg``, ``exp_avg_sq``: full-size tensors duplicated across ``tp_group`` ranks. The + eigenbasis-rotation step that SOAP performs in + :func:`~emerging_optimizers.soap.soap.update_eigenbasis_and_exp_avgs` is skipped because the + previous eigenbasis is not retained — these moments are interpreted as living in the most + recent eigenbasis. This is a deliberate simplification consistent with REKLS' premise that + the eigenbasis evolves slowly between steps. + - ``L``, ``R``: kronecker factor matrices, sharded along dimension 0 across ``tp_group``. + + Each step issues exactly one collective: an all-gather of the local gradient and ``L``/``R`` shards + via :func:`~emerging_optimizers.soap.tp_utils.all_gather_grad_and_kronecker_factors_tp`. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. + betas: Inner Adam betas ``(b1, b2)``. + shampoo_beta: Beta for the kronecker factor moving average. + eps: Inner Adam epsilon. + weight_decay: Weight decay coefficient. + weight_decay_method: See :class:`~emerging_optimizers.mixin.WeightDecayMixin`. + tp_group: Process group across which parameters and gradients are sharded. + fp32_matmul_prec: Precision for the optimizer-state GEMM operations. + + Note: + Each parameter must carry a ``partition_dim`` attribute (an int in ``{0, 1}``) describing the + dimension along which it is sharded across ``tp_group``. This matches the megatron-lm + convention. Parameters without a ``partition_dim`` attribute will raise :class:`TypeError`. + """ + + def __init__( + self, + params: ParamsT, + lr: float, + betas: tuple[float, float] = (0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + *, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + tp_group: dist.ProcessGroup, + fp32_matmul_prec: FP32MatmulPrecT = "high", + ) -> None: + self.tp_group = tp_group + self.tp_size = get_pg_size(tp_group) + self.tp_rank = get_pg_rank(tp_group) + + self.weight_decay_method = weight_decay_method + self.fp32_matmul_prec = fp32_matmul_prec + + defaults = { + "lr": lr, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + @staticmethod + def _get_partition_dim(p: torch.Tensor) -> int: + partition_dim = getattr(p, "partition_dim", None) + if partition_dim is None: + raise TypeError( + f"TpRekls requires each parameter to carry a 'partition_dim' attribute " + f"(megatron-lm convention); got a parameter of shape {tuple(p.shape)} without one." + ) + if partition_dim not in (0, 1): + raise ValueError(f"partition_dim must be 0 or 1, got {partition_dim}") + return partition_dim + + def _full_shape(self, p: torch.Tensor) -> tuple[int, int]: + m, n = p.shape + partition_dim = self._get_partition_dim(p) + if partition_dim == 0: + m *= self.tp_size + else: + n *= self.tp_size + return m, n + + @torch.no_grad() # type: ignore[misc] + def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: + for p in group["params"]: + if skip_non_grad_params and p.grad is None: + continue + if p.dim() != 2: + raise TypeError("TpRekls is only supported for 2D tensors") + state = self.state[p] + if len(state) == 0: + m, n = self._full_shape(p) + state["step"] = 0 + state["exp_avg"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) + state["exp_avg_sq"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) + # Match init_kronecker_factors in soap.py: default dtype (typically float32). + state["L"] = torch.zeros((m // self.tp_size, m), device=p.device) + state["R"] = torch.zeros((n // self.tp_size, n), device=p.device) + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + if closure is None: + loss = None + else: + loss = closure() + for group in self.param_groups: + self._init_group(group) + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + + local_grad = p.grad.to(torch.float32) + state = self.state[p] + partition_dim = self._get_partition_dim(p) + curr_iter_1_based = state["step"] + 1 + + # Apply weight decay before the gather so l2 mode propagates into full_grad. + self._apply_weight_decay_inplace(p, local_grad, group["lr"], group["weight_decay"]) + + full_grad, full_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[state["L"], state["R"]], + grad=local_grad, + partition_dim=partition_dim, + tp_group=self.tp_group, + ) + + # Apply shampoo beta bias correction. + shampoo_beta = group["shampoo_beta"] + shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based) + + # KL-Shampoo correction needs the eigenbasis of the *pre-update* L, R; recompute it + # via eigh since we do not persist eigenbases across steps. + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + pre_eigenbasis_list = soap_utils.get_eigenbasis_eigh(full_factors) + soap.update_kronecker_factors_kl_shampoo( + full_factors, + full_grad, + shampoo_beta=shampoo_beta, + eigenbasis_list=pre_eigenbasis_list, + eps=group["eps"], + ) + + # Persist the updated local shard back into state. + state["L"].copy_(full_factors[0].chunk(self.tp_size, dim=0)[self.tp_rank]) + state["R"].copy_(full_factors[1].chunk(self.tp_size, dim=0)[self.tp_rank]) + + # Eigenbasis from the updated L, R is what gets used for the gradient projection. + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + eigenbasis_list = soap_utils.get_eigenbasis_eigh(full_factors) + + full_grad_projected = soap.precondition(full_grad, eigenbasis_list, dims=[[0], [0]]) + + # No matmul inside adam update. Put it under fp32_matmul_precision for code simplicity. + full_adam_update = update_functions.calculate_adam_update( + full_grad_projected, + state["exp_avg"], + state["exp_avg_sq"], + group["betas"], + True, # correct_bias + False, # nesterov + curr_iter_1_based, + group["eps"], + ) + + full_precond_update = soap.precondition(full_adam_update, eigenbasis_list, dims=[[0], [1]]) + + local_precond_update = full_precond_update.chunk(self.tp_size, dim=partition_dim)[self.tp_rank] + p.add_(local_precond_update, alpha=-group["lr"]) + + state["step"] += 1 + + return loss diff --git a/emerging_optimizers/soap/tp_soap.py b/emerging_optimizers/soap/tp_utils.py similarity index 100% rename from emerging_optimizers/soap/tp_soap.py rename to emerging_optimizers/soap/tp_utils.py diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py index ac0e5dd3..00fc300e 100644 --- a/tests/test_distributed_soap_utils_cpu.py +++ b/tests/test_distributed_soap_utils_cpu.py @@ -19,7 +19,7 @@ from absl import flags, logging from absl.testing import absltest, parameterized -from emerging_optimizers.soap import soap, tp_soap +from emerging_optimizers.soap import soap, tp_utils from emerging_optimizers.utils import get_pg_rank, get_pg_size @@ -64,7 +64,7 @@ def test_matches_non_distributed(self, shape, partition_dim): local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() local_r = full_r.chunk(self.world_size, dim=0)[self.rank].contiguous() - gathered_grad, gathered_factors = tp_soap.all_gather_grad_and_kronecker_factors_tp( + gathered_grad, gathered_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( kronecker_factor_list=[local_l, local_r], grad=local_grad, partition_dim=partition_dim, @@ -101,7 +101,7 @@ def test_updated_factors_match_non_distributed(self, shape, partition_dim): local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() local_r = full_r.chunk(self.world_size, dim=0)[self.rank].contiguous() - gathered_grad, gathered_factors = tp_soap.all_gather_grad_and_kronecker_factors_tp( + gathered_grad, gathered_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( kronecker_factor_list=[local_l, local_r], grad=local_grad, partition_dim=partition_dim, From 72a632ed05449efb95c9ad50c72b12b8d02cd3af Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 11:04:02 -0700 Subject: [PATCH 04/16] add test Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 22 +++-- tests/test_distributed_rekls_cpu.py | 120 ++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 tests/test_distributed_rekls_cpu.py diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 1eff6ba3..0d24647c 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -83,12 +83,11 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): State per parameter (one entry per rank): - ``step`` - - ``exp_avg``, ``exp_avg_sq``: full-size tensors duplicated across ``tp_group`` ranks. The - eigenbasis-rotation step that SOAP performs in - :func:`~emerging_optimizers.soap.soap.update_eigenbasis_and_exp_avgs` is skipped because the - previous eigenbasis is not retained — these moments are interpreted as living in the most - recent eigenbasis. This is a deliberate simplification consistent with REKLS' premise that - the eigenbasis evolves slowly between steps. + - ``exp_avg``, ``exp_avg_sq``: full-size tensors duplicated across ``tp_group`` ranks. ``exp_avg`` + is rotated through the basis change between steps (project back via the pre-update eigenbasis, + then forward via the post-update eigenbasis), matching SOAP's + :func:`~emerging_optimizers.soap.soap.update_eigenbasis_and_exp_avgs`. ``exp_avg_sq`` is not + rotated, matching SOAP's eigh path. - ``L``, ``R``: kronecker factor matrices, sharded along dimension 0 across ``tp_group``. Each step issues exactly one collective: an all-gather of the local gradient and ``L``/``R`` shards @@ -236,9 +235,16 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: state["L"].copy_(full_factors[0].chunk(self.tp_size, dim=0)[self.tp_rank]) state["R"].copy_(full_factors[1].chunk(self.tp_size, dim=0)[self.tp_rank]) - # Eigenbasis from the updated L, R is what gets used for the gradient projection. with utils.fp32_matmul_precision(self.fp32_matmul_prec): - eigenbasis_list = soap_utils.get_eigenbasis_eigh(full_factors) + # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis, + # and recompute the post-update eigenbasis via eigh. + eigenbasis_list, state["exp_avg"], state["exp_avg_sq"] = soap.update_eigenbasis_and_exp_avgs( + kronecker_factor_list=full_factors, + eigenbasis_list=pre_eigenbasis_list, + exp_avg_sq=state["exp_avg_sq"], + exp_avg=state["exp_avg"], + use_eigh=True, + ) full_grad_projected = soap.precondition(full_grad, eigenbasis_list, dims=[[0], [0]]) diff --git a/tests/test_distributed_rekls_cpu.py b/tests/test_distributed_rekls_cpu.py new file mode 100644 index 00000000..bee55b1c --- /dev/null +++ b/tests/test_distributed_rekls_cpu.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import torch +from absl import flags, logging +from absl.testing import absltest, parameterized + +from emerging_optimizers.soap.rekls import REKLS, TpRekls +from emerging_optimizers.utils import get_pg_rank, get_pg_size + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +class TpReklsCpuTest(parameterized.TestCase): + def setUp(self): + super().setUp() + self.tp_group = torch.distributed.group.WORLD + self.world_size = get_pg_size(self.tp_group) + self.rank = get_pg_rank(self.tp_group) + + def test_5steps_matches_non_distributed_rekls(self): + """Multiple TpRekls steps over multiple params (mixed partition_dim) must produce + bit-identical updates to non-distributed REKLS for every param at every step. + """ + params_config = [ + {"shape": (16, 32), "partition_dim": 0}, + {"shape": (32, 16), "partition_dim": 1}, + {"shape": (96, 200), "partition_dim": 0}, + {"shape": (96, 200), "partition_dim": 1}, + ] + for cfg in params_config: + m, n = cfg["shape"] + assert m % self.world_size == 0 and n % self.world_size == 0, ( + f"shape {cfg['shape']} must be divisible by world size {self.world_size}" + ) + + # Initial param data — all-reduce so every rank starts from the same tensors. + full_params_data = [] + for cfg in params_config: + d = torch.randn(cfg["shape"]) + torch.distributed.all_reduce(d) + full_params_data.append(d) + + ref_params = [torch.nn.Parameter(d.clone()) for d in full_params_data] + ref_optimizer = REKLS(ref_params, lr=1e-3) + + tp_params = [] + for cfg, d in zip(params_config, full_params_data): + local_data = d.chunk(self.world_size, dim=cfg["partition_dim"])[self.rank].contiguous() + local_param = torch.nn.Parameter(local_data) + local_param.partition_dim = cfg["partition_dim"] + tp_params.append(local_param) + tp_optimizer = TpRekls(tp_params, lr=1e-3, tp_group=self.tp_group) + + num_steps = 5 + for _ in range(num_steps): + full_grads = [] + for cfg in params_config: + g = torch.randn(cfg["shape"]) + torch.distributed.all_reduce(g) + full_grads.append(g) + + for ref_p, full_g in zip(ref_params, full_grads): + ref_p.grad = full_g.clone() + for tp_p, cfg, full_g in zip(tp_params, params_config, full_grads): + tp_p.grad = full_g.chunk(self.world_size, dim=cfg["partition_dim"])[self.rank].contiguous() + + ref_optimizer.step() + tp_optimizer.step() + + for ref_p, tp_p, cfg in zip(ref_params, tp_params, params_config): + ref_local = ref_p.detach().chunk(self.world_size, dim=cfg["partition_dim"])[self.rank] + torch.testing.assert_close( + tp_p.detach(), + ref_local, + atol=0, + rtol=0, + ) + + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="gloo") + torch.set_float32_matmul_precision("highest") + + rank = torch.distributed.get_rank() + + for i, arg in enumerate(sys.argv): + if arg.startswith("--xml_output_file="): + base, ext = os.path.splitext(arg) + sys.argv[i] = f"{base}_rank{rank}{ext}" + break + + absltest.main() + + torch.distributed.destroy_process_group() From 172a41e5719fb4b8b53417bbbf3c168125baec84 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 11:20:58 -0700 Subject: [PATCH 05/16] add fallback and test Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 72 ++++++++++++++++++------------- tests/ci/L0_Tests_CPU.sh | 17 ++++++-- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 0d24647c..796b83fc 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -105,9 +105,11 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): fp32_matmul_prec: Precision for the optimizer-state GEMM operations. Note: - Each parameter must carry a ``partition_dim`` attribute (an int in ``{0, 1}``) describing the - dimension along which it is sharded across ``tp_group``. This matches the megatron-lm - convention. Parameters without a ``partition_dim`` attribute will raise :class:`TypeError`. + A parameter is treated as tensor-parallel iff it carries a ``partition_dim`` attribute + (an int in ``{0, 1}``) describing the dimension along which it is sharded across + ``tp_group``. This matches the megatron-lm convention. Parameters without + ``partition_dim`` are treated as replicated and updated with the plain (non-TP) REKLS step + on each rank — no collectives, full-size ``L``/``R``. """ def __init__( @@ -140,26 +142,15 @@ def __init__( super().__init__(params, defaults) @staticmethod - def _get_partition_dim(p: torch.Tensor) -> int: + def _get_partition_dim(p: torch.Tensor) -> int | None: + """Returns ``p.partition_dim`` if set, else ``None`` (param is treated as replicated).""" partition_dim = getattr(p, "partition_dim", None) if partition_dim is None: - raise TypeError( - f"TpRekls requires each parameter to carry a 'partition_dim' attribute " - f"(megatron-lm convention); got a parameter of shape {tuple(p.shape)} without one." - ) + return None if partition_dim not in (0, 1): raise ValueError(f"partition_dim must be 0 or 1, got {partition_dim}") return partition_dim - def _full_shape(self, p: torch.Tensor) -> tuple[int, int]: - m, n = p.shape - partition_dim = self._get_partition_dim(p) - if partition_dim == 0: - m *= self.tp_size - else: - n *= self.tp_size - return m, n - @torch.no_grad() # type: ignore[misc] def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: for p in group["params"]: @@ -169,13 +160,22 @@ def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: raise TypeError("TpRekls is only supported for 2D tensors") state = self.state[p] if len(state) == 0: - m, n = self._full_shape(p) + partition_dim = self._get_partition_dim(p) + m, n = p.shape + if partition_dim == 0: + m *= self.tp_size + elif partition_dim == 1: + n *= self.tp_size + # When partition_dim is None: param is replicated, m and n are already full. + state["step"] = 0 state["exp_avg"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) state["exp_avg_sq"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) # Match init_kronecker_factors in soap.py: default dtype (typically float32). - state["L"] = torch.zeros((m // self.tp_size, m), device=p.device) - state["R"] = torch.zeros((n // self.tp_size, n), device=p.device) + # L, R are sharded along dim 0 only when the param is tensor-parallel. + shard = self.tp_size if partition_dim is not None else 1 + state["L"] = torch.zeros((m // shard, m), device=p.device) + state["R"] = torch.zeros((n // shard, n), device=p.device) if TYPE_CHECKING: @@ -208,12 +208,17 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Apply weight decay before the gather so l2 mode propagates into full_grad. self._apply_weight_decay_inplace(p, local_grad, group["lr"], group["weight_decay"]) - full_grad, full_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( - kronecker_factor_list=[state["L"], state["R"]], - grad=local_grad, - partition_dim=partition_dim, - tp_group=self.tp_group, - ) + if partition_dim is None: + # Replicated parameter: no all-gather, state is already full-size. + full_grad = local_grad + full_factors = [state["L"], state["R"]] + else: + full_grad, full_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( + kronecker_factor_list=[state["L"], state["R"]], + grad=local_grad, + partition_dim=partition_dim, + tp_group=self.tp_group, + ) # Apply shampoo beta bias correction. shampoo_beta = group["shampoo_beta"] @@ -231,9 +236,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: eps=group["eps"], ) - # Persist the updated local shard back into state. - state["L"].copy_(full_factors[0].chunk(self.tp_size, dim=0)[self.tp_rank]) - state["R"].copy_(full_factors[1].chunk(self.tp_size, dim=0)[self.tp_rank]) + # Persist the updated local shard back into state — only needed for the TP path, + # since the replicated path updated state["L"], state["R"] in place via the alias. + if partition_dim is not None: + state["L"].copy_(full_factors[0].chunk(self.tp_size, dim=0)[self.tp_rank]) + state["R"].copy_(full_factors[1].chunk(self.tp_size, dim=0)[self.tp_rank]) with utils.fp32_matmul_precision(self.fp32_matmul_prec): # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis, @@ -262,8 +269,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: full_precond_update = soap.precondition(full_adam_update, eigenbasis_list, dims=[[0], [1]]) - local_precond_update = full_precond_update.chunk(self.tp_size, dim=partition_dim)[self.tp_rank] - p.add_(local_precond_update, alpha=-group["lr"]) + if partition_dim is None: + p.add_(full_precond_update, alpha=-group["lr"]) + else: + local_precond_update = full_precond_update.chunk(self.tp_size, dim=partition_dim)[self.tp_rank] + p.add_(local_precond_update, alpha=-group["lr"]) state["step"] += 1 diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index 30c09b2b..44042f3c 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -17,12 +17,21 @@ mkdir -p test-results/tests/ error=0 for n in 8 4; do - torchrun --nproc_per_node=$n --no-python coverage run -p \ - tests/test_distributed_muon_utils_cpu.py \ - --xml_output_file="test-results/tests/test_distributed_muon_utils_cpu_n${n}.xml" \ - -v -2 || error=1 + for test in tests/test_distributed_*_cpu.py; do + report_base=$(basename "$test" .py) + torchrun --nproc_per_node=$n --no-python coverage run -p \ + "$test" \ + --xml_output_file="test-results/tests/${report_base}_n${n}.xml" \ + -v -2 || error=1 + done done +# Single-process sanity check for TpRekls — exercises the size-1 tp_group path. +torchrun --nproc_per_node=1 --no-python coverage run -p \ + tests/test_distributed_rekls_cpu.py \ + --xml_output_file="test-results/tests/test_distributed_rekls_cpu_n1.xml" \ + -v -2 || error=1 + for test in "tests/test_scalar_optimizers.py" "tests/test_procrustes_step.py"; do report_name="test-results/${test}.xml" coverage run -p --source=emerging_optimizers $test --device=cpu -v -2 --xml_output_file="$report_name" || error=1 From 14c7e84d8248d9f57a95fc1da015a6e377063d70 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 11:23:26 -0700 Subject: [PATCH 06/16] update doc Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index dfe265a1..89aece76 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -380,7 +380,24 @@ def update_kronecker_factors_kl_shampoo( ) -> None: """Updates the kronecker factor matrices in place using KL-Shampoo correction. - Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378 + Implements the Kullback–Leibler minimization update from https://arxiv.org/pdf/2509.03378. + + For a gradient :math:`G \\in \\mathbb{R}^{m \\times n}`, current kronecker factors + :math:`L_t \\in \\mathbb{R}^{m \\times m}`, :math:`R_t \\in \\mathbb{R}^{n \\times n}`, and their + orthonormal eigenbases :math:`Q_L, Q_R`, the approximate eigenvalues in the current eigenbasis are + + .. math:: + \\Lambda_L = \\mathrm{diag}(Q_L^{\\top} L_t Q_L), \\quad + \\Lambda_R = \\mathrm{diag}(Q_R^{\\top} R_t Q_R) + + and the EMA update with momentum :math:`\\beta` (= ``shampoo_beta``) and exponent :math:`p` + (= ``eigval_exp``, default ``-1``) is + + .. math:: + L_{t+1} = \\beta\\, L_t + \\frac{1-\\beta}{n}\\, G\\, Q_R\\, \\mathrm{diag}(\\Lambda_R^{p})\\, Q_R^{\\top} G^{\\top} \\\\ + R_{t+1} = \\beta\\, R_t + \\frac{1-\\beta}{m}\\, G^{\\top}\\, Q_L\\, \\mathrm{diag}(\\Lambda_L^{p})\\, Q_L^{\\top} G + + Eigenvalues are clamped to ``eps`` from below before exponentiation for numerical stability. Args: kronecker_factor_list: List of preconditioner matrices (L and R) to update. @@ -388,7 +405,7 @@ def update_kronecker_factors_kl_shampoo( shampoo_beta: Momentum coefficient for updating preconditioners. eigenbasis_list: List of orthonormal eigenbases of the kronecker factor matrices eps: Small offset for numerical stability. - eigenval_exp: Exponent of the eigenvalues. + eigval_exp: Exponent applied to the (clamped) eigenvalues. Defaults to ``-1.0``. """ if grad.dim() != 2: raise TypeError("KL-Shampoo mathematical correction is only supported for 2D tensors") From 72349ef77c1ce55f7562847caacd77406b51cf8a Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 11:34:15 -0700 Subject: [PATCH 07/16] fix couple tiny issues Signed-off-by: Hao Wu --- emerging_optimizers/soap/tp_utils.py | 2 +- tests/test_distributed_rekls_cpu.py | 32 ++++++++++++++++++------ tests/test_distributed_soap_utils_cpu.py | 14 +++++------ 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/emerging_optimizers/soap/tp_utils.py b/emerging_optimizers/soap/tp_utils.py index c168bff1..32840ce3 100644 --- a/emerging_optimizers/soap/tp_utils.py +++ b/emerging_optimizers/soap/tp_utils.py @@ -30,7 +30,7 @@ def all_gather_grad_and_kronecker_factors_tp( This is a simple implementation to support tensor parallel. It assumes grad is sharded among tensor parallel domain with partition_dim indicating the dimension it was sharded. To save memory, kronecker factors are also sharded - but always long dimension 0 to make gather operation easy. + but always along dimension 0 to make gather operation easy. Gradient and kronecker factors are both all-gathered to all tensor parallel ranks and returned so the caller can pass them to ``update_kronecker_factors`` (and downstream eigenbasis computations) without further diff --git a/tests/test_distributed_rekls_cpu.py b/tests/test_distributed_rekls_cpu.py index bee55b1c..b5f88693 100644 --- a/tests/test_distributed_rekls_cpu.py +++ b/tests/test_distributed_rekls_cpu.py @@ -47,13 +47,18 @@ def test_5steps_matches_non_distributed_rekls(self): """Multiple TpRekls steps over multiple params (mixed partition_dim) must produce bit-identical updates to non-distributed REKLS for every param at every step. """ + # ``partition_dim=None`` exercises the replicated fallback path: param has no partition_dim + # attribute, no all-gather, full-size L/R, full update applied directly. params_config = [ {"shape": (16, 32), "partition_dim": 0}, {"shape": (32, 16), "partition_dim": 1}, {"shape": (96, 200), "partition_dim": 0}, {"shape": (96, 200), "partition_dim": 1}, + {"shape": (24, 40), "partition_dim": None}, ] for cfg in params_config: + if cfg["partition_dim"] is None: + continue m, n = cfg["shape"] assert m % self.world_size == 0 and n % self.world_size == 0, ( f"shape {cfg['shape']} must be divisible by world size {self.world_size}" @@ -63,7 +68,7 @@ def test_5steps_matches_non_distributed_rekls(self): full_params_data = [] for cfg in params_config: d = torch.randn(cfg["shape"]) - torch.distributed.all_reduce(d) + torch.distributed.all_reduce(d, group=self.tp_group) full_params_data.append(d) ref_params = [torch.nn.Parameter(d.clone()) for d in full_params_data] @@ -71,9 +76,14 @@ def test_5steps_matches_non_distributed_rekls(self): tp_params = [] for cfg, d in zip(params_config, full_params_data): - local_data = d.chunk(self.world_size, dim=cfg["partition_dim"])[self.rank].contiguous() + pd = cfg["partition_dim"] + if pd is None: + local_data = d.clone() + else: + local_data = d.chunk(self.world_size, dim=pd)[self.rank].contiguous() local_param = torch.nn.Parameter(local_data) - local_param.partition_dim = cfg["partition_dim"] + if pd is not None: + local_param.partition_dim = pd tp_params.append(local_param) tp_optimizer = TpRekls(tp_params, lr=1e-3, tp_group=self.tp_group) @@ -82,19 +92,27 @@ def test_5steps_matches_non_distributed_rekls(self): full_grads = [] for cfg in params_config: g = torch.randn(cfg["shape"]) - torch.distributed.all_reduce(g) + torch.distributed.all_reduce(g, group=self.tp_group) full_grads.append(g) for ref_p, full_g in zip(ref_params, full_grads): ref_p.grad = full_g.clone() for tp_p, cfg, full_g in zip(tp_params, params_config, full_grads): - tp_p.grad = full_g.chunk(self.world_size, dim=cfg["partition_dim"])[self.rank].contiguous() + pd = cfg["partition_dim"] + if pd is None: + tp_p.grad = full_g.clone() + else: + tp_p.grad = full_g.chunk(self.world_size, dim=pd)[self.rank].contiguous() ref_optimizer.step() tp_optimizer.step() for ref_p, tp_p, cfg in zip(ref_params, tp_params, params_config): - ref_local = ref_p.detach().chunk(self.world_size, dim=cfg["partition_dim"])[self.rank] + pd = cfg["partition_dim"] + if pd is None: + ref_local = ref_p.detach() + else: + ref_local = ref_p.detach().chunk(self.world_size, dim=pd)[self.rank] torch.testing.assert_close( tp_p.detach(), ref_local, @@ -107,7 +125,7 @@ def test_5steps_matches_non_distributed_rekls(self): torch.distributed.init_process_group(backend="gloo") torch.set_float32_matmul_precision("highest") - rank = torch.distributed.get_rank() + rank = get_pg_rank(torch.distributed.group.WORLD) for i, arg in enumerate(sys.argv): if arg.startswith("--xml_output_file="): diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py index 00fc300e..8cf86fac 100644 --- a/tests/test_distributed_soap_utils_cpu.py +++ b/tests/test_distributed_soap_utils_cpu.py @@ -56,9 +56,9 @@ def test_matches_non_distributed(self, shape, partition_dim): full_l = torch.randint(-5, 5, (m, m)) full_r = torch.randint(-5, 5, (n, n)) # All-reduce ensures that every rank starts from the same tensors. - torch.distributed.all_reduce(full_grad) - torch.distributed.all_reduce(full_l) - torch.distributed.all_reduce(full_r) + torch.distributed.all_reduce(full_grad, group=self.tp_group) + torch.distributed.all_reduce(full_l, group=self.tp_group) + torch.distributed.all_reduce(full_r, group=self.tp_group) local_grad = full_grad.chunk(self.world_size, dim=partition_dim)[self.rank].contiguous() local_l = full_l.chunk(self.world_size, dim=0)[self.rank].contiguous() @@ -87,9 +87,9 @@ def test_updated_factors_match_non_distributed(self, shape, partition_dim): full_grad = torch.randint(-5, 5, shape, dtype=torch.float32) full_l = torch.randint(-5, 5, (m, m), dtype=torch.float32) full_r = torch.randint(-5, 5, (n, n), dtype=torch.float32) - torch.distributed.all_reduce(full_grad) - torch.distributed.all_reduce(full_l) - torch.distributed.all_reduce(full_r) + torch.distributed.all_reduce(full_grad, group=self.tp_group) + torch.distributed.all_reduce(full_l, group=self.tp_group) + torch.distributed.all_reduce(full_r, group=self.tp_group) shampoo_beta = 0.95 @@ -117,7 +117,7 @@ def test_updated_factors_match_non_distributed(self, shape, partition_dim): torch.distributed.init_process_group(backend="gloo") torch.set_float32_matmul_precision("highest") - rank = torch.distributed.get_rank() + rank = get_pg_rank(torch.distributed.group.WORLD) for i, arg in enumerate(sys.argv): if arg.startswith("--xml_output_file="): From 258545e9975d24c7595b0731fd1e14a4f308b990 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 13:08:26 -0700 Subject: [PATCH 08/16] disable closure Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 796b83fc..6b3bc5f9 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -188,10 +188,7 @@ def step(self, closure: Callable[[], float]) -> float: ... @torch.no_grad() # type: ignore[misc] @override def step(self, closure: Callable[[], float] | None = None) -> float | None: - if closure is None: - loss = None - else: - loss = closure() + assert closure is None, "No support for closure" for group in self.param_groups: self._init_group(group) @@ -277,4 +274,4 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: state["step"] += 1 - return loss + return None From fc173bcd120de7f95f563ddd541b331b66ac4d8b Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 13:16:10 -0700 Subject: [PATCH 09/16] move destroy process group to module tear down Signed-off-by: Hao Wu --- tests/test_distributed_rekls_cpu.py | 6 ++++-- tests/test_distributed_soap_utils_cpu.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_distributed_rekls_cpu.py b/tests/test_distributed_rekls_cpu.py index b5f88693..8f76b7da 100644 --- a/tests/test_distributed_rekls_cpu.py +++ b/tests/test_distributed_rekls_cpu.py @@ -36,6 +36,10 @@ def setUpModule() -> None: torch.cuda.manual_seed_all(FLAGS.seed) +def tearDownModule() -> None: + torch.distributed.destroy_process_group() + + class TpReklsCpuTest(parameterized.TestCase): def setUp(self): super().setUp() @@ -134,5 +138,3 @@ def test_5steps_matches_non_distributed_rekls(self): break absltest.main() - - torch.distributed.destroy_process_group() diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py index 8cf86fac..229a88de 100644 --- a/tests/test_distributed_soap_utils_cpu.py +++ b/tests/test_distributed_soap_utils_cpu.py @@ -36,6 +36,10 @@ def setUpModule() -> None: torch.cuda.manual_seed_all(FLAGS.seed) +def tearDownModule() -> None: + torch.distributed.destroy_process_group() + + class AllGatherGradAndKroneckerFactorsTpCpuTest(parameterized.TestCase): def setUp(self): super().setUp() @@ -128,5 +132,3 @@ def test_updated_factors_match_non_distributed(self, shape, partition_dim): break absltest.main() - - torch.distributed.destroy_process_group() From e90c7e4f8883899154c1ffbd90a089eebd8bcb97 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 13:39:41 -0700 Subject: [PATCH 10/16] narrow step overload Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 6b3bc5f9..bb141286 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -187,7 +187,7 @@ def step(self, closure: Callable[[], float]) -> float: ... @torch.no_grad() # type: ignore[misc] @override - def step(self, closure: Callable[[], float] | None = None) -> float | None: + def step(self, closure: None = None) -> None: assert closure is None, "No support for closure" for group in self.param_groups: self._init_group(group) From 83165974e277ed0f2477dbdbcfb733545d3a6c12 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 15:28:58 -0700 Subject: [PATCH 11/16] generalize partition_dim to be group setting Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 35 +++++++++++++++-------------- tests/test_distributed_rekls_cpu.py | 14 +++++++----- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index bb141286..970da0bf 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -90,9 +90,6 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): rotated, matching SOAP's eigh path. - ``L``, ``R``: kronecker factor matrices, sharded along dimension 0 across ``tp_group``. - Each step issues exactly one collective: an all-gather of the local gradient and ``L``/``R`` shards - via :func:`~emerging_optimizers.soap.tp_utils.all_gather_grad_and_kronecker_factors_tp`. - Args: params: Iterable of parameters to optimize or dicts defining parameter groups. lr: Learning rate. @@ -105,11 +102,18 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): fp32_matmul_prec: Precision for the optimizer-state GEMM operations. Note: - A parameter is treated as tensor-parallel iff it carries a ``partition_dim`` attribute - (an int in ``{0, 1}``) describing the dimension along which it is sharded across - ``tp_group``. This matches the megatron-lm convention. Parameters without - ``partition_dim`` are treated as replicated and updated with the plain (non-TP) REKLS step - on each rank — no collectives, full-size ``L``/``R``. + Sharding is configured per-parameter-group via ``partition_dim`` (an int in ``{0, 1}``, + or ``None`` for replicated parameters). Mixed-layout models should use one group per + distinct ``partition_dim``:: + + optimizer = TpRekls([ + {"params": column_parallel_params, "partition_dim": 0}, + {"params": row_parallel_params, "partition_dim": 1}, + {"params": replicated_params, "partition_dim": None}, + ], lr=1e-3, tp_group=tp_group) + + Groups without ``partition_dim`` use the default (``None`` → replicated, plain non-TP REKLS + step on each rank, no collectives, full-size ``L``/``R``). """ def __init__( @@ -138,21 +142,19 @@ def __init__( "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay, + "partition_dim": None, } super().__init__(params, defaults) @staticmethod - def _get_partition_dim(p: torch.Tensor) -> int | None: - """Returns ``p.partition_dim`` if set, else ``None`` (param is treated as replicated).""" - partition_dim = getattr(p, "partition_dim", None) - if partition_dim is None: - return None - if partition_dim not in (0, 1): - raise ValueError(f"partition_dim must be 0 or 1, got {partition_dim}") + def _validate_partition_dim(partition_dim: int | None) -> int | None: + if partition_dim is not None and partition_dim not in (0, 1): + raise ValueError(f"partition_dim must be 0, 1, or None, got {partition_dim}") return partition_dim @torch.no_grad() # type: ignore[misc] def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: + partition_dim = self._validate_partition_dim(group["partition_dim"]) for p in group["params"]: if skip_non_grad_params and p.grad is None: continue @@ -160,7 +162,6 @@ def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: raise TypeError("TpRekls is only supported for 2D tensors") state = self.state[p] if len(state) == 0: - partition_dim = self._get_partition_dim(p) m, n = p.shape if partition_dim == 0: m *= self.tp_size @@ -193,13 +194,13 @@ def step(self, closure: None = None) -> None: self._init_group(group) for group in self.param_groups: + partition_dim = self._validate_partition_dim(group["partition_dim"]) for p in group["params"]: if p.grad is None: continue # pragma: no cover local_grad = p.grad.to(torch.float32) state = self.state[p] - partition_dim = self._get_partition_dim(p) curr_iter_1_based = state["step"] + 1 # Apply weight decay before the gather so l2 mode propagates into full_grad. diff --git a/tests/test_distributed_rekls_cpu.py b/tests/test_distributed_rekls_cpu.py index 8f76b7da..096a3dd5 100644 --- a/tests/test_distributed_rekls_cpu.py +++ b/tests/test_distributed_rekls_cpu.py @@ -85,11 +85,15 @@ def test_5steps_matches_non_distributed_rekls(self): local_data = d.clone() else: local_data = d.chunk(self.world_size, dim=pd)[self.rank].contiguous() - local_param = torch.nn.Parameter(local_data) - if pd is not None: - local_param.partition_dim = pd - tp_params.append(local_param) - tp_optimizer = TpRekls(tp_params, lr=1e-3, tp_group=self.tp_group) + tp_params.append(torch.nn.Parameter(local_data)) + + # One param group per distinct partition_dim — TpRekls reads partition_dim from group. + tp_param_groups: list[dict] = [] + for pd in (0, 1, None): + members = [tp_p for tp_p, cfg in zip(tp_params, params_config) if cfg["partition_dim"] == pd] + if members: + tp_param_groups.append({"params": members, "partition_dim": pd}) + tp_optimizer = TpRekls(tp_param_groups, lr=1e-3, tp_group=self.tp_group) num_steps = 5 for _ in range(num_steps): From 6eee42fd11999e13c88b4132b65905c87b940f8f Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 6 May 2026 15:45:06 -0700 Subject: [PATCH 12/16] add dimension guard Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 970da0bf..e34143c2 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -169,6 +169,14 @@ def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: n *= self.tp_size # When partition_dim is None: param is replicated, m and n are already full. + # Both dimensions must be divisible by tp_size for the L/R shards (each sharded + # along dim 0) to gather back to the full square shape via torch.cat. + if partition_dim is not None and (m % self.tp_size or n % self.tp_size): + raise ValueError( + f"TpRekls requires both dimensions to be divisible by tp_size={self.tp_size}; " + f"got full shape ({m}, {n}) for a parameter with partition_dim={partition_dim}." + ) + state["step"] = 0 state["exp_avg"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) state["exp_avg_sq"] = torch.zeros((m, n), dtype=torch.float32, device=p.device) From 391845931e116051a012541640a169cd84e962d9 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 14 May 2026 09:01:02 -0700 Subject: [PATCH 13/16] improve naming consistency Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index e34143c2..1e5cd034 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -217,9 +217,9 @@ def step(self, closure: None = None) -> None: if partition_dim is None: # Replicated parameter: no all-gather, state is already full-size. full_grad = local_grad - full_factors = [state["L"], state["R"]] + kronecker_factor_list = [state["L"], state["R"]] else: - full_grad, full_factors = tp_utils.all_gather_grad_and_kronecker_factors_tp( + full_grad, kronecker_factor_list = tp_utils.all_gather_grad_and_kronecker_factors_tp( kronecker_factor_list=[state["L"], state["R"]], grad=local_grad, partition_dim=partition_dim, @@ -233,9 +233,9 @@ def step(self, closure: None = None) -> None: # KL-Shampoo correction needs the eigenbasis of the *pre-update* L, R; recompute it # via eigh since we do not persist eigenbases across steps. with utils.fp32_matmul_precision(self.fp32_matmul_prec): - pre_eigenbasis_list = soap_utils.get_eigenbasis_eigh(full_factors) + pre_eigenbasis_list = soap_utils.get_eigenbasis_eigh(kronecker_factor_list) soap.update_kronecker_factors_kl_shampoo( - full_factors, + kronecker_factor_list, full_grad, shampoo_beta=shampoo_beta, eigenbasis_list=pre_eigenbasis_list, @@ -245,14 +245,14 @@ def step(self, closure: None = None) -> None: # Persist the updated local shard back into state — only needed for the TP path, # since the replicated path updated state["L"], state["R"] in place via the alias. if partition_dim is not None: - state["L"].copy_(full_factors[0].chunk(self.tp_size, dim=0)[self.tp_rank]) - state["R"].copy_(full_factors[1].chunk(self.tp_size, dim=0)[self.tp_rank]) + state["L"].copy_(kronecker_factor_list[0].chunk(self.tp_size, dim=0)[self.tp_rank]) + state["R"].copy_(kronecker_factor_list[1].chunk(self.tp_size, dim=0)[self.tp_rank]) with utils.fp32_matmul_precision(self.fp32_matmul_prec): # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis, # and recompute the post-update eigenbasis via eigh. eigenbasis_list, state["exp_avg"], state["exp_avg_sq"] = soap.update_eigenbasis_and_exp_avgs( - kronecker_factor_list=full_factors, + kronecker_factor_list=kronecker_factor_list, eigenbasis_list=pre_eigenbasis_list, exp_avg_sq=state["exp_avg_sq"], exp_avg=state["exp_avg"], From 619cc158c3d9f91fda940fbfdd06c07bab17a93a Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 14 May 2026 09:15:44 -0700 Subject: [PATCH 14/16] test svd for tp rekls Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 22 ++++++--------- emerging_optimizers/soap/soap_utils.py | 39 ++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 1e5cd034..b91c7c00 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -76,8 +76,8 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): Reimplemented from scratch (not inheriting from :class:`~emerging_optimizers.soap.soap.SOAP`) so the tensor-parallel bookkeeping stays isolated. Eigenbases are not stored in optimizer state; they are - recomputed via :func:`~emerging_optimizers.soap.soap_utils.get_eigenbasis_eigh` from the kronecker - factors. Each step calls eigh twice — once on the pre-update L, R for the + recomputed via :func:`~emerging_optimizers.soap.soap_utils.get_eigenbasis_svd` from the kronecker + factors. Each step calls svd twice — once on the pre-update L, R for the :func:`~emerging_optimizers.soap.soap.update_kronecker_factors_kl_shampoo` correction, and once on the post-update L, R for the gradient projection. @@ -231,9 +231,9 @@ def step(self, closure: None = None) -> None: shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based) # KL-Shampoo correction needs the eigenbasis of the *pre-update* L, R; recompute it - # via eigh since we do not persist eigenbases across steps. + # via svd since we do not persist eigenbases across steps. with utils.fp32_matmul_precision(self.fp32_matmul_prec): - pre_eigenbasis_list = soap_utils.get_eigenbasis_eigh(kronecker_factor_list) + pre_eigenbasis_list = soap_utils.get_eigenbasis_svd(kronecker_factor_list) soap.update_kronecker_factors_kl_shampoo( kronecker_factor_list, full_grad, @@ -249,15 +249,11 @@ def step(self, closure: None = None) -> None: state["R"].copy_(kronecker_factor_list[1].chunk(self.tp_size, dim=0)[self.tp_rank]) with utils.fp32_matmul_precision(self.fp32_matmul_prec): - # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis, - # and recompute the post-update eigenbasis via eigh. - eigenbasis_list, state["exp_avg"], state["exp_avg_sq"] = soap.update_eigenbasis_and_exp_avgs( - kronecker_factor_list=kronecker_factor_list, - eigenbasis_list=pre_eigenbasis_list, - exp_avg_sq=state["exp_avg_sq"], - exp_avg=state["exp_avg"], - use_eigh=True, - ) + # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis + # (matches update_eigenbasis_and_exp_avgs in soap.py for the eigh path; we use svd here). + state["exp_avg"] = soap.precondition(state["exp_avg"], pre_eigenbasis_list, dims=[[0], [1]]) + eigenbasis_list = soap_utils.get_eigenbasis_svd(kronecker_factor_list) + state["exp_avg"] = soap.precondition(state["exp_avg"], eigenbasis_list, dims=[[0], [0]]) full_grad_projected = soap.precondition(full_grad, eigenbasis_list, dims=[[0], [0]]) diff --git a/emerging_optimizers/soap/soap_utils.py b/emerging_optimizers/soap/soap_utils.py index 507c4aad..12d9c3e9 100644 --- a/emerging_optimizers/soap/soap_utils.py +++ b/emerging_optimizers/soap/soap_utils.py @@ -25,6 +25,7 @@ __all__ = [ "get_eigenbasis_eigh", "get_eigenbasis_qr", + "get_eigenbasis_svd", ] @@ -61,6 +62,44 @@ def get_eigenbasis_eigh( return updated_eigenbasis_list +def get_eigenbasis_svd( + kronecker_factor_list: TensorList, +) -> TensorList: + """Computes the eigenbases of the preconditioner using torch.linalg.svd decomposition. + + The kronecker factors :math:`L = GG^\\top` and :math:`R = G^\\top G` are symmetric positive + semi-definite, so the left and right singular vectors coincide (up to sign in the presence + of repeated singular values); this function returns the left singular vectors :math:`U` as + the eigenbasis. Singular values from ``torch.linalg.svd`` are returned in descending order. + + Args: + kronecker_factor_list: Matrix List to compute eigenbases of + + Returns: + List of orthonormal kronecker factor eigenbases matrices + + Example: + .. code-block:: python + + # Create sample Kronecker factors (symmetric positive definite matrices) + k_factor1 = torch.randn(4, 4) + k_factor1 = k_factor1 @ k_factor1.T # Make symmetric positive definite + k_factor2 = torch.randn(5, 5) + k_factor2 = k_factor2 @ k_factor2.T # Make symmetric positive definite + + # Get orthogonal matrices for these factors + ortho_matrices = get_eigenbasis_svd([k_factor1, k_factor2]) + # ortho_matrices[0] has shape [4, 4] and ortho_matrices[1] has shape [5, 5] + """ + updated_eigenbasis_list: TensorList = [] + + for kronecker_factor in kronecker_factor_list: + U, _, _ = torch.linalg.svd(kronecker_factor) + updated_eigenbasis_list.append(U) + + return updated_eigenbasis_list + + def get_eigenbasis_qr( kronecker_factor_list: TensorList, eigenbasis_list: TensorList, From c8418a5a65690ff0e5eff249dc841725a6a77261 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 14 May 2026 14:33:21 -0700 Subject: [PATCH 15/16] remove svd exp Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index b91c7c00..1e5cd034 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -76,8 +76,8 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer): Reimplemented from scratch (not inheriting from :class:`~emerging_optimizers.soap.soap.SOAP`) so the tensor-parallel bookkeeping stays isolated. Eigenbases are not stored in optimizer state; they are - recomputed via :func:`~emerging_optimizers.soap.soap_utils.get_eigenbasis_svd` from the kronecker - factors. Each step calls svd twice — once on the pre-update L, R for the + recomputed via :func:`~emerging_optimizers.soap.soap_utils.get_eigenbasis_eigh` from the kronecker + factors. Each step calls eigh twice — once on the pre-update L, R for the :func:`~emerging_optimizers.soap.soap.update_kronecker_factors_kl_shampoo` correction, and once on the post-update L, R for the gradient projection. @@ -231,9 +231,9 @@ def step(self, closure: None = None) -> None: shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based) # KL-Shampoo correction needs the eigenbasis of the *pre-update* L, R; recompute it - # via svd since we do not persist eigenbases across steps. + # via eigh since we do not persist eigenbases across steps. with utils.fp32_matmul_precision(self.fp32_matmul_prec): - pre_eigenbasis_list = soap_utils.get_eigenbasis_svd(kronecker_factor_list) + pre_eigenbasis_list = soap_utils.get_eigenbasis_eigh(kronecker_factor_list) soap.update_kronecker_factors_kl_shampoo( kronecker_factor_list, full_grad, @@ -249,11 +249,15 @@ def step(self, closure: None = None) -> None: state["R"].copy_(kronecker_factor_list[1].chunk(self.tp_size, dim=0)[self.tp_rank]) with utils.fp32_matmul_precision(self.fp32_matmul_prec): - # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis - # (matches update_eigenbasis_and_exp_avgs in soap.py for the eigh path; we use svd here). - state["exp_avg"] = soap.precondition(state["exp_avg"], pre_eigenbasis_list, dims=[[0], [1]]) - eigenbasis_list = soap_utils.get_eigenbasis_svd(kronecker_factor_list) - state["exp_avg"] = soap.precondition(state["exp_avg"], eigenbasis_list, dims=[[0], [0]]) + # Rotate exp_avg from the pre-update eigenbasis to the post-update eigenbasis, + # and recompute the post-update eigenbasis via eigh. + eigenbasis_list, state["exp_avg"], state["exp_avg_sq"] = soap.update_eigenbasis_and_exp_avgs( + kronecker_factor_list=kronecker_factor_list, + eigenbasis_list=pre_eigenbasis_list, + exp_avg_sq=state["exp_avg_sq"], + exp_avg=state["exp_avg"], + use_eigh=True, + ) full_grad_projected = soap.precondition(full_grad, eigenbasis_list, dims=[[0], [0]]) From 208720cf51ff802d82778966dac2988190e03fec Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Fri, 22 May 2026 14:32:17 -0700 Subject: [PATCH 16/16] EXP: use laprop inside REKLS Signed-off-by: Hao Wu --- emerging_optimizers/soap/rekls.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/emerging_optimizers/soap/rekls.py b/emerging_optimizers/soap/rekls.py index 1e5cd034..5efbe061 100644 --- a/emerging_optimizers/soap/rekls.py +++ b/emerging_optimizers/soap/rekls.py @@ -163,11 +163,11 @@ def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None: state = self.state[p] if len(state) == 0: m, n = p.shape + # Get full size of m, n if the parameter is tensor-parallel. if partition_dim == 0: m *= self.tp_size elif partition_dim == 1: n *= self.tp_size - # When partition_dim is None: param is replicated, m and n are already full. # Both dimensions must be divisible by tp_size for the L/R shards (each sharded # along dim 0) to gather back to the full square shape via torch.cat. @@ -262,13 +262,12 @@ def step(self, closure: None = None) -> None: full_grad_projected = soap.precondition(full_grad, eigenbasis_list, dims=[[0], [0]]) # No matmul inside adam update. Put it under fp32_matmul_precision for code simplicity. - full_adam_update = update_functions.calculate_adam_update( + full_adam_update = update_functions.calculate_laprop_update( full_grad_projected, state["exp_avg"], state["exp_avg_sq"], - group["betas"], True, # correct_bias - False, # nesterov + group["betas"], curr_iter_1_based, group["eps"], )