|
76 | 76 | "ToDevicePass", |
77 | 77 | "EdgeToBackendOpsPass", |
78 | 78 | "MemoryFormatOpsPass", |
| 79 | + "InPlaceElemWiseLikeOpsPass", |
| 80 | + "ElemWiseInPlaceAwareMemoryPlanningPass", |
79 | 81 | "MemoryPlanningPass", |
80 | 82 | "HintBasedSymShapeEvalPass", |
81 | 83 | "insert_write_back_for_buffers_pass", |
@@ -260,6 +262,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None: |
260 | 262 | # we won't see it in the input graph to the to_out_variant pass, unless |
261 | 263 | # it's retraced after running to_out_variant with the first trace. |
262 | 264 | memory.alloc, |
| 265 | + memory.alloc_inplace, |
263 | 266 | memory.view, |
264 | 267 | executorch_call_delegate, |
265 | 268 | } |
@@ -444,6 +447,109 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: |
444 | 447 | return PassResult(graph_module, True) |
445 | 448 |
|
446 | 449 |
|
| 450 | +class InPlaceElemWiseLikeOpsPass(PassBase): |
| 451 | + """Replace memory.alloc with memory.alloc_inplace for element-wise-like ops. |
| 452 | +
|
| 453 | + For out-variant ops that are element-wise, the output can be allocated in |
| 454 | + the same memory as the input when: |
| 455 | + 1. output_bytes <= input_bytes |
| 456 | + 2. The input tensor has no other users after this op (dead after consumption) |
| 457 | +
|
| 458 | + This pass replaces the memory.alloc node for the output with |
| 459 | + memory.alloc_inplace(input_node, spec), which signals to the memory planner |
| 460 | + to place the output at the same offset as the input. |
| 461 | +
|
| 462 | + Eligible ops are specified via the constructor's eligible_ops parameter or a |
| 463 | + default set is used set by _default_eligible_ops. |
| 464 | + """ |
| 465 | + |
| 466 | + @staticmethod |
| 467 | + def _default_eligible_ops() -> Set[Callable[..., Any]]: |
| 468 | + return { |
| 469 | + torch.ops.cortex_m.quantize_per_tensor.out, |
| 470 | + } |
| 471 | + |
| 472 | + def __init__(self, eligible_ops: Optional[Set[Callable[..., Any]]] = None) -> None: |
| 473 | + self._eligible_ops = ( |
| 474 | + eligible_ops if eligible_ops is not None else self._default_eligible_ops() |
| 475 | + ) |
| 476 | + |
| 477 | + def _is_eligible(self, target: Any) -> bool: |
| 478 | + return target in self._eligible_ops |
| 479 | + |
| 480 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 481 | + changed = False |
| 482 | + for node in graph_module.graph.nodes: |
| 483 | + if node.op != "call_function": |
| 484 | + continue |
| 485 | + if not self._is_eligible(node.target): |
| 486 | + continue |
| 487 | + if not memory_planning._is_out_var_node(node): |
| 488 | + continue |
| 489 | + |
| 490 | + out_arg_names = get_out_args_from_opoverload(node.target) |
| 491 | + if len(out_arg_names) != 1: |
| 492 | + continue |
| 493 | + |
| 494 | + out_alloc_node = node.kwargs.get(out_arg_names[0]) |
| 495 | + if out_alloc_node is None or out_alloc_node.target != memory.alloc: |
| 496 | + continue |
| 497 | + |
| 498 | + out_val = out_alloc_node.meta.get("val") |
| 499 | + if out_val is None or not isinstance(out_val, torch.Tensor): |
| 500 | + continue |
| 501 | + out_nbytes = out_val.nelement() * out_val.element_size() |
| 502 | + |
| 503 | + input_node = None |
| 504 | + for arg in node.args: |
| 505 | + if not isinstance(arg, torch.fx.Node): |
| 506 | + continue |
| 507 | + in_val = arg.meta.get("val") |
| 508 | + if in_val is None or not isinstance(in_val, torch.Tensor): |
| 509 | + continue |
| 510 | + if in_val.nelement() * in_val.element_size() < out_nbytes: |
| 511 | + continue |
| 512 | + if any(u != node and u.target != memory.free for u in arg.users): |
| 513 | + continue |
| 514 | + input_node = arg |
| 515 | + break |
| 516 | + if input_node is None: |
| 517 | + continue |
| 518 | + |
| 519 | + with graph_module.graph.inserting_before(out_alloc_node): |
| 520 | + inplace_node = graph_module.graph.call_function( |
| 521 | + memory.alloc_inplace, |
| 522 | + (input_node, out_alloc_node.args[0]), |
| 523 | + ) |
| 524 | + inplace_node.meta = out_alloc_node.meta.copy() |
| 525 | + |
| 526 | + out_alloc_node.replace_all_uses_with(inplace_node) |
| 527 | + graph_module.graph.erase_node(out_alloc_node) |
| 528 | + changed = True |
| 529 | + |
| 530 | + return PassResult(graph_module, changed) |
| 531 | + |
| 532 | + |
| 533 | +class ElemWiseInPlaceAwareMemoryPlanningPass(MemoryPlanningPass): |
| 534 | + """MemoryPlanningPass that first runs InPlaceElemWiseLikeOpsPass.""" |
| 535 | + |
| 536 | + def __init__( |
| 537 | + self, |
| 538 | + eligible_ops: Optional[Set[Callable[..., Any]]] = None, |
| 539 | + **kwargs: Any, |
| 540 | + ) -> None: |
| 541 | + super().__init__(**kwargs) |
| 542 | + self._inplace_pass = InPlaceElemWiseLikeOpsPass(eligible_ops) |
| 543 | + |
| 544 | + def run( |
| 545 | + self, |
| 546 | + graph_module: torch.fx.GraphModule, |
| 547 | + graph_signature=None, |
| 548 | + ) -> PassResult: |
| 549 | + self._inplace_pass(graph_module) |
| 550 | + return super().run(graph_module, graph_signature) |
| 551 | + |
| 552 | + |
447 | 553 | def to_scratch_op_pass(graph_module: torch.fx.GraphModule) -> PassResult: |
448 | 554 | for node in graph_module.graph.nodes: |
449 | 555 | if node.op != "call_function": |
|
0 commit comments