@@ -184,23 +184,47 @@ MLXProgram load_program(const void* data, size_t size) {
184184 check_collection_size (fb_graph->input_map ()->size (), " input_map" );
185185 for (size_t i = 0 ; i < fb_graph->input_map ()->size (); ++i) {
186186 const auto * slot = fb_graph->input_map ()->Get (static_cast <flatbuffers::uoffset_t >(i));
187- program.input_map .push_back (convert_slot_variant (slot));
187+ auto sv = convert_slot_variant (slot);
188+ if (sv.slot_type == SlotType::TensorSlot &&
189+ sv.idx >= program.num_tensors ()) {
190+ throw std::runtime_error (
191+ " input_map: slot index " + std::to_string (sv.idx ) +
192+ " exceeds num_tensors " +
193+ std::to_string (program.num_tensors ()));
194+ }
195+ program.input_map .push_back (sv);
188196 }
189197 }
190198
191199 if (fb_graph->output_map ()) {
192200 check_collection_size (fb_graph->output_map ()->size (), " output_map" );
193201 for (size_t i = 0 ; i < fb_graph->output_map ()->size (); ++i) {
194202 const auto * slot = fb_graph->output_map ()->Get (static_cast <flatbuffers::uoffset_t >(i));
195- program.output_map .push_back (convert_slot_variant (slot));
203+ auto sv = convert_slot_variant (slot);
204+ if (sv.slot_type == SlotType::TensorSlot &&
205+ sv.idx >= program.num_tensors ()) {
206+ throw std::runtime_error (
207+ " output_map: slot index " + std::to_string (sv.idx ) +
208+ " exceeds num_tensors " +
209+ std::to_string (program.num_tensors ()));
210+ }
211+ program.output_map .push_back (sv);
196212 }
197213 }
198214
199215 if (fb_graph->mutable_buffer_map ()) {
200216 check_collection_size (fb_graph->mutable_buffer_map ()->size (), " mutable_buffer_map" );
201217 for (size_t i = 0 ; i < fb_graph->mutable_buffer_map ()->size (); ++i) {
202218 const auto * slot = fb_graph->mutable_buffer_map ()->Get (static_cast <flatbuffers::uoffset_t >(i));
203- program.mutable_buffer_map .push_back (convert_slot_variant (slot));
219+ auto sv = convert_slot_variant (slot);
220+ if (sv.slot_type == SlotType::TensorSlot &&
221+ sv.idx >= program.num_tensors ()) {
222+ throw std::runtime_error (
223+ " mutable_buffer_map: slot index " + std::to_string (sv.idx ) +
224+ " exceeds num_tensors " +
225+ std::to_string (program.num_tensors ()));
226+ }
227+ program.mutable_buffer_map .push_back (sv);
204228 }
205229 }
206230
0 commit comments