Skip to content

Commit b2e17ab

Browse files
committed
Fix: Add optimizer initialization validation for backward and ZeRO-3
Signed-off-by: amadhan882 <amadhan882@gmail.com>
1 parent b36d39a commit b2e17ab

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def __init__(self,
419419
raise RuntimeError(
420420
"DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors.")
421421

422-
optimizer_methods = ['step', 'load_state_dict']
422+
optimizer_methods = ['step', 'load_state_dict','backward']
423423

424424
if self.zero_optimization_partition_gradients():
425425
optimizer_methods.append('overlapping_partition_gradients_reduce_epilogue')
@@ -432,13 +432,6 @@ def __init__(self,
432432
"This indicates incomplete initialization (e.g., JIT/toolchain failure)."
433433
)
434434

435-
# Validate engine separately
436-
if not hasattr(self, "backward") or not callable(getattr(self, "backward")):
437-
raise RuntimeError(
438-
"DeepSpeedEngine initialization failed: missing callable `backward`. "
439-
"Engine may be partially initialized."
440-
)
441-
442435
if self.global_rank == 0:
443436
self._config.print("DeepSpeedEngine configuration")
444437
if self.dump_state():
@@ -2438,8 +2431,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
24382431
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
24392432
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
24402433
if self.zero_optimization_partition_gradients():
2441-
if hasattr(self.optimizer, 'overlapping_partition_gradients_reduce_epilogue'):
2442-
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
2434+
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
24432435

24442436
# Communicate only at gradient accumulation boundaries
24452437
elif self.is_gradient_accumulation_boundary():

0 commit comments

Comments
 (0)