Skip to content

Commit b0c9ec5

Browse files
committed
fix Tensor.cuda/requires_grad/ctx.saved_tensors
1 parent 9e25fff commit b0c9ec5

8 files changed

Lines changed: 190 additions & 86 deletions

File tree

paconvert/api_mapping.json

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -818,18 +818,7 @@
818818
"paddle_api": "paddle.Tensor.crows"
819819
},
820820
"torch.Tensor.cuda": {
821-
"Matcher": "Device2IntMatcher",
822-
"paddle_api": "paddle.Tensor.cuda",
823-
"min_input_args": 0,
824-
"args_list": [
825-
"device",
826-
"non_blocking",
827-
"memory_format"
828-
],
829-
"kwargs_change": {
830-
"device": "device_id",
831-
"memory_format": ""
832-
}
821+
"Matcher": "ChangePrefixMatcher"
833822
},
834823
"torch.Tensor.cummax": {
835824
"Matcher": "ChangePrefixMatcher"
@@ -3319,9 +3308,6 @@
33193308
"torch.autograd.function.FunctionCtx.save_for_forward": {
33203309
"min_input_args": 0
33213310
},
3322-
"torch.autograd.function.FunctionCtx.saved_tensors": {
3323-
"Matcher": "ChangePrefixMatcher"
3324-
},
33253311
"torch.autograd.function.FunctionCtx.set_materialize_grads": {
33263312
"Matcher": "ChangePrefixMatcher"
33273313
},

paconvert/converter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727

2828
from paconvert.transformer.basic_transformer import BasicTransformer
2929
from paconvert.transformer.import_transformer import ImportTransformer
30-
from paconvert.transformer.tensor_requires_grad_transformer import (
31-
TensorRequiresGradTransformer,
32-
)
3330
from paconvert.transformer.custom_op_transformer import (
3431
PreCustomOpTransformer,
3532
CustomOpTransformer,
@@ -378,7 +375,6 @@ def transfer_file(self, old_path, new_path):
378375
def transfer_node(self, root, file):
379376
transformers = [
380377
ImportTransformer, # import ast transformer
381-
TensorRequiresGradTransformer, # attribute requires_grad transformer
382378
BasicTransformer, # most of api transformer
383379
PreCustomOpTransformer, # pre process for C++ custom op
384380
CustomOpTransformer, # C++ custom op transformer

tests/code_library/code_case/paddle_code/attribute_paddle_Tensor_stop_gradient.py renamed to tests/code_library/code_case/paddle_code/attribute_paddle_Tensor_requires_grad.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
print("#########################case2#########################")
88
print(data.requires_grad)
99
print("#########################case3#########################")
10-
data.stop_gradient = not False
10+
data.requires_grad = False
1111
print("#########################case4#########################")
1212
requires_grad = data.requires_grad
1313
print("#########################case5#########################")
@@ -25,9 +25,8 @@ def test():
2525
return True
2626

2727

28-
data.stop_gradient = not test()
28+
data.requires_grad = test()
2929
print("#########################case10#########################")
3030
z = True, False, True
31-
a, temp, c = z
32-
data.stop_gradient = not temp
31+
a, data.requires_grad, c = z
3332
print(data.requires_grad)

tests/test_inner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,34 @@ def test_case_4():
7070

7171

7272
# The paddle input does not support integer type
73-
def _test_case_5():
73+
def test_case_5():
7474
pytorch_code = textwrap.dedent(
7575
"""
7676
import torch
7777
result = torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1]))
7878
"""
7979
)
8080
obj.run(pytorch_code, ["result"])
81+
82+
83+
def test_case_6():
84+
pytorch_code = textwrap.dedent(
85+
"""
86+
import torch
87+
out = torch.randn([])
88+
result = torch.inner(input=torch.tensor([1., 2, 3]), other=torch.tensor([0., 2, 1]), out=out)
89+
"""
90+
)
91+
obj.run(pytorch_code, ["result", "out"])
92+
93+
94+
# generated by validate_unittest autofix, based on test_case_6
95+
def test_case_7():
96+
pytorch_code = textwrap.dedent(
97+
"""
98+
import torch
99+
out = torch.randn([])
100+
result = torch.inner(out=out, other=torch.tensor([0., 2, 1]), input=torch.tensor([1., 2, 3]))
101+
"""
102+
)
103+
obj.run(pytorch_code, ["result", "out"])

tests/test_neg.py

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_case_2():
4343

4444

4545
def test_case_3():
46-
"""Out parameter with tensor"""
46+
"""Out parameter with positional and keyword arguments"""
4747
pytorch_code = textwrap.dedent(
4848
"""
4949
import torch
@@ -56,7 +56,7 @@ def test_case_3():
5656

5757

5858
def test_case_4():
59-
"""Out parameter with keyword arguments"""
59+
"""Out parameter with all keyword arguments"""
6060
pytorch_code = textwrap.dedent(
6161
"""
6262
import torch
@@ -118,18 +118,6 @@ def test_case_8():
118118

119119

120120
def test_case_9():
121-
"""Float32 tensor"""
122-
pytorch_code = textwrap.dedent(
123-
"""
124-
import torch
125-
x = torch.tensor([1.5, -2.5, 3.5], dtype=torch.float32)
126-
result = torch.neg(x)
127-
"""
128-
)
129-
obj.run(pytorch_code, ["result"])
130-
131-
132-
def test_case_10():
133121
"""Float64 tensor"""
134122
pytorch_code = textwrap.dedent(
135123
"""
@@ -141,20 +129,7 @@ def test_case_10():
141129
obj.run(pytorch_code, ["result"])
142130

143131

144-
def test_case_11():
145-
"""Variable argument"""
146-
pytorch_code = textwrap.dedent(
147-
"""
148-
import torch
149-
x = torch.tensor([1.0, -2.0, 3.0])
150-
result = torch.neg(x)
151-
result2 = torch.neg(result)
152-
"""
153-
)
154-
obj.run(pytorch_code, ["result2"])
155-
156-
157-
def test_case_12():
132+
def test_case_10():
158133
"""Expression argument"""
159134
pytorch_code = textwrap.dedent(
160135
"""
@@ -165,37 +140,15 @@ def test_case_12():
165140
obj.run(pytorch_code, ["result"])
166141

167142

168-
def test_case_13():
169-
"""Zero tensor"""
170-
pytorch_code = textwrap.dedent(
171-
"""
172-
import torch
173-
x = torch.zeros(3)
174-
result = torch.neg(x)
175-
"""
176-
)
177-
obj.run(pytorch_code, ["result"])
178-
179-
180-
def test_case_14():
181-
"""Ones tensor (negated)"""
143+
def test_case_11():
144+
"""Gradient computation"""
182145
pytorch_code = textwrap.dedent(
183146
"""
184147
import torch
185-
x = torch.ones(3)
186-
result = torch.neg(x)
187-
"""
188-
)
189-
obj.run(pytorch_code, ["result"])
190-
191-
192-
def test_case_15():
193-
"""Empty parameter specification (default)"""
194-
pytorch_code = textwrap.dedent(
148+
x = torch.tensor([1.0, -2.0, 3.0], requires_grad=True)
149+
y = torch.neg(x)
150+
y.sum().backward()
151+
x_grad = x.grad
195152
"""
196-
import torch
197-
x = torch.tensor([1.0, -2.0, 3.0])
198-
result = torch.neg(x)
199-
"""
200153
)
201-
obj.run(pytorch_code, ["result"])
154+
obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False)

tests/test_positive.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
#
1514

1615
import textwrap
1716

@@ -21,6 +20,7 @@
2120

2221

2322
def test_case_1():
23+
"""Basic usage with 1D tensor"""
2424
pytorch_code = textwrap.dedent(
2525
"""
2626
import torch
@@ -32,6 +32,7 @@ def test_case_1():
3232

3333

3434
def test_case_2():
35+
"""2D tensor with positional argument"""
3536
pytorch_code = textwrap.dedent(
3637
"""
3738
import torch
@@ -42,6 +43,7 @@ def test_case_2():
4243

4344

4445
def test_case_3():
46+
"""Keyword argument"""
4547
pytorch_code = textwrap.dedent(
4648
"""
4749
import torch
@@ -50,3 +52,63 @@ def test_case_3():
5052
"""
5153
)
5254
obj.run(pytorch_code, ["result"])
55+
56+
57+
def test_case_4():
58+
"""Keyword argument with expression"""
59+
pytorch_code = textwrap.dedent(
60+
"""
61+
import torch
62+
result = torch.positive(input=torch.tensor([[-4., 1., 1., 16.]]))
63+
"""
64+
)
65+
obj.run(pytorch_code, ["result"])
66+
67+
68+
def test_case_5():
69+
"""Gradient computation"""
70+
pytorch_code = textwrap.dedent(
71+
"""
72+
import torch
73+
x = torch.tensor([1.0, -2.0, 3.0], requires_grad=True)
74+
y = torch.positive(x)
75+
y.sum().backward()
76+
x_grad = x.grad
77+
"""
78+
)
79+
obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False)
80+
81+
82+
def test_case_6():
83+
"""3D tensor"""
84+
pytorch_code = textwrap.dedent(
85+
"""
86+
import torch
87+
x = torch.tensor([[[1.0, -2.0], [3.0, -4.0]], [[5.0, -6.0], [7.0, -8.0]]])
88+
result = torch.positive(x)
89+
"""
90+
)
91+
obj.run(pytorch_code, ["result"])
92+
93+
94+
def test_case_7():
95+
"""Integer tensor"""
96+
pytorch_code = textwrap.dedent(
97+
"""
98+
import torch
99+
x = torch.tensor([1, -2, 3, -4])
100+
result = torch.positive(x)
101+
"""
102+
)
103+
obj.run(pytorch_code, ["result"])
104+
105+
106+
def test_case_8():
107+
"""Expression argument"""
108+
pytorch_code = textwrap.dedent(
109+
"""
110+
import torch
111+
result = torch.positive(torch.tensor([1.0, -2.0, 3.0]) + torch.tensor([0.5, 0.5, 0.5]))
112+
"""
113+
)
114+
obj.run(pytorch_code, ["result"])

tests/test_rad2deg.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,31 @@
2020

2121

2222
def test_case_1():
23+
"""Basic usage with positional argument"""
2324
pytorch_code = textwrap.dedent(
2425
"""
2526
import torch
26-
x = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])
27+
x = torch.tensor([3.142, -3.142, 6.283, -6.283])
2728
result = torch.rad2deg(x)
2829
"""
2930
)
3031
obj.run(pytorch_code, ["result"])
3132

3233

3334
def test_case_2():
35+
"""2D tensor"""
3436
pytorch_code = textwrap.dedent(
3537
"""
3638
import torch
37-
x = torch.tensor([3.142, -3.142, 6.283, -6.283])
39+
x = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])
3840
result = torch.rad2deg(x)
3941
"""
4042
)
4143
obj.run(pytorch_code, ["result"])
4244

4345

4446
def test_case_3():
47+
"""Keyword argument"""
4548
pytorch_code = textwrap.dedent(
4649
"""
4750
import torch
@@ -50,3 +53,52 @@ def test_case_3():
5053
"""
5154
)
5255
obj.run(pytorch_code, ["result"])
56+
57+
58+
def test_case_4():
59+
"""Keyword argument out of order"""
60+
pytorch_code = textwrap.dedent(
61+
"""
62+
import torch
63+
x = torch.tensor([3.142, -3.142, 6.283, -6.283])
64+
result = torch.rad2deg(input=x)
65+
"""
66+
)
67+
obj.run(pytorch_code, ["result"])
68+
69+
70+
def test_case_5():
71+
"""Gradient computation"""
72+
pytorch_code = textwrap.dedent(
73+
"""
74+
import torch
75+
x = torch.tensor([3.142, -3.142], requires_grad=True)
76+
y = torch.rad2deg(x)
77+
y.sum().backward()
78+
x_grad = x.grad
79+
"""
80+
)
81+
obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False)
82+
83+
84+
def test_case_6():
85+
"""Edge case with 3D tensor"""
86+
pytorch_code = textwrap.dedent(
87+
"""
88+
import torch
89+
x = torch.tensor([[[1.57, -1.57], [3.14, -3.14]]])
90+
result = torch.rad2deg(x)
91+
"""
92+
)
93+
obj.run(pytorch_code, ["result"])
94+
95+
96+
def test_case_7():
97+
"""Expression argument"""
98+
pytorch_code = textwrap.dedent(
99+
"""
100+
import torch
101+
result = torch.rad2deg(torch.tensor([1.57, 3.14]) * 1.0)
102+
"""
103+
)
104+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)