Skip to content

Commit 33db7c4

Browse files
committed
fix: fix LinearFunctionForZeroStage3 to support torch.func transforms
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
1 parent 5f7b687 commit 33db7c4

1 file changed

Lines changed: 110 additions & 57 deletions

File tree

deepspeed/runtime/zero/linear.py

Lines changed: 110 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -35,69 +35,122 @@ def print_rank_0(message, debug=False, force=False):
3535
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
3636
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())
3737

38+
# PyTorch >= 2.0 supports setup_context, which is required for
39+
# torch.func transforms (vmap, grad, jvp, jacrev, etc.)
40+
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context')
3841

39-
class LinearFunctionForZeroStage3(torch.autograd.Function):
42+
if _SUPPORTS_SETUP_CONTEXT:
4043

41-
# Note that both forward and backward are @staticmethods
42-
@staticmethod
43-
@autocast_custom_fwd
44-
# bias is an optional argument
45-
def forward(ctx, input, weight, bias=None):
44+
class LinearFunctionForZeroStage3(torch.autograd.Function):
4645

47-
ctx.save_for_backward(input, weight, bias)
46+
@staticmethod
47+
@autocast_custom_fwd
48+
def forward(input, weight, bias=None):
4849

49-
if input.dim() == 2 and bias is not None:
50-
# fused op is marginally faster
51-
ret = torch.addmm(bias, input, weight.t())
52-
else:
53-
output = input.matmul(weight.t())
54-
if bias is not None:
55-
output += bias
56-
ret = output
57-
58-
return ret
59-
60-
# This function has only a single output, so it gets only one gradient
61-
@staticmethod
62-
@autocast_custom_bwd
63-
def backward(ctx, grad_output):
64-
# This is a pattern that is very convenient - at the top of backward
65-
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
66-
# None. Thanks to the fact that additional trailing Nones are
67-
# ignored, the return statement is simple even when the function has
68-
# optional inputs.
69-
input, weight, bias = ctx.saved_tensors
70-
71-
grad_input = grad_weight = grad_bias = None
72-
73-
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
74-
# These needs_input_grad checks are optional and there only to
75-
# improve efficiency. If you want to make your code simpler, you can
76-
# skip them. Returning gradients for inputs that don't require it is
77-
# not an error.
78-
dim = grad_output.dim()
79-
if ctx.needs_input_grad[0]:
80-
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
81-
grad_input = grad_output.matmul(weight)
82-
#print(f"Computed grad input {grad_input.shape}")
83-
if ctx.needs_input_grad[1]:
84-
#print("Computing grad weight")
85-
if dim > 2:
86-
grad_weight = grad_output.reshape(-1,
87-
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
50+
if input.dim() == 2 and bias is not None:
51+
# fused op is marginally faster
52+
ret = torch.addmm(bias, input, weight.t())
8853
else:
89-
grad_weight = grad_output.t().matmul(input)
90-
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
91-
if bias is not None and ctx.needs_input_grad[2]:
92-
#print("Computing grad bias")
93-
if dim > 2:
94-
grad_bias = grad_output.sum([i for i in range(dim - 1)])
54+
output = input.matmul(weight.t())
55+
if bias is not None:
56+
output += bias
57+
ret = output
58+
59+
return ret
60+
61+
@staticmethod
62+
def setup_context(ctx, inputs, output):
63+
input, weight, bias = inputs
64+
ctx.save_for_backward(input, weight, bias)
65+
66+
# This function has only a single output, so it gets only one gradient
67+
@staticmethod
68+
@autocast_custom_bwd
69+
def backward(ctx, grad_output):
70+
input, weight, bias = ctx.saved_tensors
71+
72+
grad_input = grad_weight = grad_bias = None
73+
74+
dim = grad_output.dim()
75+
if ctx.needs_input_grad[0]:
76+
grad_input = grad_output.matmul(weight)
77+
if ctx.needs_input_grad[1]:
78+
if dim > 2:
79+
grad_weight = grad_output.reshape(-1,
80+
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
81+
else:
82+
grad_weight = grad_output.t().matmul(input)
83+
if bias is not None and ctx.needs_input_grad[2]:
84+
if dim > 2:
85+
grad_bias = grad_output.sum([i for i in range(dim - 1)])
86+
else:
87+
grad_bias = grad_output.sum(0)
88+
return grad_input, grad_weight, grad_bias
89+
90+
else:
91+
92+
class LinearFunctionForZeroStage3(torch.autograd.Function):
93+
94+
# Note that both forward and backward are @staticmethods
95+
@staticmethod
96+
@autocast_custom_fwd
97+
# bias is an optional argument
98+
def forward(ctx, input, weight, bias=None):
99+
100+
ctx.save_for_backward(input, weight, bias)
101+
102+
if input.dim() == 2 and bias is not None:
103+
# fused op is marginally faster
104+
ret = torch.addmm(bias, input, weight.t())
95105
else:
96-
grad_bias = grad_output.sum(0)
97-
#print("Done computing grad bias")
98-
#print("needs bias")
99-
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
100-
return grad_input, grad_weight, grad_bias
106+
output = input.matmul(weight.t())
107+
if bias is not None:
108+
output += bias
109+
ret = output
110+
111+
return ret
112+
113+
# This function has only a single output, so it gets only one gradient
114+
@staticmethod
115+
@autocast_custom_bwd
116+
def backward(ctx, grad_output):
117+
# This is a pattern that is very convenient - at the top of backward
118+
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
119+
# None. Thanks to the fact that additional trailing Nones are
120+
# ignored, the return statement is simple even when the function has
121+
# optional inputs.
122+
input, weight, bias = ctx.saved_tensors
123+
124+
grad_input = grad_weight = grad_bias = None
125+
126+
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
127+
# These needs_input_grad checks are optional and there only to
128+
# improve efficiency. If you want to make your code simpler, you can
129+
# skip them. Returning gradients for inputs that don't require it is
130+
# not an error.
131+
dim = grad_output.dim()
132+
if ctx.needs_input_grad[0]:
133+
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
134+
grad_input = grad_output.matmul(weight)
135+
#print(f"Computed grad input {grad_input.shape}")
136+
if ctx.needs_input_grad[1]:
137+
#print("Computing grad weight")
138+
if dim > 2:
139+
grad_weight = grad_output.reshape(-1,
140+
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
141+
else:
142+
grad_weight = grad_output.t().matmul(input)
143+
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
144+
if bias is not None and ctx.needs_input_grad[2]:
145+
#print("Computing grad bias")
146+
if dim > 2:
147+
grad_bias = grad_output.sum([i for i in range(dim - 1)])
148+
else:
149+
grad_bias = grad_output.sum(0)
150+
#print("Done computing grad bias")
151+
#print("needs bias")
152+
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
153+
return grad_input, grad_weight, grad_bias
101154

102155

103156
def zero3_linear_wrap(input, weight, bias=None):

0 commit comments

Comments
 (0)