Skip to content

Commit 6eee42f

Browse files
committed
add dimension guard
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent 8316597 commit 6eee42f

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

emerging_optimizers/soap/rekls.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ def _init_group(self, group: dict, skip_non_grad_params: bool = True) -> None:
169169
n *= self.tp_size
170170
# When partition_dim is None: param is replicated, m and n are already full.
171171

172+
# Both dimensions must be divisible by tp_size for the L/R shards (each sharded
173+
# along dim 0) to gather back to the full square shape via torch.cat.
174+
if partition_dim is not None and (m % self.tp_size or n % self.tp_size):
175+
raise ValueError(
176+
f"TpRekls requires both dimensions to be divisible by tp_size={self.tp_size}; "
177+
f"got full shape ({m}, {n}) for a parameter with partition_dim={partition_dim}."
178+
)
179+
172180
state["step"] = 0
173181
state["exp_avg"] = torch.zeros((m, n), dtype=torch.float32, device=p.device)
174182
state["exp_avg_sq"] = torch.zeros((m, n), dtype=torch.float32, device=p.device)

0 commit comments

Comments
 (0)