Skip to content

Commit 0f36b0f

Browse files
btabacopybara-github
authored andcommitted
Import NVIDIA/warp from GitHub.
PiperOrigin-RevId: 872969263 Change-Id: Ic890cb53c32593c158a1f3926906b492c7f85348
1 parent e5a2367 commit 0f36b0f

3 files changed

Lines changed: 24 additions & 12 deletions

File tree

mjx/cuda_requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ jax-cuda12-pjrt==0.5.3; python_version >= '3.10' \
1616
jax-cuda12-pjrt==0.4.30; python_version == '3.9' \
1717
--hash=sha256:895d0198ad99638fcaf976c47592e2a543eef79ea15fabd24a402d055390c328 \
1818
--hash=sha256:c36fb1e0c236563bf3a87e70f4d1ab28a31d7cf5d722c9ede30c4172116e8bcb
19-
warp-lang==1.11.0 \
20-
--hash=sha256:3a4f1c9a6e721d7de7d6dad6b242c54afaf20c6e14a767c0da03e5e963fcc13c \
21-
--hash=sha256:524dce20de6162ba25333552168ebf430973050e00d9f8116b8df41a60d25d6e \
22-
--hash=sha256:1ae6cfc226107f96e4d495b41a3dab32488e8ee8f074b0e1bcaf22e7fb8c904d \
23-
--hash=sha256:80d8493cbe243a3510134f3af289646d7bd7484217a30ecf565d676466ef8a5e
19+
warp-lang==1.11.1 \
20+
--hash=sha256:1ad11f1fa775269e991a3d55039152c8a504baf86701c849b485cb8e66c49d15 \
21+
--hash=sha256:8b098f41e71d421d80ee7562e38aa8380ff6b0d3b4c6ee866cfbdef733ac5bdc \
22+
--hash=sha256:5d0904b0eefcc81f39ba65375427a3de99006088aa43e24a9011263f07d0cd07 \
23+
--hash=sha256:15dc10aa51fb0fdbe1ca16d52e5fadca35a47ffd9d0c636826506f96bb2e7c41

mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,22 @@ def check_jax_version():
6262

6363

6464
class GraphMode(IntEnum):
65-
NONE = 0 # don't capture a graph
66-
JAX = 1 # let JAX capture a graph
67-
WARP = 2 # let Warp capture a graph
68-
WARP_STAGED = 3 # use Warp graph with staging buffers, copy inside of the graph
69-
WARP_STAGED_EX = 4 # use Warp graph with staging buffers, copy outside of the graph
65+
"""CUDA graph capture modes for :func:`warp.jax_experimental.jax_callable`.
66+
67+
These modes control whether JAX or Warp captures a CUDA graph, and whether
68+
staging buffers are used when capturing with Warp.
69+
"""
70+
71+
NONE = 0
72+
"""Disable graph capture. Use when operations are not CUDA-graph compatible (for example, host synchronization)."""
73+
JAX = 1
74+
"""Let JAX capture the graph so the callable can be used as a subgraph within a larger JAX capture."""
75+
WARP = 2
76+
"""Let Warp capture the graph and replay it for matching buffer addresses."""
77+
WARP_STAGED = 3
78+
"""Capture a Warp graph using staging buffers and insert memcpy nodes inside the graph."""
79+
WARP_STAGED_EX = 4
80+
"""Capture a Warp graph using staging buffers and perform memcpy outside the graph."""
7081

7182

7283
class ModulePreloadMode(IntEnum):
@@ -682,12 +693,13 @@ def ffi_callback(self, call_frame):
682693
assert num_outputs == self.num_outputs
683694

684695
cuda_stream = get_stream_from_callframe(call_frame.contents)
696+
device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
685697

686698
if self.graph_mode == GraphMode.WARP:
687699
# check if we already captured an identical call
688700
ip = [inputs[i].contents.data for i in self.array_input_indices]
689701
op = [outputs[i].contents.data for i in self.array_output_indices]
690-
capture_key = hash((call_id, *ip, *op))
702+
capture_key = hash((device_ordinal, call_id, *ip, *op))
691703
capture = self.captures.get(capture_key)
692704

693705
# launch existing graph

mjx/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ dependencies = [
3636

3737
[project.optional-dependencies]
3838
warp = [
39-
"warp-lang==1.11.0",
39+
"warp-lang==1.11.1",
4040
]
4141

4242
[project.scripts]

0 commit comments

Comments
 (0)