Skip to content

Commit be87a3c

Browse files
Copilotnathon-lee
authored andcommitted
Reapply "fix: update 1 file reformatted."
This reverts commit b90aee5. Signed-off-by: nathon-lee <leejianwoo@gmail.com>
1 parent 6033a0d commit be87a3c

1 file changed

Lines changed: 1 addition & 10 deletions

File tree

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,11 @@ def _enforce_cpu_offload():
284284

285285
self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32
286286

287-
# Check for Muon optimizer usage
288-
self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params'])
289-
290287
if self.reduce_scatter and self.partition_gradients:
291288
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
292289
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
293290
assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
294291
assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
295-
296-
# Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2)
297-
if self.reduce_scatter and self.uses_muon:
298-
assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer."
299292

300293
# param flattened by groups
301294
self.bit16_groups = []
@@ -1224,9 +1217,7 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt
12241217
stream = get_accelerator().current_stream()
12251218

12261219
with get_accelerator().stream(stream):
1227-
# Check if current configuration requires full all-reduce
1228-
if not self.reduce_scatter or any(self.group_uses_muon):
1229-
# Force full all-reduce for Muon parameters or when reduce_scatter is disabled
1220+
if not self.reduce_scatter:
12301221
self.gradient_reduction_w_predivide(tensor, communication_data_type)
12311222
return
12321223

0 commit comments

Comments
 (0)