1818
1919FWD_MODULE_STACK = list ()
2020
21+ # PyTorch >= 2.0: setup_context on autograd.Function is required for torch.func transforms.
22+ # Match deepspeed/runtime/zero/linear.py: keep legacy forward(ctx, ...) when unavailable.
23+ _SUPPORTS_SETUP_CONTEXT = hasattr (torch .autograd .Function , "setup_context" )
24+
2125
2226#for each tensor in outputs run the forward_function and register backward_function as hook
2327def _apply_forward_and_backward_to_tensors_only (module , forward_function , backward_function , outputs ):
@@ -401,23 +405,45 @@ def _run_before_backward_function(sub_module):
401405 sub_module .applied_pre_backward_ref_cnt -= 1
402406 #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
403407
404- class PreBackwardFunctionForModule (torch .autograd .Function ):
408+ if _SUPPORTS_SETUP_CONTEXT :
409+
410+ class PreBackwardFunctionForModule (torch .autograd .Function ):
411+
412+ @staticmethod
413+ def forward (outputs ):
414+ return outputs .detach ()
405415
406- @staticmethod
407- def forward (ctx , outputs ):
408- # Capture `module` and _run_before_backward_function
409- ctx .module = module
410- ctx .pre_backward_function = _run_before_backward_function
411- if not hasattr (ctx .module , "applied_pre_backward_ref_cnt" ):
412- ctx .module .applied_pre_backward_ref_cnt = 0
413- ctx .module .applied_pre_backward_ref_cnt += 1
414- outputs = outputs .detach ()
415- return outputs
416+ @staticmethod
417+ def setup_context (ctx , inputs , output ):
418+ ctx .module = module
419+ ctx .pre_backward_function = _run_before_backward_function
420+ if not hasattr (ctx .module , "applied_pre_backward_ref_cnt" ):
421+ ctx .module .applied_pre_backward_ref_cnt = 0
422+ ctx .module .applied_pre_backward_ref_cnt += 1
423+
424+ @staticmethod
425+ def backward (ctx , * args ):
426+ ctx .pre_backward_function (ctx .module )
427+ return args
428+
429+ else :
416430
417- @staticmethod
418- def backward (ctx , * args ):
419- ctx .pre_backward_function (ctx .module )
420- return args
431+ class PreBackwardFunctionForModule (torch .autograd .Function ):
432+
433+ @staticmethod
434+ def forward (ctx , outputs ):
435+ ctx .module = module
436+ ctx .pre_backward_function = _run_before_backward_function
437+ if not hasattr (ctx .module , "applied_pre_backward_ref_cnt" ):
438+ ctx .module .applied_pre_backward_ref_cnt = 0
439+ ctx .module .applied_pre_backward_ref_cnt += 1
440+ outputs = outputs .detach ()
441+ return outputs
442+
443+ @staticmethod
444+ def backward (ctx , * args ):
445+ ctx .pre_backward_function (ctx .module )
446+ return args
421447
422448 module .pre_bwd_fn = PreBackwardFunctionForModule
423449
@@ -431,31 +457,64 @@ def _run_after_backward_function(sub_module):
431457 if sub_module .ds_grads_remaining == 0 :
432458 self .post_sub_module_backward_function (sub_module )
433459
434- class PostBackwardFunctionModule (torch .autograd .Function ):
435-
436- @staticmethod
437- def forward (ctx , output ):
438- ctx .module = module
439- if output .requires_grad :
440- #TODO SOME TIMES post backward does not seem to be triggered debug in detail
441- #Should only cause increase in memory not correctness issue
442- #if output.grad_fn.__class__.__name__ == 'ViewBackward':
443- # ctx.view=True
444- # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
445- #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
446- #if module.ds_grads_remaining == 0:
447- # print(f"Before Forward: {ctx.module.__class__.__name__}")
448- module .ds_grads_remaining += 1
449- ctx .post_backward_function = _run_after_backward_function
450- output = output .detach ()
451- return output
452-
453- @staticmethod
454- def backward (ctx , * args ):
455- ctx .module .ds_grads_remaining = ctx .module .ds_grads_remaining - 1
456- if ctx .module .ds_grads_remaining == 0 :
457- ctx .post_backward_function (ctx .module )
458- return args
460+ if _SUPPORTS_SETUP_CONTEXT :
461+
462+ class PostBackwardFunctionModule (torch .autograd .Function ):
463+
464+ @staticmethod
465+ def forward (output ):
466+ return output .detach ()
467+
468+ @staticmethod
469+ def setup_context (ctx , inputs , output ):
470+ (output_in ,) = inputs
471+ ctx .module = module
472+ if output_in .requires_grad :
473+ #TODO SOME TIMES post backward does not seem to be triggered debug in detail
474+ #Should only cause increase in memory not correctness issue
475+ #if output.grad_fn.__class__.__name__ == 'ViewBackward':
476+ # ctx.view=True
477+ # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
478+ #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
479+ #if module.ds_grads_remaining == 0:
480+ # print(f"Before Forward: {ctx.module.__class__.__name__}")
481+ module .ds_grads_remaining += 1
482+ ctx .post_backward_function = _run_after_backward_function
483+
484+ @staticmethod
485+ def backward (ctx , * args ):
486+ ctx .module .ds_grads_remaining = ctx .module .ds_grads_remaining - 1
487+ if ctx .module .ds_grads_remaining == 0 :
488+ ctx .post_backward_function (ctx .module )
489+ return args
490+
491+ else :
492+
493+ class PostBackwardFunctionModule (torch .autograd .Function ):
494+
495+ @staticmethod
496+ def forward (ctx , output ):
497+ ctx .module = module
498+ if output .requires_grad :
499+ #TODO SOME TIMES post backward does not seem to be triggered debug in detail
500+ #Should only cause increase in memory not correctness issue
501+ #if output.grad_fn.__class__.__name__ == 'ViewBackward':
502+ # ctx.view=True
503+ # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
504+ #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
505+ #if module.ds_grads_remaining == 0:
506+ # print(f"Before Forward: {ctx.module.__class__.__name__}")
507+ module .ds_grads_remaining += 1
508+ ctx .post_backward_function = _run_after_backward_function
509+ output = output .detach ()
510+ return output
511+
512+ @staticmethod
513+ def backward (ctx , * args ):
514+ ctx .module .ds_grads_remaining = ctx .module .ds_grads_remaining - 1
515+ if ctx .module .ds_grads_remaining == 0 :
516+ ctx .post_backward_function (ctx .module )
517+ return args
459518
460519 module .post_bwd_fn = PostBackwardFunctionModule
461520
0 commit comments