Skip to content

Commit 1bf65e3

Browse files
authored
Fix CPU dynamic slice copy bound for collapsed shapes (#3739)
1 parent e37e926 commit 1bf65e3

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

mlx/backend/cpu/copy.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ void copy_general_general(
7070
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
7171
auto o_offset_ptr =
7272
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
73-
auto size = src.size();
7473
if (data_shape.empty()) {
7574
auto val = static_cast<DstT>(*src_ptr);
7675
*dst_ptr = val;
@@ -107,6 +106,8 @@ void copy_general_general(
107106
dst_ptr += o_offset_ptr[0];
108107
}
109108

109+
auto size = std::accumulate(
110+
shape.begin(), shape.end(), int64_t{1}, std::multiplies<int64_t>());
110111
ContiguousIterator in(shape, strides[0], ndim - 3);
111112
ContiguousIterator out(shape, strides[1], ndim - 3);
112113
auto stride = std::accumulate(

python/tests/test_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3244,6 +3244,11 @@ def test_dynamic_slicing(self):
32443244
out = mx.slice(x, mx.array([1, 2, 3]), (0, 1, 2), (3, 2, 1))
32453245
self.assertTrue(mx.array_equal(expected, out))
32463246

3247+
x = mx.arange(5 * 6 * 7 * 8).reshape(5, 6, 7, 8)
3248+
expected = x[1:3, 2:4, 3:5, 4:6]
3249+
out = mx.slice(x, mx.array([1, 2, 3, 4]), (0, 1, 2, 3), (2, 2, 2, 2))
3250+
self.assertTrue(mx.array_equal(expected, out))
3251+
32473252
x = mx.zeros(shape=(4, 4, 4))
32483253
update = mx.random.randint(0, 100, shape=(3, 2, 1))
32493254
out = mx.slice_update(x, update, mx.array([1, 2, 3]), (0, 1, 2))

0 commit comments

Comments
 (0)