Skip to content

Commit f3f639a

Browse files
author
Paul Prescod
committed
Semi-deterministic testing
1 parent 82e7a4d commit f3f639a

9 files changed

Lines changed: 130 additions & 58 deletions

File tree

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1954,7 +1954,7 @@ granularity), `0.05` (nickle), `0.10` (dime), `0.25` (quarter) and
19541954
total: 100
19551955
min: 10
19561956
max: 50
1957-
step: 0.1
1957+
step: 0.01
19581958
fields:
19591959
Amount: ${{current_value}}
19601960
```

examples/sum_pennies.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88
total: 100
99
min: 10
1010
max: 50
11-
step: 0.1
11+
step: 0.01
1212
fields:
1313
Amount: ${{current_value}}

snowfakery/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def generate_data(
151151
update_passthrough_fields: T.Sequence[
152152
str
153153
] = (), # pass through these fields from input to output
154+
seed: T.Optional[int] = None,
154155
) -> None:
155156
stopping_criteria = stopping_criteria_from_target_number(target_number)
156157
dburls = dburls or ([dburl] if dburl else [])
@@ -193,6 +194,7 @@ def open_with_cleanup(file, mode, **kwargs):
193194
plugin_options=plugin_options,
194195
update_input_file=open_update_input_file,
195196
update_passthrough_fields=update_passthrough_fields,
197+
seed=seed,
196198
)
197199

198200
if open_cci_mapping_file:

snowfakery/data_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def generate(
131131
plugin_options: dict = None,
132132
update_input_file: OpenFileLike = None,
133133
update_passthrough_fields: T.Sequence[str] = (),
134+
seed: T.Optional[int] = None,
134135
) -> ExecutionSummary:
135136
"""The main entry point to the package for Python applications."""
136137
from .api import SnowfakeryApplication
@@ -188,6 +189,7 @@ def generate(
188189
parse_result=parse_result,
189190
globals=globls,
190191
continuing=bool(continuation_data),
192+
seed=seed,
191193
) as interpreter:
192194
runtime_context = interpreter.execute()
193195

snowfakery/data_generator_runtime.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import defaultdict, ChainMap
44
from datetime import date, datetime, timezone
55
from contextlib import contextmanager
6+
from random import Random
67

78
from typing import Optional, Dict, Sequence, Mapping, NamedTuple, Set
89
import typing as T
@@ -300,6 +301,7 @@ def __init__(
300301
snowfakery_plugins: Optional[Mapping[str, callable]] = None,
301302
faker_providers: Sequence[object] = (),
302303
continuing=False,
304+
seed: Optional[int] = None,
303305
):
304306
self.output_stream = output_stream
305307
self.options = options or {}
@@ -354,6 +356,7 @@ def __init__(
354356
self.globals.nicknames_and_tables,
355357
)
356358
self.resave_objects_from_continuation(globals, self.tables_to_keep_history_for)
359+
self.random_number_generator = Random(seed)
357360

358361
def resave_objects_from_continuation(
359362
self, globals: Globals, tables_to_keep_history_for: T.Iterable[str]

snowfakery/plugins.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from random import Random
12
import sys
23

34
from typing import Any, Callable, Mapping, Union, NamedTuple, List, Tuple
@@ -141,8 +142,8 @@ def current_filename(self):
141142
return self.interpreter.current_context.current_template.filename
142143

143144
@property
144-
def current_filename(self):
145-
return self.interpreter.current_context.current_template.filename
145+
def random_number_generator(self) -> Random:
146+
return self.interpreter.random_number_generator
146147

147148

148149
def lazy(func: Any) -> Callable:
Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from random import randint, shuffle
2+
from random import Random
33
from types import SimpleNamespace
44
from typing import List, Optional, Union
55
from snowfakery.plugins import SnowfakeryPlugin, memorable, PluginResultIterator
@@ -17,9 +17,12 @@ def random_partition(
1717
*,
1818
min: int = 1,
1919
max: Optional[int] = None,
20-
step: int = 1,
20+
step: float = 1,
2121
):
22-
return GenericPluginResultIterator(False, parts(total, min, max, step))
22+
random = self.context.random_number_generator
23+
return GenericPluginResultIterator(
24+
False, parts(total, min, max, step, random)
25+
)
2326

2427
mathns = MathNamespace()
2528
mathns.__dict__.update(math.__dict__.copy())
@@ -38,7 +41,13 @@ def __init__(self, repeat, iterable):
3841
self.next = iter(iterable).__next__
3942

4043

41-
def parts(total: int, min_: int = 1, max_=None, step=1) -> List[Union[int, float]]:
44+
def parts(
45+
total: int,
46+
min_: int = 1,
47+
max_: Optional[int] = None,
48+
requested_step: float = 1,
49+
rand: Optional[Random] = None,
50+
) -> List[Union[int, float]]:
4251
"""Split a number into a randomized set of 'pieces'.
4352
The pieces add up to the `total`. E.g.
4453
@@ -56,50 +65,75 @@ def parts(total: int, min_: int = 1, max_=None, step=1) -> List[Union[int, float
5665
of `step`.
5766
"""
5867
max_ = max_ or total
59-
factor = 0
60-
61-
if step < 1:
62-
assert step in [0.01, 0.5, 0.1, 0.20, 0.25, 0.50], step
63-
factor = step
64-
total = int(total / factor)
65-
step = int(total / factor)
66-
min_ = int(total / factor)
67-
max_ = int(total / factor)
68+
rand = rand or Random()
69+
70+
if requested_step < 1:
71+
allowed_steps = [0.01, 0.5, 0.1, 0.20, 0.25, 0.50]
72+
assert (
73+
requested_step in allowed_steps
74+
), f"`step` must be one of {', '.join(str(f) for f in allowed_steps)}, not {requested_step}"
75+
# multiply up into the integer range so we don't need to do float math
76+
total = int(total / requested_step)
77+
step = 1
78+
min_ = int(min_ / requested_step)
79+
max_ = int(max_ / requested_step)
80+
else:
81+
step = int(requested_step)
82+
assert step == requested_step, f"`step` should be an integer, not {step}"
6883

6984
pieces = []
7085

7186
while sum(pieces) < total:
7287
remaining = total - sum(pieces)
7388
smallest = max(min_, step)
7489
if remaining < smallest:
75-
# try to add it to a random other piece
76-
for i, val in enumerate(pieces):
77-
if val + remaining <= max_:
78-
pieces[i] += remaining
79-
remaining = 0
80-
break
81-
82-
# just tack it on the end despite
83-
# it being too small...our
84-
# constraints must have been impossible
85-
# to fulfil
86-
if remaining:
87-
pieces.append(remaining)
90+
# mutates pieces
91+
handle_last_bit(pieces, rand, remaining, min_, max_)
8892

8993
else:
90-
part = randint(smallest, min(remaining, max_))
91-
round_up = part + step - (part % step)
92-
if round_up <= min(remaining, max_) and randint(0, 1):
93-
part = round_up
94-
else:
95-
part -= part % step
96-
97-
pieces.append(part)
94+
pieces.append(generate_piece(pieces, rand, smallest, remaining, max_, step))
9895

9996
assert sum(pieces) == total, pieces
10097
assert 0 not in pieces, pieces
10198

102-
shuffle(pieces)
103-
if factor:
104-
pieces = [round(p * factor, 2) for p in pieces]
99+
if requested_step != step:
100+
pieces = [round(p * requested_step, 2) for p in pieces]
105101
return pieces
102+
103+
104+
def handle_last_bit(
105+
pieces: List[int], rand: Random, remaining: int, min_: int, max_: int
106+
):
107+
"""If the piece is big enough, add it.
108+
Otherwise, try to add it to another piece."""
109+
110+
if remaining > min_:
111+
pos = rand.randint(0, len(pieces))
112+
pieces.insert(pos, remaining)
113+
return
114+
115+
# try to add it to some other piece
116+
for i, val in enumerate(pieces):
117+
if val + remaining <= max_:
118+
pieces[i] += remaining
119+
remaining = 0
120+
return
121+
122+
# just insert it despite it being too small...our
123+
# constraints must have been impossible to fulfill
124+
if remaining:
125+
pos = rand.randint(0, len(pieces))
126+
pieces.insert(pos, remaining)
127+
128+
129+
def generate_piece(
130+
pieces: List[int], rand: Random, smallest: int, remaining: int, max_: int, step: int
131+
):
132+
part = rand.randint(smallest, min(remaining, max_))
133+
round_up = part + step - (part % step)
134+
if round_up <= min(remaining, max_) and rand.randint(0, 1):
135+
part = round_up
136+
else:
137+
part -= part % step
138+
139+
return part

tests/test_bad_step.recipe.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
- plugin: snowfakery.standard_plugins.Math
2+
- object: Obj
3+
for_each:
4+
var: child_value
5+
value:
6+
Math.random_partition:
7+
total: 28
8+
step: 0.3
9+
fields:
10+
Amount: ${{child_value}}

tests/test_summation.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
import pytest
2+
from random import randint
23
from io import StringIO
34
from snowfakery import generate_data
5+
from snowfakery.data_gen_exceptions import DataGenError
46

57
REPS = 1
68

79

8-
@pytest.mark.parametrize("_", range(REPS))
10+
@pytest.mark.parametrize("seed", [randint(0, 2 ** 32) for r in range(REPS)])
911
class TestSummation:
10-
def test_example(self, generated_rows, _):
11-
generate_data("examples/math_partition_simple.recipe.yml")
12+
def test_example(self, generated_rows, seed):
13+
generate_data("examples/math_partition_simple.recipe.yml", seed=seed)
1214
parents = generated_rows.table_values("ParentObject__c")
1315
children = generated_rows.table_values("ChildObject__c")
1416
assert sum(p["TotalAmount__c"] for p in parents) == sum(
1517
c["Amount__c"] for c in children
1618
), (parents, children)
1719

18-
def test_example_pennies(self, generated_rows, _):
19-
generate_data("examples/sum_pennies.yml")
20+
def test_example_pennies(self, generated_rows, seed):
21+
generate_data("examples/sum_pennies.yml", seed=seed)
2022
objs = generated_rows.table_values("Values")
2123
assert round(sum(p["Amount"] for p in objs)) == 100, sum(
2224
p["Amount"] for p in objs
2325
)
2426

2527
@pytest.mark.parametrize("step", [0.01, 0.5, 0.1, 0.20, 0.25, 0.50])
26-
def test_example_pennies_param(self, generated_rows, _, step: int):
27-
generate_data("examples/sum_pennies_param.yml", user_options={"step": step})
28+
def test_example_pennies_param(self, generated_rows, seed, step: int):
29+
generate_data(
30+
"examples/sum_pennies_param.yml", user_options={"step": step}, seed=1
31+
)
2832
objs = generated_rows.table_values("Values")
2933
assert round(sum(p["Amount"] for p in objs)) == 100, sum(
3034
p["Amount"] for p in objs
3135
)
3236

33-
def test_step(self, generated_rows, _):
37+
def test_step(self, generated_rows, seed):
3438
yaml = """
3539
- plugin: snowfakery.standard_plugins.Math
3640
- object: Obj
@@ -43,14 +47,14 @@ def test_step(self, generated_rows, _):
4347
fields:
4448
Amount: ${{child_value}}
4549
"""
46-
generate_data(StringIO(yaml))
50+
generate_data(StringIO(yaml), seed=seed)
4751
values = generated_rows.table_values("Obj")
4852
assert 1 <= len(values) <= 6
4953
amounts = [r["Amount"] for r in values]
5054
assert sum(amounts) == 60, amounts
5155
assert sum([r % 10 for r in amounts]) == 0, amounts
5256

53-
def test_min(self, generated_rows, _):
57+
def test_min(self, generated_rows, seed):
5458
yaml = """
5559
- plugin: snowfakery.standard_plugins.Math
5660
- object: Obj
@@ -63,13 +67,13 @@ def test_min(self, generated_rows, _):
6367
fields:
6468
Amount: ${{child_value}}
6569
"""
66-
generate_data(StringIO(yaml))
70+
generate_data(StringIO(yaml), seed=seed)
6771
values = generated_rows.table_values("Obj")
6872
results = [r["Amount"] for r in values]
6973
assert sum(results) == 60, results
7074
assert not [r for r in results if r < 5], results
7175

72-
def test_min_not_factor_of_total(self, generated_rows, _):
76+
def test_min_not_factor_of_total(self, generated_rows, seed):
7377
yaml = """
7478
- plugin: snowfakery.standard_plugins.Math
7579
- object: Obj
@@ -82,13 +86,13 @@ def test_min_not_factor_of_total(self, generated_rows, _):
8286
fields:
8387
Amount: ${{child_value}}
8488
"""
85-
generate_data(StringIO(yaml))
89+
generate_data(StringIO(yaml), seed=seed)
8690
values = generated_rows.table_values("Obj")
8791
results = [r["Amount"] for r in values]
8892
assert sum(results) == 63
8993
assert not [r for r in results if r < 5], results
9094

91-
def test_step_not_factor_of_total(self, generated_rows, _):
95+
def test_step_not_factor_of_total(self, generated_rows, seed):
9296
yaml = """
9397
- plugin: snowfakery.standard_plugins.Math
9498
- object: Obj
@@ -101,13 +105,13 @@ def test_step_not_factor_of_total(self, generated_rows, _):
101105
fields:
102106
Amount: ${{child_value}}
103107
"""
104-
generate_data(StringIO(yaml))
108+
generate_data(StringIO(yaml), seed=seed)
105109
values = generated_rows.table_values("Obj")
106110
results = [r["Amount"] for r in values]
107111
assert sum(results) == 63, results
108112
assert len([r for r in results if r < 5]) <= 1, results
109113

110-
def test_max(self, generated_rows, _):
114+
def test_max(self, generated_rows, seed):
111115
yaml = """
112116
- plugin: snowfakery.standard_plugins.Math
113117
- object: Obj
@@ -121,9 +125,25 @@ def test_max(self, generated_rows, _):
121125
fields:
122126
Amount: ${{child_value}}
123127
"""
124-
generate_data(StringIO(yaml))
128+
generate_data(StringIO(yaml), seed=seed)
125129
values = generated_rows.table_values("Obj")
126130
results = [r["Amount"] for r in values]
127131
assert sum(results) == 28, results
128132
assert not [r for r in results if r % 2], results
129133
assert not [r for r in results if r > 6], results
134+
135+
def test_bad_step(self, generated_rows, seed):
136+
yaml = """
137+
- plugin: snowfakery.standard_plugins.Math
138+
- object: Obj
139+
for_each:
140+
var: child_value
141+
value:
142+
Math.random_partition:
143+
total: 28
144+
step: 0.3
145+
fields:
146+
Amount: ${{child_value}}
147+
"""
148+
with pytest.raises(DataGenError, match="step.*0.3"):
149+
generate_data(StringIO(yaml), seed=seed)

0 commit comments

Comments
 (0)