|
44 | 44 | ) |
45 | 45 | from executorch.backends.mlx.serialization.mlx_graph_schema import ( |
46 | 46 | FloatOrVid, |
| 47 | + IdCopyNode, |
47 | 48 | Instruction, |
48 | 49 | InstructionChain, |
49 | 50 | IntOrVid, |
@@ -557,11 +558,51 @@ def check_support_only(self) -> None: |
557 | 558 | # SymInts and corrupts the shape_env. This method is used for |
558 | 559 | # ops_to_not_decompose() where we only need support status. |
559 | 560 |
|
| 561 | + def _emit_buffer_mutation_writebacks(self): |
| 562 | + """Emit copy-back instructions for BUFFER_MUTATION outputs. |
| 563 | +
|
| 564 | + When a model mutates a buffer (e.g., via .copy_() or .mul_()), |
| 565 | + torch.export functionalizes it: the new value is a computation result, |
| 566 | + and the output spec marks it as BUFFER_MUTATION with a target buffer. |
| 567 | +
|
| 568 | + This method emits an IdCopyNode for each BUFFER_MUTATION output, |
| 569 | + copying the computation result back to the mutable buffer slot so |
| 570 | + the updated value persists across execution calls. |
| 571 | + """ |
| 572 | + from torch.export.graph_signature import InputKind, OutputKind |
| 573 | + |
| 574 | + # Map buffer target name -> input placeholder name |
| 575 | + target_to_placeholder = {} |
| 576 | + for ispec in self.ep.graph_signature.input_specs: |
| 577 | + if ispec.kind == InputKind.BUFFER and ispec.target is not None: |
| 578 | + target_to_placeholder[ispec.target] = ispec.arg.name |
| 579 | + |
| 580 | + for ospec in self.ep.graph_signature.output_specs: |
| 581 | + if ospec.kind != OutputKind.BUFFER_MUTATION: |
| 582 | + continue |
| 583 | + |
| 584 | + result_slot = self.slot_manager.get_slot(ospec.arg.name) |
| 585 | + placeholder_name = target_to_placeholder.get(ospec.target) |
| 586 | + if result_slot is None or placeholder_name is None: |
| 587 | + continue |
| 588 | + |
| 589 | + buffer_slot = self.slot_manager.get_slot(placeholder_name) |
| 590 | + if buffer_slot is None or buffer_slot.id_space != IdSpace.MutableBuffer: |
| 591 | + continue |
| 592 | + |
| 593 | + self.emit( |
| 594 | + IdCopyNode( |
| 595 | + x=self.slot_to_tid(result_slot), |
| 596 | + out=self.slot_to_tid(buffer_slot), |
| 597 | + ) |
| 598 | + ) |
| 599 | + |
560 | 600 | def build(self) -> MLXGraph: |
561 | 601 | if self._mlx_graph is not None: |
562 | 602 | return self._mlx_graph |
563 | 603 |
|
564 | 604 | self._process_nodes() |
| 605 | + self._emit_buffer_mutation_writebacks() |
565 | 606 | self._verify_build() |
566 | 607 | self._mlx_graph = self._build_mlx_graph() |
567 | 608 | return self._mlx_graph |
|
0 commit comments