Skip to content

Commit 7a04804

Browse files
committed
dynamic seed updates in python evaluations
1 parent 768e71f commit 7a04804

1 file changed

Lines changed: 22 additions & 7 deletions

File tree

examples/eval.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import math
77
from pathlib import Path
8-
from typing import Any
8+
from typing import Any, Optional
99

1010
import torch.cuda
1111

@@ -45,7 +45,18 @@ class TestCase:
4545
spec: str
4646

4747

48-
def get_test_cases(file_name: str) -> list[TestCase]:
48+
def _combine(a: int, b: int) -> int:
49+
# combine two integers into one:
50+
# we need this to generate a secret seed based on the test-level seed and
51+
# the global secret seed.
52+
# the test-level seeds are public knowledge, and typically relatively small numbers,
53+
# so we need to make sure they don't provide any useful info for the full seed.
54+
# This Cantor construction ensures that if the secret seed is a large number,
55+
# then so is the overall seed.
56+
return int(a + (a+b)*(a+b+1)//2)
57+
58+
59+
def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]:
4960
try:
5061
content = Path(file_name).read_text()
5162
except Exception as E:
@@ -73,6 +84,11 @@ def get_test_cases(file_name: str) -> list[TestCase]:
7384
case[key] = val
7485
tests.append(TestCase(spec=line, args=case))
7586

87+
if seed is not None:
88+
for test in tests:
89+
if "seed" in test.args:
90+
test.args["seed"] = _combine(test.args["seed"], seed)
91+
7692
return tests
7793

7894

@@ -236,13 +252,12 @@ def main():
236252
return 2
237253

238254
mode = sys.argv[1]
239-
tests = get_test_cases(sys.argv[2])
255+
seed = os.getenv("POPCORN_SEED")
256+
seed = int(seed) if seed else None
257+
set_seed(seed or 42)
258+
tests = get_test_cases(sys.argv[2], seed)
240259

241260
with PopcornOutput(int(fd)) as logger:
242-
seed = os.getenv("POPCORN_SEED")
243-
seed = int(seed) if seed else 42
244-
set_seed(seed)
245-
246261
if mode == "test":
247262
return run_testing(logger, tests)
248263

0 commit comments

Comments
 (0)