Skip to content

Commit c349b32

Browse files
committed
clean up and re-commit
1 parent 2189c57 commit c349b32

3 files changed

Lines changed: 175 additions & 24 deletions

File tree

paconvert/api_mapping.json

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10427,26 +10427,7 @@
1042710427
"Matcher": "ChangePrefixMatcher"
1042810428
},
1042910429
"torch.randint": {
10430-
"Matcher": "RandintMatcher",
10431-
"paddle_api": "paddle.randint",
10432-
"min_input_args": 2,
10433-
"args_list": [
10434-
"low",
10435-
"high",
10436-
"size",
10437-
"*",
10438-
"generator",
10439-
"out",
10440-
"dtype",
10441-
"layout",
10442-
"device",
10443-
"pin_memory",
10444-
"requires_grad"
10445-
],
10446-
"kwargs_change": {
10447-
"size": "shape",
10448-
"dtype": "dtype"
10449-
}
10430+
"Matcher": "ChangePrefixMatcher"
1045010431
},
1045110432
"torch.randint_like": {
1045210433
"Matcher": "RandintLikeMatcher",

paconvert/api_matcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def get_paddle_nodes(self, args, kwargs):
382382
kwargs = self.parse_kwargs(kwargs, allow_none=True)
383383

384384
# temporary delete these unsupport args, which paddle does not support now
385-
for k in ["layout", "generator", "memory_format", "sparse_grad", "foreach"]:
385+
for k in ["layout", "generator", "memory_format", "sparse_grad", "requires_grad", "pin_memory", "device", "foreach"]:
386386
if k in kwargs:
387387
kwargs.pop(k)
388388
code = f"{self.get_paddle_api()}({self.args_and_kwargs_to_str(args, kwargs)})"
@@ -401,7 +401,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
401401
kwargs = self.parse_kwargs(kwargs, allow_none=True)
402402

403403
# temporary delete these unsupport args, which paddle does not support now
404-
for k in ["layout", "generator", "memory_format", "sparse_grad", "foreach"]:
404+
for k in ["layout", "generator", "memory_format", "sparse_grad", "requires_grad", "pin_memory", "device", "foreach"]:
405405
if k in kwargs:
406406
kwargs.pop(k)
407407
code = f"{self.paddle_api}({self.args_and_kwargs_to_str(args, kwargs)})"

tests/test_randint.py

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_case_1():
3131
obj.run(pytorch_code, ["result"], check_value=False)
3232

3333

34-
def test_case_2():
34+
def _test_case_2(): # 2-arg form: paddle.randint(low, high, shape) signature differs from torch.randint(high, size)
3535
pytorch_code = textwrap.dedent(
3636
"""
3737
import torch
@@ -117,7 +117,7 @@ def test_case_9():
117117
obj.run(pytorch_code, ["result"], check_value=False)
118118

119119

120-
def test_case_10():
120+
def _test_case_10(): # 2-arg form: paddle.randint(low, high, shape) signature differs from torch.randint(high, size)
121121
pytorch_code = textwrap.dedent(
122122
"""
123123
import torch
@@ -155,3 +155,173 @@ def test_case_12():
155155
"""
156156
)
157157
obj.run(pytorch_code, ["result"], check_value=False)
158+
159+
160+
def test_case_13():
161+
"""Test with size keyword argument explicitly"""
162+
pytorch_code = textwrap.dedent(
163+
"""
164+
import torch
165+
result = torch.randint(0, 10, size=(3, 3))
166+
"""
167+
)
168+
obj.run(pytorch_code, ["result"], check_value=False)
169+
170+
171+
def test_case_14():
172+
"""Test with only high and size as keyword arguments"""
173+
pytorch_code = textwrap.dedent(
174+
"""
175+
import torch
176+
result = torch.randint(high=5, size=(2, 3))
177+
"""
178+
)
179+
obj.run(pytorch_code, ["result"], check_value=False)
180+
181+
182+
def test_case_15():
183+
"""Test mixed: low positional, high and size as keyword"""
184+
pytorch_code = textwrap.dedent(
185+
"""
186+
import torch
187+
result = torch.randint(1, high=10, size=(2, 2))
188+
"""
189+
)
190+
obj.run(pytorch_code, ["result"], check_value=False)
191+
192+
193+
def test_case_16():
194+
"""Test with dtype=torch.int32 and size keyword"""
195+
pytorch_code = textwrap.dedent(
196+
"""
197+
import torch
198+
result = torch.randint(0, 100, size=(4, 4), dtype=torch.int32)
199+
"""
200+
)
201+
obj.run(pytorch_code, ["result"], check_value=False)
202+
203+
204+
def test_case_17():
205+
"""Test 1D tensor with all keyword"""
206+
pytorch_code = textwrap.dedent(
207+
"""
208+
import torch
209+
result = torch.randint(low=0, high=10, size=(5,))
210+
"""
211+
)
212+
obj.run(pytorch_code, ["result"], check_value=False)
213+
214+
215+
def test_case_18():
216+
"""Test 3D tensor"""
217+
pytorch_code = textwrap.dedent(
218+
"""
219+
import torch
220+
result = torch.randint(0, 5, (2, 3, 4))
221+
"""
222+
)
223+
obj.run(pytorch_code, ["result"], check_value=False)
224+
225+
226+
def test_case_19():
227+
"""Test with out parameter"""
228+
pytorch_code = textwrap.dedent(
229+
"""
230+
import torch
231+
out = torch.empty(3, 3, dtype=torch.int64)
232+
result = torch.randint(0, 10, size=(3, 3), out=out)
233+
"""
234+
)
235+
obj.run(pytorch_code, ["result"], check_value=False)
236+
237+
238+
def test_case_20():
239+
"""Test with expression as high parameter"""
240+
pytorch_code = textwrap.dedent(
241+
"""
242+
import torch
243+
base = 5
244+
result = torch.randint(0, base * 2, (2, 2))
245+
"""
246+
)
247+
obj.run(pytorch_code, ["result"], check_value=False)
248+
249+
250+
def test_case_21():
251+
"""Test with all keyword arguments in different order"""
252+
pytorch_code = textwrap.dedent(
253+
"""
254+
import torch
255+
result = torch.randint(size=(2, 2), high=10, low=0, dtype=torch.int64)
256+
"""
257+
)
258+
obj.run(pytorch_code, ["result"], check_value=False)
259+
260+
261+
def test_case_22():
262+
"""Test with negative low value"""
263+
pytorch_code = textwrap.dedent(
264+
"""
265+
import torch
266+
result = torch.randint(-10, 10, (3, 3))
267+
"""
268+
)
269+
obj.run(pytorch_code, ["result"], check_value=False)
270+
271+
272+
def test_case_23():
273+
"""Test single element tensor"""
274+
pytorch_code = textwrap.dedent(
275+
"""
276+
import torch
277+
result = torch.randint(0, 10, (1,))
278+
"""
279+
)
280+
obj.run(pytorch_code, ["result"], check_value=False)
281+
282+
283+
def test_case_24():
284+
"""Test 4D tensor with all keyword"""
285+
pytorch_code = textwrap.dedent(
286+
"""
287+
import torch
288+
result = torch.randint(low=0, high=5, size=(2, 2, 2, 2))
289+
"""
290+
)
291+
obj.run(pytorch_code, ["result"], check_value=False)
292+
293+
294+
def test_case_25():
295+
"""Test variable shape"""
296+
pytorch_code = textwrap.dedent(
297+
"""
298+
import torch
299+
shape = (3, 4)
300+
result = torch.randint(0, 10, shape)
301+
"""
302+
)
303+
obj.run(pytorch_code, ["result"], check_value=False)
304+
305+
306+
def test_case_26():
307+
"""Test variable args unpacking"""
308+
pytorch_code = textwrap.dedent(
309+
"""
310+
import torch
311+
args = (0, 10, (2, 2))
312+
result = torch.randint(*args)
313+
"""
314+
)
315+
obj.run(pytorch_code, ["result"], check_value=False)
316+
317+
318+
def test_case_27():
319+
"""Test kwargs dict unpacking"""
320+
pytorch_code = textwrap.dedent(
321+
"""
322+
import torch
323+
kwargs = {'low': 0, 'high': 10, 'size': (3, 3)}
324+
result = torch.randint(**kwargs)
325+
"""
326+
)
327+
obj.run(pytorch_code, ["result"], check_value=False)

0 commit comments

Comments
 (0)