Skip to content

Commit 10ff69f

Browse files
committed
up
1 parent 1842ad3 commit 10ff69f

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,14 @@ inline void
15481548
exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) {
15491549
const auto& x = st.const_tensor_ref(n.x);
15501550
int repeats = static_cast<int>(resolve_int(n.repeats, st));
1551+
if (repeats < 0) {
1552+
throw std::invalid_argument(
1553+
"repeat: repeats must be non-negative, got " + std::to_string(repeats));
1554+
}
1555+
auto out_shape = x.shape();
1556+
int axis = n.axis < 0 ? n.axis + static_cast<int>(x.ndim()) : n.axis;
1557+
out_shape[static_cast<size_t>(axis)] *= repeats;
1558+
check_allocation_bounded(out_shape, x.dtype(), "repeat");
15511559
st.set_tensor(n.out, repeat(x, repeats, n.axis, s));
15521560
}
15531561

0 commit comments

Comments
 (0)