Skip to content

Commit ddca910

Browse files
committed
Refactor: Generalize ZeROOptimizer support and extend core APIs for BF16/ZenFlow integration
Signed-off-by: amadhan882 <amadhan882@gmail.com>
1 parent d931be0 commit ddca910

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

deepspeed/runtime/engine.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,20 @@ def __init__(self,
415415
self.engine_timers_cache = {}
416416

417417
if self.optimizer_name() or self.client_optimizer is not None:
418-
if self.optimizer is None:
419-
raise RuntimeError("DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors.")
420-
# ZeRO-3 specific check to prevent step 0 deadlocks
421-
if self.zero_optimization_stage() == 3:
422-
if not hasattr(self.optimizer, 'step'):
423-
raise AttributeError("DeepSpeedEngine: ZeRO-3 optimizer is missing core functional attributes (.step). "
424-
"This usually indicates a toolchain mismatch or failed JIT kernels.")
418+
if self.optimizer is None:
419+
raise RuntimeError(
420+
"DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors.")
421+
422+
required_methods = ['step', 'backward', 'load_state_dict']
423+
424+
if self.zero_optimization_partition_gradients():
425+
required_methods.append('overlapping_partition_gradients_reduce_epilogue')
426+
427+
for method in required_methods:
428+
if not hasattr(self.optimizer, method):
429+
raise AttributeError(
430+
f"DeepSpeedEngine: Optimizer is missing core functional attribute (.{method}). "
431+
"This usually indicates a toolchain mismatch or failed JIT kernels.")
425432

426433
if self.global_rank == 0:
427434
self._config.print("DeepSpeedEngine configuration")
@@ -2422,9 +2429,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
24222429
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
24232430
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
24242431
if self.zero_optimization_partition_gradients():
2425-
if hasattr(self.optimizer, 'overlapping_partition_gradients_reduce_epilogue'):
2432+
if hasattr(self.optimizer, 'overlapping_partition_gradients_reduce_epilogue'):
24262433
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
2427-
24282434

24292435
# Communicate only at gradient accumulation boundaries
24302436
elif self.is_gradient_accumulation_boundary():

0 commit comments

Comments
 (0)