Skip to content

Commit b336700

Browse files
authored
cudax/stf: migrate stackable/ from cuda_safe_call to cuda_try (#9165)
* cudax/stf: migrate stackable/ from cuda_safe_call to cuda_try In stackable_ctx_impl.cuh, replace cuda_safe_call with cuda_try in the graph_ctx_node constructor and finalize() so CUDA errors are reported as exceptions rather than aborting the process. The constructor builds a CUDA graph in stages, so add transactional cleanup: - In the nested non-conditional branch, the freshly created dummy_graph is destroyed intentionally mid-block. Guard it with a SCOPE(fail) that frees it only while dummy_graph_owned is true, and disarm the flag right after the intentional destroy. - The outer `graph` is owned by us only in the non-nested case (in the nested cases it is either parent_graph or a child of parent_graph, both owned upstream). A SCOPE(fail) destroys it on early throw and is disarmed the instant graph_ctx adopts it via `auto gctx = graph_ctx(sub_graph, ...);`, matching graph_ctx's documented ownership contract ("User code is not supposed to destroy the graph later"). - The conditional handle (cudaGraphConditionalHandleCreate) and any nodes added to `graph` (cudaGraphAddNode, cudaGraphAddKernelNode) are implicitly cleaned up by the outer SCOPE(fail) destroying `graph`. Two residual hazards are intentionally documented inline rather than fixed in this commit: - cudaGraphAddChildGraphNode leaves an orphaned child node inside parent_graph on later throw; cleanly removing it would need cudaGraphDestroyNode and dependency rewiring. - cudaGraphConditionalHandleCreate writes a handle into a caller-owned pointer; CUDA has no destroy API for conditional handles, so on throw the handle is left invalid (its backing graph is destroyed). Both are no worse than the prior behavior (which aborted). The four cuda_safe_call sites in finalize() (cudaGraphAddDependencies on both CTK branches, cudaGraphDebugDotPrint, cudaGraphLaunch) become plain cuda_try; no resource rollback applies. The two cuda_safe_call sites inside the new SCOPE(fail) bodies are intentional: SCOPE bodies are noexcept, so cuda_safe_call is the correct tool there. In stackable_ctx.cuh, the two cuda_safe_call sites inside UNITTEST host-task lambdas are kept and annotated. Those lambdas are dispatched by the STF host-task path, whose exception-safety has not been audited, so an abort remains safer than an unannotated throw. * cudax/stf: use templated cuda_try<F> form for graph out-params in stackable/ Where a graph_ctx_node site captures a CUDA out-parameter, switch from the runtime-status form `cuda_try(cudaFn(&out, ...))` to the templated form `out = cuda_try<cudaFn>(...)`. This allows the captured handle to be a single const-initialized local instead of declare-then-fill. Converted (each capturing one output handle): - cudaGraphCreate (x2) -> dummy_graph / graph - cudaGraphAddChildGraphNode -> const n - cudaGraphChildGraphNodeGetGraph -> graph (last-output-param form) - cudaGraphAddNode (both CTK -> const conditionalNode branches) - cudaGraphAddKernelNode -> const reset_node dummy_graph stays non-const because it is reset to nullptr to disarm its SCOPE(fail) guard after the intentional destroy. Left in runtime-status form on purpose: - cudaGraphConditionalHandleCreate: its output is written into the caller-owned config.conditional_handle, not a synthesizable local; the templated form would create a throwaway local and lose the write. - cudaGraphDestroy, cudaGraphAddDependencies, cudaGraphDebugDotPrint, cudaGraphLaunch: no captured output, so the templated form adds nothing. cudaGraphAddNode is convertible despite its trailing non-const cudaGraphNodeParams* because the last-output interpretation fails to typecheck (cudaGraph_t is not convertible to cudaGraphNode_t* once the synthesized pointer is appended), so cuda_try selects the first-output form unambiguously.
1 parent ee9f95b commit b336700

2 files changed

Lines changed: 77 additions & 27 deletions

File tree

cudax/include/cuda/experimental/__stf/stackable/stackable_ctx.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,10 @@ UNITTEST("stackable task on exec_place::host()")
16381638
stackable_ctx ctx;
16391639
auto lA = ctx.logical_data(shape_of<slice<int>>(1024));
16401640
ctx.task(exec_place::host(), lA.write())->*[](cudaStream_t stream, auto) {
1641+
// cuda_safe_call (not cuda_try) on purpose: this lambda body is invoked from
1642+
// the STF runtime under host-task dispatch, where exception safety has not
1643+
// been audited. An abort here is preferable to an unannotated throw escaping
1644+
// into the runtime.
16411645
cuda_safe_call(cudaStreamSynchronize(stream));
16421646
};
16431647
ctx.finalize();
@@ -1648,6 +1652,8 @@ UNITTEST("stackable task with set_symbol and set_exec_place")
16481652
stackable_ctx ctx;
16491653
auto lA = ctx.logical_data(shape_of<slice<int>>(1024));
16501654
ctx.task(lA.write()).set_symbol("task").set_exec_place(exec_place::host())->*[](cudaStream_t stream, auto) {
1655+
// Same rationale as the previous test: keep cuda_safe_call inside this
1656+
// host-task lambda until the dispatch path is audited for exception safety.
16511657
cuda_safe_call(cudaStreamSynchronize(stream));
16521658
};
16531659
ctx.finalize();

cudax/include/cuda/experimental/__stf/stackable/stackable_ctx_impl.cuh

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "cuda/experimental/__stf/stackable/stackable_node_hierarchy.cuh"
3838
#include "cuda/experimental/__stf/stackable/stackable_task_dep.cuh"
3939
#include "cuda/experimental/__stf/utility/hash.cuh"
40+
#include "cuda/experimental/__stf/utility/scope_guard.cuh"
4041
#include "cuda/experimental/__stf/utility/source_location.cuh"
4142

4243
namespace cuda::experimental::stf
@@ -453,38 +454,76 @@ public:
453454
else
454455
#endif // _CCCL_CTK_AT_LEAST(12, 4) && !defined(CUDASTF_DISABLE_CODE_GENERATION) && defined(__CUDACC__)
455456
{
456-
cudaGraph_t dummy_graph;
457-
cuda_safe_call(cudaGraphCreate(&dummy_graph, 0));
457+
cudaGraph_t dummy_graph = cuda_try<cudaGraphCreate>(0);
458+
459+
// dummy_graph is intentionally destroyed below. Until that destroy
460+
// succeeds, we own it; this guard releases it if any of the following
461+
// calls throws. The handle itself carries the ownership bit -- we null
462+
// it after the intentional destroy to disarm.
463+
SCOPE(fail)
464+
{
465+
if (dummy_graph != nullptr)
466+
{
467+
cuda_safe_call(cudaGraphDestroy(dummy_graph));
468+
}
469+
};
458470

459-
// The dependencies to this child graph will be added later
460-
cudaGraphNode_t n;
461-
cuda_safe_call(cudaGraphAddChildGraphNode(&n, parent_graph, nullptr, 0, dummy_graph));
471+
// The dependencies to this child graph will be added later.
472+
// NOTE: cudaGraphAddChildGraphNode adds `n` into parent_graph. If the
473+
// constructor throws after this point, `n` remains as an orphaned child
474+
// node inside parent_graph. Removing it cleanly would require
475+
// cudaGraphDestroyNode plus careful dependency rewiring; we accept the
476+
// orphan because parent_graph's lifetime extends well beyond this scope
477+
// and the orphan is harmless until parent_graph is destroyed.
478+
const cudaGraphNode_t n = cuda_try<cudaGraphAddChildGraphNode>(parent_graph, nullptr, 0, dummy_graph);
462479

463480
// cudaGraphAddChildGraphNode clones dummy_graph into the parent
464-
cuda_safe_call(cudaGraphDestroy(dummy_graph));
481+
cuda_try(cudaGraphDestroy(dummy_graph));
482+
dummy_graph = nullptr;
465483

466484
// Get the graph described by the child, not the graph that was
467485
// cloned into the child graph node so that changes are reflected
468486
// in it.
469-
cuda_safe_call(cudaGraphChildGraphNodeGetGraph(n, &graph));
487+
graph = cuda_try<cudaGraphChildGraphNodeGetGraph>(n);
470488
input_node = n;
471489
output_node = n;
472490
}
473491
}
474492
else
475493
{
476-
cuda_safe_call(cudaGraphCreate(&graph, 0));
494+
graph = cuda_try<cudaGraphCreate>(0);
477495
}
478496

497+
// We own `graph` (the raw cudaGraph_t member) only when we freshly allocated
498+
// it above (the non-nested case). In the nested case, `graph` is either
499+
// parent_graph itself or a child graph inside parent_graph, both owned by
500+
// parent_graph -- destroying it would be incorrect. The graph_ctx built
501+
// below (`gctx`) takes ownership when constructed successfully (see
502+
// graph_ctx's ctor doc: "User code is not supposed to destroy the graph
503+
// later."), so we disarm this guard right after that point.
504+
bool graph_owned_by_us = !nested_graph;
505+
SCOPE(fail)
506+
{
507+
if (graph_owned_by_us)
508+
{
509+
cuda_safe_call(cudaGraphDestroy(graph));
510+
}
511+
};
512+
479513
// This is the graph which will be used by our STF context. If we need
480514
// to use a conditional node later, will will change that value.
481515
cudaGraph_t sub_graph = graph;
482516

483517
#if _CCCL_CTK_AT_LEAST(12, 4) && !defined(CUDASTF_DISABLE_CODE_GENERATION) && defined(__CUDACC__)
484518
if (config.conditional_handle != nullptr)
485519
{
486-
// Create the conditional handle and store it in the provided pointer
487-
cuda_safe_call(cudaGraphConditionalHandleCreate(
520+
// Create the conditional handle and store it in the provided pointer.
521+
// NOTE: on a thrown cuda_try below, *config.conditional_handle is left
522+
// pointing at a handle whose backing graph will be destroyed by the
523+
// SCOPE(fail) above. CUDA has no destroy API for conditional handles
524+
// (they are tied to their graph), so the handle simply becomes invalid
525+
// and the caller must not use it after catching the exception.
526+
cuda_try(cudaGraphConditionalHandleCreate(
488527
config.conditional_handle, graph, config.default_launch_value, config.flags));
489528

490529
// Create conditional node parameters
@@ -494,12 +533,13 @@ public:
494533
cParams.conditional.type = config.conditional_type;
495534
cParams.conditional.size = 1;
496535

497-
// Add conditional node to parent graph
498-
cudaGraphNode_t conditionalNode;
536+
// Add conditional node to parent graph. The node lives inside `graph`;
537+
// if construction fails later, the SCOPE(fail) above destroys `graph`
538+
// and the node is cleaned up implicitly.
499539
# if _CCCL_CTK_AT_LEAST(13, 0)
500-
cuda_safe_call(cudaGraphAddNode(&conditionalNode, graph, nullptr, nullptr, 0, &cParams));
540+
const cudaGraphNode_t conditionalNode = cuda_try<cudaGraphAddNode>(graph, nullptr, nullptr, 0, &cParams);
501541
# else
502-
cuda_safe_call(cudaGraphAddNode(&conditionalNode, graph, nullptr, 0, &cParams));
542+
const cudaGraphNode_t conditionalNode = cuda_try<cudaGraphAddNode>(graph, nullptr, 0, &cParams);
503543
# endif
504544

505545
// Get the body graph from the conditional node
@@ -518,8 +558,7 @@ public:
518558
kconfig.kernelParams = kconfig_args;
519559
kconfig.sharedMemBytes = 0;
520560

521-
cudaGraphNode_t reset_node;
522-
cuda_safe_call(cudaGraphAddKernelNode(&reset_node, graph, &conditionalNode, 1, &kconfig));
561+
const cudaGraphNode_t reset_node = cuda_try<cudaGraphAddKernelNode>(graph, &conditionalNode, 1, &kconfig);
523562

524563
input_node = conditionalNode;
525564
output_node = reset_node;
@@ -540,6 +579,11 @@ public:
540579
}
541580

542581
auto gctx = graph_ctx(sub_graph, support_stream, handle);
582+
// graph_ctx's ctor took ownership of sub_graph (and, in the non-nested
583+
// non-conditional case, of `graph` since sub_graph == graph). From here on,
584+
// gctx's destructor handles cleanup, including if any of the following lines
585+
// throws before `ctx = gctx`.
586+
graph_owned_by_us = false;
543587

544588
// Set up context properties
545589
gctx.set_parent_ctx(parent_ctx);
@@ -605,7 +649,7 @@ public:
605649
{
606650
static ::std::atomic<int> debug_graph_cnt{0};
607651
::std::string filename = "instantiated_graph" + ::std::to_string(debug_graph_cnt++) + ".dot";
608-
cuda_safe_call(cudaGraphDebugDotPrint(graph, filename.c_str(), cudaGraphDebugDotFlags(0)));
652+
cuda_try(cudaGraphDebugDotPrint(graph, filename.c_str(), cudaGraphDebugDotFlags(0)));
609653
::std::cout << "Debug: Stackable graph DOT output written to " << filename << '\n';
610654
}
611655
}
@@ -628,11 +672,11 @@ public:
628672

629673
size_t nnodes;
630674
size_t nedges;
631-
cuda_safe_call(cudaGraphGetNodes(graph, nullptr, &nnodes));
675+
cuda_try(cudaGraphGetNodes(graph, nullptr, &nnodes));
632676
#if _CCCL_CTK_AT_LEAST(13, 0)
633-
cuda_safe_call(cudaGraphGetEdges(graph, nullptr, nullptr, nullptr, &nedges));
677+
cuda_try(cudaGraphGetEdges(graph, nullptr, nullptr, nullptr, &nedges));
634678
#else
635-
cuda_safe_call(cudaGraphGetEdges(graph, nullptr, nullptr, &nedges));
679+
cuda_try(cudaGraphGetEdges(graph, nullptr, nullptr, &nedges));
636680
#endif
637681

638682
auto [cached_exec, cache_hit] = ctx.async_resources().cached_graphs_query(nnodes, nedges, graph);
@@ -680,7 +724,7 @@ public:
680724
_CCCL_ASSERT(exec_graph_, "launch_once called before ensure_instantiated");
681725
_CCCL_ASSERT(stream == support_stream, "launch_once only supports the node's support stream");
682726
ensure_prereqs_synced();
683-
cuda_safe_call(cudaGraphLaunch(*exec_graph_, stream));
727+
cuda_try(cudaGraphLaunch(*exec_graph_, stream));
684728
}
685729

686730
// Release resources and build the finalize_prereqs event list that the
@@ -740,11 +784,11 @@ public:
740784
auto* cache_stat = ctx.graph_get_cache_stat();
741785
if (cache_stat)
742786
{
743-
cuda_safe_call(cudaGraphGetNodes(body, nullptr, &cache_stat->nnodes));
787+
cuda_try(cudaGraphGetNodes(body, nullptr, &cache_stat->nnodes));
744788
#if _CCCL_CTK_AT_LEAST(13, 0)
745-
cuda_safe_call(cudaGraphGetEdges(body, nullptr, nullptr, nullptr, &cache_stat->nedges));
789+
cuda_try(cudaGraphGetEdges(body, nullptr, nullptr, nullptr, &cache_stat->nedges));
746790
#else
747-
cuda_safe_call(cudaGraphGetEdges(body, nullptr, nullptr, &cache_stat->nedges));
791+
cuda_try(cudaGraphGetEdges(body, nullptr, nullptr, &cache_stat->nedges));
748792
#endif
749793
}
750794

@@ -754,7 +798,7 @@ public:
754798
{
755799
static ::std::atomic<int> nested_cnt{0};
756800
::std::string filename = "nested_graph" + ::std::to_string(nested_cnt++) + ".dot";
757-
cuda_safe_call(cudaGraphDebugDotPrint(body, filename.c_str(), cudaGraphDebugDotFlags(0)));
801+
cuda_try(cudaGraphDebugDotPrint(body, filename.c_str(), cudaGraphDebugDotFlags(0)));
758802
}
759803

760804
cudaGraph_t support_graph = parent_ctx.graph();
@@ -773,10 +817,10 @@ public:
773817
// Create a vector of input_node repeated for each dependency
774818
::std::vector<cudaGraphNode_t> to_nodes(ctx_ready_nodes.size(), input_node);
775819
#if _CCCL_CTK_AT_LEAST(13, 0)
776-
cuda_safe_call(cudaGraphAddDependencies(
820+
cuda_try(cudaGraphAddDependencies(
777821
support_graph, ctx_ready_nodes.data(), to_nodes.data(), nullptr, ctx_ready_nodes.size()));
778822
#else // _CCCL_CTK_AT_LEAST(13, 0)
779-
cuda_safe_call(
823+
cuda_try(
780824
cudaGraphAddDependencies(support_graph, ctx_ready_nodes.data(), to_nodes.data(), ctx_ready_nodes.size()));
781825
#endif // _CCCL_CTK_AT_LEAST(13, 0)
782826
}

0 commit comments

Comments
 (0)