Skip to content

Commit f08f1dd

Browse files
committed
up
1 parent f66885d commit f08f1dd

12 files changed

Lines changed: 556 additions & 193 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
3333
echo "::group::Install ExecuTorch and configure build"
3434
${CONDA_RUN} python install_executorch.py > /dev/null
35-
${CONDA_RUN} cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON
35+
${CONDA_RUN} cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON -DEXECUTORCH_MLX_ENABLE_SANITIZERS=ON
3636
echo "::endgroup::"
3737
3838
${CONDA_RUN} pip list

backends/mlx/CMakeLists.txt

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,32 @@ endif()
1616

1717
# Source root directory for executorch.
1818
if(NOT EXECUTORCH_ROOT)
19-
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
19+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
2020
endif()
2121

2222
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2323

24-
set(_common_compile_options -Wno-deprecated-declarations)
24+
set(_common_compile_options -Wall -Werror -Wno-deprecated-declarations)
25+
26+
# Sanitizer flags (asan + ubsan) for security hardening — CI only. Enable via:
27+
# cmake --preset mlx-release -DEXECUTORCH_MLX_ENABLE_SANITIZERS=ON
28+
option(EXECUTORCH_MLX_ENABLE_SANITIZERS
29+
"Enable ASan + UBSan for MLX delegate and tests" OFF
30+
)
31+
if(EXECUTORCH_MLX_ENABLE_SANITIZERS)
32+
list(APPEND _common_compile_options -fsanitize=address,undefined
33+
-fno-omit-frame-pointer
34+
)
35+
set(_mlx_sanitizer_link_options -fsanitize=address,undefined)
36+
endif()
2537

2638
# -----------------------------------------------------------------------------
2739
# Code generation from schema.fbs
2840
# -----------------------------------------------------------------------------
2941
#
30-
# The generate.py script generates all code from schema.fbs: - Python:
31-
# mlx_graph_schema.py, _generated_serializers.py, _generated/ - C++:
32-
# MLXLoader.h, MLXLoader.cpp, schema_generated.h
42+
# The generate.py script generates all code from schema.fbs: Python:
43+
# mlx_graph_schema.py, _generated_serializers.py, _generated/ C++: MLXLoader.h,
44+
# MLXLoader.cpp, schema_generated.h
3345
#
3446
# We run generate.py at build time so these files don't need to be checked in.
3547
# -----------------------------------------------------------------------------
@@ -200,25 +212,6 @@ set(MLX_METAL_JIT
200212
CACHE BOOL "Use JIT compiled Metal kernels"
201213
)
202214

203-
# Apply JSON patch to prevent conflict with ExecuTorch's nlohmann_json MLX uses
204-
# FetchContent for json; ExecuTorch already has it as submodule
205-
execute_process(
206-
COMMAND git apply --check ${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch
207-
WORKING_DIRECTORY ${MLX_SOURCE_DIR}
208-
RESULT_VARIABLE _patch_check_result
209-
OUTPUT_QUIET ERROR_QUIET
210-
)
211-
if(_patch_check_result EQUAL 0)
212-
execute_process(
213-
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch
214-
WORKING_DIRECTORY ${MLX_SOURCE_DIR}
215-
RESULT_VARIABLE _patch_result
216-
)
217-
if(_patch_result EQUAL 0)
218-
message(STATUS "Applied MLX JSON patch")
219-
endif()
220-
endif()
221-
222215
# Add MLX subdirectory
223216
message(STATUS "Adding MLX from submodule: ${MLX_SOURCE_DIR}")
224217
add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx)
@@ -237,7 +230,7 @@ set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
237230
add_library(mlxdelegate ${_mlx_backend__srcs})
238231

239232
# Ensure schema is generated before compiling
240-
add_dependencies(mlxdelegate mlx_schema flatc)
233+
add_dependencies(mlxdelegate mlx_schema)
241234

242235
# Add logging flag if enabled
243236
if(ET_MLX_ENABLE_OP_LOGGING)
@@ -247,7 +240,6 @@ endif()
247240

248241
target_include_directories(
249242
mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime
250-
${_mlx_schema__include_dir} ${MLX_SOURCE_DIR}
251243
)
252244

253245
# Link against MLX and executorch mlx is only available at BUILD_INTERFACE -
@@ -257,7 +249,10 @@ target_link_libraries(
257249
)
258250

259251
executorch_target_link_options_shared_lib(mlxdelegate)
260-
target_compile_options(mlxdelegate PUBLIC ${_common_compile_options})
252+
target_compile_options(mlxdelegate PRIVATE ${_common_compile_options})
253+
if(EXECUTORCH_MLX_ENABLE_SANITIZERS)
254+
target_link_options(mlxdelegate PRIVATE ${_mlx_sanitizer_link_options})
255+
endif()
261256

262257
install(
263258
TARGETS mlxdelegate mlx_schema

backends/mlx/runtime/MLXBackend.cpp

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,24 @@ array tensor_to_mlx(
7272

7373
::mlx::core::Shape shape;
7474
for (int i = 0; i < t.dim(); ++i) {
75-
shape.push_back(static_cast<int>(t.size(i)));
75+
auto dim_size = t.size(i);
76+
if (dim_size > std::numeric_limits<int>::max() ||
77+
dim_size < std::numeric_limits<int>::min()) {
78+
throw std::runtime_error(
79+
"tensor_to_mlx: dimension " + std::to_string(i) + " size " +
80+
std::to_string(dim_size) + " exceeds int range");
81+
}
82+
shape.push_back(static_cast<int>(dim_size));
7683
}
7784

78-
void* data_ptr = const_cast<void*>(t.const_data_ptr());
85+
// SAFETY: MLX reads this data during async_eval() Metal command encoding,
86+
// which completes before the lock is released. The ET tensor must remain
87+
// valid until async_eval returns.
88+
const void* cptr = t.const_data_ptr();
89+
if (!cptr) {
90+
throw std::runtime_error("tensor_to_mlx: tensor has null data pointer");
91+
}
92+
void* data_ptr = const_cast<void*>(cptr);
7993
auto deleter = [](void*) {};
8094
return array(data_ptr, shape, dtype, deleter);
8195
}
@@ -115,8 +129,11 @@ void write_output(array& arr, ETTensor& out) {
115129
}
116130

117131
if (!shape_matches) {
118-
std::vector<executorch::aten::SizesType> new_sizes(
119-
mlx_shape.begin(), mlx_shape.end());
132+
std::vector<executorch::aten::SizesType> new_sizes;
133+
new_sizes.reserve(mlx_shape.size());
134+
for (auto d : mlx_shape) {
135+
new_sizes.push_back(static_cast<executorch::aten::SizesType>(d));
136+
}
120137
auto err = resize_tensor(
121138
out,
122139
ArrayRef<executorch::aten::SizesType>(
@@ -134,7 +151,12 @@ void write_output(array& arr, ETTensor& out) {
134151
" bytes, output has " + std::to_string(out_nbytes) + " bytes");
135152
}
136153

137-
std::memcpy(out.mutable_data_ptr(), arr.data<void>(), out_nbytes);
154+
const void* src = arr.data<void>();
155+
if (!src) {
156+
throw std::runtime_error(
157+
"write_output: arr.data<void>() is null after wait()");
158+
}
159+
std::memcpy(out.mutable_data_ptr(), src, out_nbytes);
138160
}
139161

140162
} // namespace
@@ -172,7 +194,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
172194
~MLXBackend() override = default;
173195

174196
bool is_available() const override {
175-
return true;
197+
return ::mlx::core::metal::is_available();
176198
}
177199

178200
Result<DelegateHandle*> init(
@@ -189,9 +211,20 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
189211
try {
190212
new (handle) MLXHandle();
191213

214+
if (!processed || !processed->data() || processed->size() == 0) {
215+
throw std::runtime_error("init: null or empty delegate payload");
216+
}
217+
192218
handle->program = loader::load_program(
193219
static_cast<const uint8_t*>(processed->data()), processed->size());
194220

221+
// Validate schema version
222+
if (handle->program.version != "1") {
223+
throw std::runtime_error(
224+
"Unsupported MLX schema version '" + handle->program.version +
225+
"' (expected '1'). Rebuild the .pte with a matching SDK version.");
226+
}
227+
195228
// Load constants from named_data_map
196229
// Constants are stored by name in the .pte file and provided by ET at
197230
// runtime
@@ -214,7 +247,9 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
214247
handle->state.bind(
215248
handle->program, handle->constants, handle->mutable_buffers);
216249

217-
// Run init chain if present
250+
// Run init chain if present.
251+
// SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the
252+
// static_cast<uint32_t> cannot produce UINT32_MAX from a -1 sentinel.
218253
if (handle->program.init_chain_idx >= 0) {
219254
handle->interpreter.run_chain(
220255
handle->program,
@@ -258,8 +293,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
258293

259294
h->state.reset();
260295

261-
const size_t expected_args =
262-
program.input_map.size() + program.output_map.size();
296+
const size_t n_inputs = program.input_map.size();
297+
const size_t n_outputs = program.output_map.size();
298+
if (n_inputs > SIZE_MAX - n_outputs) {
299+
throw std::runtime_error("execute: input + output count overflow");
300+
}
301+
const size_t expected_args = n_inputs + n_outputs;
263302
if (args.size() != expected_args) {
264303
ET_LOG(
265304
Error, "Expected %zu args, got %zu", expected_args, args.size());
@@ -268,6 +307,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
268307

269308
// Bind inputs
270309
for (const auto& slot : program.input_map) {
310+
if (arg_idx >= args.size()) {
311+
throw std::runtime_error(
312+
"execute: arg_idx " + std::to_string(arg_idx) +
313+
" out of bounds (args.size()=" + std::to_string(args.size()) +
314+
")");
315+
}
271316
if (slot.slot_type == SlotType::TensorSlot) {
272317
const ETTensor& tensor = args[arg_idx++]->toTensor();
273318
Tid tid{slot.idx};

0 commit comments

Comments
 (0)