Skip to content

Commit e5b6c23

Browse files
author
zkjh
committed
test: generate fallback randint data on cpu
1 parent f44f9af commit e5b6c23

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

tests/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,22 @@ def rand_strided(shape, strides, *, dtype=None, device=None):
7474

7575
def randint_strided(low, high, shape, strides, *, dtype=None, device=None):
7676
output = empty_strided(shape, strides, dtype=dtype, device=device)
77-
78-
output.as_strided(
77+
output_flat = output.as_strided(
7978
(output.untyped_storage().size() // output.element_size(),), (1,)
80-
).random_(low, high)
79+
)
80+
81+
try:
82+
output_flat.random_(low, high)
83+
except RuntimeError as exc:
84+
if "random_" not in str(exc):
85+
raise
86+
87+
cpu_output = empty_strided(shape, strides, dtype=dtype, device="cpu")
88+
cpu_flat = cpu_output.as_strided(
89+
(cpu_output.untyped_storage().size() // cpu_output.element_size(),), (1,)
90+
)
91+
cpu_flat.random_(low, high)
92+
output_flat.copy_(cpu_flat.to(device=output.device))
8193

8294
return output
8395

0 commit comments

Comments
 (0)