Skip to content

Commit 8dfc70e

Browse files
committed
[Relax][Frontend][TFLite] Support dynamic RANGE scalar bounds
Previously convert_range raised OpNotImplemented when start/limit/delta were runtime (non-constant) scalar tensors, so any TFLite model that computes RANGE bounds at runtime failed to import. This was one of the "partial implementation" items tracked in #19412. relax.op.arange only takes compile-time PrimExpr bounds, and its struct-info length formula (InferTypeArange) has no negative-step branch, so feeding symbolic bounds straight in would mis-declare descending ranges. Instead, compute the element count in-graph and lift it to one symbolic output dimension via relax.op.tensor_to_shape + match_cast (the bridge already used by _get_shape_expr_from_tensor), so the declared and runtime lengths match by construction. Values are rebuilt as arange(0, count) * delta + start. One unified path covers both dtypes: - int: count = -floor_divide(start - limit, delta), exact and sign-agnostic (no float-precision loss), equal to ceil((limit - start) / delta); - float: count = ceil((limit - start) / delta). No new Relax op is needed. Replace the "not supported" test with a compile-and-run test covering ascending/descending integer and float dynamic bounds.
1 parent 9808108 commit 8dfc70e

2 files changed

Lines changed: 110 additions & 31 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,7 +1542,13 @@ def convert_tanh(self, op):
15421542
return out
15431543

15441544
def convert_range(self, op):
1545-
"""Convert TFLite Range"""
1545+
"""Convert TFLite Range.
1546+
1547+
Constant bounds lower directly to ``relax.op.arange``. Runtime (dynamic)
1548+
scalar bounds are handled by computing the element count in-graph,
1549+
lifting it to a symbolic output dimension, and rebuilding the values as
1550+
``arange(0, count) * delta + start`` (see ``_convert_dynamic_range``).
1551+
"""
15461552

15471553
from tflite.TensorType import TensorType
15481554

@@ -1551,38 +1557,86 @@ def convert_range(self, op):
15511557

15521558
start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
15531559

1554-
def get_scalar_value(tensor):
1560+
# out type inference
1561+
if delta.tensor.Type() == TensorType.FLOAT32:
1562+
out_type = self.get_tensor_type_str(delta.tensor.Type())
1563+
else:
1564+
out_type = self.get_tensor_type_str(start.tensor.Type())
1565+
1566+
def is_dynamic(tensor):
1567+
return self.has_expr(tensor.tensor_idx) and not isinstance(
1568+
self.get_expr(tensor.tensor_idx), relax.Constant
1569+
)
1570+
1571+
def static_scalar(tensor):
15551572
if self.has_expr(tensor.tensor_idx):
1556-
expr = self.get_expr(tensor.tensor_idx)
1557-
if isinstance(expr, relax.Constant):
1558-
value = expr.data.numpy()
1559-
else:
1560-
# relax.op.arange currently expects scalar-like values here.
1561-
# Keep dynamic scalar RANGE explicit until frontend support is added.
1562-
raise tvm.error.OpNotImplemented(
1563-
"TFLite RANGE with dynamic scalar inputs is not supported in"
1564-
"Relax frontend yet."
1565-
)
1573+
value = self.get_expr(tensor.tensor_idx).data.numpy()
15661574
else:
15671575
value = self.get_tensor_value(tensor)
1568-
15691576
# TFLite RANGE operands are scalar tensors in the flatbuffer.
15701577
assert value.size == 1, "RANGE scalar input must have exactly one element"
15711578
return value.item()
15721579

1573-
start_value = get_scalar_value(start)
1574-
limit_value = get_scalar_value(limit)
1575-
delta_value = get_scalar_value(delta)
1580+
if not (is_dynamic(start) or is_dynamic(limit) or is_dynamic(delta)):
1581+
return relax.op.arange(
1582+
static_scalar(start), static_scalar(limit), static_scalar(delta), out_type
1583+
)
1584+
1585+
return self._convert_dynamic_range(start, limit, delta, out_type)
15761586

1577-
# out type inference
1578-
if delta.tensor.Type() == TensorType.FLOAT32:
1579-
out_type = self.get_tensor_type_str(delta.tensor.Type())
1580-
else:
1581-
out_type = self.get_tensor_type_str(start.tensor.Type())
1587+
def _scalar_tensor_to_dim(self, expr, name):
1588+
"""Lift a runtime scalar Relax expr to a symbolic ``tirx.Var`` dimension.
15821589
1583-
out = relax.op.arange(start_value, limit_value, delta_value, out_type)
1590+
Mirrors the ``tensor_to_shape`` + ``match_cast`` bridge used by
1591+
``_get_shape_expr_from_tensor`` so a data-dependent scalar can be used as
1592+
a ``PrimExpr`` (e.g. an output length). The scalar is cast to int64 first.
1593+
"""
1594+
expr = self.bb.normalize(relax.op.astype(expr, "int64"))
1595+
expr = self.bb.normalize(relax.op.reshape(expr, (1,)))
1596+
expr = self.bb.match_cast(expr, relax.TensorType([1], "int64"))
1597+
shape_var = self.bb.emit(relax.op.tensor_to_shape(expr))
1598+
dim = tirx.Var(name, "int64")
1599+
self.bb.match_cast(shape_var, relax.ShapeType([dim]))
1600+
return dim
1601+
1602+
def _convert_dynamic_range(self, start, limit, delta, out_type):
1603+
"""RANGE with dynamic (runtime) scalar bounds, for int and float dtypes.
1604+
1605+
``relax.op.arange`` only accepts compile-time ``PrimExpr`` bounds, and its
1606+
struct-info length formula lacks a negative-step branch, so feeding
1607+
symbolic bounds directly would mis-declare descending ranges. Instead the
1608+
element count ``max(0, ceil((limit - start) / delta))`` is computed
1609+
in-graph and lifted to one symbolic dimension ``L`` (so the declared and
1610+
runtime lengths match by construction); values are rebuilt as
1611+
``arange(0, L) * delta + start``.
1612+
"""
1613+
# int ranges work in int64 for an exact, sign-agnostic count; float
1614+
# ranges work in the output float dtype.
1615+
work_type = out_type if out_type.startswith("float") else "int64"
15841616

1585-
return out
1617+
def scalar_expr(tensor):
1618+
return self.bb.normalize(relax.op.astype(self.get_tensor_expr(tensor), work_type))
1619+
1620+
start_e = scalar_expr(start)
1621+
limit_e = scalar_expr(limit)
1622+
delta_e = scalar_expr(delta)
1623+
1624+
if work_type.startswith("float"):
1625+
count = relax.op.ceil(relax.op.divide(relax.op.subtract(limit_e, start_e), delta_e))
1626+
else:
1627+
# ceil((limit - start) / delta) == -floordiv(start - limit, delta),
1628+
# which stays exact and handles negative delta without a float cast.
1629+
count = relax.op.negative(
1630+
relax.op.floor_divide(relax.op.subtract(start_e, limit_e), delta_e)
1631+
)
1632+
count = relax.op.maximum(count, relax.const(0, work_type))
1633+
dim = self._scalar_tensor_to_dim(count, "range_len")
1634+
1635+
positions = self.bb.normalize(
1636+
relax.op.astype(relax.op.arange(0, dim, 1, "int64"), work_type)
1637+
)
1638+
out = relax.op.add(relax.op.multiply(positions, delta_e), start_e)
1639+
return out if work_type == out_type else relax.op.astype(out, out_type)
15861640

15871641
def convert_rank(self, op):
15881642
"""Convert TFLite RANK."""

tests/python/relax/test_frontend_tflite.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -717,22 +717,47 @@ def func(self):
717717
verify(Range)
718718

719719

720-
def test_range_dynamic_scalar_inputs_not_supported():
721-
"""RANGE conversion currently rejects dynamic scalar inputs."""
720+
@pytest.mark.parametrize(
721+
"start, limit, delta, dtype",
722+
[
723+
(2, 13, 3, tf.int32),
724+
(8, 0, -2, tf.int32),
725+
(0.0, 1.0, 0.25, tf.float32),
726+
(1.0, -1.0, -0.5, tf.float32),
727+
],
728+
)
729+
def test_range_dynamic_scalar_inputs(start, limit, delta, dtype):
730+
"""RANGE lowers dynamic (runtime) scalar bounds for both int and float dtypes."""
722731

723732
class RangeDynamic(tf.Module):
724733
@tf.function(
725734
input_signature=[
726-
tf.TensorSpec(shape=(), dtype=tf.int32),
727-
tf.TensorSpec(shape=(), dtype=tf.int32),
728-
tf.TensorSpec(shape=(), dtype=tf.int32),
735+
tf.TensorSpec(shape=(), dtype=dtype),
736+
tf.TensorSpec(shape=(), dtype=dtype),
737+
tf.TensorSpec(shape=(), dtype=dtype),
729738
]
730739
)
731740
def func(self, start, limit, delta):
732-
return tf.range(start, limit, delta, dtype=tf.int32)
741+
return tf.range(start, limit, delta)
742+
743+
cf = RangeDynamic().func.get_concrete_function()
744+
mod = _get_mod_from_cfunc(cf)
745+
746+
np_dtype = np.int32 if dtype == tf.int32 else np.float32
747+
inputs = [
748+
np.array(start, np_dtype),
749+
np.array(limit, np_dtype),
750+
np.array(delta, np_dtype),
751+
]
752+
753+
ex = tvm.compile(mod, tvm.target.Target("llvm"))
754+
vm = relax.VirtualMachine(ex, tvm.cpu())
755+
vm.set_input("main", *inputs)
756+
vm.invoke_stateful("main")
757+
tvm_out = vm.get_outputs("main").numpy()
733758

734-
with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"):
735-
verify(RangeDynamic)
759+
expected = np.arange(start, limit, delta, dtype=np_dtype)
760+
np.testing.assert_allclose(tvm_out, expected, rtol=1e-5, atol=1e-5)
736761

737762

738763
def test_tile_ir():

0 commit comments

Comments
 (0)