Skip to content

Commit b956af6

Browse files
committed
enhance tests for randint
1 parent 009444b commit b956af6

1 file changed

Lines changed: 210 additions & 5 deletions

File tree

tests/test_randint.py

Lines changed: 210 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import textwrap
1616

17-
import paddle
1817
import pytest
1918
from apibase import APIBase
2019

@@ -141,11 +140,11 @@ def test_case_11():
141140
obj.run(pytorch_code, ["result"], check_value=False)
142141

143142

144-
@pytest.mark.skipif(
145-
condition=not paddle.device.is_compiled_with_cuda(),
146-
reason="can only run on paddle with CUDA",
147-
)
148143
def test_case_12():
144+
import torch
145+
146+
if not torch.cuda.is_available():
147+
pytest.skip("pin_memory=True requires CUDA")
149148
pytorch_code = textwrap.dedent(
150149
"""
151150
import torch
@@ -155,3 +154,209 @@ def test_case_12():
155154
"""
156155
)
157156
obj.run(pytorch_code, ["result"], check_value=False)
157+
158+
159+
# Additional test cases for comprehensive coverage
160+
161+
162+
def test_case_13():
163+
"""Test with size keyword argument explicitly"""
164+
pytorch_code = textwrap.dedent(
165+
"""
166+
import torch
167+
result = torch.randint(0, 10, size=(3, 3))
168+
"""
169+
)
170+
obj.run(pytorch_code, ["result"], check_value=False)
171+
172+
173+
def test_case_14():
174+
"""Test with only high and size as keyword arguments"""
175+
pytorch_code = textwrap.dedent(
176+
"""
177+
import torch
178+
result = torch.randint(high=5, size=(2, 3))
179+
"""
180+
)
181+
obj.run(pytorch_code, ["result"], check_value=False)
182+
183+
184+
def test_case_15():
185+
"""Test with mixed positional and keyword: low positional, high and size as keyword"""
186+
pytorch_code = textwrap.dedent(
187+
"""
188+
import torch
189+
result = torch.randint(1, high=10, size=(2, 2))
190+
"""
191+
)
192+
obj.run(pytorch_code, ["result"], check_value=False)
193+
194+
195+
def test_case_16():
196+
"""Test with dtype=torch.int32 and size keyword"""
197+
pytorch_code = textwrap.dedent(
198+
"""
199+
import torch
200+
result = torch.randint(0, 100, size=(4, 4), dtype=torch.int32)
201+
"""
202+
)
203+
obj.run(pytorch_code, ["result"], check_value=False)
204+
205+
206+
def test_case_17():
207+
"""Test 1D tensor"""
208+
pytorch_code = textwrap.dedent(
209+
"""
210+
import torch
211+
result = torch.randint(low=0, high=10, size=(5,))
212+
"""
213+
)
214+
obj.run(pytorch_code, ["result"], check_value=False)
215+
216+
217+
def test_case_18():
218+
"""Test 3D tensor"""
219+
pytorch_code = textwrap.dedent(
220+
"""
221+
import torch
222+
result = torch.randint(0, 5, (2, 3, 4))
223+
"""
224+
)
225+
obj.run(pytorch_code, ["result"], check_value=False)
226+
227+
228+
def test_case_19():
229+
"""Test with out parameter and size keyword"""
230+
pytorch_code = textwrap.dedent(
231+
"""
232+
import torch
233+
out = torch.empty(3, 3, dtype=torch.int64)
234+
result = torch.randint(0, 10, size=(3, 3), out=out)
235+
"""
236+
)
237+
obj.run(pytorch_code, ["result"], check_value=False)
238+
239+
240+
def test_case_20():
241+
"""Test with expression as high parameter"""
242+
pytorch_code = textwrap.dedent(
243+
"""
244+
import torch
245+
base = 5
246+
result = torch.randint(0, base * 2, (2, 2))
247+
"""
248+
)
249+
obj.run(pytorch_code, ["result"], check_value=False)
250+
251+
252+
def test_case_21():
253+
"""Test with expression as size parameter"""
254+
pytorch_code = textwrap.dedent(
255+
"""
256+
import torch
257+
dim = 2
258+
result = torch.randint(0, 10, (dim + 1, dim + 1))
259+
"""
260+
)
261+
obj.run(pytorch_code, ["result"], check_value=False)
262+
263+
264+
def test_case_22():
265+
"""Test with all keyword arguments in different order"""
266+
pytorch_code = textwrap.dedent(
267+
"""
268+
import torch
269+
result = torch.randint(size=(2, 2), high=10, low=0, dtype=torch.int64)
270+
"""
271+
)
272+
obj.run(pytorch_code, ["result"], check_value=False)
273+
274+
275+
def test_case_23():
276+
"""Test with negative low value"""
277+
pytorch_code = textwrap.dedent(
278+
"""
279+
import torch
280+
result = torch.randint(-10, 10, (3, 3))
281+
"""
282+
)
283+
obj.run(pytorch_code, ["result"], check_value=False)
284+
285+
286+
def test_case_24():
287+
"""Test with large range"""
288+
pytorch_code = textwrap.dedent(
289+
"""
290+
import torch
291+
result = torch.randint(0, 1000000, (2, 2), dtype=torch.int64)
292+
"""
293+
)
294+
obj.run(pytorch_code, ["result"], check_value=False)
295+
296+
297+
def test_case_25():
298+
"""Test with single element tensor"""
299+
pytorch_code = textwrap.dedent(
300+
"""
301+
import torch
302+
result = torch.randint(0, 10, (1,))
303+
"""
304+
)
305+
obj.run(pytorch_code, ["result"], check_value=False)
306+
307+
308+
def test_case_26():
309+
"""Test with 4D tensor"""
310+
pytorch_code = textwrap.dedent(
311+
"""
312+
import torch
313+
result = torch.randint(low=0, high=5, size=(2, 2, 2, 2))
314+
"""
315+
)
316+
obj.run(pytorch_code, ["result"], check_value=False)
317+
318+
319+
def test_case_27():
320+
"""Test with variable unpacking for size"""
321+
pytorch_code = textwrap.dedent(
322+
"""
323+
import torch
324+
shape = (3, 4)
325+
result = torch.randint(0, 10, shape)
326+
"""
327+
)
328+
obj.run(pytorch_code, ["result"], check_value=False)
329+
330+
331+
def test_case_28():
332+
"""Test with only positional arguments: low, high, size"""
333+
pytorch_code = textwrap.dedent(
334+
"""
335+
import torch
336+
result = torch.randint(5, 15, (2, 3))
337+
"""
338+
)
339+
obj.run(pytorch_code, ["result"], check_value=False)
340+
341+
342+
def test_case_29():
343+
"""Test with out parameter as keyword, other as positional"""
344+
pytorch_code = textwrap.dedent(
345+
"""
346+
import torch
347+
out = torch.empty(2, 2, dtype=torch.int64)
348+
result = torch.randint(0, 10, (2, 2), out=out)
349+
"""
350+
)
351+
obj.run(pytorch_code, ["result"], check_value=False)
352+
353+
354+
def test_case_30():
355+
"""Test with dtype as keyword only"""
356+
pytorch_code = textwrap.dedent(
357+
"""
358+
import torch
359+
result = torch.randint(10, (3, 3), dtype=torch.int32)
360+
"""
361+
)
362+
obj.run(pytorch_code, ["result"], check_value=False)

0 commit comments

Comments
 (0)