Skip to content

Commit ae8f470

Browse files
committed
test(core.utils): multiprocess stress for FileStreamProgramCache
Spawns multiple processes to hammer the cache: writers on a shared key prove last-write-wins without corruption, writers on distinct keys prove nothing is lost under contention, and a reader racing against a writer confirms torn files are never observed because os.replace is atomic. Part of issue #178.
1 parent 4f368b1 commit ae8f470

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
"""Multiprocess stress tests for FileStreamProgramCache.
5+
6+
These run without a GPU. They exercise the atomic-rename write path from
7+
multiple processes launched via ``multiprocessing.get_context("spawn")``.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import multiprocessing as _mp
13+
14+
15+
def _worker_write(root: str, key: bytes, payload: bytes, name: str) -> None:
16+
from cuda.core._module import ObjectCode
17+
from cuda.core.utils import FileStreamProgramCache
18+
19+
with FileStreamProgramCache(root) as cache:
20+
cache[key] = ObjectCode._init(payload, "cubin", name=name)
21+
22+
23+
def _worker_write_many(root: str, base: int, n: int) -> None:
24+
from cuda.core._module import ObjectCode
25+
from cuda.core.utils import FileStreamProgramCache
26+
27+
with FileStreamProgramCache(root) as cache:
28+
for i in range(n):
29+
key = f"proc-{base}-key-{i}".encode()
30+
cache[key] = ObjectCode._init(
31+
f"payload-{base}-{i}".encode(), "cubin", name=f"p{base}-{i}"
32+
)
33+
34+
35+
def _worker_reader(root: str, key: bytes, rounds: int, result_queue) -> None:
36+
from cuda.core.utils import FileStreamProgramCache
37+
38+
hits = 0
39+
for _ in range(rounds):
40+
with FileStreamProgramCache(root) as cache:
41+
got = cache.get(key)
42+
if got is not None:
43+
hits += 1
44+
result_queue.put(hits)
45+
46+
47+
def test_concurrent_writers_same_key_no_corruption(tmp_path):
48+
from cuda.core.utils import FileStreamProgramCache
49+
50+
root = str(tmp_path / "fc")
51+
ctx = _mp.get_context("spawn")
52+
procs = [
53+
ctx.Process(
54+
target=_worker_write,
55+
args=(root, b"shared", f"v{i}".encode() * 64, f"p{i}"),
56+
)
57+
for i in range(6)
58+
]
59+
for p in procs:
60+
p.start()
61+
for p in procs:
62+
p.join(timeout=60)
63+
assert p.exitcode == 0, f"worker exited with {p.exitcode}"
64+
65+
with FileStreamProgramCache(root) as cache:
66+
got = cache[b"shared"] # must not raise; payload is one of the writers'
67+
assert bytes(got._module).startswith(b"v")
68+
69+
70+
def test_concurrent_writers_distinct_keys_all_survive(tmp_path):
71+
from cuda.core.utils import FileStreamProgramCache
72+
73+
root = str(tmp_path / "fc")
74+
n_procs = 4
75+
per_proc = 25
76+
ctx = _mp.get_context("spawn")
77+
procs = [
78+
ctx.Process(target=_worker_write_many, args=(root, base, per_proc))
79+
for base in range(n_procs)
80+
]
81+
for p in procs:
82+
p.start()
83+
for p in procs:
84+
p.join(timeout=60)
85+
assert p.exitcode == 0
86+
87+
with FileStreamProgramCache(root) as cache:
88+
for base in range(n_procs):
89+
for i in range(per_proc):
90+
key = f"proc-{base}-key-{i}".encode()
91+
assert key in cache
92+
93+
94+
def test_concurrent_reader_never_sees_torn_file(tmp_path):
95+
from cuda.core._module import ObjectCode
96+
from cuda.core.utils import FileStreamProgramCache
97+
98+
root = str(tmp_path / "fc")
99+
# Seed 'k' so the reader can hit; the writer writes unrelated keys so 'k'
100+
# is never overwritten while the reader is active.
101+
with FileStreamProgramCache(root) as cache:
102+
cache[b"k"] = ObjectCode._init(b"seed" * 256, "cubin", name="seed")
103+
104+
ctx = _mp.get_context("spawn")
105+
queue = ctx.Queue()
106+
writer = ctx.Process(target=_worker_write_many, args=(root, 99, 50))
107+
reader = ctx.Process(
108+
target=_worker_reader, args=(root, b"k", 200, queue)
109+
)
110+
reader.start()
111+
writer.start()
112+
writer.join(timeout=60)
113+
reader.join(timeout=60)
114+
assert writer.exitcode == 0
115+
assert reader.exitcode == 0
116+
hits = queue.get(timeout=5)
117+
# 'k' was never overwritten, so every read must hit.
118+
assert hits == 200

0 commit comments

Comments
 (0)