Skip to content

Commit 8f8a1f9

Browse files
committed
Update tests for random.choice
1 parent 7ee68fd commit 8f8a1f9

1 file changed

Lines changed: 160 additions & 0 deletions

File tree

dpnp/tests/third_party/cupy/random_tests/test_generator.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import functools
24
import os
35
import threading
@@ -850,6 +852,7 @@ def test_goodness_of_fit(self):
850852
assert _hypothesis.chi_square_test(counts, expected)
851853

852854
@_condition.repeat(3, 10)
855+
# @pytest.mark.xfail(runtime.is_hip, reason="ROCm/HIP may have a bug")
853856
def test_goodness_of_fit_2(self):
854857
vals = self.generate(3, (5, 20), True, [0.3, 0.3, 0.4]).get()
855858
counts = numpy.histogram(vals, bins=numpy.arange(4))[0]
@@ -929,6 +932,163 @@ def test_bound(self):
929932
assert numpy.unique(val).size == val.size
930933

931934

935+
@testing.parameterize(
936+
# Edge cases with small domain sizes
937+
{"a": 0, "size": 0},
938+
{"a": 1, "size": 1},
939+
{"a": 2, "size": 1},
940+
{"a": 256, "size": 100}, # Minimum cipher bits threshold
941+
{"a": 257, "size": 100},
942+
# large scalare uniqueness
943+
{"a": 100, "size": 50},
944+
{"a": 1000, "size": 500},
945+
{"a": 10000, "size": 5000},
946+
{"a": 100000, "size": 50000},
947+
# full inpupt permutation
948+
{"a": 10, "size": 10},
949+
{"a": 100, "size": 100},
950+
{"a": 1000, "size": 1000},
951+
# Power of 2
952+
{"a": 2**8, "size": 100},
953+
{"a": 2**10, "size": 500},
954+
{"a": 2**16, "size": 1000},
955+
{"a": 2**20, "size": 5000},
956+
{"a": 2**24, "size": 10000},
957+
# Just below power of 2
958+
{"a": 2**8 - 1, "size": 100},
959+
{"a": 2**16 - 1, "size": 1000},
960+
{"a": 2**20 - 1, "size": 5000},
961+
# Just above power of 2
962+
{"a": 2**8 + 1, "size": 100},
963+
{"a": 2**16 + 1, "size": 1000},
964+
{"a": 2**20 + 1, "size": 5000},
965+
# Test multi-dimensional shapes.
966+
{"a": 6, "size": (2, 3)},
967+
{"a": 32, "size": (4, 5)},
968+
{"a": 120, "size": (5, 4, 5)},
969+
)
970+
@testing.fix_random()
971+
class TestChoiceReplaceFalseLargeScale(RandomGeneratorTestCase):
972+
"""Test large-scale uniqueness for Feistel bijection implementation."""
973+
974+
target_method = "choice"
975+
976+
def test_uniqueness_and_bounds(self):
977+
"""Test that samples have no duplicates and correct bounds."""
978+
val = self.generate(a=self.a, size=self.size, replace=False).get()
979+
size = self.size if isinstance(self.size, tuple) else (self.size,)
980+
981+
# Check shape
982+
assert val.shape == size
983+
984+
# Check bounds
985+
assert (0 <= val).all()
986+
assert (val < self.a).all()
987+
988+
# Check uniqueness
989+
val_flat = numpy.asarray(val).flatten()
990+
assert (
991+
numpy.unique(val_flat).size == val_flat.size
992+
), "Found duplicate values in replace=False sample"
993+
994+
995+
@testing.fix_random()
996+
class TestChoiceReplaceFalseStatistical(RandomGeneratorTestCase):
997+
"""Statistical tests for uniformity of Feistel bijection."""
998+
999+
target_method = "choice"
1000+
1001+
@_condition.repeat(3)
1002+
def test_small_domain_uniformity(self):
1003+
"""Chi-square test for uniform sampling in small domain."""
1004+
# Sample from domain of size 10, taking 5 elements
1005+
# Repeat many times and check each index appears uniformly
1006+
n = 10
1007+
sample_size = 5
1008+
n_trials = 1000
1009+
1010+
counts = cupy.zeros(n, dtype=int)
1011+
vals = self.generate_many(
1012+
n, size=sample_size, replace=False, _count=n_trials
1013+
)
1014+
for val in vals:
1015+
counts[val] += 1
1016+
counts = counts.get()
1017+
1018+
# Each index should appear ~500 times (5/10 * 1000)
1019+
expected = numpy.ones(n, dtype=int) * (sample_size * n_trials // n)
1020+
assert _hypothesis.chi_square_test(counts, expected)
1021+
1022+
@_condition.repeat(3, 10)
1023+
def test_permutation_variability(self):
1024+
"""Test that repeated full permutations are different."""
1025+
n = 20
1026+
n_trials = 10
1027+
1028+
vals = self.generate_many(n, size=n, replace=False, _count=n_trials)
1029+
perms = cupy.vstack(vals)
1030+
1031+
# Should have multiple unique permutations
1032+
unique_perms = cupy.unique(perms, axis=0)
1033+
assert (
1034+
len(unique_perms) == n_trials
1035+
), "Permutations should vary across multiple calls"
1036+
1037+
1038+
@testing.slow
1039+
@testing.fix_random()
1040+
class TestChoiceReplaceFalseVeryLargeDomain(unittest.TestCase):
1041+
"""Test memory efficiency with very large domains."""
1042+
1043+
def setUp(self):
1044+
self.rs = _generator.RandomState(seed=testing.generate_seed())
1045+
1046+
def test_large_domain_memory_efficiency(self):
1047+
"""Test that very large domains don't allocate full arrays."""
1048+
# This should NOT allocate a 2^30 element array
1049+
# If it did, it would require ~8GB of memory
1050+
a = 2**30
1051+
size = 1000
1052+
1053+
val = self.rs.choice(a=a, size=size, replace=False).get()
1054+
1055+
# Check bounds
1056+
assert (0 <= val).all()
1057+
assert (val < a).all()
1058+
1059+
# Check uniqueness
1060+
assert numpy.unique(val).size == size
1061+
1062+
def test_near_32bit_limit(self):
1063+
"""Test at the 32-bit boundary."""
1064+
# Current implementation supports up to 2^32
1065+
a = 2**31
1066+
size = 500
1067+
1068+
val = self.rs.choice(a=a, size=size, replace=False).get()
1069+
1070+
# Check bounds
1071+
assert (0 <= val).all()
1072+
assert (val < a).all()
1073+
1074+
# Check uniqueness
1075+
assert numpy.unique(val).size == size
1076+
1077+
1078+
@testing.fix_random()
1079+
class TestChoiceReplaceFalseDtypeConsistency(RandomGeneratorTestCase):
1080+
"""Test output dtype consistency."""
1081+
1082+
target_method = "choice"
1083+
1084+
def test_integer_input_dtype(self):
1085+
"""Integer input should produce int64/long dtype."""
1086+
val = self.generate(a=100, size=50, replace=False)
1087+
1088+
# Should be 'l' (long) dtype, which is int64 on most platforms
1089+
assert val.dtype == numpy.dtype("l") or val.dtype == numpy.int64
1090+
1091+
9321092
@testing.fix_random()
9331093
class TestGumbel(RandomGeneratorTestCase):
9341094

0 commit comments

Comments
 (0)