@@ -43,7 +43,7 @@ def test_case_2():
4343
4444
4545def test_case_3 ():
46- """Out parameter with tensor """
46+ """Out parameter with positional and keyword arguments """
4747 pytorch_code = textwrap .dedent (
4848 """
4949 import torch
@@ -56,7 +56,7 @@ def test_case_3():
5656
5757
5858def test_case_4 ():
59- """Out parameter with keyword arguments"""
59+ """Out parameter with all keyword arguments"""
6060 pytorch_code = textwrap .dedent (
6161 """
6262 import torch
@@ -118,18 +118,6 @@ def test_case_8():
118118
119119
120120def test_case_9 ():
121- """Float32 tensor"""
122- pytorch_code = textwrap .dedent (
123- """
124- import torch
125- x = torch.tensor([1.5, -2.5, 3.5], dtype=torch.float32)
126- result = torch.neg(x)
127- """
128- )
129- obj .run (pytorch_code , ["result" ])
130-
131-
132- def test_case_10 ():
133121 """Float64 tensor"""
134122 pytorch_code = textwrap .dedent (
135123 """
@@ -141,20 +129,7 @@ def test_case_10():
141129 obj .run (pytorch_code , ["result" ])
142130
143131
144- def test_case_11 ():
145- """Variable argument"""
146- pytorch_code = textwrap .dedent (
147- """
148- import torch
149- x = torch.tensor([1.0, -2.0, 3.0])
150- result = torch.neg(x)
151- result2 = torch.neg(result)
152- """
153- )
154- obj .run (pytorch_code , ["result2" ])
155-
156-
157- def test_case_12 ():
132+ def test_case_10 ():
158133 """Expression argument"""
159134 pytorch_code = textwrap .dedent (
160135 """
@@ -165,37 +140,15 @@ def test_case_12():
165140 obj .run (pytorch_code , ["result" ])
166141
167142
168- def test_case_13 ():
169- """Zero tensor"""
170- pytorch_code = textwrap .dedent (
171- """
172- import torch
173- x = torch.zeros(3)
174- result = torch.neg(x)
175- """
176- )
177- obj .run (pytorch_code , ["result" ])
178-
179-
180- def test_case_14 ():
181- """Ones tensor (negated)"""
143+ def test_case_11 ():
144+ """Gradient computation"""
182145 pytorch_code = textwrap .dedent (
183146 """
184147 import torch
185- x = torch.ones(3)
186- result = torch.neg(x)
187- """
188- )
189- obj .run (pytorch_code , ["result" ])
190-
191-
192- def test_case_15 ():
193- """Empty parameter specification (default)"""
194- pytorch_code = textwrap .dedent (
148+ x = torch.tensor([1.0, -2.0, 3.0], requires_grad=True)
149+ y = torch.neg(x)
150+ y.sum().backward()
151+ x_grad = x.grad
195152 """
196- import torch
197- x = torch.tensor([1.0, -2.0, 3.0])
198- result = torch.neg(x)
199- """
200153 )
201- obj .run (pytorch_code , ["result" ] )
154+ obj .run (pytorch_code , ["y" , "x_grad" ], check_stop_gradient = False )
0 commit comments