|
27 | 27 | from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union |
28 | 28 |
|
29 | 29 | import torch |
30 | | - |
31 | 30 | from executorch.backends.mlx._logging import logger |
32 | 31 | from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type |
33 | 32 | from executorch.backends.mlx.builder.op_registry import ( |
@@ -132,7 +131,9 @@ class MLXProgramBuilder: |
132 | 131 |
|
133 | 132 | def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): |
134 | 133 | self.ep: ExportedProgram = ep |
135 | | - self._instrs: List[Instruction] = [] |
| 134 | + self._chains: List[List[Instruction]] = [[]] # chain 0 = main |
| 135 | + self._current_chain: int = 0 |
| 136 | + self.init_chain_idx: int = -1 |
136 | 137 | self.extra_constants: Dict[str, torch.Tensor] = {} |
137 | 138 | self.slot_manager = SlotManager() |
138 | 139 | self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) |
@@ -163,7 +164,13 @@ def _prefix_key(self, name: str) -> str: |
163 | 164 | return name |
164 | 165 |
|
165 | 166 | def emit(self, op: OpNodeUnion) -> None: |
166 | | - self._instrs.append(Instruction(op=op)) |
| 167 | + self._chains[self._current_chain].append(Instruction(op=op)) |
| 168 | + |
| 169 | + def emit_init(self, op: OpNodeUnion) -> None: |
| 170 | + if self.init_chain_idx == -1: |
| 171 | + self.init_chain_idx = len(self._chains) |
| 172 | + self._chains.append([]) |
| 173 | + self._chains[self.init_chain_idx].append(Instruction(op=op)) |
167 | 174 |
|
168 | 175 | def args(self, node: Node) -> Tuple[Any, ...]: |
169 | 176 | return self.slot_map(node.args) |
@@ -934,9 +941,11 @@ def _build_mlx_graph(self) -> MLXGraph: |
934 | 941 | num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], |
935 | 942 | num_temp_tensors=num_temp_tensors, |
936 | 943 | num_values=num_values_count, |
937 | | - instruction_chains=[InstructionChain(instructions=self._instrs)], |
| 944 | + instruction_chains=[ |
| 945 | + InstructionChain(instructions=chain) for chain in self._chains |
| 946 | + ], |
938 | 947 | main_chain_idx=0, |
939 | | - init_chain_idx=-1, |
| 948 | + init_chain_idx=self.init_chain_idx, |
940 | 949 | input_map=input_map, |
941 | 950 | output_map=output_map, |
942 | 951 | mutable_buffer_map=mutable_buffer_map, |
|
0 commit comments