Commit fff3b4b
authored
[Relax][Frontend][TFLite] Support StableHLO region-based ops and multi-subgraph models (#19587)
## Summary
This PR adds Relax TFLite frontend support for 10 additional StableHLO
builtin
operators from #19519 item I, building on the 29 ops merged in PR
#19536.
The first 5 ops are direct single-subgraph converters: `CBRT`,
`REMAINDER`,
`DYNAMIC_UPDATE_SLICE`, `DOT_GENERAL`, and `CONVOLUTION`. The remaining
5 ops
are region/subgraph-based: `REDUCE`, `REDUCE_WINDOW`, `SORT`, `SCATTER`,
and
`COMPOSITE`. To support these, the TFLite frontend is extended to accept
multi-subgraph models while still converting only `Subgraphs(0)` into
the
Relax main function. Region subgraphs are consumed by their parent op
converters as needed.
Relates to #19519.
## Changes
1. **Single-subgraph ops**
- `CBRT` — sign-preserving composite expression:
`where(x < 0, -power(-x, 1/3), power(x, 1/3))`. Float dtype only.
- `REMAINDER` — truncating remainder via `x - y * trunc(x / y)`,
matching
StableHLO semantics (sign follows dividend). Float dtype only.
- `DYNAMIC_UPDATE_SLICE` — static start indices + static shapes only,
lowered
to `R.scatter_nd` with a coordinate grid generated via `np.indices`.
Runtime starts and out-of-bounds ranges raise `OpNotImplemented`.
- `DOT_GENERAL` — canonical 2D matmul subset: no batching dims,
`lhs_contracting=[1]`, `rhs_contracting=[0]`, lowered to `R.matmul`.
- `CONVOLUTION` — canonical 2D NHWC/HWIO subset with
`BatchGroupCount=1`,
`FeatureGroupCount=1`, lowered to `R.nn.conv2d`. Non-canonical dimension
numbers and grouped/depthwise conv raise `OpNotImplemented`.
2. **Multi-subgraph infrastructure**
- Lift `from_tflite()` assertion from `model.SubgraphsLength() == 1` to
`model.SubgraphsLength() >= 1`. Only `Subgraphs(0)` is converted into
the
Relax main function.
- Limit `_input_type()` to `Subgraphs(0)` inputs, preventing region
parameters from leaking as Relax main function parameters.
- Add `_get_stablehlo_simple_body_op` helper for validating and
extracting
the single operator from a region body subgraph.
- Extend test helper `_finish_tflite_model` with `extra_subgraphs`
parameter
for constructing multi-subgraph TFLite flatbuffers.
3. **Region/subgraph ops**
- `REDUCE` — single-op reducer body subgraph. Supports `ADD` → `R.sum`,
`MAXIMUM` → `R.max`, `MINIMUM` → `R.min`, `MULTIPLY` → `R.prod`.
Init value must match the reducer identity element.
- `SORT` — single-op comparator body subgraph. `LT` → ascending sort,
`GT` → descending sort via `R.sort`. `IsStable` is not mapped.
- `REDUCE_WINDOW` — NHWC 4D 2D-pooling subset with `MAXIMUM` reducer and
identity init, lowered to `R.nn.max_pool2d`. BaseDilations must be all
1.
- `SCATTER` — single-op update computation body subgraph. Supports
`ADD`/`MAXIMUM`/`MINIMUM`/`MULTIPLY` → `R.scatter_nd` with the
corresponding reduction mode. Only canonical point-update semantics
(no window dims).
- `COMPOSITE` — inlines a decomposition subgraph through a recursive
`OperatorConverter` with an isolated `ExprTable`, so decomposition
tensor
bindings cannot overwrite main graph bindings. Only supports composites
without `CompositeAttributes`.
4. **Not included**
- `STABLEHLO_RESHAPE`, `STABLEHLO_TRANSPOSE`, and `STABLEHLO_SLICE` are
left to another contributor.
- `WHILE`, `CUSTOM_CALL`, and `RNG_BIT_GENERATOR` are deferred to
follow-up
PRs.
5. **Bug fix**
- Fixed `DYNAMIC_UPDATE_SLICE` scatter_nd indices layout: `np.indices`
returns `(rank, *update_shape)` but `scatter_nd` expects
`(*update_shape, rank)`. Added `np.moveaxis` to transpose the coordinate
axis from first to last position.
## Testing
All tests use manually-built minimal TFLite flatbuffers with
`tvm.ir.assert_structural_equal`. Region/subgraph tests construct the
smallest
valid body/comparator/update subgraphs. BuiltinOptions2 ops construct
their
options via the FlatBuffers schema API.
```bash
python -m pytest tests/python/relax/test_frontend_tflite.py -k stablehlo -q
```
## Result
- 39 StableHLO operators registered in the Relax TFLite frontend (29
from
PR #19536 + 10 from this PR).
- 77 StableHLO test cases covering all registered ops, including
structural-equal tests and unsupported/error-path checks:
- `REMAINDER` truncating semantics
- `DYNAMIC_UPDATE_SLICE` with dynamic starts and out-of-bounds starts
- `DOT_GENERAL` with non-canonical contracting dimensions
- `CONVOLUTION` with non-canonical dimension numbers and
`FeatureGroupCount > 1`
- `REDUCE` with unsupported reducer and non-identity init value
- `SORT` with unsupported comparator and stable sort
- `REDUCE_WINDOW` with unsupported reducer and base dilation
- `SCATTER` with unsupported reducer and update window dims
- `COMPOSITE` with composite attributes and scope isolation
- Multi-subgraph model with unused subgraphs
- All 77 StableHLO tests pass.
## References
- Issue #19519 item I: StableHLO operators in TFLite
- PR #19536: First batch of 29 StableHLO ops1 parent 1720d30 commit fff3b4b
2 files changed
Lines changed: 1776 additions & 23 deletions
File tree
- python/tvm/relax/frontend/tflite
- tests/python/relax
0 commit comments