Skip to content

Commit fff3b4b

Browse files
authored
[Relax][Frontend][TFLite] Support StableHLO region-based ops and multi-subgraph models (#19587)
## Summary This PR adds Relax TFLite frontend support for 10 additional StableHLO builtin operators from #19519 item I, building on the 29 ops merged in PR #19536. The first 5 ops are direct single-subgraph converters: `CBRT`, `REMAINDER`, `DYNAMIC_UPDATE_SLICE`, `DOT_GENERAL`, and `CONVOLUTION`. The remaining 5 ops are region/subgraph-based: `REDUCE`, `REDUCE_WINDOW`, `SORT`, `SCATTER`, and `COMPOSITE`. To support these, the TFLite frontend is extended to accept multi-subgraph models while still converting only `Subgraphs(0)` into the Relax main function. Region subgraphs are consumed by their parent op converters as needed. Relates to #19519. ## Changes 1. **Single-subgraph ops** - `CBRT` — sign-preserving composite expression: `where(x < 0, -power(-x, 1/3), power(x, 1/3))`. Float dtype only. - `REMAINDER` — truncating remainder via `x - y * trunc(x / y)`, matching StableHLO semantics (sign follows dividend). Float dtype only. - `DYNAMIC_UPDATE_SLICE` — static start indices + static shapes only, lowered to `R.scatter_nd` with a coordinate grid generated via `np.indices`. Runtime starts and out-of-bounds ranges raise `OpNotImplemented`. - `DOT_GENERAL` — canonical 2D matmul subset: no batching dims, `lhs_contracting=[1]`, `rhs_contracting=[0]`, lowered to `R.matmul`. - `CONVOLUTION` — canonical 2D NHWC/HWIO subset with `BatchGroupCount=1`, `FeatureGroupCount=1`, lowered to `R.nn.conv2d`. Non-canonical dimension numbers and grouped/depthwise conv raise `OpNotImplemented`. 2. **Multi-subgraph infrastructure** - Lift `from_tflite()` assertion from `model.SubgraphsLength() == 1` to `model.SubgraphsLength() >= 1`. Only `Subgraphs(0)` is converted into the Relax main function. - Limit `_input_type()` to `Subgraphs(0)` inputs, preventing region parameters from leaking as Relax main function parameters. - Add `_get_stablehlo_simple_body_op` helper for validating and extracting the single operator from a region body subgraph. - Extend test helper `_finish_tflite_model` with `extra_subgraphs` parameter for constructing multi-subgraph TFLite flatbuffers. 3. **Region/subgraph ops** - `REDUCE` — single-op reducer body subgraph. Supports `ADD` → `R.sum`, `MAXIMUM` → `R.max`, `MINIMUM` → `R.min`, `MULTIPLY` → `R.prod`. Init value must match the reducer identity element. - `SORT` — single-op comparator body subgraph. `LT` → ascending sort, `GT` → descending sort via `R.sort`. `IsStable` is not mapped. - `REDUCE_WINDOW` — NHWC 4D 2D-pooling subset with `MAXIMUM` reducer and identity init, lowered to `R.nn.max_pool2d`. BaseDilations must be all 1. - `SCATTER` — single-op update computation body subgraph. Supports `ADD`/`MAXIMUM`/`MINIMUM`/`MULTIPLY` → `R.scatter_nd` with the corresponding reduction mode. Only canonical point-update semantics (no window dims). - `COMPOSITE` — inlines a decomposition subgraph through a recursive `OperatorConverter` with an isolated `ExprTable`, so decomposition tensor bindings cannot overwrite main graph bindings. Only supports composites without `CompositeAttributes`. 4. **Not included** - `STABLEHLO_RESHAPE`, `STABLEHLO_TRANSPOSE`, and `STABLEHLO_SLICE` are left to another contributor. - `WHILE`, `CUSTOM_CALL`, and `RNG_BIT_GENERATOR` are deferred to follow-up PRs. 5. **Bug fix** - Fixed `DYNAMIC_UPDATE_SLICE` scatter_nd indices layout: `np.indices` returns `(rank, *update_shape)` but `scatter_nd` expects `(*update_shape, rank)`. Added `np.moveaxis` to transpose the coordinate axis from first to last position. ## Testing All tests use manually-built minimal TFLite flatbuffers with `tvm.ir.assert_structural_equal`. Region/subgraph tests construct the smallest valid body/comparator/update subgraphs. BuiltinOptions2 ops construct their options via the FlatBuffers schema API. ```bash python -m pytest tests/python/relax/test_frontend_tflite.py -k stablehlo -q ``` ## Result - 39 StableHLO operators registered in the Relax TFLite frontend (29 from PR #19536 + 10 from this PR). - 77 StableHLO test cases covering all registered ops, including structural-equal tests and unsupported/error-path checks: - `REMAINDER` truncating semantics - `DYNAMIC_UPDATE_SLICE` with dynamic starts and out-of-bounds starts - `DOT_GENERAL` with non-canonical contracting dimensions - `CONVOLUTION` with non-canonical dimension numbers and `FeatureGroupCount > 1` - `REDUCE` with unsupported reducer and non-identity init value - `SORT` with unsupported comparator and stable sort - `REDUCE_WINDOW` with unsupported reducer and base dilation - `SCATTER` with unsupported reducer and update window dims - `COMPOSITE` with composite attributes and scope isolation - Multi-subgraph model with unused subgraphs - All 77 StableHLO tests pass. ## References - Issue #19519 item I: StableHLO operators in TFLite - PR #19536: First batch of 29 StableHLO ops
1 parent 1720d30 commit fff3b4b

2 files changed

Lines changed: 1776 additions & 23 deletions

File tree

0 commit comments

Comments
 (0)