|
9 | 9 | import operator |
10 | 10 | import traceback |
11 | 11 | from inspect import isclass |
12 | | -from typing import cast, List, Optional, Sequence, Tuple |
| 12 | +from typing import cast, Optional, Sequence |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | import torch.fx |
|
19 | 19 | from executorch.exir import ExportedProgram |
20 | 20 | from executorch.exir.dialects._ops import ops as exir_ops |
21 | 21 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
22 | | -from executorch.exir.graph_module import ( |
23 | | - _get_control_flow_submodules, |
24 | | - get_control_flow_submodules, |
25 | | -) |
26 | 22 | from executorch.exir.pass_base import NodeMetadata |
27 | 23 |
|
28 | 24 | from torch._export.utils import ( |
|
36 | 32 | from torch._ops import OpOverload |
37 | 33 | from torch._subclasses.fake_tensor import FakeTensor |
38 | 34 | from torch.export.graph_signature import InputKind |
39 | | -from torch.fx import GraphModule, Node |
40 | 35 |
|
41 | 36 |
|
42 | 37 | def is_submodule_node(node: torch.fx.Node): |
@@ -364,48 +359,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value): |
364 | 359 | raise RuntimeError("Invalid type") |
365 | 360 |
|
366 | 361 |
|
367 | | -def is_nested_control_flow_graph(graph_module: GraphModule) -> bool: |
368 | | - """Returns True if graph_module is a nested control-flow graph.""" |
369 | | - |
370 | | - # Find all top-level control-flow submodules |
371 | | - top_cf = get_control_flow_submodules(graph_module) |
372 | | - # For each submodule, see if it itself has control-flow inside |
373 | | - for _, submod, _ in top_cf: |
374 | | - if get_control_flow_submodules(submod): |
375 | | - return True |
376 | | - return False |
377 | | - |
378 | | - |
379 | | -def get_cond_while_submodules_nested( |
380 | | - graph_module: GraphModule, |
381 | | - apply_quantization: bool = False, |
382 | | -) -> List[Tuple[str, GraphModule, Node]]: |
383 | | - """Recursively find cond/while_loop submodules in an GraphModule. |
384 | | -
|
385 | | - In nested control flow graphs, FX records the submodule functions |
386 | | - (true/false or cond/body) in reverse order compared to top-level graphs. We |
387 | | - must swap the indices when nested so that cond (first) and body/true_fn |
388 | | - (second) are consistently identified across all nesting levels. |
389 | | -
|
390 | | - """ |
391 | | - |
392 | | - # Determine arg indices based on nesting and whether only cond branch is needed |
393 | | - nested = is_nested_control_flow_graph(graph_module) |
394 | | - # cond: [true_fn, false_fn] or swapped if nested |
395 | | - cond_indices = [2, 1] if nested else [1, 2] |
396 | | - # while_loop: [cond_fn, body_fn] or swapped if nested |
397 | | - while_indices = [1, 0] if nested else [0, 1] |
398 | | - if apply_quantization: |
399 | | - # only keep the cond_fn for while_loop (first index) when quantizing. |
400 | | - while_indices = [while_indices[0]] |
401 | | - mapping = { |
402 | | - torch.ops.higher_order.cond: cond_indices, |
403 | | - torch.ops.higher_order.while_loop: while_indices, |
404 | | - } |
405 | | - # collect cond/while submodules (using mapping indices) |
406 | | - return _get_control_flow_submodules(graph_module, mapping) |
407 | | - |
408 | | - |
409 | 362 | def to_2tuple(value): |
410 | 363 | """Normalizes scalars, and 1-element sequences to a tuple of length 2.""" |
411 | 364 | if isinstance(value, int): |
|
0 commit comments