diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index f940997b9..f33911f05 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -10427,26 +10427,7 @@ "Matcher": "ChangePrefixMatcher" }, "torch.randint": { - "Matcher": "RandintMatcher", - "paddle_api": "paddle.randint", - "min_input_args": 2, - "args_list": [ - "low", - "high", - "size", - "*", - "generator", - "out", - "dtype", - "layout", - "device", - "pin_memory", - "requires_grad" - ], - "kwargs_change": { - "size": "shape", - "dtype": "dtype" - } + "Matcher": "ChangePrefixMatcher" }, "torch.randint_like": { "Matcher": "RandintLikeMatcher", diff --git a/tests/test_randint.py b/tests/test_randint.py index 4ddde41cb..d8f5750c2 100644 --- a/tests/test_randint.py +++ b/tests/test_randint.py @@ -155,3 +155,173 @@ def test_case_12(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_13(): + """Test with size keyword argument explicitly""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(0, 10, size=(3, 3)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_14(): + """Test with only high and size as keyword arguments""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(high=5, size=(2, 3)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_15(): + """Test mixed: low positional, high and size as keyword""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(1, high=10, size=(2, 2)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_16(): + """Test with dtype=torch.int32 and size keyword""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(0, 100, size=(4, 4), dtype=torch.int32) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_17(): + """Test 1D tensor with all keyword""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(low=0, high=10, size=(5,)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_18(): + """Test 3D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(0, 5, (2, 3, 4)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_19(): + """Test with out parameter""" + pytorch_code = textwrap.dedent( + """ + import torch + out = torch.empty(3, 3, dtype=torch.int64) + result = torch.randint(0, 10, size=(3, 3), out=out) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_20(): + """Test with expression as high parameter""" + pytorch_code = textwrap.dedent( + """ + import torch + base = 5 + result = torch.randint(0, base * 2, (2, 2)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_21(): + """Test with all keyword arguments in different order""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(size=(2, 2), high=10, low=0, dtype=torch.int64) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_22(): + """Test with negative low value""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(-10, 10, (3, 3)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_23(): + """Test single element tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(0, 10, (1,)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_24(): + """Test 4D tensor with all keyword""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.randint(low=0, high=5, size=(2, 2, 2, 2)) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_25(): + """Test variable shape""" + pytorch_code = textwrap.dedent( + """ + import torch + shape = (3, 4) + result = torch.randint(0, 10, shape) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_26(): + """Test variable args unpacking""" + pytorch_code = textwrap.dedent( + """ + import torch + args = (0, 10, (2, 2)) + result = torch.randint(*args) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_27(): + """Test kwargs dict unpacking""" + pytorch_code = textwrap.dedent( + """ + import torch + kwargs = {'low': 0, 'high': 10, 'size': (3, 3)} + result = torch.randint(**kwargs) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False)