Skip to content

Commit 5e83d05

Browse files
committed
fix(zero): use setup_context for offload pre/post backward Functions
PyTorch versions that expose autograd.Function.setup_context need the modern forward + setup_context shape for torch.func / functorch. Signed-off-by: Zhang <jianmusings@gmail.com>
1 parent c0b9694 commit 5e83d05

File tree

1 file changed

+99
-40
lines changed

1 file changed

+99
-40
lines changed

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
FWD_MODULE_STACK = list()
2020

21+
# PyTorch >= 2.0: setup_context on autograd.Function is required for torch.func transforms.
22+
# Match deepspeed/runtime/zero/linear.py: keep legacy forward(ctx, ...) when unavailable.
23+
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, "setup_context")
24+
2125

2226
#for each tensor in outputs run the forward_function and register backward_function as hook
2327
def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs):
@@ -401,23 +405,45 @@ def _run_before_backward_function(sub_module):
401405
sub_module.applied_pre_backward_ref_cnt -= 1
402406
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
403407

404-
class PreBackwardFunctionForModule(torch.autograd.Function):
408+
if _SUPPORTS_SETUP_CONTEXT:
409+
410+
class PreBackwardFunctionForModule(torch.autograd.Function):
411+
412+
@staticmethod
413+
def forward(outputs):
414+
return outputs.detach()
405415

406-
@staticmethod
407-
def forward(ctx, outputs):
408-
# Capture `module` and _run_before_backward_function
409-
ctx.module = module
410-
ctx.pre_backward_function = _run_before_backward_function
411-
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
412-
ctx.module.applied_pre_backward_ref_cnt = 0
413-
ctx.module.applied_pre_backward_ref_cnt += 1
414-
outputs = outputs.detach()
415-
return outputs
416+
@staticmethod
417+
def setup_context(ctx, inputs, output):
418+
ctx.module = module
419+
ctx.pre_backward_function = _run_before_backward_function
420+
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
421+
ctx.module.applied_pre_backward_ref_cnt = 0
422+
ctx.module.applied_pre_backward_ref_cnt += 1
423+
424+
@staticmethod
425+
def backward(ctx, *args):
426+
ctx.pre_backward_function(ctx.module)
427+
return args
428+
429+
else:
416430

417-
@staticmethod
418-
def backward(ctx, *args):
419-
ctx.pre_backward_function(ctx.module)
420-
return args
431+
class PreBackwardFunctionForModule(torch.autograd.Function):
432+
433+
@staticmethod
434+
def forward(ctx, outputs):
435+
ctx.module = module
436+
ctx.pre_backward_function = _run_before_backward_function
437+
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
438+
ctx.module.applied_pre_backward_ref_cnt = 0
439+
ctx.module.applied_pre_backward_ref_cnt += 1
440+
outputs = outputs.detach()
441+
return outputs
442+
443+
@staticmethod
444+
def backward(ctx, *args):
445+
ctx.pre_backward_function(ctx.module)
446+
return args
421447

422448
module.pre_bwd_fn = PreBackwardFunctionForModule
423449

@@ -431,31 +457,64 @@ def _run_after_backward_function(sub_module):
431457
if sub_module.ds_grads_remaining == 0:
432458
self.post_sub_module_backward_function(sub_module)
433459

434-
class PostBackwardFunctionModule(torch.autograd.Function):
435-
436-
@staticmethod
437-
def forward(ctx, output):
438-
ctx.module = module
439-
if output.requires_grad:
440-
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
441-
#Should only cause increase in memory not correctness issue
442-
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
443-
# ctx.view=True
444-
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
445-
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
446-
#if module.ds_grads_remaining == 0:
447-
# print(f"Before Forward: {ctx.module.__class__.__name__}")
448-
module.ds_grads_remaining += 1
449-
ctx.post_backward_function = _run_after_backward_function
450-
output = output.detach()
451-
return output
452-
453-
@staticmethod
454-
def backward(ctx, *args):
455-
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
456-
if ctx.module.ds_grads_remaining == 0:
457-
ctx.post_backward_function(ctx.module)
458-
return args
460+
if _SUPPORTS_SETUP_CONTEXT:
461+
462+
class PostBackwardFunctionModule(torch.autograd.Function):
463+
464+
@staticmethod
465+
def forward(output):
466+
return output.detach()
467+
468+
@staticmethod
469+
def setup_context(ctx, inputs, output):
470+
(output_in,) = inputs
471+
ctx.module = module
472+
if output_in.requires_grad:
473+
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
474+
#Should only cause increase in memory not correctness issue
475+
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
476+
# ctx.view=True
477+
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
478+
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
479+
#if module.ds_grads_remaining == 0:
480+
# print(f"Before Forward: {ctx.module.__class__.__name__}")
481+
module.ds_grads_remaining += 1
482+
ctx.post_backward_function = _run_after_backward_function
483+
484+
@staticmethod
485+
def backward(ctx, *args):
486+
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
487+
if ctx.module.ds_grads_remaining == 0:
488+
ctx.post_backward_function(ctx.module)
489+
return args
490+
491+
else:
492+
493+
class PostBackwardFunctionModule(torch.autograd.Function):
494+
495+
@staticmethod
496+
def forward(ctx, output):
497+
ctx.module = module
498+
if output.requires_grad:
499+
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
500+
#Should only cause increase in memory not correctness issue
501+
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
502+
# ctx.view=True
503+
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
504+
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
505+
#if module.ds_grads_remaining == 0:
506+
# print(f"Before Forward: {ctx.module.__class__.__name__}")
507+
module.ds_grads_remaining += 1
508+
ctx.post_backward_function = _run_after_backward_function
509+
output = output.detach()
510+
return output
511+
512+
@staticmethod
513+
def backward(ctx, *args):
514+
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
515+
if ctx.module.ds_grads_remaining == 0:
516+
ctx.post_backward_function(ctx.module)
517+
return args
459518

460519
module.post_bwd_fn = PostBackwardFunctionModule
461520

0 commit comments

Comments
 (0)