Skip to content

Commit eb4902f

Browse files
committed
check that we're not leaking the seed
1 parent 3ad8f90 commit eb4902f

3 files changed

Lines changed: 4 additions & 0 deletions

File tree

examples/eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def main():
288288

289289
mode = sys.argv[1]
290290
seed = os.getenv("POPCORN_SEED")
291+
os.unsetenv("POPCORN_SEED")
291292
seed = int(seed) if seed else None
292293
set_seed(seed or 42)
293294
tests = get_test_cases(sys.argv[2], seed)

examples/identity_py/cheat-rng.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
def custom_kernel(data: input_t) -> output_t:
88
if data.numel() == 65536:
99
gen = torch.Generator(device='cuda')
10+
assert "POPCORN_SEED" not in os.environ
1011
gen.manual_seed(125432)
1112
data = torch.empty(65536, device='cuda', dtype=torch.float16)
1213
data.uniform_(0, 1, generator=gen)

scripts/ci_test_python.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def test_randomization():
122122
assert run.passed is False
123123
assert run.stdout == ""
124124
assert run.result['check'] == 'fail'
125+
assert "mismatch found!" in run.result['test.0.error']
125126

126127

127128
def test_fd_hacking():
@@ -135,3 +136,4 @@ def test_overwrite_input():
135136
run = run_pytorch_helper(
136137
{**files, "submission.py": Path("examples/identity_py/cheat-input.py").read_text()})
137138
assert run.result['check'] == 'fail'
139+
assert "mismatch found!" in run.result['test.0.error']

0 commit comments

Comments
 (0)