Skip to content

Commit 5e989c9

Browse files
committed
update muon_sharding_optimizer with rebuilding 2d_params.
1 parent 72c2189 commit 5e989c9

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

test/collective/fleet/hybrid_parallel_sharding_muon_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ def train_batch(self, batch, model, optimizer):
243243
output = model(batch)
244244
loss = output.mean()
245245
loss.backward()
246-
if isinstance(optimizer, MuonShardingOptimizer):
246+
inner_opt = getattr(optimizer, '_inner_opt', optimizer)
247+
if isinstance(inner_opt, MuonShardingOptimizer):
247248
optimizer.clear_param_storage('test_color')
248-
optimizer.reset_param_storage()
249249
optimizer.step()
250250
optimizer.clear_grad()
251251
return loss

0 commit comments

Comments
 (0)