Skip to content

Commit 4b0a039

Browse files
authored
[Relax][TFLite] Add remaining operator tests and reverse_sequence op (#19814)
## Summary This PR adds focused Relax TFLite frontend coverage for the remaining non-quantized builtin operators tracked by #18971: - `SQUEEZE` - `REVERSE_SEQUENCE` - `UNPACK` - `ZEROS_LIKE` The tests manually build minimal TFLite flatbuffers and compare the imported Relax IR with `tvm.ir.assert_structural_equal`. This keeps the coverage on the frontend importer itself, without depending on TensorFlow converter rewrites or constant folding. The PR also adds first-class Relax support for `reverse_sequence`. TFLite `REVERSE_SEQUENCE` was previously routed through: ```text R.call_dps_packed("topi.reverse_sequence", ...) ``` That is not executable as a runtime packed call because `topi.reverse_sequence` is a TE compute and expects TE tensors during lowering. The frontend now emits `R.reverse_sequence`, and `LegalizeOps` lowers it through TOPI to TIR: ```text TFLite REVERSE_SEQUENCE -> R.reverse_sequence -> LegalizeOps -> topi.reverse_sequence -> R.call_tir ``` ## Design ### TFLite Operator Tests The new TFLite tests use hand-built flatbuffers for small importer fixtures: - `SQUEEZE` checks axis handling and direct Relax `squeeze` lowering. - `REVERSE_SEQUENCE` checks import to `R.reverse_sequence`, rejects the old `R.call_dps_packed("topi.reverse_sequence", ...)` path, compiles the module, and runs it with the VM. - `UNPACK` checks multi-output lowering through Relax tuple output handling. - `ZEROS_LIKE` checks direct Relax zero-like tensor creation. ### Relax reverse_sequence Operator The PR adds a public Relax operator: ```python relax.op.reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0) ``` The operator uses `ReverseSequenceAttrs` with `seq_axis` and `batch_axis`. Type inference preserves the input tensor's shape, dtype, and vdevice, and validates the statically known constraints: - `data` must be a tensor. - `seq_lengths` must be a 1-D tensor. - `seq_lengths` dtype must be `int32` or `int64`. - `seq_axis` and `batch_axis` must be in `[-ndim, ndim)` when the input rank is known. - `seq_lengths.shape[0]` must match the batch-axis extent when both shapes are statically available. The op is exported through Python as `relax.op.reverse_sequence` and through the script builder as `R.reverse_sequence`. ### Legalization `relax.reverse_sequence` is registered in `LegalizeOps` and lowered with `bb.call_te`: ```python bb.call_te( topi.reverse_sequence, data, seq_lengths, seq_axis, batch_axis, primfunc_name_hint="reverse_sequence", ) ``` This produces `R.call_tir` in the legalized Relax module, keeping runtime execution on the normal TOPI/TIR path. ### TOPI Packed Registration The Python TOPI wrapper already accepts `batch_axis`: ```python topi.reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0) ``` The C++ packed registration only forwarded the first three arguments, so Python calls that provided `batch_axis` would drop it before reaching the TOPI compute. The registration now forwards the fourth argument and keeps the old three-argument call form compatible by defaulting `batch_axis=0`. ## Operator Support | Operator | TFLite options | Relax lowering | Supported subset | |---|---|---|---| | `SQUEEZE` | `SqueezeOptions.SqueezeDims()` | `R.squeeze` | static squeeze axes from TFLite options | | `REVERSE_SEQUENCE` | `ReverseSequenceOptions.SeqDim()`, `BatchDim()` | `R.reverse_sequence` legalized to TOPI/TIR | tensor input, 1-D int32/int64 `seq_lengths`, valid `seq_axis` and `batch_axis` | | `UNPACK` | `UnpackOptions.Axis()`, `Num()` | Relax tuple output | static axis and output count from TFLite options | | `ZEROS_LIKE` | none | `R.zeros_like` | tensor input | ## Not Included - Quantized TFLite `REVERSE_SEQUENCE` support. - A runtime DPS packed implementation for `topi.reverse_sequence`. - Changes to TOPI compute semantics. - ONNX `ReverseSequence` importer support. ## Tests The tests cover both the TFLite frontend fixtures and the new Relax op: | Test | Coverage | |---|---| | `test_squeeze` | imports TFLite `SQUEEZE` to Relax `squeeze` | | `test_reverse_sequence` | imports TFLite `REVERSE_SEQUENCE` to `R.reverse_sequence`, avoids the old TOPI DPS packed call, compiles, and runs through VM | | `test_unpack` | imports TFLite `UNPACK` as multi-output Relax tuple handling | | `test_zeros_like` | imports TFLite `ZEROS_LIKE` to Relax `zeros_like` | | `test_op_correctness` | `relax.op.reverse_sequence(...).op` resolves to `relax.reverse_sequence` | | `test_reverse_sequence_infer_ty` | static shape, unknown dtype, unknown ndim, symbolic shape, and vdevice propagation | | `test_reverse_sequence_infer_ty_wrong_inputs` | non-tensor `seq_lengths`, wrong rank, wrong dtype, invalid axes, and static batch mismatch | | `test_reverse_sequence` in `test_transform_legalize_ops_manipulate.py` | `LegalizeOps` emits `R.call_tir` and exercises `seq_axis=0, batch_axis=1` | Local validation: ```bash python -m pytest tests/python/relax/test_op_manipulate.py \ -k reverse_sequence -q python -m pytest tests/python/relax/test_transform_legalize_ops_manipulate.py \ -k reverse_sequence -q python -m pytest --noconftest tests/python/relax/test_frontend_tflite.py \ -k "reverse_sequence or squeeze or unpack or zeros_like" -q ``` Result: ```text cmake build: passed py_compile: passed ruff format --check: 9 files already formatted ruff check: All checks passed clang-format --dry-run --Werror: passed pre-commit run --files: passed test_op_manipulate.py -k reverse_sequence: 3 passed test_transform_legalize_ops_manipulate.py -k reverse_sequence: 1 passed test_frontend_tflite.py -k "reverse_sequence or squeeze or unpack or zeros_like": 4 passed ``` ## References - Issue #18971: TFLite non-quantized operator unit-test coverage - TFLite `REVERSE_SEQUENCE` builtin semantics
1 parent 4a70d96 commit 4b0a039

13 files changed

Lines changed: 569 additions & 5 deletions

File tree

include/tvm/relax/attrs/manipulate.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,23 @@ struct FlipAttrs : public AttrsNode {
195195
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, AttrsNode);
196196
}; // struct FlipAttrs
197197

198+
/*! \brief Attributes used in reverse_sequence operators */
199+
struct ReverseSequenceAttrs : public AttrsNode {
200+
int64_t seq_axis;
201+
int64_t batch_axis;
202+
203+
static void RegisterReflection() {
204+
namespace refl = tvm::ffi::reflection;
205+
refl::ObjectDef<ReverseSequenceAttrs>()
206+
.def_ro("seq_axis", &ReverseSequenceAttrs::seq_axis,
207+
"The axis along which to reverse variable length slices.")
208+
.def_ro("batch_axis", &ReverseSequenceAttrs::batch_axis,
209+
"The axis that indexes the batch.");
210+
}
211+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ReverseSequenceAttrs", ReverseSequenceAttrs,
212+
AttrsNode);
213+
}; // struct ReverseSequenceAttrs
214+
198215
/*! \brief Attributes used in gather_elements operators */
199216
struct GatherElementsAttrs : public AttrsNode {
200217
int64_t axis;

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5614,10 +5614,10 @@ def convert_unpack(self, op):
56145614
else:
56155615
splitted = relax.op.split(in_expr, indices_or_sections=num_unpacks, axis=unpack_axis)
56165616
squeezed = relax.Tuple(
5617-
relax.Tuple(
5618-
[_op.squeeze(split_item, axis=squeeze_axis) for split_item in splitted]
5619-
),
5620-
len(splitted),
5617+
[
5618+
_op.squeeze(relax.TupleGetItem(splitted, i), axis=squeeze_axis)
5619+
for i in range(num_unpacks)
5620+
]
56215621
)
56225622

56235623
return squeezed

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
permute_dims,
109109
repeat,
110110
reshape,
111+
reverse_sequence,
111112
scatter_elements,
112113
scatter_nd,
113114
slice_scatter,

python/tvm/relax/op/manipulate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,31 @@ def flip(data, axis):
460460
return _ffi_api.flip(data, axis) # type: ignore
461461

462462

463+
def reverse_sequence(data: Expr, seq_lengths: Expr, seq_axis: int = 1, batch_axis: int = 0) -> Expr:
464+
"""Reverses variable length slices.
465+
466+
Parameters
467+
----------
468+
data : relax.Expr
469+
The input tensor.
470+
471+
seq_lengths : relax.Expr
472+
A 1-D tensor containing sequence lengths for each batch.
473+
474+
seq_axis : int
475+
The axis along which to reverse variable length slices.
476+
477+
batch_axis : int
478+
The axis that indexes the batch.
479+
480+
Returns
481+
-------
482+
ret : relax.Expr
483+
The computed result.
484+
"""
485+
return _ffi_api.reverse_sequence(data, seq_lengths, seq_axis, batch_axis) # type: ignore
486+
487+
463488
def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
464489
"""Gather elements from data according to indices along the specified axis.
465490

python/tvm/relax/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ class FlipAttrs(Attrs):
206206
"""Attributes for flip operator"""
207207

208208

209+
@tvm_ffi.register_object("relax.attrs.ReverseSequenceAttrs")
210+
class ReverseSequenceAttrs(Attrs):
211+
"""Attributes for reverse_sequence operator"""
212+
213+
209214
@tvm_ffi.register_object("relax.attrs.PadAttrs")
210215
class PadAttrs(Attrs):
211216
"""Attributes used in pad operator"""

python/tvm/relax/script/builder/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
quantize,
154154
repeat,
155155
reshape,
156+
reverse_sequence,
156157
right_shift,
157158
round,
158159
rsqrt,
@@ -941,6 +942,7 @@ def dtype(value: py_str | DataType) -> Expr:
941942
"quantize",
942943
"repeat",
943944
"reshape",
945+
"reverse_sequence",
944946
"rewriter",
945947
"right_shift",
946948
"rocm",

python/tvm/relax/transform/legalize_ops/manipulate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,18 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
170170
return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))
171171

172172

173+
@register_legalize("relax.reverse_sequence")
174+
def _reverse_sequence(bb: BlockBuilder, call: Call) -> Expr:
175+
return bb.call_te(
176+
topi.reverse_sequence,
177+
call.args[0],
178+
call.args[1],
179+
int(call.attrs.seq_axis),
180+
int(call.attrs.batch_axis),
181+
primfunc_name_hint="reverse_sequence",
182+
)
183+
184+
173185
@register_legalize("relax.gather_elements")
174186
def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
175187
return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis), call.args[1])

src/relax/op/tensor/manipulate.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
5151
RepeatAttrs::RegisterReflection();
5252
TileAttrs::RegisterReflection();
5353
FlipAttrs::RegisterReflection();
54+
ReverseSequenceAttrs::RegisterReflection();
5455
GatherElementsAttrs::RegisterReflection();
5556
GatherNDAttrs::RegisterReflection();
5657
IndexPutAttrs::RegisterReflection();
@@ -2071,6 +2072,96 @@ TVM_REGISTER_OP("relax.flip")
20712072
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutFlip)
20722073
.set_attr<bool>("FPurity", true);
20732074

2075+
/* relax.reverse_sequence */
2076+
2077+
Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t batch_axis) {
2078+
auto attrs = ffi::make_object<ReverseSequenceAttrs>();
2079+
attrs->seq_axis = seq_axis;
2080+
attrs->batch_axis = batch_axis;
2081+
static const Op& op = Op::Get("relax.reverse_sequence");
2082+
return Call(op, {std::move(data), std::move(seq_lengths)}, Attrs{attrs}, {});
2083+
}
2084+
2085+
TVM_FFI_STATIC_INIT_BLOCK() {
2086+
namespace refl = tvm::ffi::reflection;
2087+
refl::GlobalDef().def("relax.op.reverse_sequence", reverse_sequence);
2088+
}
2089+
2090+
Type InferTypeReverseSequence(const Call& call, const BlockBuilder& ctx) {
2091+
if (call->args.size() != 2) {
2092+
TVM_FFI_VISIT_THROW(ValueError, call) << "ReverseSequence op should take 2 arguments";
2093+
}
2094+
TensorType data_ty = GetInputTensorType(call, 0, ctx);
2095+
TensorType seq_lengths_ty = GetInputTensorType(call, 1, ctx);
2096+
2097+
if (!seq_lengths_ty->IsUnknownNdim() && seq_lengths_ty->ndim != 1) {
2098+
TVM_FFI_VISIT_THROW(ValueError, call)
2099+
<< "ReverseSequence requires seq_lengths to be 1-D. However, seq_lengths has ndim "
2100+
<< seq_lengths_ty->ndim;
2101+
}
2102+
if (!seq_lengths_ty->dtype.is_void() && !seq_lengths_ty->dtype.is_int()) {
2103+
TVM_FFI_VISIT_THROW(ValueError, call)
2104+
<< "ReverseSequence requires seq_lengths to have dtype int32 or int64. However, "
2105+
"seq_lengths has dtype "
2106+
<< seq_lengths_ty->dtype;
2107+
}
2108+
if (seq_lengths_ty->dtype.is_int() && seq_lengths_ty->dtype.bits() != 32 &&
2109+
seq_lengths_ty->dtype.bits() != 64) {
2110+
TVM_FFI_VISIT_THROW(ValueError, call)
2111+
<< "ReverseSequence requires seq_lengths to have dtype int32 or int64. However, "
2112+
"seq_lengths has dtype "
2113+
<< seq_lengths_ty->dtype;
2114+
}
2115+
2116+
const auto* attrs = call->attrs.as<ReverseSequenceAttrs>();
2117+
int64_t seq_axis = attrs->seq_axis;
2118+
int64_t batch_axis = attrs->batch_axis;
2119+
if (!data_ty->IsUnknownNdim()) {
2120+
int ndim = data_ty->ndim;
2121+
auto check_axis = [&](int64_t axis, ffi::String axis_name) {
2122+
if (axis < -ndim || axis >= ndim) {
2123+
TVM_FFI_VISIT_THROW(ValueError, call)
2124+
<< "ReverseSequence requires " << axis_name
2125+
<< " to belong to range [-ndim, ndim). However, the axis is " << axis
2126+
<< ", while ndim is " << ndim;
2127+
}
2128+
};
2129+
check_axis(seq_axis, "seq_axis");
2130+
check_axis(batch_axis, "batch_axis");
2131+
2132+
if (batch_axis < 0) {
2133+
batch_axis += ndim;
2134+
}
2135+
2136+
if (data_ty->shape.defined() && seq_lengths_ty->shape.defined()) {
2137+
const auto* data_shape_ty = GetTypeAs<ShapeTypeNode>(data_ty->shape.value());
2138+
const auto* seq_lengths_shape_ty = GetTypeAs<ShapeTypeNode>(seq_lengths_ty->shape.value());
2139+
if (data_shape_ty != nullptr && seq_lengths_shape_ty != nullptr &&
2140+
data_shape_ty->values.defined() && seq_lengths_shape_ty->values.defined()) {
2141+
PrimExpr batch_extent = data_shape_ty->values.value()[batch_axis];
2142+
PrimExpr seq_lengths_extent = seq_lengths_shape_ty->values.value()[0];
2143+
if (ctx->GetAnalyzer()->CanProve(seq_lengths_extent != batch_extent)) {
2144+
TVM_FFI_VISIT_THROW(ValueError, call)
2145+
<< "ReverseSequence requires seq_lengths.shape[0] to equal the batch axis extent. "
2146+
"However, seq_lengths.shape[0] is "
2147+
<< seq_lengths_extent << ", while data.shape[" << batch_axis << "] is "
2148+
<< batch_extent;
2149+
}
2150+
}
2151+
}
2152+
}
2153+
2154+
return data_ty;
2155+
}
2156+
2157+
TVM_REGISTER_OP("relax.reverse_sequence")
2158+
.set_attrs_type<ReverseSequenceAttrs>()
2159+
.set_num_inputs(2)
2160+
.add_argument("data", "Tensor", "The input tensor.")
2161+
.add_argument("seq_lengths", "Tensor", "The sequence length tensor.")
2162+
.set_attr<FInferType>("FInferType", InferTypeReverseSequence)
2163+
.set_attr<bool>("FPurity", true);
2164+
20742165
/* relax.gather_elements */
20752166

20762167
Expr gather_elements(Expr data, Expr indices, int axis) {

src/relax/op/tensor/manipulate.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ Expr tile(Expr data, ffi::Array<int64_t> repeats);
181181
*/
182182
Expr flip(Expr data, int64_t axis);
183183

184+
/*!
185+
* \brief Reverses variable length slices along seq_axis.
186+
* \param data The input tensor.
187+
* \param seq_lengths A 1-D tensor containing sequence lengths for each batch.
188+
* \param seq_axis The axis along which to reverse.
189+
* \param batch_axis The axis that indexes the batch.
190+
* \return The computed result.
191+
*/
192+
Expr reverse_sequence(Expr data, Expr seq_lengths, int64_t seq_axis, int64_t batch_axis);
193+
184194
/*!
185195
* \brief Gather elements from a tensor using indices.
186196
* \param data The input tensor.

src/topi/transform.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
5858
})
5959
.def_packed("topi.reverse_sequence",
6060
[](ffi::PackedArgs args, ffi::Any* rv) {
61+
int batch_axis = args.size() >= 4 ? args[3].cast<int>() : 0;
6162
*rv = reverse_sequence(args[0].cast<te::Tensor>(), args[1].cast<te::Tensor>(),
62-
args[2].cast<int>());
63+
args[2].cast<int>(), batch_axis);
6364
})
6465
.def_packed("topi.reshape",
6566
[](ffi::PackedArgs args, ffi::Any* rv) {

0 commit comments

Comments
 (0)