Skip to content

Commit d50e8d6

Browse files
committed
up
1 parent 2f45116 commit d50e8d6

5 files changed

Lines changed: 11 additions & 6 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ jobs:
9595
fi
9696
9797
if [ "${{ matrix.suite }}" = "operators" ]; then
98-
MAX_FAILURES=1
98+
MAX_FAILURES=0
9999
else
100100
MAX_FAILURES=3
101101
fi

backends/mlx/runtime/MLXBackend.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
390390
write_output(prepared_outputs[info.prepared_idx], out_tensor);
391391
}
392392

393+
h->state.reset(); // Release temp GPU buffers back to MLX cache
394+
393395
return Error::Ok;
394396
} catch (const std::exception& e) {
395397
ET_LOG(Error, "MLX execute failed: %s", e.what());

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,9 @@ inline void exec_slice_update(
924924
ExecutionState& st,
925925
StreamOrDevice s) {
926926
// When out == dst, use direct assignment to preserve MLX buffer donation.
927+
// TODO: I'm not sure if this is needed as a special case since the standard
928+
// st.set_tensor does a std::move. Keeping for now, but should investigate and
929+
// possibly remove in future.
927930
const bool in_place = (n.out.idx == n.dst.idx);
928931
array& dst = st.tensor_ref(n.dst);
929932
const array& upd = st.const_tensor_ref(n.update);
@@ -1060,6 +1063,10 @@ exec_index_copy(const IndexCopyNode& n, ExecutionState& st, StreamOrDevice s) {
10601063
idx_vec[i] = static_cast<int>(idx);
10611064
}
10621065

1066+
// When out == dst, use direct assignment to preserve MLX buffer donation.
1067+
// TODO: I'm not sure if this is needed as a special case since the standard
1068+
// st.set_tensor does a std::move. Keeping for now, but should investigate and
1069+
// possibly remove in future.
10631070
const bool in_place = (n.out.idx == n.dst.idx);
10641071

10651072
if (idx_vec.empty()) {
@@ -1995,8 +2002,6 @@ class Interpreter {
19952002
case OpCode::ARG_PARTITION:
19962003
ops::exec_argpartition(std::get<ArgPartitionNode>(instr.node), st, s);
19972004
break;
1998-
case OpCode::SENTINEL:
1999-
break;
20002005
default:
20012006
throw std::runtime_error(
20022007
"Unknown opcode: " + std::to_string(static_cast<int>(instr.op)));

backends/mlx/serialization/MLXLoader.h.tmpl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,12 @@ struct IntOrVidOrTid {
7979

8080
enum class OpCode : uint8_t {
8181
{{OPCODE_ENUM_VALUES}}
82-
SENTINEL
8382
};
8483

8584
// OpCode to string conversion (for logging)
8685
inline const char* op_name(OpCode op) {
8786
switch (op) {
8887
{{OP_NAME_CASES}}
89-
case OpCode::SENTINEL:
90-
return "SENTINEL";
9188
}
9289
return "UNKNOWN";
9390
}

extension/llm/export/config/llm_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ class EthosUConfig:
558558
system_config: str = "default"
559559

560560

561+
@dataclass
561562
class MLXConfig:
562563
"""
563564
Configures the MLX backend for Apple Silicon.

0 commit comments

Comments
 (0)