@@ -72,22 +72,28 @@ class EdgeProgramToIRConverter:
7272 _default_target_spec = NeutronTargetSpec ("imxrt700" )
7373 _default_delegation_options = CustomDelegationOptions ()
7474
75+ def __init__ (self ):
76+ self .edge_to_tflite_map = {}
77+
7578 def convert_program (
7679 self ,
7780 edge_program : ExportedProgram ,
7881 conversion_config : ConversionConfig = _default_conversion_config ,
7982 neutron_target_spec : NeutronTargetSpec = _default_target_spec ,
8083 custom_delegation_options : CustomDelegationOptions = _default_delegation_options ,
81- ) -> tuple [bytes , dict [str , DataFormat ]]:
84+ ) -> tuple [bytes , dict [str , DataFormat ], dict [ int , tuple [ int , ...]] ]:
8285 """
8386 Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.
8487
8588 :param edge_program: Converter ExportedProgram.
8689 :param conversion_config: ConversionConfig instance.
8790 :param neutron_target_spec: Object for querying the target platform to retrieve its properties.
8891 :param custom_delegation_options: Custom user options which affect node delegation.
89- :return: TFLite flatbuffers as bytes.
92+ :return: TFLite flatbuffers as bytes, I/O formats, and edge-to-tflite mapping .
9093 """
94+ # Reset the edge to tflite map for each conversion
95+ self .edge_to_tflite_map = {}
96+
9197 parameters_mapping = self .map_inputs_to_parameters (edge_program )
9298 dim_order_map = self .map_nodes_to_dim_order (edge_program )
9399
@@ -110,14 +116,17 @@ def convert_program(
110116 # Apply optimizations and finalize the model.
111117 internal_tflite_model = cc .tflite_builder .finish ()
112118
119+ # Get the final edge to tflite mapping after optimization
120+ self .edge_to_tflite_map = cc .tflite_builder .edge_to_tflite_map
121+
113122 # Extract the formats of the model's inputs and outputs.
114123 io_formats = cc .tflite_builder .get_io_formats (edge_program .graph_signature )
115124
116125 # TFLite model generation
117126 flatbuffers_builder = flatbuffers .Builder ()
118127 internal_tflite_model .gen_tflite (flatbuffers_builder )
119128
120- return bytes (flatbuffers_builder .Output ()), io_formats
129+ return bytes (flatbuffers_builder .Output ()), io_formats , self . edge_to_tflite_map
121130
122131 @staticmethod
123132 def append_placeholders_and_tensors (nodes : list [Node ], context : ConversionContext ):
@@ -159,7 +168,6 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex
159168 exir_ops .edge .quantized_decomposed .dequantize_per_channel .default ,
160169 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
161170 ]
162-
163171 for node in nodes :
164172 if node .op == "call_function" :
165173 if node .target in qdq_related_functions and "cluster" in node .meta :
@@ -171,7 +179,22 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex
171179 # The node was already processed alongside the Q/DQ ops.
172180 pass
173181 elif node .target in functions_converters :
182+ # Get TFLite op count BEFORE conversion
183+ tflite_op_count_before = len (conversion_context .tflite_builder .get_operators ().vector )
184+ # Convert the node
174185 functions_converters [node .target ](conversion_context ).convert (node )
186+ # Get TFLite op count AFTER conversion
187+ tflite_op_count_after = len (conversion_context .tflite_builder .get_operators ().vector )
188+
189+ # Track the mapping - store edge debug handle in operators
190+ edge_debug_handle = node .meta .get ("debug_handle" , None )
191+ if edge_debug_handle is not None and tflite_op_count_after > tflite_op_count_before :
192+ operators = conversion_context .tflite_builder .get_operators ().vector
193+ for i in range (tflite_op_count_before , tflite_op_count_after ):
194+ # Store edge debug handle in operator's temporary attribute
195+ operators [i ].tmp_edge_debug_handle = edge_debug_handle
196+ logger .i (f"Tagged TFLite ops { list (range (tflite_op_count_before , tflite_op_count_after ))} with edge debug_handle={ edge_debug_handle } for node '{ node .name } '" )
197+
175198 else :
176199 logger .e (
177200 logger .Code .NOT_IMPLEMENTED ,
0 commit comments