@@ -772,41 +772,27 @@ def _cupy_to_device(
772772 stream : int | Any | None = None ,
773773) -> cp .ndarray :
774774 import cupy as cp
775- from cupy .cuda import Device as _Device # pyright: ignore
776- from cupy .cuda import stream as stream_module # pyright: ignore
777- from cupy_backends .cuda .api import runtime # pyright: ignore
778775
779- if device == x .device :
780- return x
781- elif device == "cpu" :
776+ if device == "cpu" :
782777 # allowing us to use `to_device(x, "cpu")`
783778 # is useful for portable test swapping between
784779 # host and device backends
785780 return x .get ()
786- elif not isinstance (device , _Device ):
787- raise ValueError (f"Unsupported device { device !r} " )
788- else :
789- # see cupy/cupy#5985 for the reason how we handle device/stream here
790- prev_device : Device = runtime .getDevice () # pyright: ignore[reportUnknownMemberType]
791- prev_stream = None
792- if stream is not None :
793- prev_stream = stream_module .get_current_stream () # pyright: ignore
794- # stream can be an int as specified in __dlpack__, or a CuPy stream
795- if isinstance (stream , int ):
796- stream = cp .cuda .ExternalStream (stream ) # pyright: ignore
797- elif isinstance (stream , cp .cuda .Stream ): # pyright: ignore[reportUnknownMemberType]
798- pass
799- else :
800- raise ValueError ("the input stream is not recognized" )
801- stream .use () # pyright: ignore[reportUnknownMemberType]
802- try :
803- runtime .setDevice (device .id ) # pyright: ignore[reportUnknownMemberType]
804- arr = x .copy ()
805- finally :
806- runtime .setDevice (prev_device ) # pyright: ignore[reportUnknownMemberType]
807- if prev_stream is not None :
808- prev_stream .use ()
809- return arr
781+ if not isinstance (device , cp .cuda .Device ):
782+ raise TypeError (f"Unsupported device type { device !r} " )
783+
784+ if stream is None :
785+ with device :
786+ return cp .asarray (x )
787+
788+ # stream can be an int as specified in __dlpack__, or a CuPy stream
789+ if isinstance (stream , int ):
790+ stream = cp .cuda .ExternalStream (stream )
791+ elif not isinstance (stream , cp .cuda .Stream ):
792+ raise TypeError (f"Unsupported stream type { stream !r} " )
793+
794+ with device , stream :
795+ return cp .asarray (x )
810796
811797
812798def _torch_to_device (
0 commit comments