2323 WithReplacementSampleOrder ,
2424)
2525
26+ # Exercise small/medium/large dataset sizes so shuffle-buffer behavior is
27+ # covered for inputs both much smaller and much larger than typical batches.
28+ _DATASET_SIZES = [3 , 100 , 10_000 ]
29+
2630
2731@pytest .mark .unit
2832class TestWithoutReplacementSampleOrder :
29- def test_yields_all_indices (self ):
33+ @pytest .mark .parametrize ("n_samples" , _DATASET_SIZES )
34+ def test_yields_all_indices (self , n_samples : int ):
3035 order = WithoutReplacementSampleOrder (
31- n_samples_in_dataset = 5 , rng = random .Random (42 )
36+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
3237 )
33- indices = [next (order ) for _ in range (5 )]
34- assert sorted (indices ) == [ 0 , 1 , 2 , 3 , 4 ]
38+ indices = [next (order ) for _ in range (n_samples )]
39+ assert sorted (indices ) == list ( range ( n_samples ))
3540
36- def test_reshuffles_after_exhaustion (self ):
41+ @pytest .mark .parametrize ("n_samples" , _DATASET_SIZES )
42+ def test_reshuffles_after_exhaustion (self , n_samples : int ):
3743 order = WithoutReplacementSampleOrder (
38- n_samples_in_dataset = 3 , rng = random .Random (42 )
44+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
3945 )
40- first_pass = [next (order ) for _ in range (3 )]
41- second_pass = [next (order ) for _ in range (3 )]
42- assert sorted (first_pass ) == [ 0 , 1 , 2 ]
43- assert sorted (second_pass ) == [ 0 , 1 , 2 ]
46+ first_pass = [next (order ) for _ in range (n_samples )]
47+ second_pass = [next (order ) for _ in range (n_samples )]
48+ assert sorted (first_pass ) == list ( range ( n_samples ))
49+ assert sorted (second_pass ) == list ( range ( n_samples ))
4450
45- def test_never_raises_stop_iteration (self ):
51+ @pytest .mark .parametrize ("n_samples" , _DATASET_SIZES )
52+ def test_never_raises_stop_iteration (self , n_samples : int ):
4653 order = WithoutReplacementSampleOrder (
47- n_samples_in_dataset = 2 , rng = random .Random (42 )
54+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
4855 )
4956 # Should be able to draw far more than dataset size
50- indices = [next (order ) for _ in range (100 )]
51- assert len (indices ) == 100
52- assert all (0 <= i < 2 for i in indices )
57+ draws = max (100 , n_samples * 3 )
58+ indices = [next (order ) for _ in range (draws )]
59+ assert len (indices ) == draws
60+ assert all (0 <= i < n_samples for i in indices )
5361
54- def test_reproducible_with_seed (self ):
62+ @pytest .mark .parametrize ("n_samples" , _DATASET_SIZES )
63+ def test_reproducible_with_seed (self , n_samples : int ):
5564 order1 = WithoutReplacementSampleOrder (
56- n_samples_in_dataset = 10 , rng = random .Random (42 )
65+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
5766 )
5867 order2 = WithoutReplacementSampleOrder (
59- n_samples_in_dataset = 10 , rng = random .Random (42 )
68+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
6069 )
61- seq1 = [next (order1 ) for _ in range (20 )]
62- seq2 = [next (order2 ) for _ in range (20 )]
70+ seq1 = [next (order1 ) for _ in range (n_samples * 2 )]
71+ seq2 = [next (order2 ) for _ in range (n_samples * 2 )]
6372 assert seq1 == seq2
6473
6574 def test_invalid_size_raises (self ):
@@ -69,20 +78,22 @@ def test_invalid_size_raises(self):
6978
7079@pytest .mark .unit
7180class TestWithReplacementSampleOrder :
72- def test_yields_valid_indices (self ):
81+ @pytest .mark .parametrize ("n_samples" , _DATASET_SIZES )
82+ def test_yields_valid_indices (self , n_samples : int ):
7383 order = WithReplacementSampleOrder (
74- n_samples_in_dataset = 5 , rng = random .Random (42 )
84+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
7585 )
76- indices = [next (order ) for _ in range (100 )]
77- assert all (0 <= i < 5 for i in indices )
86+ indices = [next (order ) for _ in range (max ( 100 , n_samples ) )]
87+ assert all (0 <= i < n_samples for i in indices )
7888
79- def test_reproducible_with_seed (self ):
89+ @pytest .mark .parametrize ("n_samples" , _DATASET_SIZES )
90+ def test_reproducible_with_seed (self , n_samples : int ):
8091 order1 = WithReplacementSampleOrder (
81- n_samples_in_dataset = 10 , rng = random .Random (42 )
92+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
8293 )
8394 order2 = WithReplacementSampleOrder (
84- n_samples_in_dataset = 10 , rng = random .Random (42 )
95+ n_samples_in_dataset = n_samples , rng = random .Random (42 )
8596 )
86- seq1 = [next (order1 ) for _ in range (20 )]
87- seq2 = [next (order2 ) for _ in range (20 )]
97+ seq1 = [next (order1 ) for _ in range (n_samples * 2 )]
98+ seq2 = [next (order2 ) for _ in range (n_samples * 2 )]
8899 assert seq1 == seq2
0 commit comments