@@ -419,7 +419,7 @@ def __init__(self,
419419 raise RuntimeError (
420420 "DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors." )
421421
422- optimizer_methods = ['step' , 'load_state_dict' ]
422+ optimizer_methods = ['step' , 'load_state_dict' , 'backward' ]
423423
424424 if self .zero_optimization_partition_gradients ():
425425 optimizer_methods .append ('overlapping_partition_gradients_reduce_epilogue' )
@@ -432,13 +432,6 @@ def __init__(self,
432432 "This indicates incomplete initialization (e.g., JIT/toolchain failure)."
433433 )
434434
435- # Validate engine separately
436- if not hasattr (self , "backward" ) or not callable (getattr (self , "backward" )):
437- raise RuntimeError (
438- "DeepSpeedEngine initialization failed: missing callable `backward`. "
439- "Engine may be partially initialized."
440- )
441-
442435 if self .global_rank == 0 :
443436 self ._config .print ("DeepSpeedEngine configuration" )
444437 if self .dump_state ():
@@ -2438,8 +2431,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
24382431 self .optimizer .is_gradient_accumulation_boundary = self .is_gradient_accumulation_boundary ()
24392432 # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
24402433 if self .zero_optimization_partition_gradients ():
2441- if hasattr (self .optimizer , 'overlapping_partition_gradients_reduce_epilogue' ):
2442- self .optimizer .overlapping_partition_gradients_reduce_epilogue ()
2434+ self .optimizer .overlapping_partition_gradients_reduce_epilogue ()
24432435
24442436 # Communicate only at gradient accumulation boundaries
24452437 elif self .is_gradient_accumulation_boundary ():
0 commit comments