Skip to content

Commit 8ef1b12

Browse files
committed
[bug] fix numpy 1d array length encoding bug, fix #4
1 parent f0a0269 commit 8ef1b12

3 files changed

Lines changed: 22 additions & 20 deletions

File tree

bjdata/encoder.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2022 Qianqian Fang <q.fang at neu.edu>. All rights reserved.
1+
# Copyright (c) 2020-2023 Qianqian Fang <q.fang at neu.edu>. All rights reserved.
22
# Copyright (c) 2016-2019 Iotic Labs Ltd. All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -288,13 +288,10 @@ def __encode_numpy(fp_write, item, islittle, default):
288288
item = np.array( item, order = 'C') # currently, BJData ND-array syntax only support row-major
289289

290290
fp_write(ARRAY_START + CONTAINER_TYPE + __map_dtype(item.dtype.str) + CONTAINER_COUNT)
291-
if item.ndim == 1:
292-
__encode_int(fp_write, len(item), islittle)
293-
else:
294-
fp_write(ARRAY_START)
295-
for value in item.shape:
296-
__encode_int(fp_write, value, islittle)
297-
fp_write(ARRAY_END)
291+
fp_write(ARRAY_START)
292+
for value in item.shape:
293+
__encode_int(fp_write, value, islittle)
294+
fp_write(ARRAY_END)
298295

299296
fp_write(item.data)
300297

src/encoder.c

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2022 Qianqian Fang <q.fang at neu.edu>. All rights reserved.
2+
* Copyright (c) 2020-2023 Qianqian Fang <q.fang at neu.edu>. All rights reserved.
33
* Copyright (c) 2016-2019 Iotic Labs Ltd. All rights reserved.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -317,16 +317,14 @@ static int _encode_NDarray(PyObject *obj, _bjdata_encoder_buffer_t *buffer) {
317317
WRITE_CHAR_OR_BAIL((char)marker);
318318
}
319319
WRITE_CHAR_OR_BAIL(CONTAINER_COUNT);
320-
if(ndim == 1) {
321-
BAIL_ON_NONZERO(_encode_longlong(bytes, buffer));
322-
} else {
323-
WRITE_CHAR_OR_BAIL(ARRAY_START);
324-
for(int i=0 ; i<ndim; i++)
325-
_encode_longlong(dims[i], buffer);
326-
if(type == NPY_UNICODE)
327-
_encode_longlong(4, buffer);
328-
WRITE_CHAR_OR_BAIL(ARRAY_END);
329-
}
320+
321+
WRITE_CHAR_OR_BAIL(ARRAY_START);
322+
for(int i=0 ; i<ndim; i++)
323+
_encode_longlong(dims[i], buffer);
324+
if(type == NPY_UNICODE)
325+
_encode_longlong(4, buffer);
326+
WRITE_CHAR_OR_BAIL(ARRAY_END);
327+
330328
WRITE_OR_BAIL(PyArray_BYTES(arr), bytes*total);
331329
Py_DECREF(arr);
332330
// no ARRAY_END since length was specified

test/test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2022 Qianqian Fang <q.fang at neu.edu>. All rights reserved.
1+
# Copyright (c) 2020-2023 Qianqian Fang <q.fang at neu.edu>. All rights reserved.
22
# Copyright (c) 2016-2019 Iotic Labs Ltd. All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -314,6 +314,13 @@ def test_nd_array(self):
314314
self.assertEqual((self.bjdloadb(self.bjddumpb(np.float16(2.2))) == 16486), True)
315315
self.assertEqual((self.bjdloadb(self.bjddumpb(np.float32(2.2))) == np.float32(2.2)), True)
316316

317+
self.assertEqual((self.bjdloadb(self.bjddumpb(np.array([1.3,-0.5,0.7,1000,11], dtype=np.float32))) == np.array([1.3,-0.5,0.7,1000,11], dtype=np.float32)).all(), True)
318+
self.assertEqual((self.bjdloadb(self.bjddumpb(np.array([1,2,3,4], dtype=np.uint16))) == np.array([1,2,3,4], dtype=np.uint16)).all(), True)
319+
self.assertEqual((self.bjdloadb(self.bjddumpb(np.array([1,2,3,4], dtype=np.uint8))) == np.array([1,2,3,4], dtype=np.uint8)).all(), True)
320+
self.assertEqual((self.bjdloadb(self.bjddumpb(np.array([], dtype=np.int8))) == np.array([], dtype=np.int8)).all(), True)
321+
self.assertEqual((self.bjdloadb(self.bjddumpb(np.array([-1,-2,-3,-4], dtype=np.float32))) == np.array([-1,-2,-3,-4], dtype=np.float32)).all(), True)
322+
self.assertEqual((self.bjdloadb(self.bjddumpb(np.array([[-1,-2,5],[-3,-4,-6]], dtype=np.float64))) == np.array([[-1,-2,5],[-3,-4,-6]], dtype=np.float64)).all(), True)
323+
317324
raw_start = (ARRAY_START + CONTAINER_TYPE + TYPE_INT8 + CONTAINER_COUNT + ARRAY_START + \
318325
TYPE_UINT8 + b'\x03' + TYPE_UINT16 + b'\x02' + b'\x00' + ARRAY_END + \
319326
b'\x01'+ b'\x02'+ b'\x03'+ b'\x04'+ b'\x05'+ b'\x06')

0 commit comments

Comments
 (0)