Skip to content

Commit a4b5bad

Browse files
author
Paul Prescod
committed
Handle a corner case.
1 parent 68fbd48 commit a4b5bad

2 files changed

Lines changed: 65 additions & 33 deletions

File tree

snowfakery/standard_plugins/_math.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def __init__(self, repeat, iterable):
4242

4343

4444
def parts(
45-
total: int,
46-
min_: int = 1,
47-
max_: Optional[int] = None,
48-
requested_step: float = 1,
45+
user_total: int,
46+
user_min: int = 1,
47+
user_max: Optional[int] = None,
48+
user_step: float = 1,
4949
rand: Optional[Random] = None,
5050
) -> List[Union[int, float]]:
5151
"""Split a number into a randomized set of 'pieces'.
@@ -64,22 +64,24 @@ def parts(
6464
will be inconsistent with them. e.g. if `total` is not a multiple
6565
of `step`.
6666
"""
67-
max_ = max_ or total
67+
max_ = user_max or user_total
6868
rand = rand or Random()
6969

70-
if requested_step < 1:
70+
if user_step < 1:
7171
allowed_steps = [0.01, 0.5, 0.1, 0.20, 0.25, 0.50]
7272
assert (
73-
requested_step in allowed_steps
74-
), f"`step` must be one of {', '.join(str(f) for f in allowed_steps)}, not {requested_step}"
73+
user_step in allowed_steps
74+
), f"`step` must be one of {', '.join(str(f) for f in allowed_steps)}, not {user_step}"
7575
# multiply up into the integer range so we don't need to do float math
76-
total = int(total / requested_step)
76+
total = int(user_total / user_step)
7777
step = 1
78-
min_ = int(min_ / requested_step)
79-
max_ = int(max_ / requested_step)
78+
min_ = int(user_min / user_step)
79+
max_ = int(max_ / user_step)
8080
else:
81-
step = int(requested_step)
82-
assert step == requested_step, f"`step` should be an integer, not {step}"
81+
step = int(user_step)
82+
min_ = user_min
83+
total = user_total
84+
assert step == user_step, f"`step` should be an integer, not {step}"
8385

8486
pieces = []
8587

@@ -88,47 +90,55 @@ def parts(
8890
smallest = max(min_, step)
8991
if remaining < smallest:
9092
# mutates pieces
91-
handle_last_bit(pieces, rand, remaining, min_, max_)
93+
success = handle_last_bit(pieces, rand, remaining, min_, max_)
94+
# our constraints must have been impossible to fulfill
95+
assert (
96+
success
97+
), f"No way to match all constraints: total: {user_total}, min: {user_min}, max: {user_max}, step: {user_step}"
9298

9399
else:
94-
pieces.append(generate_piece(pieces, rand, smallest, remaining, max_, step))
100+
pieces.append(generate_piece(rand, smallest, remaining, max_, step))
95101

96102
assert sum(pieces) == total, pieces
97103
assert 0 not in pieces, pieces
98104

99-
if requested_step != step:
100-
pieces = [round(p * requested_step, 2) for p in pieces]
105+
if user_step != step:
106+
pieces = [round(p * user_step, 2) for p in pieces]
101107
return pieces
102108

103109

104110
def handle_last_bit(
105111
pieces: List[int], rand: Random, remaining: int, min_: int, max_: int
106-
):
112+
) -> bool:
107113
"""If the piece is big enough, add it.
108114
Otherwise, try to add it to another piece."""
109115

110116
if remaining > min_:
111117
pos = rand.randint(0, len(pieces))
112118
pieces.insert(pos, remaining)
113-
return
119+
return True
114120

115121
# try to add it to some other piece
116122
for i, val in enumerate(pieces):
117123
if val + remaining <= max_:
118124
pieces[i] += remaining
119125
remaining = 0
120-
return
126+
return True
121127

122-
# just insert it despite it being too small...our
123-
# constraints must have been impossible to fulfill
124-
if remaining:
125-
pos = rand.randint(0, len(pieces))
126-
pieces.insert(pos, remaining)
128+
# No other piece has enough room...so
129+
# split it up among several other pieces
130+
for i, val in enumerate(pieces):
131+
chunk = min(max_ - pieces[i], remaining)
132+
remaining -= chunk
133+
pieces[i] = max_
134+
assert remaining >= 0
135+
if remaining == 0:
136+
return True
137+
138+
return False
127139

128140

129-
def generate_piece(
130-
pieces: List[int], rand: Random, smallest: int, remaining: int, max_: int, step: int
131-
):
141+
def generate_piece(rand: Random, smallest: int, remaining: int, max_: int, step: int):
132142
part = rand.randint(smallest, min(remaining, max_))
133143
round_up = part + step - (part % step)
134144
if round_up <= min(remaining, max_) and rand.randint(0, 1):
Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from snowfakery.data_gen_exceptions import DataGenError
66

77
REPS = 1
8+
SEEDS = [randint(0, 2 ** 32) for r in range(REPS)]
89

910

10-
@pytest.mark.parametrize("seed", [randint(0, 2 ** 32) for r in range(REPS)])
11-
class TestSummation:
11+
@pytest.mark.parametrize("seed", SEEDS)
12+
class TestMathPartition:
1213
def test_example(self, generated_rows, seed):
1314
generate_data(
1415
"examples/math_partition/math_partition_simple.recipe.yml", seed=seed
@@ -19,8 +20,11 @@ def test_example(self, generated_rows, seed):
1920
c["Amount__c"] for c in children
2021
), (parents, children)
2122

22-
def test_example_pennies(self, generated_rows, seed):
23-
generate_data("examples/math_partition/sum_pennies.recipe.yml", seed=seed)
23+
regression_seeds = [824956277]
24+
25+
@pytest.mark.parametrize("seed2", regression_seeds + SEEDS)
26+
def test_example_pennies(self, generated_rows, seed, seed2):
27+
generate_data("examples/math_partition/sum_pennies.recipe.yml", seed=seed2)
2428
objs = generated_rows.table_values("Values")
2529
assert round(sum(p["Amount"] for p in objs)) == 100, sum(
2630
p["Amount"] for p in objs
@@ -31,7 +35,7 @@ def test_example_pennies_param(self, generated_rows, seed, step: int):
3135
generate_data(
3236
"examples/math_partition/sum_pennies_param.recipe.yml",
3337
user_options={"step": step},
34-
seed=1,
38+
seed=seed,
3539
)
3640
objs = generated_rows.table_values("Values")
3741
assert round(sum(p["Amount"] for p in objs)) == 100, sum(
@@ -151,3 +155,21 @@ def test_bad_step(self, generated_rows, seed):
151155
"""
152156
with pytest.raises(DataGenError, match="step.*0.3"):
153157
generate_data(StringIO(yaml), seed=seed)
158+
159+
def test_inconsistent_constraints(self, generated_rows, seed):
160+
yaml = """
161+
- plugin: snowfakery.standard_plugins.Math
162+
- object: Obj
163+
for_each:
164+
var: child_value
165+
value:
166+
Math.random_partition:
167+
total: 10
168+
min: 8
169+
max: 8
170+
step: 5
171+
fields:
172+
Amount: ${{child_value}}
173+
"""
174+
with pytest.raises(DataGenError, match="constraints"):
175+
generate_data(StringIO(yaml), seed=seed)

0 commit comments

Comments
 (0)