Skip to content

Commit b0748ad

Browse files
authored
Fix regression in array creation (#3353)
1 parent 2ffafe0 commit b0748ad

2 files changed

Lines changed: 10 additions & 1 deletion

File tree

python/src/convert.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ mx::array create_array(nb::object v, std::optional<mx::Dtype> t) {
491491
std::optional<nb::dlpack::dtype> nb_dtype;
492492
// Nanobind does not recognize bfloat16 numpy array:
493493
// https://github.com/wjakob/nanobind/discussions/560
494-
if (v.attr("dtype").equal(nb::str("bfloat16"))) {
494+
if (nb::hasattr(v, "dtype") && v.attr("dtype").equal(nb::str("bfloat16"))) {
495495
nd = nb::cast<ContigArray>(v.attr("view")("uint16"));
496496
nb_dtype = nb::dtype<mx::bfloat16_t>();
497497
} else {

python/tests/test_array.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,15 @@ def test_array_view_ref_counting(self):
17821782
a_np = None
17831783
self.assertIsNone(wr())
17841784

1785+
def test_create_from_buffer(self):
1786+
x = mx.array(b"Hello")
1787+
self.assertEqual(x.dtype, mx.uint8)
1788+
self.assertEqual(x.tolist(), [72, 101, 108, 108, 111])
1789+
1790+
x = mx.array(bytearray([1, 2, 3]))
1791+
self.assertEqual(x.dtype, mx.uint8)
1792+
self.assertEqual(x.tolist(), [1, 2, 3])
1793+
17851794
@unittest.skipIf(not has_tf, "requires TensorFlow")
17861795
def test_buffer_protocol_tf(self):
17871796
dtypes_list = [

0 commit comments

Comments
 (0)