66import io
77import pickle as _pickle
88from ast import literal_eval
9- from typing import Any
9+ from collections .abc import Callable , Iterable , Sequence
10+ from typing import BinaryIO , TypeAlias , TypeGuard , cast
11+
12+ from typing_extensions import override
1013
1114from vortex ._lib .arrays import Array # pyright: ignore[reportMissingModuleSource]
1215from vortex ._lib .serde import ( # pyright: ignore[reportMissingModuleSource]
1821_ARRAY_PERSISTENT_ID = "vortex.array"
1922_ARRAY_PERSISTENT_ID_VERSION = 1
2023
24+ _BufferSequence : TypeAlias = Sequence [bytes | memoryview ]
25+ _ArrayPersistentId : TypeAlias = tuple [str , int , _BufferSequence , _BufferSequence ]
26+ _BufferCallback : TypeAlias = Callable [[_pickle .PickleBuffer ], object | None ]
27+ _OutOfBandBuffers : TypeAlias = Iterable [bytes | bytearray | memoryview | _pickle .PickleBuffer ]
28+
29+
30+ def _is_buffer_sequence (obj : object ) -> TypeGuard [_BufferSequence ]:
31+ return isinstance (obj , Sequence ) and all (isinstance (buffer , bytes | memoryview ) for buffer in obj )
32+
33+
34+ def _parse_array_persistent_id (pid : object ) -> _ArrayPersistentId :
35+ parsed_pid : object = pid
36+ if isinstance (parsed_pid , str ):
37+ try :
38+ parsed_pid = cast (object , literal_eval (parsed_pid ))
39+ except (SyntaxError , ValueError ) as err :
40+ raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " ) from err
41+
42+ if not isinstance (parsed_pid , tuple ):
43+ raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " )
44+
45+ parsed_tuple = cast (tuple [object , ...], parsed_pid )
46+ if len (parsed_tuple ) != 4 :
47+ raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " )
48+
49+ tag , version , array_buffers , dtype_buffers = parsed_tuple
50+ if tag != _ARRAY_PERSISTENT_ID or version != _ARRAY_PERSISTENT_ID_VERSION :
51+ raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " )
52+ if not _is_buffer_sequence (array_buffers ) or not _is_buffer_sequence (dtype_buffers ):
53+ raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " )
54+
55+ return (_ARRAY_PERSISTENT_ID , _ARRAY_PERSISTENT_ID_VERSION , array_buffers , dtype_buffers )
56+
2157
2258class Pickler (_pickle .Pickler ):
2359 """Pickler that serializes Vortex arrays using an explicit session."""
2460
2561 def __init__ (
2662 self ,
27- file : Any , # pyright: ignore[reportExplicitAny]
63+ file : BinaryIO ,
2864 * ,
2965 session : Session ,
3066 protocol : int | None = None ,
3167 fix_imports : bool = True ,
32- buffer_callback : Any | None = None , # pyright: ignore[reportExplicitAny]
68+ buffer_callback : _BufferCallback | None = None ,
3369 ) -> None :
3470 super ().__init__ (
3571 file ,
3672 protocol = protocol ,
3773 fix_imports = fix_imports ,
3874 buffer_callback = buffer_callback ,
3975 )
40- self ._session = session
76+ self ._session : Session = session
4177
78+ @override
4279 def persistent_id (self , obj : object ) -> object | None :
4380 if isinstance (obj , Array ):
4481 array_buffers , dtype_buffers = encode_ipc_array_buffers (obj , session = self ._session )
@@ -51,13 +88,13 @@ class Unpickler(_pickle.Unpickler):
5188
5289 def __init__ (
5390 self ,
54- file : Any , # pyright: ignore[reportExplicitAny]
91+ file : BinaryIO ,
5592 * ,
5693 session : Session ,
5794 fix_imports : bool = True ,
5895 encoding : str = "ASCII" ,
5996 errors : str = "strict" ,
60- buffers : Any | None = None , # pyright: ignore[reportExplicitAny]
97+ buffers : _OutOfBandBuffers | None = None ,
6198 ) -> None :
6299 super ().__init__ (
63100 file ,
@@ -66,33 +103,22 @@ def __init__(
66103 errors = errors ,
67104 buffers = buffers ,
68105 )
69- self ._session = session
106+ self ._session : Session = session
70107
108+ @override
71109 def persistent_load (self , pid : object ) -> object :
72- if isinstance (pid , str ):
73- try :
74- pid = literal_eval (pid )
75- except (SyntaxError , ValueError ) as err :
76- raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " ) from err
77-
78- if not isinstance (pid , tuple ) or len (pid ) != 4 :
79- raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " )
80-
81- tag , version , array_buffers , dtype_buffers = pid
82- if tag != _ARRAY_PERSISTENT_ID or version != _ARRAY_PERSISTENT_ID_VERSION :
83- raise _pickle .UnpicklingError (f"unsupported persistent id: { pid !r} " )
84-
110+ _ , _ , array_buffers , dtype_buffers = _parse_array_persistent_id (pid )
85111 return decode_ipc_array_buffers (array_buffers , dtype_buffers , session = self ._session )
86112
87113
88114def dump (
89115 obj : object ,
90- file : Any , # pyright: ignore[reportExplicitAny]
116+ file : BinaryIO ,
91117 * ,
92118 session : Session ,
93119 protocol : int | None = None ,
94120 fix_imports : bool = True ,
95- buffer_callback : Any | None = None , # pyright: ignore[reportExplicitAny]
121+ buffer_callback : _BufferCallback | None = None ,
96122) -> None :
97123 Pickler (
98124 file ,
@@ -109,7 +135,7 @@ def dumps(
109135 session : Session ,
110136 protocol : int | None = None ,
111137 fix_imports : bool = True ,
112- buffer_callback : Any | None = None , # pyright: ignore[reportExplicitAny]
138+ buffer_callback : _BufferCallback | None = None ,
113139) -> bytes :
114140 file = io .BytesIO ()
115141 dump (
@@ -124,22 +150,25 @@ def dumps(
124150
125151
126152def load (
127- file : Any , # pyright: ignore[reportExplicitAny]
153+ file : BinaryIO ,
128154 * ,
129155 session : Session ,
130156 fix_imports : bool = True ,
131157 encoding : str = "ASCII" ,
132158 errors : str = "strict" ,
133- buffers : Any | None = None , # pyright: ignore[reportExplicitAny]
159+ buffers : _OutOfBandBuffers | None = None ,
134160) -> object :
135- return Unpickler (
136- file ,
137- session = session ,
138- fix_imports = fix_imports ,
139- encoding = encoding ,
140- errors = errors ,
141- buffers = buffers ,
142- ).load ()
161+ return cast (
162+ object ,
163+ Unpickler (
164+ file ,
165+ session = session ,
166+ fix_imports = fix_imports ,
167+ encoding = encoding ,
168+ errors = errors ,
169+ buffers = buffers ,
170+ ).load (),
171+ )
143172
144173
145174def loads (
@@ -149,7 +178,7 @@ def loads(
149178 fix_imports : bool = True ,
150179 encoding : str = "ASCII" ,
151180 errors : str = "strict" ,
152- buffers : Any | None = None , # pyright: ignore[reportExplicitAny]
181+ buffers : _OutOfBandBuffers | None = None ,
153182) -> object :
154183 return load (
155184 io .BytesIO (data ),
0 commit comments