44Support for serialization of numpy data types with msgpack.
55"""
66
7- # Copyright (c) 2013-2020 , Lev E. Givon
7+ # Copyright (c) 2013-2022 , Lev E. Givon
88# All rights reserved.
99# Distributed under the terms of the BSD license:
1010# http://www.opensource.org/licenses/bsd-license
1111
1212import sys
1313import functools
14+ import pickle
1415import warnings
1516
1617import msgpack
1920import numpy as np
2021
2122if sys .version_info >= (3 , 0 ):
22- if sys .platform == 'darwin' :
23- ndarray_to_bytes = lambda obj : obj .tobytes ()
24- else :
25- ndarray_to_bytes = lambda obj : obj .data if obj .flags ['C_CONTIGUOUS' ] else obj .tobytes ()
23+ def ndarray_to_bytes (obj ):
24+ if obj .dtype == 'O' :
25+ return obj .dumps ()
26+ else :
27+ if sys .platform == 'darwin' :
28+ return obj .tobytes ()
29+ else :
30+ return obj .data if obj .flags ['C_CONTIGUOUS' ] else obj .tobytes ()
2631
2732 num_to_bytes = lambda obj : obj .data
2833
@@ -32,10 +37,14 @@ def tostr(x):
3237 else :
3338 return str (x )
3439else :
35- if sys .platform == 'darwin' :
36- ndarray_to_bytes = lambda obj : obj .tobytes ()
37- else :
38- ndarray_to_bytes = lambda obj : memoryview (obj .data ) if obj .flags ['C_CONTIGUOUS' ] else obj .tobytes ()
40+ def ndarray_to_bytes (obj ):
41+ if obj .dtype == 'O' :
42+ return obj .dumps ()
43+ else :
44+ if sys .platform == 'darwin' :
45+ return obj .tobytes ()
46+ else :
47+ return memoryview (obj .data ) if obj .flags ['C_CONTIGUOUS' ] else obj .tobytes ()
3948
4049 num_to_bytes = lambda obj : memoryview (obj .data )
4150
@@ -50,12 +59,13 @@ def encode(obj, chain=None):
5059 if isinstance (obj , np .ndarray ):
5160 # If the dtype is structured, store the interface description;
5261 # otherwise, store the corresponding array protocol type string:
53- if obj .dtype .kind == 'V' :
54- kind = b'V'
62+ if obj .dtype .kind in ( 'V' , 'O' ) :
63+ kind = bytes ( obj . dtype . kind , 'ascii' )
5564 descr = obj .dtype .descr
5665 else :
5766 kind = b''
5867 descr = obj .dtype .str
68+
5969 return {b'nd' : True ,
6070 b'type' : descr ,
6171 b'kind' : kind ,
@@ -81,14 +91,18 @@ def decode(obj, chain=None):
8191 if obj [b'nd' ] is True :
8292
8393 # Check if b'kind' is in obj to enable decoding of data
84- # serialized with older versions (#20):
94+ # serialized with older versions (#20) or data
95+ # that had dtype == 'O' (#46):
8596 if b'kind' in obj and obj [b'kind' ] == b'V' :
8697 descr = [tuple (tostr (t ) if type (t ) is bytes else t for t in d ) \
8798 for d in obj [b'type' ]]
99+ elif b'kind' in obj and obj [b'kind' ] == b'O' :
100+ return pickle .loads (obj [b'data' ])
88101 else :
89102 descr = obj [b'type' ]
90- return np .frombuffer (obj [b'data' ],
91- dtype = _unpack_dtype (descr )).reshape (obj [b'shape' ])
103+ return np .ndarray (buffer = obj [b'data' ],
104+ dtype = _unpack_dtype (descr ),
105+ shape = obj [b'shape' ])
92106 else :
93107 descr = obj [b'type' ]
94108 return np .frombuffer (obj [b'data' ],
0 commit comments