@@ -90,9 +90,6 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer):
9090 rotated, matching SOAP's eigh path.
9191 - ``L``, ``R``: kronecker factor matrices, sharded along dimension 0 across ``tp_group``.
9292
93- Each step issues exactly one collective: an all-gather of the local gradient and ``L``/``R`` shards
94- via :func:`~emerging_optimizers.soap.tp_utils.all_gather_grad_and_kronecker_factors_tp`.
95-
9693 Args:
9794 params: Iterable of parameters to optimize or dicts defining parameter groups.
9895 lr: Learning rate.
@@ -105,11 +102,18 @@ class TpRekls(opt_mixin.WeightDecayMixin, optim.Optimizer):
105102 fp32_matmul_prec: Precision for the optimizer-state GEMM operations.
106103
107104 Note:
108- A parameter is treated as tensor-parallel iff it carries a ``partition_dim`` attribute
109- (an int in ``{0, 1}``) describing the dimension along which it is sharded across
110- ``tp_group``. This matches the megatron-lm convention. Parameters without
111- ``partition_dim`` are treated as replicated and updated with the plain (non-TP) REKLS step
112- on each rank — no collectives, full-size ``L``/``R``.
105+ Sharding is configured per-parameter-group via ``partition_dim`` (an int in ``{0, 1}``,
106+ or ``None`` for replicated parameters). Mixed-layout models should use one group per
107+ distinct ``partition_dim``::
108+
109+ optimizer = TpRekls([
110+ {"params": column_parallel_params, "partition_dim": 0},
111+ {"params": row_parallel_params, "partition_dim": 1},
112+ {"params": replicated_params, "partition_dim": None},
113+ ], lr=1e-3, tp_group=tp_group)
114+
115+ Groups without ``partition_dim`` use the default (``None`` → replicated, plain non-TP REKLS
116+ step on each rank, no collectives, full-size ``L``/``R``).
113117 """
114118
115119 def __init__ (
@@ -138,29 +142,26 @@ def __init__(
138142 "shampoo_beta" : shampoo_beta ,
139143 "eps" : eps ,
140144 "weight_decay" : weight_decay ,
145+ "partition_dim" : None ,
141146 }
142147 super ().__init__ (params , defaults )
143148
144149 @staticmethod
145- def _get_partition_dim (p : torch .Tensor ) -> int | None :
146- """Returns ``p.partition_dim`` if set, else ``None`` (param is treated as replicated)."""
147- partition_dim = getattr (p , "partition_dim" , None )
148- if partition_dim is None :
149- return None
150- if partition_dim not in (0 , 1 ):
151- raise ValueError (f"partition_dim must be 0 or 1, got { partition_dim } " )
150+ def _validate_partition_dim (partition_dim : int | None ) -> int | None :
151+ if partition_dim is not None and partition_dim not in (0 , 1 ):
152+ raise ValueError (f"partition_dim must be 0, 1, or None, got { partition_dim } " )
152153 return partition_dim
153154
154155 @torch .no_grad () # type: ignore[misc]
155156 def _init_group (self , group : dict , skip_non_grad_params : bool = True ) -> None :
157+ partition_dim = self ._validate_partition_dim (group ["partition_dim" ])
156158 for p in group ["params" ]:
157159 if skip_non_grad_params and p .grad is None :
158160 continue
159161 if p .dim () != 2 :
160162 raise TypeError ("TpRekls is only supported for 2D tensors" )
161163 state = self .state [p ]
162164 if len (state ) == 0 :
163- partition_dim = self ._get_partition_dim (p )
164165 m , n = p .shape
165166 if partition_dim == 0 :
166167 m *= self .tp_size
@@ -193,13 +194,13 @@ def step(self, closure: None = None) -> None:
193194 self ._init_group (group )
194195
195196 for group in self .param_groups :
197+ partition_dim = self ._validate_partition_dim (group ["partition_dim" ])
196198 for p in group ["params" ]:
197199 if p .grad is None :
198200 continue # pragma: no cover
199201
200202 local_grad = p .grad .to (torch .float32 )
201203 state = self .state [p ]
202- partition_dim = self ._get_partition_dim (p )
203204 curr_iter_1_based = state ["step" ] + 1
204205
205206 # Apply weight decay before the gather so l2 mode propagates into full_grad.
0 commit comments