Skip to content

Commit 545c332

Browse files
authored
[Relax][Frontend][TFLite] Fix bool REDUCE_ANY/REDUCE_ALL compile failure (#19415)
## Problem #19413 registered `REDUCE_ANY` / `REDUCE_ALL` as `_convert_reduce` with `relax.op.max` / `relax.op.min`. These TFLite ops are bool-only (per TFL op schema: `TFL_ReduceAnyOp` / `TFL_ReduceAllOp` take and return `TFL_BoolTensor`), and `relax.op.max` / `relax.op.min` are not defined on bool, so any real model using these ops fails at compile time with: ``` Cannot decide min_value for type bool Cannot decide max_value for type bool ``` The existing structural-equality test passed because it never attempted to compile the converted module (E2E is gated on `CI_ENV_NIGHTLY`). ## Fix Introduce a dedicated `_convert_reduce_bool` handler that casts the input to int8, reduces with max/min, and casts back to bool. Update the test to compile the expected module so this lowering is exercised without `CI_ENV_NIGHTLY`. ## Testing Verified compile + VM-run (TF converter → Relax → LLVM) across the full shape / axes / keepdims matrix from `test_reduction_bool_ops`: 12 cases, all PASS. Follow-up to #19413.
1 parent ff127c4 commit 545c332

2 files changed

Lines changed: 52 additions & 3 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def __init__(self, model, subgraph, exp_tab, ctx):
195195
"PRELU": self.convert_prelu,
196196
"RANGE": self.convert_range,
197197
"QUANTIZE": self.convert_quantize,
198-
"REDUCE_ALL": functools.partial(self._convert_reduce, relax_op=_op.min),
199-
"REDUCE_ANY": functools.partial(self._convert_reduce, relax_op=_op.max),
198+
"REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min),
199+
"REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max),
200200
"REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max),
201201
"REDUCE_MIN": functools.partial(self._convert_reduce, relax_op=_op.min),
202202
"REDUCE_PROD": functools.partial(self._convert_reduce, relax_op=_op.prod),
@@ -1787,6 +1787,36 @@ def _convert_reduce(self, relax_op, op):
17871787

17881788
return out
17891789

1790+
def _convert_reduce_bool(self, relax_op, op):
1791+
"""Convert TFLite REDUCE_ANY / REDUCE_ALL (bool-only ops).
1792+
1793+
Relax max/min are undefined on bool, so cast through int8.
1794+
"""
1795+
from tflite.BuiltinOptions import BuiltinOptions
1796+
from tflite.ReducerOptions import ReducerOptions
1797+
1798+
input_tensors = self.get_input_tensors(op)
1799+
assert len(input_tensors) == 2, "input tensors length should be 2"
1800+
1801+
input_tensor = input_tensors[0]
1802+
in_expr = self.get_expr(input_tensor.tensor_idx)
1803+
1804+
axis_value = self.get_tensor_value(input_tensors[1])
1805+
axis = tuple(axis_value) if len(axis_value.shape) > 0 else tuple((axis_value.item(),))
1806+
1807+
if op.BuiltinOptionsType():
1808+
assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
1809+
reduce_options = ReducerOptions()
1810+
op_options = op.BuiltinOptions()
1811+
reduce_options.Init(op_options.Bytes, op_options.Pos)
1812+
keep_dims = reduce_options.KeepDims()
1813+
else:
1814+
keep_dims = False
1815+
1816+
in_expr = relax.op.astype(in_expr, "int8")
1817+
out = relax_op(in_expr, axis, keep_dims)
1818+
return relax.op.astype(out, "bool")
1819+
17901820
def _convert_arg_min_max(self, op, relax_op):
17911821
"""Generic method converting TFLite arg_min_max"""
17921822

tests/python/relax/test_frontend_tflite.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1977,6 +1977,22 @@ def func(self, x):
19771977
verify(ReduceModule, expected)
19781978

19791979

1980+
def _make_reduce_bool_expected(relax_op, input_shape, axes, keepdims):
1981+
if axes is None:
1982+
axes = list(range(len(input_shape)))
1983+
bb = relax.BlockBuilder()
1984+
x = relax.Var("x", relax.TensorStructInfo(input_shape, "bool"))
1985+
with bb.function("main", [x]):
1986+
with bb.dataflow():
1987+
cast_in = bb.emit(relax.op.astype(x, "int8"))
1988+
reduced = bb.emit(relax_op(cast_in, axis=axes, keepdims=keepdims))
1989+
gv = bb.emit_output(relax.op.astype(reduced, "bool"))
1990+
bb.emit_func_output(gv)
1991+
mod = bb.get()
1992+
mod["main"] = mod["main"].with_attr("num_input", 1)
1993+
return mod
1994+
1995+
19801996
@pytest.mark.parametrize(
19811997
"tf_op, relax_op",
19821998
[
@@ -2002,9 +2018,12 @@ class ReduceBoolModule(tf.Module):
20022018
def func(self, x):
20032019
return tf_op(x, axis=axes, keepdims=keepdims)
20042020

2005-
expected = _make_reduce_expected(relax_op, input_shape, axes, keepdims, "bool")
2021+
expected = _make_reduce_bool_expected(relax_op, input_shape, axes, keepdims)
20062022
verify(ReduceBoolModule, expected)
20072023

2024+
# Regression guard: compile to catch a bool max/min lowering path.
2025+
tvm.compile(expected, tvm.target.Target("llvm"))
2026+
20082027

20092028
def test_pad():
20102029
class Pad(tf.Module):

0 commit comments

Comments
 (0)