Skip to content

Commit a6ff7cd

Browse files
authored
release memory in muon optimizer (#78856)
1 parent f2d2763 commit a6ff7cd

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,16 @@ def clear_param_storage(self, color):
549549
comm_buffer._clear_param_storage()
550550
# 2D params
551551
if color in self._params_2d_by_color.keys():
552-
for param in self._params_2d_by_color[color]:
552+
group_info = self._color_to_group_info[color]
553+
sharding_rank = max(group_info["rank"], 0)
554+
rank2params_2d_by_color = self._rank2params_2d_by_color[color]
555+
local_2d = rank2params_2d_by_color[sharding_rank]
556+
for param in local_2d:
553557
if not g_shard_bypass_dygraph_optimizer:
554558
self._create_master_weight(param)
555559

556-
if param.name in self._master_weights:
557-
param._clear_to_zero_allocation()
560+
for param in self._params_2d_by_color[color]:
561+
param._clear_to_zero_allocation()
558562

559563
def reset_param_storage(self):
560564
for color in self.clear_color:

0 commit comments

Comments
 (0)