@@ -74,6 +74,7 @@ def to_dict(self) -> dict[str, JSON]:
7474 def _zstd_codec (self ) -> nvcomp .Codec :
7575 device = cp .cuda .Device () # Select the current default device
7676 stream = cp .cuda .get_current_stream () # Use the current default stream
77+ # Note: this returns an array with dtype=np.dtype("int8")
7778 return nvcomp .Codec (
7879 algorithm = "Zstd" ,
7980 bitstream_kind = nvcomp .BitstreamKind .RAW ,
@@ -95,12 +96,18 @@ def _convert_from_nvcomp_arrays(
9596 arrays : Iterable [nvcomp .Array ],
9697 chunks_and_specs : Iterable [tuple [Buffer | None , ArraySpec ]],
9798 ) -> Iterable [Buffer | None ]:
98- return [
99- spec .prototype .buffer .from_array_like (cp .array (a , dtype = np .dtype ("B" ), copy = False ))
100- if a
101- else None
102- for a , (_ , spec ) in zip (arrays , chunks_and_specs , strict = True )
103- ]
99+ result : list [Buffer | None ] = []
100+
101+ for a , (_ , spec ) in zip (arrays , chunks_and_specs , strict = True ):
102+ if a is None :
103+ result .append (None )
104+ else :
105+ a2 = cp .array (a , dtype = a .dtype , copy = False )
106+ if a2 .dtype != np .dtype ("B" ):
107+ a2 = a2 .view (dtype = np .dtype ("B" ))
108+ result .append (spec .prototype .buffer .from_array_like (a2 ))
109+
110+ return result
104111
105112 async def decode (
106113 self ,
0 commit comments