Skip to content

Commit 310b196

Browse files
committed
fix(graph): correct MEM_FREE node dependencies for substream allocations
Captured cudaGraphAddMemFreeNode was inheriting the begin-stream's frontier instead of the substream's, causing use-after-free on replay. Also adds reproducer.py with CPU fallback for GPU-less environments. Signed-off-by: theonlychant <sacehenry@gmail.com>
1 parent 2a8b310 commit 310b196

2 files changed

Lines changed: 108 additions & 8 deletions

File tree

reproducer.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import warp as wp
2+
3+
# CPU-mode reproducer adapted from your CUDA snippet
4+
5+
6+
def main():
7+
# Try CUDA first, fall back to CPU
8+
try:
9+
device = wp.get_device("cuda:0")
10+
except Exception:
11+
device = wp.get_device("cpu")
12+
13+
try:
14+
wp.load_module(device=device)
15+
except Exception:
16+
if not device.is_cpu:
17+
raise
18+
# CPU load_module failures are expected, ignore
19+
20+
@wp.kernel
21+
def touch(x: wp.array(dtype=wp.float32)):
22+
i = wp.tid()
23+
if i < x.shape[0]:
24+
x[i] = x[i] + 1.0
25+
26+
# CPU devices do not have CUDA streams; run a CPU-friendly capture path.
27+
if device.is_cpu:
28+
wp.capture_begin(device=device, force_module_load=False)
29+
try:
30+
for _ in range(4):
31+
t = wp.empty(4096, dtype=wp.float32, device=device)
32+
wp.launch(touch, dim=4096, inputs=[t])
33+
del t
34+
finally:
35+
g = wp.capture_end(device=device)
36+
37+
for _ in range(8):
38+
wp.capture_launch(g)
39+
else:
40+
main_stream = wp.get_stream(device)
41+
sub_stream = wp.Stream(device)
42+
43+
wp.capture_begin(device=device, stream=main_stream, force_module_load=False)
44+
try:
45+
sub_stream.wait_stream(main_stream)
46+
with wp.ScopedStream(sub_stream, sync_enter=False):
47+
for _ in range(4):
48+
t = wp.empty(4096, dtype=wp.float32, device=device)
49+
wp.launch(touch, dim=4096, inputs=[t], stream=sub_stream)
50+
del t
51+
main_stream.wait_stream(sub_stream)
52+
finally:
53+
g = wp.capture_end(device=device, stream=main_stream)
54+
55+
replay = wp.Stream(device)
56+
for _ in range(8):
57+
wp.capture_launch(g, stream=replay)
58+
wp.synchronize_stream(replay)
59+
60+
61+
if __name__ == "__main__":
62+
try:
63+
main()
64+
print("Reproducer finished successfully")
65+
except Exception as e:
66+
print("Reproducer failed:", e)
67+
raise

warp/native/warp.cu

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -827,18 +827,51 @@ void wp_free_device_async(void* context, void* ptr)
827827
// check if the capture is still active
828828
auto capture_iter = g_captures.find(capture_id);
829829
if (capture_iter != g_captures.end()) {
830-
// Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
831-
// work completes before deallocating. This works with both Warp-initiated and external captures
832-
// and avoids the need to explicitly track all streams used during the capture.
830+
// Add a mem free node. Prefer per-caller-stream capture dependencies so frees
831+
// are ordered with respect to work recorded on the stream where the free
832+
// occurs (handles forked substreams correctly). Fall back to using the
833+
// graph's leaf nodes if per-stream info isn't available.
833834
CaptureInfo* capture = capture_iter->second;
834-
cudaGraph_t graph = get_capture_graph(capture->stream);
835-
std::vector<cudaGraphNode_t> leaf_nodes;
836-
if (graph && get_graph_leaf_nodes(graph, leaf_nodes)) {
835+
cudaGraph_t begin_graph = get_capture_graph(capture->stream);
836+
837+
// get the caller stream (the stream on which wp_free_device_async was invoked)
838+
CUstream caller_cuda_stream = get_current_stream();
839+
840+
cudaGraph_t caller_graph = nullptr;
841+
const cudaGraphNode_t* capture_deps = nullptr;
842+
size_t dep_count = 0;
843+
CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
844+
845+
bool added = false;
846+
847+
// Query per-stream capture info into separate variables so we don't
848+
// clobber the begin_graph value used for the fallback.
849+
if (begin_graph
850+
&& check_cu(cuStreamGetCaptureInfo_f(
851+
caller_cuda_stream, &capture_status, nullptr, &caller_graph, &capture_deps, &dep_count
852+
))
853+
&& caller_graph && capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE) {
837854
cudaGraphNode_t free_node;
838-
if (check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr))) {
855+
if (check_cuda(cudaGraphAddMemFreeNode(&free_node, caller_graph, capture_deps, dep_count, ptr))) {
839856
check_cu(cuStreamUpdateCaptureDependencies_f(
840-
capture->stream, &free_node, 1, cudaStreamSetCaptureDependencies
857+
caller_cuda_stream, &free_node, 1, cudaStreamSetCaptureDependencies
841858
));
859+
added = true;
860+
}
861+
}
862+
863+
// Fallback: use the graph leaf nodes from the original capture if needed
864+
if (!added && begin_graph) {
865+
std::vector<cudaGraphNode_t> leaf_nodes;
866+
if (get_graph_leaf_nodes(begin_graph, leaf_nodes)) {
867+
cudaGraphNode_t free_node;
868+
if (check_cuda(
869+
cudaGraphAddMemFreeNode(&free_node, begin_graph, leaf_nodes.data(), leaf_nodes.size(), ptr)
870+
)) {
871+
check_cu(cuStreamUpdateCaptureDependencies_f(
872+
capture->stream, &free_node, 1, cudaStreamSetCaptureDependencies
873+
));
874+
}
842875
}
843876
}
844877

0 commit comments

Comments
 (0)