Skip to content

Commit d1003c5

Browse files
committed
permit numpy as _ArrayData_/_ArrayZipData_, speed up ndarray encoding/decoding
1 parent ce25fa5 commit d1003c5

1 file changed

Lines changed: 19 additions & 12 deletions

File tree

jdata/jdata.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
'bool':'uint8','byte':'int8','short':'int16','ubyte':'uint8',
2929
'ushort':'uint16','int_':'int32','uint':'uint32','complex_':'double','complex128':'double',
3030
'complex64':'single','longlong':'int64','ulonglong':'uint64',
31-
'csingle':'single','cdouble':'double'};
31+
'csingle':'single','cdouble':'double'}
3232

33-
_zipper=['zlib','gzip','lzma','lz4','blosc2blosclz','blosc2lz4','blosc2lz4hc','blosc2zlib','blosc2zstd','base64'];
33+
_zipper=('zlib','gzip','lzma','lz4','blosc2blosclz','blosc2lz4','blosc2lz4hc','blosc2zlib','blosc2zstd','base64')
34+
35+
_allownumpy=('_ArraySize_','_ArrayData_','_ArrayZipSize_','_ArrayZipData_')
3436

3537
##====================================================================================
3638
## Python to JData encoding function
@@ -44,7 +46,10 @@ def encode(d, opt={}):
4446
files
4547
4648
@param[in,out] d: an arbitrary Python data
47-
@param[in] opt: options, can contain 'compression'=['zlib','lzma','gzip'] for data compression
49+
@param[in] opt: options, can contain a dict with
50+
'compression': choose one of ['zlib','lzma','gzip','lz4','blosc2blosclz','blosc2lz4',
51+
'blosc2lz4hc','blosc2zlib','blosc2zstd'] for compression codec, default is None
52+
'nthread': number of compression thread of the codec is of the blosc2 class, default is 1
4853
"""
4954
if('compression' in opt):
5055
if(opt['compression']=='lzma'):
@@ -93,16 +98,16 @@ def encode(d, opt={}):
9398
newobj["_ArraySize_"]=list(d.shape);
9499
if(d.dtype==np.complex64 or d.dtype==np.complex128 or d.dtype==np.csingle or d.dtype==np.cdouble):
95100
newobj['_ArrayIsComplex_']=True;
96-
newobj['_ArrayData_']=[list(d.flatten().real), list(d.flatten().imag)];
101+
newobj['_ArrayData_']=np.stack(d.ravel().real, d.ravel().imag);
97102
else:
98-
newobj["_ArrayData_"]=list(d.flatten());
103+
newobj["_ArrayData_"]=d.ravel();
99104

100105
if('compression' in opt):
101106
if(opt['compression'] not in _zipper):
102107
raise Exception('JData', 'compression method is not supported')
103108
newobj['_ArrayZipType_']=opt['compression'];
104109
newobj['_ArrayZipSize_']=[1+int('_ArrayIsComplex_' in newobj), d.size];
105-
newobj['_ArrayZipData_']=np.asarray(newobj['_ArrayData_'],dtype=d.dtype).tostring();
110+
newobj['_ArrayZipData_']=newobj["_ArrayData_"].data;
106111
if(opt['compression']=='zlib'):
107112
newobj['_ArrayZipData_']=zlib.compress(newobj['_ArrayZipData_']);
108113
elif(opt['compression']=='gzip'):
@@ -115,7 +120,7 @@ def encode(d, opt={}):
115120
pass
116121
elif(opt['compression']=='lz4'):
117122
try:
118-
newobj['_ArrayZipData_']=lz4.frame.compress(newobj['_ArrayZipData_']);
123+
newobj['_ArrayZipData_']=lz4.frame.compress(newobj['_ArrayZipData_'].tobytes());
119124
except ImportError:
120125
print('you must install "lz4" module to compress with this format, ignoring')
121126
pass
@@ -129,7 +134,7 @@ def encode(d, opt={}):
129134
blosc2nthread = 1
130135
if('nthread' in opt):
131136
blosc2nthread = opt['nthread']
132-
newobj['_ArrayZipData_']=blosc2.compress2(newobj['_ArrayZipData_'], compcode=BLOSC2CODEC[opt['compression']], typesize=d.dtype.itemsize, nthread=blosc2nthread)
137+
newobj['_ArrayZipData_']=blosc2.compress2(newobj['_ArrayZipData_'], compcode=BLOSC2CODEC[opt['compression']], typesize=d.dtype.itemsize, nthreads=blosc2nthread)
133138
except ImportError:
134139
print('you must install "blosc2" module to compress with this format, ignoring')
135140
pass
@@ -166,7 +171,7 @@ def decode(d, opt={}):
166171
elif isinstance(d, dict):
167172
if('_ArrayType_' in d):
168173
if(isinstance(d['_ArraySize_'],str)):
169-
d['_ArraySize_']=np.array(bytearray(d['_ArraySize_']));
174+
d['_ArraySize_']=np.frombuffer(bytearray(d['_ArraySize_']));
170175
if('_ArrayZipData_' in d):
171176
newobj=d['_ArrayZipData_']
172177
if(('base64' in opt) and (opt['base64'])) or ('_ArrayZipType_' in d and d['_ArrayZipType_']=='base64'):
@@ -198,11 +203,11 @@ def decode(d, opt={}):
198203
blosc2nthread = 1
199204
if('nthread' in opt):
200205
blosc2nthread = opt['nthread']
201-
newobj=blosc2.decompress2(bytes(newobj), as_bytearray=False, nthread=blosc2nthread)
206+
newobj=blosc2.decompress2(bytes(newobj), as_bytearray=False, nthreads=blosc2nthread)
202207
except Exception:
203208
print('Warning: you must install "blosc2" module to decompress a data record in this file, ignoring')
204209
pass
205-
newobj=np.fromstring(newobj,dtype=np.dtype(d['_ArrayType_'])).reshape(d['_ArrayZipSize_']);
210+
newobj=np.frombuffer(newobj,dtype=np.dtype(d['_ArrayType_'])).reshape(d['_ArrayZipSize_']);
206211
if('_ArrayIsComplex_' in d and newobj.shape[0]==2):
207212
newobj=newobj[0]+1j*newobj[1];
208213
if('_ArrayOrder_' in d and (d['_ArrayOrder_'].lower()=='c' or d['_ArrayOrder_'].lower()=='col' or d['_ArrayOrder_'].lower()=='column')):
@@ -219,7 +224,7 @@ def decode(d, opt={}):
219224
newobj=np.asarray(d['_ArrayData_'],dtype=np.dtype(d['_ArrayType_']));
220225
if('_ArrayZipSize_' in d and newobj.shape[0]==1):
221226
if(isinstance(d['_ArrayZipSize_'],str)):
222-
d['_ArrayZipSize_']=np.array(bytearray(d['_ArrayZipSize_']));
227+
d['_ArrayZipSize_']=np.frombuffer(bytearray(d['_ArrayZipSize_']));
223228
newobj=newobj.reshape(d['_ArrayZipSize_']);
224229
if('_ArrayIsComplex_' in d and newobj.shape[0]==2):
225230
newobj=newobj[0]+1j*newobj[1];
@@ -259,6 +264,8 @@ def jsonfilter(obj):
259264
def encodedict(d0, opt={}):
260265
d=dict(d0);
261266
for k, v in d0.items():
267+
if isinstance(v, np.ndarray) and isinstance(k, str) and (k in _allownumpy):
268+
continue
262269
newkey=encode(k,opt)
263270
d[newkey]=encode(v,opt);
264271
if(k!=newkey):

0 commit comments

Comments
 (0)