Skip to content

Commit 0c76afa

Browse files
committed
up
1 parent dc25245 commit 0c76afa

1 file changed

Lines changed: 27 additions & 3 deletions

File tree

backends/mlx/serialization/MLXLoader.cpp.tmpl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,23 +184,47 @@ MLXProgram load_program(const void* data, size_t size) {
184184
check_collection_size(fb_graph->input_map()->size(), "input_map");
185185
for (size_t i = 0; i < fb_graph->input_map()->size(); ++i) {
186186
const auto* slot = fb_graph->input_map()->Get(static_cast<flatbuffers::uoffset_t>(i));
187-
program.input_map.push_back(convert_slot_variant(slot));
187+
auto sv = convert_slot_variant(slot);
188+
if (sv.slot_type == SlotType::TensorSlot &&
189+
sv.idx >= program.num_tensors()) {
190+
throw std::runtime_error(
191+
"input_map: slot index " + std::to_string(sv.idx) +
192+
" exceeds num_tensors " +
193+
std::to_string(program.num_tensors()));
194+
}
195+
program.input_map.push_back(sv);
188196
}
189197
}
190198

191199
if (fb_graph->output_map()) {
192200
check_collection_size(fb_graph->output_map()->size(), "output_map");
193201
for (size_t i = 0; i < fb_graph->output_map()->size(); ++i) {
194202
const auto* slot = fb_graph->output_map()->Get(static_cast<flatbuffers::uoffset_t>(i));
195-
program.output_map.push_back(convert_slot_variant(slot));
203+
auto sv = convert_slot_variant(slot);
204+
if (sv.slot_type == SlotType::TensorSlot &&
205+
sv.idx >= program.num_tensors()) {
206+
throw std::runtime_error(
207+
"output_map: slot index " + std::to_string(sv.idx) +
208+
" exceeds num_tensors " +
209+
std::to_string(program.num_tensors()));
210+
}
211+
program.output_map.push_back(sv);
196212
}
197213
}
198214

199215
if (fb_graph->mutable_buffer_map()) {
200216
check_collection_size(fb_graph->mutable_buffer_map()->size(), "mutable_buffer_map");
201217
for (size_t i = 0; i < fb_graph->mutable_buffer_map()->size(); ++i) {
202218
const auto* slot = fb_graph->mutable_buffer_map()->Get(static_cast<flatbuffers::uoffset_t>(i));
203-
program.mutable_buffer_map.push_back(convert_slot_variant(slot));
219+
auto sv = convert_slot_variant(slot);
220+
if (sv.slot_type == SlotType::TensorSlot &&
221+
sv.idx >= program.num_tensors()) {
222+
throw std::runtime_error(
223+
"mutable_buffer_map: slot index " + std::to_string(sv.idx) +
224+
" exceeds num_tensors " +
225+
std::to_string(program.num_tensors()));
226+
}
227+
program.mutable_buffer_map.push_back(sv);
204228
}
205229
}
206230

0 commit comments

Comments
 (0)