@@ -62,11 +62,22 @@ def check_jax_version():
6262
6363
6464class GraphMode (IntEnum ):
65- NONE = 0 # don't capture a graph
66- JAX = 1 # let JAX capture a graph
67- WARP = 2 # let Warp capture a graph
68- WARP_STAGED = 3 # use Warp graph with staging buffers, copy inside of the graph
69- WARP_STAGED_EX = 4 # use Warp graph with staging buffers, copy outside of the graph
65+ """CUDA graph capture modes for :func:`warp.jax_experimental.jax_callable`.
66+
67+ These modes control whether JAX or Warp captures a CUDA graph, and whether
68+ staging buffers are used when capturing with Warp.
69+ """
70+
71+ NONE = 0
72+ """Disable graph capture. Use when operations are not CUDA-graph compatible (for example, host synchronization)."""
73+ JAX = 1
74+ """Let JAX capture the graph so the callable can be used as a subgraph within a larger JAX capture."""
75+ WARP = 2
76+ """Let Warp capture the graph and replay it for matching buffer addresses."""
77+ WARP_STAGED = 3
78+ """Capture a Warp graph using staging buffers and insert memcpy nodes inside the graph."""
79+ WARP_STAGED_EX = 4
80+ """Capture a Warp graph using staging buffers and perform memcpy outside the graph."""
7081
7182
7283class ModulePreloadMode (IntEnum ):
@@ -682,12 +693,13 @@ def ffi_callback(self, call_frame):
682693 assert num_outputs == self .num_outputs
683694
684695 cuda_stream = get_stream_from_callframe (call_frame .contents )
696+ device_ordinal = get_device_ordinal_from_callframe (call_frame .contents )
685697
686698 if self .graph_mode == GraphMode .WARP :
687699 # check if we already captured an identical call
688700 ip = [inputs [i ].contents .data for i in self .array_input_indices ]
689701 op = [outputs [i ].contents .data for i in self .array_output_indices ]
690- capture_key = hash ((call_id , * ip , * op ))
702+ capture_key = hash ((device_ordinal , call_id , * ip , * op ))
691703 capture = self .captures .get (capture_key )
692704
693705 # launch existing graph
0 commit comments