Commit fa66213
authored
[Relax][Frontend][TFLite] Support control-flow multi-subgraph operators (#19616)
## Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
control-flow / multi-subgraph operator family from #19519 item F:
`CALL`, `IF`, `WHILE`, and `CALL_ONCE`.
It builds on the multi-subgraph import infrastructure merged in PR
#19587.
The frontend already accepts TFLite models with extra subgraphs while
converting
only `Subgraphs(0)` into the Relax `main` function. This PR uses those
extra
subgraphs as callable or control-flow regions for the TFLite
control-flow
operators.
The supported subset is intentionally pure tensor and guard-first:
- `CALL` lowers a referenced TFLite subgraph to a private Relax function
and
emits a direct call.
- `IF` lowers the then/else subgraphs to private Relax functions and
emits a
private wrapper function containing Relax `If`.
- `WHILE` lowers the cond/body subgraphs to private Relax functions and
emits a
recursive private Relax function for the loop.
- `CALL_ONCE` supports the empty-init no-op subset and explicitly
rejects
non-empty or resource-like init patterns.
This PR does not model resource variable side effects. Those cases
remain
explicitly guarded instead of being imported with incorrect pure
functional
semantics.
## Design
### Shared Subgraph Lowering
The frontend now keeps shared conversion state across the main graph and
referenced subgraphs:
- `lowered_subgraphs`
- `lowered_if_functions`
- `lowered_while_functions`
- `lowering_stack`
- `module_builder`
Referenced pure tensor subgraphs are lowered through a recursive
`OperatorConverter` using an isolated `ExprTable`, so subgraph tensor
bindings
cannot overwrite bindings from the main graph. Lowered subgraphs are
cached by
subgraph index and reused when the same region is referenced more than
once.
Generated private functions are registered through the shared parent
`module_builder`, so nested cases such as `main CALL -> subgraph A ->
CALL
subgraph B` keep all private functions in the final IRModule.
Recursive ordinary `CALL` subgraphs are guarded with `OpNotImplemented`.
`WHILE` uses a dedicated recursive wrapper function instead, because
recursion
is part of the intended Relax representation for the loop itself.
### Boundary Validation
The control-flow converters validate subgraph boundaries before
lowering:
- referenced subgraph indices must be valid
- op input/output arity must match the referenced subgraph interface
- branch and loop tensor shape/dtype metadata must match the surrounding
op
- `IF` and `WHILE` conditions must be scalar bool tensors
- `WHILE` loop-carried input/output tensors must have matching metadata
The shared `_check_subgraph_interface` helper is used by `CALL`, `IF`,
and
`WHILE` to keep arity and metadata checks consistent across the
control-flow
operators. `_require_scalar_bool_tensor` accepts both frontend
`TensorWrapper`
objects and raw TFLite tensors so caller and referenced-subgraph
condition
checks use the same path.
These checks keep the first implementation conservative and make
unsupported
cases fail with targeted `OpNotImplemented` diagnostics.
### Tuple Outputs
TFLite `CALL`, `IF`, and `WHILE` may produce multiple output tensors.
The
frontend maps those cases to Relax tuple returns:
```text
single output -> tensor expression
multi output -> Tuple(...)
op outputs -> TupleGetItem(...)
```
This keeps the single-output IR simple while covering multi-output
calls,
multi-output branches, and multi-variable loop state.
## Operator Support
| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `CALL` | `CallOptions.Subgraph()` | private Relax function call | pure
tensor subgraphs, single or multiple outputs |
| `IF` | `IfOptions.ThenSubgraphIndex()`, `ElseSubgraphIndex()` |
private wrapper function containing Relax `If` | scalar bool condition,
matching branch I/O metadata |
| `WHILE` | `WhileOptions.CondSubgraphIndex()`, `BodySubgraphIndex()` |
recursive private Relax function | scalar bool cond output, tensor
loop-carried state |
| `CALL_ONCE` | `CallOnceOptions.InitSubgraphIndex()` | no-op for empty
init subgraph | empty init subgraph only |
## Not Included
- Full `CALL_ONCE` resource/variable initialization semantics.
- Resource, variant, hashtable, or variable tensor support.
- TensorFlow-generated `tf.cond` / `tf.while_loop` smoke tests.
- Dynamic-shape loop-state refinements beyond the current static
metadata
checks.
## Tests
The tests manually build minimal TFLite flatbuffers and compare the
imported
Relax IR with `tvm.ir.assert_structural_equal`. Unsupported-boundary
tests use
`pytest.raises`.
| Test | Coverage |
|---|---|
| `test_call_subgraph` | basic `CALL` to a pure tensor subgraph |
| `test_call_subgraph_multi_output` | `CALL` tuple return and output
binding |
| `test_call_subgraph_nested_call` | nested `CALL` private function
registration |
| `test_call_subgraph_invalid_index_unsupported` | invalid `CALL`
subgraph index |
| `test_call_subgraph_io_mismatch_unsupported` | `CALL` arity mismatch |
| `test_call_subgraph_output_metadata_mismatch_unsupported` | `CALL`
output metadata guard |
| `test_if_subgraphs` | basic `IF` branch selection |
| `test_if_subgraphs_multi_output` | `IF` tuple branch returns |
| `test_if_subgraphs_non_bool_condition_unsupported` | `IF` condition
dtype guard |
| `test_if_subgraphs_invalid_index_unsupported` | invalid then/else
subgraph index |
| `test_if_subgraphs_output_count_mismatch_unsupported` | branch output
count guard |
| `test_if_subgraphs_input_metadata_mismatch_unsupported` | branch input
metadata guard |
| `test_if_subgraphs_output_metadata_mismatch_unsupported` | branch
output metadata guard |
| `test_while_subgraphs` | basic recursive `WHILE` lowering |
| `test_while_subgraphs_repeated_cond_body_pair` | shared cond/body loop
function cache |
| `test_while_subgraphs_two_loop_vars` | multi-variable loop state tuple
path |
| `test_while_subgraphs_non_bool_condition_unsupported` | `WHILE` cond
output dtype guard |
| `test_while_subgraphs_invalid_index_unsupported` | invalid cond/body
subgraph index |
| `test_while_subgraphs_zero_loop_vars_unsupported` | zero-loop-var
guard |
| `test_while_subgraphs_loop_state_metadata_mismatch_unsupported` | loop
state metadata guard |
| `test_while_subgraphs_output_count_mismatch_unsupported` | body output
count guard |
| `test_while_subgraphs_input_metadata_mismatch_unsupported` | cond/body
input metadata guard |
| `test_while_subgraphs_output_metadata_mismatch_unsupported` |
cond/body output metadata guard |
| `test_call_once_empty_init_subgraph` | empty `CALL_ONCE` no-op subset
|
| `test_call_once_non_empty_init_subgraph_unsupported` | non-empty init
subgraph guard |
| `test_call_once_inputs_outputs_unsupported` | `CALL_ONCE` op I/O guard
|
| `test_call_once_init_subgraph_io_unsupported` | init subgraph I/O
guard |
| `test_call_once_invalid_index_unsupported` | invalid init subgraph
index |
Local validation:
```bash
python -m ruff format --check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m ruff check \
python/tvm/relax/frontend/tflite/tflite_frontend.py \
tests/python/relax/test_frontend_tflite.py
python -m pytest \
tests/python/relax/test_frontend_tflite.py \
-k "call_subgraph or if_subgraphs or while_subgraphs or call_once" -q
python -m pytest \
tests/python/relax/test_frontend_tflite.py -q
```
Result:
```text
ruff format --check: 2 files already formatted
ruff check: All checks passed
28 passed, 434 deselected
462 passed
```
## References
- Issue #19519 item F: TFLite control-flow / multi-subgraph operators
- PR #19587: StableHLO region-based ops and multi-subgraph model support1 parent ec3171a commit fa66213
2 files changed
Lines changed: 1782 additions & 4 deletions
File tree
- python/tvm/relax/frontend/tflite
- tests/python/relax
0 commit comments