Skip to content

Commit 80a0dc0

Browse files
prady0tlucascolley
andauthored
ENH: testing: add dlpack conversion support (#749)
* Adding dlpack conversion support Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * precommit error fix Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * precommit error fix Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * fix lint * remove comment --------- Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 56fb8d8 commit 80a0dc0

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
136136
cpu = cast(Device, jax.devices("cpu")[0])
137137
array = to_device(array, cpu)
138138

139+
if hasattr(array, "__dlpack__"):
140+
try:
141+
return np.from_dlpack(array) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
142+
except (TypeError, BufferError):
143+
pass
144+
139145
return np.asarray(array)
140146

141147

0 commit comments

Comments
 (0)