Skip to content

Commit fa66213

Browse files
authored
[Relax][Frontend][TFLite] Support control-flow multi-subgraph operators (#19616)
## Summary This PR adds Relax TFLite frontend support for the TFLite builtin control-flow / multi-subgraph operator family from #19519 item F: `CALL`, `IF`, `WHILE`, and `CALL_ONCE`. It builds on the multi-subgraph import infrastructure merged in PR #19587. The frontend already accepts TFLite models with extra subgraphs while converting only `Subgraphs(0)` into the Relax `main` function. This PR uses those extra subgraphs as callable or control-flow regions for the TFLite control-flow operators. The supported subset is intentionally pure tensor and guard-first: - `CALL` lowers a referenced TFLite subgraph to a private Relax function and emits a direct call. - `IF` lowers the then/else subgraphs to private Relax functions and emits a private wrapper function containing Relax `If`. - `WHILE` lowers the cond/body subgraphs to private Relax functions and emits a recursive private Relax function for the loop. - `CALL_ONCE` supports the empty-init no-op subset and explicitly rejects non-empty or resource-like init patterns. This PR does not model resource variable side effects. Those cases remain explicitly guarded instead of being imported with incorrect pure functional semantics. ## Design ### Shared Subgraph Lowering The frontend now keeps shared conversion state across the main graph and referenced subgraphs: - `lowered_subgraphs` - `lowered_if_functions` - `lowered_while_functions` - `lowering_stack` - `module_builder` Referenced pure tensor subgraphs are lowered through a recursive `OperatorConverter` using an isolated `ExprTable`, so subgraph tensor bindings cannot overwrite bindings from the main graph. Lowered subgraphs are cached by subgraph index and reused when the same region is referenced more than once. Generated private functions are registered through the shared parent `module_builder`, so nested cases such as `main CALL -> subgraph A -> CALL subgraph B` keep all private functions in the final IRModule. Recursive ordinary `CALL` subgraphs are guarded with `OpNotImplemented`. `WHILE` uses a dedicated recursive wrapper function instead, because recursion is part of the intended Relax representation for the loop itself. ### Boundary Validation The control-flow converters validate subgraph boundaries before lowering: - referenced subgraph indices must be valid - op input/output arity must match the referenced subgraph interface - branch and loop tensor shape/dtype metadata must match the surrounding op - `IF` and `WHILE` conditions must be scalar bool tensors - `WHILE` loop-carried input/output tensors must have matching metadata The shared `_check_subgraph_interface` helper is used by `CALL`, `IF`, and `WHILE` to keep arity and metadata checks consistent across the control-flow operators. `_require_scalar_bool_tensor` accepts both frontend `TensorWrapper` objects and raw TFLite tensors so caller and referenced-subgraph condition checks use the same path. These checks keep the first implementation conservative and make unsupported cases fail with targeted `OpNotImplemented` diagnostics. ### Tuple Outputs TFLite `CALL`, `IF`, and `WHILE` may produce multiple output tensors. The frontend maps those cases to Relax tuple returns: ```text single output -> tensor expression multi output -> Tuple(...) op outputs -> TupleGetItem(...) ``` This keeps the single-output IR simple while covering multi-output calls, multi-output branches, and multi-variable loop state. ## Operator Support | Operator | TFLite options | Relax lowering | Supported subset | |---|---|---|---| | `CALL` | `CallOptions.Subgraph()` | private Relax function call | pure tensor subgraphs, single or multiple outputs | | `IF` | `IfOptions.ThenSubgraphIndex()`, `ElseSubgraphIndex()` | private wrapper function containing Relax `If` | scalar bool condition, matching branch I/O metadata | | `WHILE` | `WhileOptions.CondSubgraphIndex()`, `BodySubgraphIndex()` | recursive private Relax function | scalar bool cond output, tensor loop-carried state | | `CALL_ONCE` | `CallOnceOptions.InitSubgraphIndex()` | no-op for empty init subgraph | empty init subgraph only | ## Not Included - Full `CALL_ONCE` resource/variable initialization semantics. - Resource, variant, hashtable, or variable tensor support. - TensorFlow-generated `tf.cond` / `tf.while_loop` smoke tests. - Dynamic-shape loop-state refinements beyond the current static metadata checks. ## Tests The tests manually build minimal TFLite flatbuffers and compare the imported Relax IR with `tvm.ir.assert_structural_equal`. Unsupported-boundary tests use `pytest.raises`. | Test | Coverage | |---|---| | `test_call_subgraph` | basic `CALL` to a pure tensor subgraph | | `test_call_subgraph_multi_output` | `CALL` tuple return and output binding | | `test_call_subgraph_nested_call` | nested `CALL` private function registration | | `test_call_subgraph_invalid_index_unsupported` | invalid `CALL` subgraph index | | `test_call_subgraph_io_mismatch_unsupported` | `CALL` arity mismatch | | `test_call_subgraph_output_metadata_mismatch_unsupported` | `CALL` output metadata guard | | `test_if_subgraphs` | basic `IF` branch selection | | `test_if_subgraphs_multi_output` | `IF` tuple branch returns | | `test_if_subgraphs_non_bool_condition_unsupported` | `IF` condition dtype guard | | `test_if_subgraphs_invalid_index_unsupported` | invalid then/else subgraph index | | `test_if_subgraphs_output_count_mismatch_unsupported` | branch output count guard | | `test_if_subgraphs_input_metadata_mismatch_unsupported` | branch input metadata guard | | `test_if_subgraphs_output_metadata_mismatch_unsupported` | branch output metadata guard | | `test_while_subgraphs` | basic recursive `WHILE` lowering | | `test_while_subgraphs_repeated_cond_body_pair` | shared cond/body loop function cache | | `test_while_subgraphs_two_loop_vars` | multi-variable loop state tuple path | | `test_while_subgraphs_non_bool_condition_unsupported` | `WHILE` cond output dtype guard | | `test_while_subgraphs_invalid_index_unsupported` | invalid cond/body subgraph index | | `test_while_subgraphs_zero_loop_vars_unsupported` | zero-loop-var guard | | `test_while_subgraphs_loop_state_metadata_mismatch_unsupported` | loop state metadata guard | | `test_while_subgraphs_output_count_mismatch_unsupported` | body output count guard | | `test_while_subgraphs_input_metadata_mismatch_unsupported` | cond/body input metadata guard | | `test_while_subgraphs_output_metadata_mismatch_unsupported` | cond/body output metadata guard | | `test_call_once_empty_init_subgraph` | empty `CALL_ONCE` no-op subset | | `test_call_once_non_empty_init_subgraph_unsupported` | non-empty init subgraph guard | | `test_call_once_inputs_outputs_unsupported` | `CALL_ONCE` op I/O guard | | `test_call_once_init_subgraph_io_unsupported` | init subgraph I/O guard | | `test_call_once_invalid_index_unsupported` | invalid init subgraph index | Local validation: ```bash python -m ruff format --check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m ruff check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m pytest \ tests/python/relax/test_frontend_tflite.py \ -k "call_subgraph or if_subgraphs or while_subgraphs or call_once" -q python -m pytest \ tests/python/relax/test_frontend_tflite.py -q ``` Result: ```text ruff format --check: 2 files already formatted ruff check: All checks passed 28 passed, 434 deselected 462 passed ``` ## References - Issue #19519 item F: TFLite control-flow / multi-subgraph operators - PR #19587: StableHLO region-based ops and multi-subgraph model support
1 parent ec3171a commit fa66213

2 files changed

Lines changed: 1782 additions & 4 deletions

File tree

0 commit comments

Comments
 (0)