Skip to content

Commit ea6c697

Browse files
committed
generalize partition_dim to be group setting
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent c602e85 commit ea6c697

2 files changed

Lines changed: 27 additions & 22 deletions

File tree

emerging_optimizers/soap/rekls.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer):
9090
rotated, matching SOAP's eigh path.
9191
- ``L``, ``R``: kronecker factor matrices, sharded along dimension 0 across ``tp_group``.
9292
93-
Each step issues exactly one collective: an all-gather of the local gradient and ``L``/``R`` shards
94-
via :func:`~emerging_optimizers.soap.tp_utils.all_gather_grad_and_kronecker_factors_tp`.
95-
9693
Args:
9794
params: Iterable of parameters to optimize or dicts defining parameter groups.
9895
lr: Learning rate.
@@ -105,11 +102,18 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer):
105102
fp32_matmul_prec: Precision for the optimizer-state GEMM operations.
106103
107104
Note:
108-
A parameter is treated as tensor-parallel iff it carries a ``partition_dim`` attribute
109-
(an int in ``{0, 1}``) describing the dimension along which it is sharded across
110-
``tp_group``. This matches the megatron-lm convention. Parameters without
111-
``partition_dim`` are treated as replicated and updated with the plain (non-TP) REKLS step
112-
on each rank — no collectives, full-size ``L``/``R``.
105+
Sharding is configured per-parameter-group via ``partition_dim`` (an int in ``{0, 1}``,
106+
or ``None`` for replicated parameters). Mixed-layout models should use one group per
107+
distinct ``partition_dim``::
108+
109+
optimizer = TpRekls([
110+
{"params": column_parallel_params, "partition_dim": 0},
111+
{"params": row_parallel_params, "partition_dim": 1},
112+
{"params": replicated_params, "partition_dim": None},
113+
], lr=1e-3, tp_group=tp_group)
114+
115+
Groups without ``partition_dim`` use the default (``None`` → replicated, plain non-TP REKLS
116+
step on each rank, no collectives, full-size ``L``/``R``).
113117
"""
114118

115119
def __init__(
@@ -138,29 +142,26 @@ def __init__(
138142
"shampoo_beta": shampoo_beta,
139143
"eps": eps,
140144
"weight_decay": weight_decay,
145+
"partition_dim": None,
141146
}
142147
super().__init__(params, defaults)
143148

144149
@staticmethod
145-
def _get_partition_dim(p: torch.Tensor) -> int | None:
146-
"""Returns ``p.partition_dim`` if set, else ``None`` (param is treated as replicated)."""
147-
partition_dim = getattr(p, "partition_dim", None)
148-
if partition_dim is None:
149-
return None
150-
if partition_dim not in (0, 1):
151-
raise ValueError(f"partition_dim must be 0 or 1, got {partition_dim}")
150+
def _validate_partition_dim(partition_dim: int | None) -> int | None:
151+
if partition_dim is not None and partition_dim not in (0, 1):
152+
raise ValueError(f"partition_dim must be 0, 1, or None, got {partition_dim}")
152153
return partition_dim
153154

154155
@torch.no_grad() # type: ignore[misc]
155156
def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None:
157+
partition_dim = self._validate_partition_dim(group["partition_dim"])
156158
for p in group["params"]:
157159
if skip_non_grad_params and p.grad is None:
158160
continue
159161
if p.dim() != 2:
160162
raise TypeError("TpRekls is only supported for 2D tensors")
161163
state = self.state[p]
162164
if len(state) == 0:
163-
partition_dim = self._get_partition_dim(p)
164165
m, n = p.shape
165166
if partition_dim == 0:
166167
m *= self.tp_size
@@ -193,13 +194,13 @@ def step(self, closure: None = None) -> None:
193194
self._init_group(group)
194195

195196
for group in self.param_groups:
197+
partition_dim = self._validate_partition_dim(group["partition_dim"])
196198
for p in group["params"]:
197199
if p.grad is None:
198200
continue # pragma: no cover
199201

200202
local_grad = p.grad.to(torch.float32)
201203
state = self.state[p]
202-
partition_dim = self._get_partition_dim(p)
203204
curr_iter_1_based = state["step"] + 1
204205

205206
# Apply weight decay before the gather so l2 mode propagates into full_grad.

tests/test_distributed_rekls_cpu.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,15 @@ def test_5steps_matches_non_distributed_rekls(self):
8585
local_data = d.clone()
8686
else:
8787
local_data = d.chunk(self.world_size, dim=pd)[self.rank].contiguous()
88-
local_param = torch.nn.Parameter(local_data)
89-
if pd is not None:
90-
local_param.partition_dim = pd
91-
tp_params.append(local_param)
92-
tp_optimizer = TpRekls(tp_params, lr=1e-3, tp_group=self.tp_group)
88+
tp_params.append(torch.nn.Parameter(local_data))
89+
90+
# One param group per distinct partition_dim — TpRekls reads partition_dim from group.
91+
tp_param_groups: list[dict] = []
92+
for pd in (0, 1, None):
93+
members = [tp_p for tp_p, cfg in zip(tp_params, params_config) if cfg["partition_dim"] == pd]
94+
if members:
95+
tp_param_groups.append({"params": members, "partition_dim": pd})
96+
tp_optimizer = TpRekls(tp_param_groups, lr=1e-3, tp_group=self.tp_group)
9397

9498
num_steps = 5
9599
for _ in range(num_steps):

0 commit comments

Comments
 (0)