|
5 | 5 | import sys |
6 | 6 | import math |
7 | 7 | from pathlib import Path |
8 | | -from typing import Any |
| 8 | +from typing import Any, Optional |
9 | 9 |
|
10 | 10 | import torch.cuda |
11 | 11 |
|
@@ -45,7 +45,18 @@ class TestCase: |
45 | 45 | spec: str |
46 | 46 |
|
47 | 47 |
|
48 | | -def get_test_cases(file_name: str) -> list[TestCase]: |
| 48 | +def _combine(a: int, b: int) -> int: |
| 49 | + # combine two integers into one: |
| 50 | + # we need this to generate a secret seed based on the test-level seed and |
| 51 | + # the global secret seed. |
| 52 | + # the test-level seeds are public knowledge, and typically relatively small numbers, |
| 53 | + # so we need to make sure they don't provide any useful info for the full seed. |
| 54 | + # This Cantor construction ensures that if the secret seed is a large number, |
| 55 | + # then so is the overall seed. |
| 56 | + return int(a + (a+b)*(a+b+1)//2) |
| 57 | + |
| 58 | + |
| 59 | +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: |
49 | 60 | try: |
50 | 61 | content = Path(file_name).read_text() |
51 | 62 | except Exception as E: |
@@ -73,6 +84,11 @@ def get_test_cases(file_name: str) -> list[TestCase]: |
73 | 84 | case[key] = val |
74 | 85 | tests.append(TestCase(spec=line, args=case)) |
75 | 86 |
|
| 87 | + if seed is not None: |
| 88 | + for test in tests: |
| 89 | + if "seed" in test.args: |
| 90 | + test.args["seed"] = _combine(test.args["seed"], seed) |
| 91 | + |
76 | 92 | return tests |
77 | 93 |
|
78 | 94 |
|
@@ -236,13 +252,12 @@ def main(): |
236 | 252 | return 2 |
237 | 253 |
|
238 | 254 | mode = sys.argv[1] |
239 | | - tests = get_test_cases(sys.argv[2]) |
| 255 | + seed = os.getenv("POPCORN_SEED") |
| 256 | + seed = int(seed) if seed else None |
| 257 | + set_seed(seed or 42) |
| 258 | + tests = get_test_cases(sys.argv[2], seed) |
240 | 259 |
|
241 | 260 | with PopcornOutput(int(fd)) as logger: |
242 | | - seed = os.getenv("POPCORN_SEED") |
243 | | - seed = int(seed) if seed else 42 |
244 | | - set_seed(seed) |
245 | | - |
246 | 261 | if mode == "test": |
247 | 262 | return run_testing(logger, tests) |
248 | 263 |
|
|
0 commit comments