|
1 | | -"""Test the random function with large entity IDs to ensure no overflow.""" |
2 | | - |
3 | | -import numpy as np |
4 | 1 | import pytest |
5 | | -from unittest.mock import Mock |
6 | | -from policyengine_core.commons.formulas import random |
7 | | - |
8 | | - |
9 | | -class TestRandomSeed: |
10 | | - """Test random seed handling to prevent NumPy overflow errors.""" |
11 | | - |
12 | | - def test_random_with_large_entity_ids(self): |
13 | | - """Test that random() handles large entity IDs without overflow.""" |
14 | | - # Create a mock population with simulation |
15 | | - population = Mock() |
16 | | - population.simulation = Mock() |
17 | | - population.simulation.count_random_calls = 0 |
18 | | - population.entity = Mock() |
19 | | - population.entity.key = "person" |
20 | | - |
21 | | - # Mock the get_holder and get_known_periods |
22 | | - holder = Mock() |
23 | | - holder.get_known_periods.return_value = [] |
24 | | - population.simulation.get_holder.return_value = holder |
25 | | - population.simulation.default_calculation_period = Mock() |
26 | | - |
27 | | - # Test with very large entity IDs that would cause overflow |
28 | | - # if not handled properly |
29 | | - large_ids = np.array( |
30 | | - [ |
31 | | - np.iinfo(np.int64).max - 1000, # Very large positive ID |
32 | | - np.iinfo(np.int64).max // 2, # Large positive ID |
33 | | - 1234567890123456789, # Another large ID |
34 | | - ] |
35 | | - ) |
36 | | - |
37 | | - # Mock the population call to return large IDs |
38 | | - population.side_effect = lambda key, period: large_ids |
39 | | - |
40 | | - # This should not raise a ValueError about negative seeds |
41 | | - result = random(population) |
42 | | - |
43 | | - # Check that we got valid random values |
44 | | - assert isinstance(result, np.ndarray) |
45 | | - assert len(result) == len(large_ids) |
46 | | - assert all(0 <= val <= 1 for val in result) |
47 | | - |
48 | | - def test_random_seed_consistency(self): |
49 | | - """Test that random() produces consistent results for same inputs.""" |
50 | | - # Create mock population |
51 | | - population = Mock() |
52 | | - population.simulation = Mock() |
53 | | - population.simulation.count_random_calls = 0 |
54 | | - population.entity = Mock() |
55 | | - population.entity.key = "household" |
56 | | - |
57 | | - holder = Mock() |
58 | | - holder.get_known_periods.return_value = [] |
59 | | - population.simulation.get_holder.return_value = holder |
60 | | - population.simulation.default_calculation_period = Mock() |
61 | | - |
62 | | - # Use same IDs |
63 | | - ids = np.array([1, 2, 3]) |
64 | | - population.side_effect = lambda key, period: ids |
65 | | - |
66 | | - # First call |
67 | | - result1 = random(population) |
68 | | - |
69 | | - # Reset count to simulate same conditions |
70 | | - population.simulation.count_random_calls = 0 |
71 | 2 |
|
72 | | - # Second call with same conditions |
73 | | - result2 = random(population) |
74 | | - |
75 | | - # Results should be identical |
76 | | - np.testing.assert_array_equal(result1, result2) |
77 | | - |
78 | | - def test_random_increments_call_count(self): |
79 | | - """Test that random() increments the call counter.""" |
80 | | - population = Mock() |
81 | | - population.simulation = Mock() |
82 | | - population.simulation.count_random_calls = 0 |
83 | | - population.entity = Mock() |
84 | | - population.entity.key = "person" |
85 | | - |
86 | | - holder = Mock() |
87 | | - holder.get_known_periods.return_value = [] |
88 | | - population.simulation.get_holder.return_value = holder |
89 | | - population.simulation.default_calculation_period = Mock() |
90 | | - |
91 | | - ids = np.array([1, 2, 3]) |
92 | | - population.side_effect = lambda key, period: ids |
93 | | - |
94 | | - # First call |
95 | | - random(population) |
96 | | - assert population.simulation.count_random_calls == 1 |
97 | | - |
98 | | - # Second call |
99 | | - random(population) |
100 | | - assert population.simulation.count_random_calls == 2 |
101 | | - |
102 | | - def test_random_handles_negative_ids(self): |
103 | | - """Test that random() handles negative IDs properly.""" |
104 | | - population = Mock() |
105 | | - population.simulation = Mock() |
106 | | - population.simulation.count_random_calls = 0 |
107 | | - population.entity = Mock() |
108 | | - population.entity.key = "person" |
109 | | - |
110 | | - holder = Mock() |
111 | | - holder.get_known_periods.return_value = [] |
112 | | - population.simulation.get_holder.return_value = holder |
113 | | - population.simulation.default_calculation_period = Mock() |
114 | | - |
115 | | - # Include negative IDs |
116 | | - ids = np.array([-100, -1, 0, 1, 100]) |
117 | | - population.side_effect = lambda key, period: ids |
118 | | - |
119 | | - # Should handle negative IDs without errors |
120 | | - result = random(population) |
121 | | - |
122 | | - assert isinstance(result, np.ndarray) |
123 | | - assert len(result) == len(ids) |
124 | | - assert all(0 <= val <= 1 for val in result) |
125 | | - |
126 | | - def test_no_negative_seed_error_with_overflow(self): |
127 | | - """Test that seed calculation overflow doesn't cause negative seed error.""" |
128 | | - population = Mock() |
129 | | - population.simulation = Mock() |
130 | | - population.simulation.count_random_calls = 999999999 # Large count |
131 | | - population.entity = Mock() |
132 | | - population.entity.key = "person" |
133 | | - |
134 | | - holder = Mock() |
135 | | - holder.get_known_periods.return_value = [] |
136 | | - population.simulation.get_holder.return_value = holder |
137 | | - population.simulation.default_calculation_period = Mock() |
138 | | - |
139 | | - # Use the exact ID that would cause overflow in old implementation |
140 | | - # This ID when multiplied by 100 and added to count_random_calls |
141 | | - # would overflow int64 and become negative |
142 | | - overflow_id = np.array([np.iinfo(np.int64).max // 100]) |
143 | | - population.side_effect = lambda key, period: overflow_id |
| 3 | +from policyengine_core.commons.formulas import random |
144 | 4 |
|
145 | | - # In the old implementation, this would raise: |
146 | | - # ValueError: Seed must be between 0 and 2**32 - 1 |
147 | | - # With the fix using abs(), it should work fine |
148 | | - result = random(population) |
149 | 5 |
|
150 | | - assert isinstance(result, np.ndarray) |
151 | | - assert len(result) == 1 |
152 | | - assert 0 <= result[0] <= 1 |
| 6 | +def test_random_raises_for_formula_time_randomness(): |
| 7 | + with pytest.raises(RuntimeError, match="Formula-time randomness is not allowed"): |
| 8 | + random(None) |
0 commit comments