Skip to content

Commit 0510351

Browse files
committed
Adding dlpack conversion support
Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
1 parent 5da15d8 commit 0510351

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
136136
cpu = cast(Device, jax.devices("cpu")[0])
137137
array = to_device(array, cpu)
138138

139+
# Try DLPack (works for JAX and other backends)
140+
if hasattr(array, "__dlpack__"):
141+
try:
142+
return np.from_dlpack(array)
143+
except (TypeError, BufferError):
144+
pass
145+
139146
return np.asarray(array)
140147

141148

0 commit comments

Comments
 (0)