@@ -72,10 +72,24 @@ array tensor_to_mlx(
7272
7373 ::mlx::core::Shape shape;
7474 for (int i = 0 ; i < t.dim (); ++i) {
75- shape.push_back (static_cast <int >(t.size (i)));
75+ auto dim_size = t.size (i);
76+ if (dim_size > std::numeric_limits<int >::max () ||
77+ dim_size < std::numeric_limits<int >::min ()) {
78+ throw std::runtime_error (
79+ " tensor_to_mlx: dimension " + std::to_string (i) + " size " +
80+ std::to_string (dim_size) + " exceeds int range" );
81+ }
82+ shape.push_back (static_cast <int >(dim_size));
7683 }
7784
78- void * data_ptr = const_cast <void *>(t.const_data_ptr ());
85+ // SAFETY: MLX reads this data during async_eval() Metal command encoding,
86+ // which completes before the lock is released. The ET tensor must remain
87+ // valid until async_eval returns.
88+ const void * cptr = t.const_data_ptr ();
89+ if (!cptr) {
90+ throw std::runtime_error (" tensor_to_mlx: tensor has null data pointer" );
91+ }
92+ void * data_ptr = const_cast <void *>(cptr);
7993 auto deleter = [](void *) {};
8094 return array (data_ptr, shape, dtype, deleter);
8195}
@@ -115,8 +129,11 @@ void write_output(array& arr, ETTensor& out) {
115129 }
116130
117131 if (!shape_matches) {
118- std::vector<executorch::aten::SizesType> new_sizes (
119- mlx_shape.begin (), mlx_shape.end ());
132+ std::vector<executorch::aten::SizesType> new_sizes;
133+ new_sizes.reserve (mlx_shape.size ());
134+ for (auto d : mlx_shape) {
135+ new_sizes.push_back (static_cast <executorch::aten::SizesType>(d));
136+ }
120137 auto err = resize_tensor (
121138 out,
122139 ArrayRef<executorch::aten::SizesType>(
@@ -134,7 +151,12 @@ void write_output(array& arr, ETTensor& out) {
134151 " bytes, output has " + std::to_string (out_nbytes) + " bytes" );
135152 }
136153
137- std::memcpy (out.mutable_data_ptr (), arr.data <void >(), out_nbytes);
154+ const void * src = arr.data <void >();
155+ if (!src) {
156+ throw std::runtime_error (
157+ " write_output: arr.data<void>() is null after wait()" );
158+ }
159+ std::memcpy (out.mutable_data_ptr (), src, out_nbytes);
138160}
139161
140162} // namespace
@@ -172,7 +194,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
172194 ~MLXBackend () override = default ;
173195
174196 bool is_available () const override {
175- return true ;
197+ return :: mlx::core::metal::is_available () ;
176198 }
177199
178200 Result<DelegateHandle*> init (
@@ -189,9 +211,20 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
189211 try {
190212 new (handle) MLXHandle ();
191213
214+ if (!processed || !processed->data () || processed->size () == 0 ) {
215+ throw std::runtime_error (" init: null or empty delegate payload" );
216+ }
217+
192218 handle->program = loader::load_program (
193219 static_cast <const uint8_t *>(processed->data ()), processed->size ());
194220
221+ // Validate schema version
222+ if (handle->program .version != " 1" ) {
223+ throw std::runtime_error (
224+ " Unsupported MLX schema version '" + handle->program .version +
225+ " ' (expected '1'). Rebuild the .pte with a matching SDK version." );
226+ }
227+
195228 // Load constants from named_data_map
196229 // Constants are stored by name in the .pte file and provided by ET at
197230 // runtime
@@ -214,7 +247,9 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
214247 handle->state .bind (
215248 handle->program , handle->constants , handle->mutable_buffers );
216249
217- // Run init chain if present
250+ // Run init chain if present.
251+ // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the
252+ // static_cast<uint32_t> cannot produce UINT32_MAX from a -1 sentinel.
218253 if (handle->program .init_chain_idx >= 0 ) {
219254 handle->interpreter .run_chain (
220255 handle->program ,
@@ -258,8 +293,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
258293
259294 h->state .reset ();
260295
261- const size_t expected_args =
262- program.input_map .size () + program.output_map .size ();
296+ const size_t n_inputs = program.input_map .size ();
297+ const size_t n_outputs = program.output_map .size ();
298+ if (n_inputs > SIZE_MAX - n_outputs) {
299+ throw std::runtime_error (" execute: input + output count overflow" );
300+ }
301+ const size_t expected_args = n_inputs + n_outputs;
263302 if (args.size () != expected_args) {
264303 ET_LOG (
265304 Error, " Expected %zu args, got %zu" , expected_args, args.size ());
@@ -268,6 +307,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
268307
269308 // Bind inputs
270309 for (const auto & slot : program.input_map ) {
310+ if (arg_idx >= args.size ()) {
311+ throw std::runtime_error (
312+ " execute: arg_idx " + std::to_string (arg_idx) +
313+ " out of bounds (args.size()=" + std::to_string (args.size ()) +
314+ " )" );
315+ }
271316 if (slot.slot_type == SlotType::TensorSlot) {
272317 const ETTensor& tensor = args[arg_idx++]->toTensor ();
273318 Tid tid{slot.idx };
0 commit comments