You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# This function has only a single output, so it gets only one gradient
61
-
@staticmethod
62
-
@autocast_custom_bwd
63
-
defbackward(ctx, grad_output):
64
-
# This is a pattern that is very convenient - at the top of backward
65
-
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
66
-
# None. Thanks to the fact that additional trailing Nones are
67
-
# ignored, the return statement is simple even when the function has
68
-
# optional inputs.
69
-
input, weight, bias=ctx.saved_tensors
70
-
71
-
grad_input=grad_weight=grad_bias=None
72
-
73
-
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
74
-
# These needs_input_grad checks are optional and there only to
75
-
# improve efficiency. If you want to make your code simpler, you can
76
-
# skip them. Returning gradients for inputs that don't require it is
77
-
# not an error.
78
-
dim=grad_output.dim()
79
-
ifctx.needs_input_grad[0]:
80
-
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
# Note that both forward and backward are @staticmethods
95
+
@staticmethod
96
+
@autocast_custom_fwd
97
+
# bias is an optional argument
98
+
defforward(ctx, input, weight, bias=None):
99
+
100
+
ctx.save_for_backward(input, weight, bias)
101
+
102
+
ifinput.dim() ==2andbiasisnotNone:
103
+
# fused op is marginally faster
104
+
ret=torch.addmm(bias, input, weight.t())
95
105
else:
96
-
grad_bias=grad_output.sum(0)
97
-
#print("Done computing grad bias")
98
-
#print("needs bias")
99
-
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
100
-
returngrad_input, grad_weight, grad_bias
106
+
output=input.matmul(weight.t())
107
+
ifbiasisnotNone:
108
+
output+=bias
109
+
ret=output
110
+
111
+
returnret
112
+
113
+
# This function has only a single output, so it gets only one gradient
114
+
@staticmethod
115
+
@autocast_custom_bwd
116
+
defbackward(ctx, grad_output):
117
+
# This is a pattern that is very convenient - at the top of backward
118
+
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
119
+
# None. Thanks to the fact that additional trailing Nones are
120
+
# ignored, the return statement is simple even when the function has
121
+
# optional inputs.
122
+
input, weight, bias=ctx.saved_tensors
123
+
124
+
grad_input=grad_weight=grad_bias=None
125
+
126
+
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
127
+
# These needs_input_grad checks are optional and there only to
128
+
# improve efficiency. If you want to make your code simpler, you can
129
+
# skip them. Returning gradients for inputs that don't require it is
130
+
# not an error.
131
+
dim=grad_output.dim()
132
+
ifctx.needs_input_grad[0]:
133
+
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
144
+
ifbiasisnotNoneandctx.needs_input_grad[2]:
145
+
#print("Computing grad bias")
146
+
ifdim>2:
147
+
grad_bias=grad_output.sum([iforiinrange(dim-1)])
148
+
else:
149
+
grad_bias=grad_output.sum(0)
150
+
#print("Done computing grad bias")
151
+
#print("needs bias")
152
+
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
0 commit comments