Skip to content

Commit 8c6093e

Browse files
authored
Merge branch 'main' into zephyr-mv2-ethosu-sample
2 parents 782d064 + b8f04aa commit 8c6093e

20 files changed

Lines changed: 488 additions & 141 deletions

File tree

backends/mlx/ops.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
AsStridedNode,
5151
AsTypeNode,
5252
Atan2Node,
53+
BitwiseInvertNode,
5354
BroadcastToNode,
5455
CeilNode,
5556
ClipNode,
@@ -3066,27 +3067,40 @@ def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot:
30663067

30673068
@REGISTRY.register(target=[torch.ops.aten.bitwise_not.default])
30683069
def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot:
3069-
"""Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not."""
3070+
"""Handle aten.bitwise_not - logical_not for bool, bitwise_invert for integers."""
30703071
args = P.args(n)
30713072
require_args(args, 1, 1, "aten.bitwise_not")
30723073
require_kwargs(P.kwargs(n), set(), "aten.bitwise_not")
30733074
x_meta = n.args[0].meta.get("val")
3075+
out = P.make_or_get_slot(n)
30743076

3075-
if x_meta is not None and x_meta.dtype == torch.bool:
3076-
# For boolean tensors, bitwise_not is equivalent to logical_not
3077-
out = P.make_or_get_slot(n)
3077+
if x_meta is None or not hasattr(x_meta, "dtype"):
3078+
raise NotImplementedError(
3079+
"aten.bitwise_not requires known input dtype metadata for MLX lowering"
3080+
)
3081+
3082+
if x_meta.dtype == torch.bool:
30783083
P.emit(
30793084
LogicalNotNode(
30803085
x=P.slot_to_tid(args[0]),
30813086
out=P.slot_to_tid(out),
30823087
)
30833088
)
3084-
return out
3089+
elif x_meta.dtype in {
3090+
torch.int32,
3091+
torch.int64,
3092+
}:
3093+
P.emit(
3094+
BitwiseInvertNode(
3095+
x=P.slot_to_tid(args[0]),
3096+
out=P.slot_to_tid(out),
3097+
)
3098+
)
30853099
else:
30863100
raise NotImplementedError(
3087-
f"aten.bitwise_not is only supported for boolean tensors. "
3088-
f"Got dtype={x_meta.dtype if x_meta else 'unknown'}"
3101+
f"aten.bitwise_not on dtype {x_meta.dtype} is not supported for MLX lowering"
30893102
)
3103+
return out
30903104

30913105

30923106
@REGISTRY.register(

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,13 @@ inline void exec_logical_not(
13801380
st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s));
13811381
}
13821382

1383+
inline void exec_bitwise_invert(
1384+
const BitwiseInvertNode& n,
1385+
ExecutionState& st,
1386+
StreamOrDevice s) {
1387+
st.set_tensor(n.out, bitwise_invert(st.const_tensor_ref(n.x), s));
1388+
}
1389+
13831390
inline void exec_logical_and(
13841391
const LogicalAndNode& n,
13851392
ExecutionState& st,
@@ -2028,6 +2035,10 @@ class Interpreter {
20282035
case OpCode::LOGICAL_NOT:
20292036
ops::exec_logical_not(std::get<LogicalNotNode>(instr.node), st, s);
20302037
break;
2038+
case OpCode::BITWISE_INVERT:
2039+
ops::exec_bitwise_invert(
2040+
std::get<BitwiseInvertNode>(instr.node), st, s);
2041+
break;
20312042
case OpCode::LOGICAL_AND:
20322043
ops::exec_logical_and(std::get<LogicalAndNode>(instr.node), st, s);
20332044
break;

backends/mlx/serialization/schema.fbs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,11 @@ table LogicalNotNode {
562562
out: Tid (required);
563563
}
564564

565+
table BitwiseInvertNode {
566+
x: Tid (required);
567+
out: Tid (required);
568+
}
569+
565570
table LogicalAndNode {
566571
a: Tid (required);
567572
b: Tid (required);
@@ -1113,7 +1118,8 @@ union OpNode {
11131118
GatherMmNode,
11141119
GatherQmmNode,
11151120
ScanNode,
1116-
MetalKernelNode
1121+
MetalKernelNode,
1122+
BitwiseInvertNode
11171123
// BC: Add new op nodes here (append only)
11181124
}
11191125

backends/mlx/test/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4111,6 +4111,7 @@ def create_model(self) -> nn.Module:
41114111
{"op_name": "abs", "op_fn": torch.abs},
41124112
{"op_name": "neg", "op_fn": torch.neg},
41134113
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
4114+
{"op_name": "bitwise_not_int", "op_fn": torch.bitwise_not, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn": _int_input_fn()},
41144115
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
41154116
# activations
41164117
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.facebook.jni.annotations.DoNotStrip;
1313
import com.facebook.soloader.nativeloader.NativeLoader;
1414
import com.facebook.soloader.nativeloader.SystemDelegate;
15+
import java.io.Closeable;
1516
import java.util.HashMap;
1617
import java.util.Map;
1718
import java.util.concurrent.locks.Lock;
@@ -24,7 +25,7 @@
2425
* <p>Warning: These APIs are experimental and subject to change without notice
2526
*/
2627
@Experimental
27-
public class Module {
28+
public class Module implements Closeable {
2829

2930
static {
3031
if (!NativeLoader.isInitialized()) {
@@ -274,12 +275,19 @@ public boolean etdump() {
274275
public void destroy() {
275276
if (mLock.tryLock()) {
276277
try {
277-
mHybridData.resetNative();
278+
if (mHybridData.isValid()) {
279+
mHybridData.resetNative();
280+
}
278281
} finally {
279282
mLock.unlock();
280283
}
281284
} else {
282285
throw new IllegalStateException("Cannot destroy module while method is executing");
283286
}
284287
}
288+
289+
@Override
290+
public void close() {
291+
destroy();
292+
}
285293
}

extension/llm/runner/text_llm_runner.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,16 @@ Error TextLLMRunner::generate(
138138
num_prompt_tokens >= 1,
139139
InvalidArgument,
140140
"Expected at least 1 prompt token");
141-
ET_CHECK_OR_RETURN_ERROR(
142-
num_prompt_tokens <= max_seq_len,
143-
InvalidArgument,
144-
"num_prompt_tokens %d > max_seq_len %" PRId64
145-
", Single prefill chunk too large - please reduce prompt size or increase max_seq_len",
146-
num_prompt_tokens,
147-
max_seq_len);
148-
// For non-sliding-window models, also check that we won't exceed
149-
// KV cache capacity. Sliding window models (where max_seq_len <
150-
// max_context_len) handle position wrapping internally.
141+
// Note: We intentionally do NOT enforce num_prompt_tokens <= max_seq_len
142+
// here. TextPrefiller::prefill() supports chunked prefill: when
143+
// num_prompt_tokens > max_seq_len it splits the prompt into max_seq_len
144+
// chunks and prefills them sequentially. Models that were exported with
145+
// max_seq_len < max_context_len (e.g. a 1024 prefill chunk over a 4096 KV
146+
// cache) rely on this behavior.
147+
// Ensure the prompt fits within total KV cache capacity. For
148+
// sliding-window models (where max_seq_len < max_context_len) the model
149+
// handles position wrapping internally, so pos_ doesn't represent
150+
// consumed capacity and we only need a per-call bound.
151151
if (max_seq_len >= max_context_len) {
152152
ET_CHECK_OR_RETURN_ERROR(
153153
pos_ + num_prompt_tokens < max_context_len,
@@ -158,6 +158,15 @@ Error TextLLMRunner::generate(
158158
pos_,
159159
num_prompt_tokens,
160160
max_context_len);
161+
} else {
162+
ET_CHECK_OR_RETURN_ERROR(
163+
num_prompt_tokens < max_context_len,
164+
InvalidArgument,
165+
"num_prompt_tokens %d >= max_context_len %" PRId64
166+
", Prompt exceeds KV cache capacity - please reduce prompt size or "
167+
"increase max_context_len in your export script",
168+
num_prompt_tokens,
169+
max_context_len);
161170
}
162171

163172
// print prompts

extension/tensor/tensor_ptr.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <numeric>
1212

13+
#include <c10/util/safe_numerics.h>
14+
1315
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1416

1517
namespace executorch {
@@ -147,11 +149,26 @@ TensorPtr make_tensor_ptr(
147149
std::vector<executorch::aten::StridesType> strides,
148150
executorch::aten::ScalarType type,
149151
executorch::aten::TensorShapeDynamism dynamism) {
152+
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
153+
ET_CHECK_MSG(
154+
numel_result.ok(),
155+
"safe_numel failed: %d",
156+
static_cast<int>(numel_result.error()));
157+
const ssize_t numel = numel_result.get();
158+
size_t nbytes;
150159
ET_CHECK_MSG(
151-
data.size() ==
152-
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
153-
executorch::aten::elementSize(type),
154-
"Data size does not match tensor size.");
160+
!c10::mul_overflows(
161+
static_cast<size_t>(numel),
162+
executorch::aten::elementSize(type),
163+
&nbytes),
164+
"Overflow computing nbytes: numel=%zd element_size=%zu",
165+
numel,
166+
executorch::aten::elementSize(type));
167+
ET_CHECK_MSG(
168+
data.size() == nbytes,
169+
"Data size (%zu) does not match tensor size (%zu).",
170+
data.size(),
171+
nbytes);
155172
auto data_ptr = data.data();
156173
return make_tensor_ptr(
157174
std::move(sizes),
@@ -205,7 +222,13 @@ TensorPtr clone_tensor_ptr(
205222
runtime::canCast(tensor_type, type),
206223
"Cannot cast tensor type to desired type.");
207224
const auto tensor_numel = static_cast<size_t>(tensor.numel());
208-
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
225+
size_t clone_nbytes;
226+
ET_CHECK_MSG(
227+
!c10::mul_overflows(tensor_numel, aten::elementSize(type), &clone_nbytes),
228+
"Overflow computing clone nbytes: numel=%zu element_size=%zu",
229+
tensor_numel,
230+
aten::elementSize(type));
231+
std::vector<uint8_t> data(clone_nbytes);
209232

210233
// Create a minimal context for error handling in ET_SWITCH
211234
struct {

extension/tensor/tensor_ptr.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ inline TensorPtr make_tensor_ptr(
110110
executorch::aten::ScalarType type = deduced_type,
111111
executorch::aten::TensorShapeDynamism dynamism =
112112
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
113+
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
113114
ET_CHECK_MSG(
114-
data.size() ==
115-
executorch::aten::compute_numel(sizes.data(), sizes.size()),
115+
numel_result.ok(),
116+
"safe_numel failed: %d",
117+
static_cast<int>(numel_result.error()));
118+
ET_CHECK_MSG(
119+
data.size() == static_cast<size_t>(numel_result.get()),
116120
"Data size does not match tensor size.");
117121
if (type != deduced_type) {
118122
ET_CHECK_MSG(
@@ -368,8 +372,13 @@ inline TensorPtr make_tensor_ptr(
368372
const auto same_rank = sizes.size() == static_cast<size_t>(tensor.dim());
369373
const auto same_shape = same_rank &&
370374
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
371-
const auto element_count =
372-
executorch::aten::compute_numel(sizes.data(), sizes.size());
375+
auto element_count_result =
376+
executorch::aten::safe_numel(sizes.data(), sizes.size());
377+
ET_CHECK_MSG(
378+
element_count_result.ok(),
379+
"safe_numel failed: %d",
380+
static_cast<int>(element_count_result.error()));
381+
const auto element_count = element_count_result.get();
373382
const auto parent_element_count = tensor.numel();
374383
ET_CHECK_MSG(
375384
element_count <= parent_element_count,

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,21 @@ TensorPtr empty_strided(
113113
std::vector<executorch::aten::StridesType> strides,
114114
executorch::aten::ScalarType type,
115115
executorch::aten::TensorShapeDynamism dynamism) {
116-
const auto numel = static_cast<size_t>(
117-
executorch::aten::compute_numel(sizes.data(), sizes.size()));
118-
const auto elem_size =
119-
static_cast<size_t>(executorch::aten::elementSize(type));
120-
size_t nbytes = 0;
116+
auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size());
121117
ET_CHECK_MSG(
122-
!c10::mul_overflows(numel, elem_size, &nbytes),
123-
"empty_strided size overflow: numel %zu * element size %zu",
118+
numel_result.ok(),
119+
"safe_numel failed: %d",
120+
static_cast<int>(numel_result.error()));
121+
const ssize_t numel = numel_result.get();
122+
size_t nbytes;
123+
ET_CHECK_MSG(
124+
!c10::mul_overflows(
125+
static_cast<size_t>(numel),
126+
executorch::aten::elementSize(type),
127+
&nbytes),
128+
"Overflow computing nbytes: numel=%zd element_size=%zu",
124129
numel,
125-
elem_size);
130+
executorch::aten::elementSize(type));
126131
std::vector<uint8_t> data(nbytes);
127132
return make_tensor_ptr(
128133
std::move(sizes),

extension/wasm/wasm_bindings.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,22 @@ inline void js_array_push(val_array<T>& array, const T& value) {
8484
_(float, Float) \
8585
_(int64_t, Long)
8686

87-
inline ssize_t compute_expected_numel(
87+
inline ::executorch::runtime::Result<ssize_t> compute_expected_numel(
8888
const std::vector<torch::executor::Tensor::SizesType>& sizes) {
89-
return executorch::aten::compute_numel(sizes.data(), sizes.size());
89+
return executorch::aten::safe_numel(sizes.data(), sizes.size());
9090
}
9191

9292
template <typename T>
9393
inline void assert_valid_numel(
9494
const std::vector<T>& data,
9595
const std::vector<torch::executor::Tensor::SizesType>& sizes) {
9696
auto computed_numel = compute_expected_numel(sizes);
97+
THROW_IF_ERROR(
98+
computed_numel.error(), "Invalid tensor sizes: numel computation failed");
9799
THROW_IF_FALSE(
98-
data.size() >= computed_numel,
100+
data.size() >= static_cast<size_t>(computed_numel.get()),
99101
"Required %ld elements, given %ld",
100-
computed_numel,
102+
computed_numel.get(),
101103
data.size());
102104
}
103105

0 commit comments

Comments
 (0)