Skip to content

Commit fbb92b0

Browse files
committed
fix(graph): correct MEM_FREE node dependencies for substream allocations
Captured cudaGraphAddMemFreeNode was inheriting the begin-stream's frontier instead of the substream's, causing use-after-free on replay. Also adds reproducer.py with CPU fallback for GPU-less environments.
1 parent 2a8b310 commit fbb92b0

2 files changed

Lines changed: 101 additions & 10 deletions

File tree

reproducer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import warp as wp
2+
3+
# CPU-mode reproducer adapted from your CUDA snippet
4+
5+
def main():
6+
device = wp.get_device("cpu")
7+
try:
8+
wp.load_module(device=device)
9+
except Exception:
10+
# load_module may be unnecessary on CPU; ignore failures
11+
pass
12+
13+
@wp.kernel
14+
def touch(x: wp.array(dtype=wp.float32)):
15+
i = wp.tid()
16+
if i < x.shape[0]:
17+
x[i] = x[i] + 1.0
18+
19+
# CPU devices do not have CUDA streams; run a CPU-friendly capture path.
20+
if device.is_cpu:
21+
wp.capture_begin(device=device, force_module_load=False)
22+
try:
23+
for _ in range(4):
24+
t = wp.empty(4096, dtype=wp.float32, device=device)
25+
wp.launch(touch, dim=4096, inputs=[t])
26+
del t
27+
finally:
28+
g = wp.capture_end(device=device)
29+
30+
for _ in range(8):
31+
wp.capture_launch(g)
32+
else:
33+
main_stream = wp.get_stream(device)
34+
sub_stream = wp.Stream(device)
35+
36+
wp.capture_begin(device=device, stream=main_stream, force_module_load=False)
37+
try:
38+
sub_stream.wait_stream(main_stream)
39+
with wp.ScopedStream(sub_stream, sync_enter=False):
40+
for _ in range(4):
41+
t = wp.empty(4096, dtype=wp.float32, device=device)
42+
wp.launch(touch, dim=4096, inputs=[t], stream=sub_stream)
43+
del t
44+
main_stream.wait_stream(sub_stream)
45+
finally:
46+
g = wp.capture_end(device=device, stream=main_stream)
47+
48+
replay = wp.Stream(device)
49+
for _ in range(8):
50+
wp.capture_launch(g, stream=replay)
51+
wp.synchronize_stream(replay)
52+
53+
54+
if __name__ == '__main__':
55+
try:
56+
main()
57+
print('Reproducer finished successfully')
58+
except Exception as e:
59+
print('Reproducer failed:', e)
60+
raise

warp/native/warp.cu

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -827,18 +827,49 @@ void wp_free_device_async(void* context, void* ptr)
827827
// check if the capture is still active
828828
auto capture_iter = g_captures.find(capture_id);
829829
if (capture_iter != g_captures.end()) {
830-
// Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
831-
// work completes before deallocating. This works with both Warp-initiated and external captures
832-
// and avoids the need to explicitly track all streams used during the capture.
830+
// Add a mem free node. Use the caller stream's capture dependencies so frees
831+
// are ordered with respect to work recorded on the stream where the free
832+
// occurs (handles forked substreams correctly). Fall back to using the
833+
// global graph leaf nodes if capture info isn't available for the caller.
833834
CaptureInfo* capture = capture_iter->second;
834835
cudaGraph_t graph = get_capture_graph(capture->stream);
835-
std::vector<cudaGraphNode_t> leaf_nodes;
836-
if (graph && get_graph_leaf_nodes(graph, leaf_nodes)) {
837-
cudaGraphNode_t free_node;
838-
if (check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr))) {
839-
check_cu(cuStreamUpdateCaptureDependencies_f(
840-
capture->stream, &free_node, 1, cudaStreamSetCaptureDependencies
841-
));
836+
837+
// get the caller stream (the stream on which wp_free_device_async was invoked)
838+
CUstream caller_cuda_stream = get_current_stream();
839+
840+
const cudaGraphNode_t* capture_deps = nullptr;
841+
size_t dep_count = 0;
842+
CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
843+
844+
bool added = false;
845+
846+
// Try to get per-stream capture dependencies for the caller stream and use
847+
// them as predecessors for the memfree node. This ensures the memfree will
848+
// be ordered after work recorded on the caller stream (including forked
849+
// substreams brought into the capture via wait_stream/wait_event).
850+
if (graph && check_cu(cuStreamGetCaptureInfo_f(caller_cuda_stream, &capture_status, nullptr, &graph, &capture_deps, &dep_count))) {
851+
if (graph && (capture_deps != nullptr || dep_count > 0)) {
852+
cudaGraphNode_t free_node;
853+
if (check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, capture_deps, dep_count, ptr))) {
854+
check_cu(cuStreamUpdateCaptureDependencies_f(
855+
caller_cuda_stream, &free_node, 1, cudaStreamSetCaptureDependencies
856+
));
857+
added = true;
858+
}
859+
}
860+
}
861+
862+
// Fallback: if we couldn't obtain per-stream capture dependencies, use
863+
// the graph leaf nodes as before.
864+
if (!added && graph) {
865+
std::vector<cudaGraphNode_t> leaf_nodes;
866+
if (get_graph_leaf_nodes(graph, leaf_nodes)) {
867+
cudaGraphNode_t free_node;
868+
if (check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr))) {
869+
check_cu(cuStreamUpdateCaptureDependencies_f(
870+
capture->stream, &free_node, 1, cudaStreamSetCaptureDependencies
871+
));
872+
}
842873
}
843874
}
844875

0 commit comments

Comments
 (0)