@@ -415,13 +415,20 @@ def __init__(self,
415415 self .engine_timers_cache = {}
416416
417417 if self .optimizer_name () or self .client_optimizer is not None :
418- if self .optimizer is None :
419- raise RuntimeError ("DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors." )
420- # ZeRO-3 specific check to prevent step 0 deadlocks
421- if self .zero_optimization_stage () == 3 :
422- if not hasattr (self .optimizer , 'step' ):
423- raise AttributeError ("DeepSpeedEngine: ZeRO-3 optimizer is missing core functional attributes (.step). "
424- "This usually indicates a toolchain mismatch or failed JIT kernels." )
418+ if self .optimizer is None :
419+ raise RuntimeError (
420+ "DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors." )
421+
422+ required_methods = ['step' , 'backward' , 'load_state_dict' ]
423+
424+ if self .zero_optimization_partition_gradients ():
425+ required_methods .append ('overlapping_partition_gradients_reduce_epilogue' )
426+
427+ for method in required_methods :
428+ if not hasattr (self .optimizer , method ):
429+ raise AttributeError (
430+ f"DeepSpeedEngine: Optimizer is missing core functional attribute (.{ method } ). "
431+ "This usually indicates a toolchain mismatch or failed JIT kernels." )
425432
426433 if self .global_rank == 0 :
427434 self ._config .print ("DeepSpeedEngine configuration" )
@@ -2422,9 +2429,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
24222429 self .optimizer .is_gradient_accumulation_boundary = self .is_gradient_accumulation_boundary ()
24232430 # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
24242431 if self .zero_optimization_partition_gradients ():
2425- if hasattr (self .optimizer , 'overlapping_partition_gradients_reduce_epilogue' ):
2432+ if hasattr (self .optimizer , 'overlapping_partition_gradients_reduce_epilogue' ):
24262433 self .optimizer .overlapping_partition_gradients_reduce_epilogue ()
2427-
24282434
24292435 # Communicate only at gradient accumulation boundaries
24302436 elif self .is_gradient_accumulation_boundary ():
0 commit comments