Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 77 additions & 23 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,7 +1542,13 @@ def convert_tanh(self, op):
return out

def convert_range(self, op):
"""Convert TFLite Range"""
"""Convert TFLite Range.

Constant bounds lower directly to ``relax.op.arange``. Runtime (dynamic)
scalar bounds are handled by computing the element count in-graph,
lifting it to a symbolic output dimension, and rebuilding the values as
``arange(0, count) * delta + start`` (see ``_convert_dynamic_range``).
"""

from tflite.TensorType import TensorType

Expand All @@ -1551,38 +1557,86 @@ def convert_range(self, op):

start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]

def get_scalar_value(tensor):
# out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())

def is_dynamic(tensor):
return self.has_expr(tensor.tensor_idx) and not isinstance(
self.get_expr(tensor.tensor_idx), relax.Constant
)

def static_scalar(tensor):
if self.has_expr(tensor.tensor_idx):
expr = self.get_expr(tensor.tensor_idx)
if isinstance(expr, relax.Constant):
value = expr.data.numpy()
else:
# relax.op.arange currently expects scalar-like values here.
# Keep dynamic scalar RANGE explicit until frontend support is added.
raise tvm.error.OpNotImplemented(
"TFLite RANGE with dynamic scalar inputs is not supported in"
"Relax frontend yet."
)
value = self.get_expr(tensor.tensor_idx).data.numpy()
else:
value = self.get_tensor_value(tensor)

# TFLite RANGE operands are scalar tensors in the flatbuffer.
assert value.size == 1, "RANGE scalar input must have exactly one element"
return value.item()

start_value = get_scalar_value(start)
limit_value = get_scalar_value(limit)
delta_value = get_scalar_value(delta)
if not (is_dynamic(start) or is_dynamic(limit) or is_dynamic(delta)):
return relax.op.arange(
static_scalar(start), static_scalar(limit), static_scalar(delta), out_type
)

return self._convert_dynamic_range(start, limit, delta, out_type)

# out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())
def _scalar_tensor_to_dim(self, expr, name):
"""Lift a runtime scalar Relax expr to a symbolic ``tirx.Var`` dimension.

out = relax.op.arange(start_value, limit_value, delta_value, out_type)
Mirrors the ``tensor_to_shape`` + ``match_cast`` bridge used by
``_get_shape_expr_from_tensor`` so a data-dependent scalar can be used as
a ``PrimExpr`` (e.g. an output length). The scalar is cast to int64 first.
"""
expr = self.bb.normalize(relax.op.astype(expr, "int64"))
expr = self.bb.normalize(relax.op.reshape(expr, (1,)))
expr = self.bb.match_cast(expr, relax.TensorType([1], "int64"))
shape_var = self.bb.emit(relax.op.tensor_to_shape(expr))
dim = tirx.Var(name, "int64")
Comment thread
Aharrypotter marked this conversation as resolved.
self.bb.match_cast(shape_var, relax.ShapeType([dim]))
return dim

def _convert_dynamic_range(self, start, limit, delta, out_type):
"""RANGE with dynamic (runtime) scalar bounds, for int and float dtypes.

``relax.op.arange`` only accepts compile-time ``PrimExpr`` bounds, and its
struct-info length formula lacks a negative-step branch, so feeding
symbolic bounds directly would mis-declare descending ranges. Instead the
element count ``max(0, ceil((limit - start) / delta))`` is computed
in-graph and lifted to one symbolic dimension ``L`` (so the declared and
runtime lengths match by construction); values are rebuilt as
``arange(0, L) * delta + start``.
"""
# int ranges work in int64 for an exact, sign-agnostic count; float
# ranges work in the output float dtype.
work_type = out_type if out_type.startswith("float") else "int64"

return out
def scalar_expr(tensor):
return self.bb.normalize(relax.op.astype(self.get_tensor_expr(tensor), work_type))

start_e = scalar_expr(start)
limit_e = scalar_expr(limit)
delta_e = scalar_expr(delta)

if work_type.startswith("float"):
count = relax.op.ceil(relax.op.divide(relax.op.subtract(limit_e, start_e), delta_e))
else:
# ceil((limit - start) / delta) == -floordiv(start - limit, delta),
# which stays exact and handles negative delta without a float cast.
count = relax.op.negative(
relax.op.floor_divide(relax.op.subtract(start_e, limit_e), delta_e)
)
count = relax.op.maximum(count, relax.const(0, work_type))
dim = self._scalar_tensor_to_dim(count, "range_len")

positions = self.bb.normalize(
relax.op.astype(relax.op.arange(0, dim, 1, "int64"), work_type)
)
out = relax.op.add(relax.op.multiply(positions, delta_e), start_e)
return out if work_type == out_type else relax.op.astype(out, out_type)

def convert_rank(self, op):
"""Convert TFLite RANK."""
Expand Down
41 changes: 33 additions & 8 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,22 +717,47 @@ def func(self):
verify(Range)


def test_range_dynamic_scalar_inputs_not_supported():
"""RANGE conversion currently rejects dynamic scalar inputs."""
@pytest.mark.parametrize(
"start, limit, delta, dtype",
[
(2, 13, 3, tf.int32),
(8, 0, -2, tf.int32),
(0.0, 1.0, 0.25, tf.float32),
(1.0, -1.0, -0.5, tf.float32),
],
)
def test_range_dynamic_scalar_inputs(start, limit, delta, dtype):
"""RANGE lowers dynamic (runtime) scalar bounds for both int and float dtypes."""

class RangeDynamic(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=dtype),
tf.TensorSpec(shape=(), dtype=dtype),
tf.TensorSpec(shape=(), dtype=dtype),
]
)
def func(self, start, limit, delta):
return tf.range(start, limit, delta, dtype=tf.int32)
return tf.range(start, limit, delta)

cf = RangeDynamic().func.get_concrete_function()
mod = _get_mod_from_cfunc(cf)

np_dtype = np.int32 if dtype == tf.int32 else np.float32
inputs = [
np.array(start, np_dtype),
np.array(limit, np_dtype),
np.array(delta, np_dtype),
]

ex = tvm.compile(mod, tvm.target.Target("llvm"))
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *inputs)
vm.invoke_stateful("main")
tvm_out = vm.get_outputs("main").numpy()

with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"):
verify(RangeDynamic)
expected = np.arange(start, limit, delta, dtype=np_dtype)
np.testing.assert_allclose(tvm_out, expected, rtol=1e-5, atol=1e-5)


def test_tile_ir():
Expand Down
Loading