diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 44e977397307..0b7393088de4 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -1542,7 +1542,13 @@ def convert_tanh(self, op): return out def convert_range(self, op): - """Convert TFLite Range""" + """Convert TFLite Range. + + Constant bounds lower directly to ``relax.op.arange``. Runtime (dynamic) + scalar bounds are handled by computing the element count in-graph, + lifting it to a symbolic output dimension, and rebuilding the values as + ``arange(0, count) * delta + start`` (see ``_convert_dynamic_range``). + """ from tflite.TensorType import TensorType @@ -1551,38 +1557,86 @@ def convert_range(self, op): start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] - def get_scalar_value(tensor): + # out type inference + if delta.tensor.Type() == TensorType.FLOAT32: + out_type = self.get_tensor_type_str(delta.tensor.Type()) + else: + out_type = self.get_tensor_type_str(start.tensor.Type()) + + def is_dynamic(tensor): + return self.has_expr(tensor.tensor_idx) and not isinstance( + self.get_expr(tensor.tensor_idx), relax.Constant + ) + + def static_scalar(tensor): if self.has_expr(tensor.tensor_idx): - expr = self.get_expr(tensor.tensor_idx) - if isinstance(expr, relax.Constant): - value = expr.data.numpy() - else: - # relax.op.arange currently expects scalar-like values here. - # Keep dynamic scalar RANGE explicit until frontend support is added. - raise tvm.error.OpNotImplemented( - "TFLite RANGE with dynamic scalar inputs is not supported in" - "Relax frontend yet." - ) + value = self.get_expr(tensor.tensor_idx).data.numpy() else: value = self.get_tensor_value(tensor) - # TFLite RANGE operands are scalar tensors in the flatbuffer. assert value.size == 1, "RANGE scalar input must have exactly one element" return value.item() - start_value = get_scalar_value(start) - limit_value = get_scalar_value(limit) - delta_value = get_scalar_value(delta) + if not (is_dynamic(start) or is_dynamic(limit) or is_dynamic(delta)): + return relax.op.arange( + static_scalar(start), static_scalar(limit), static_scalar(delta), out_type + ) + + return self._convert_dynamic_range(start, limit, delta, out_type) - # out type inference - if delta.tensor.Type() == TensorType.FLOAT32: - out_type = self.get_tensor_type_str(delta.tensor.Type()) - else: - out_type = self.get_tensor_type_str(start.tensor.Type()) + def _scalar_tensor_to_dim(self, expr, name): + """Lift a runtime scalar Relax expr to a symbolic ``tirx.Var`` dimension. - out = relax.op.arange(start_value, limit_value, delta_value, out_type) + Mirrors the ``tensor_to_shape`` + ``match_cast`` bridge used by + ``_get_shape_expr_from_tensor`` so a data-dependent scalar can be used as + a ``PrimExpr`` (e.g. an output length). The scalar is cast to int64 first. + """ + expr = self.bb.normalize(relax.op.astype(expr, "int64")) + expr = self.bb.normalize(relax.op.reshape(expr, (1,))) + expr = self.bb.match_cast(expr, relax.TensorType([1], "int64")) + shape_var = self.bb.emit(relax.op.tensor_to_shape(expr)) + dim = tirx.Var(name, "int64") + self.bb.match_cast(shape_var, relax.ShapeType([dim])) + return dim + + def _convert_dynamic_range(self, start, limit, delta, out_type): + """RANGE with dynamic (runtime) scalar bounds, for int and float dtypes. + + ``relax.op.arange`` only accepts compile-time ``PrimExpr`` bounds, and its + struct-info length formula lacks a negative-step branch, so feeding + symbolic bounds directly would mis-declare descending ranges. Instead the + element count ``max(0, ceil((limit - start) / delta))`` is computed + in-graph and lifted to one symbolic dimension ``L`` (so the declared and + runtime lengths match by construction); values are rebuilt as + ``arange(0, L) * delta + start``. + """ + # int ranges work in int64 for an exact, sign-agnostic count; float + # ranges work in the output float dtype. + work_type = out_type if out_type.startswith("float") else "int64" - return out + def scalar_expr(tensor): + return self.bb.normalize(relax.op.astype(self.get_tensor_expr(tensor), work_type)) + + start_e = scalar_expr(start) + limit_e = scalar_expr(limit) + delta_e = scalar_expr(delta) + + if work_type.startswith("float"): + count = relax.op.ceil(relax.op.divide(relax.op.subtract(limit_e, start_e), delta_e)) + else: + # ceil((limit - start) / delta) == -floordiv(start - limit, delta), + # which stays exact and handles negative delta without a float cast. + count = relax.op.negative( + relax.op.floor_divide(relax.op.subtract(start_e, limit_e), delta_e) + ) + count = relax.op.maximum(count, relax.const(0, work_type)) + dim = self._scalar_tensor_to_dim(count, "range_len") + + positions = self.bb.normalize( + relax.op.astype(relax.op.arange(0, dim, 1, "int64"), work_type) + ) + out = relax.op.add(relax.op.multiply(positions, delta_e), start_e) + return out if work_type == out_type else relax.op.astype(out, out_type) def convert_rank(self, op): """Convert TFLite RANK.""" diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 590cc4ac459f..2d0f497e94c1 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -717,22 +717,47 @@ def func(self): verify(Range) -def test_range_dynamic_scalar_inputs_not_supported(): - """RANGE conversion currently rejects dynamic scalar inputs.""" +@pytest.mark.parametrize( + "start, limit, delta, dtype", + [ + (2, 13, 3, tf.int32), + (8, 0, -2, tf.int32), + (0.0, 1.0, 0.25, tf.float32), + (1.0, -1.0, -0.5, tf.float32), + ], +) +def test_range_dynamic_scalar_inputs(start, limit, delta, dtype): + """RANGE lowers dynamic (runtime) scalar bounds for both int and float dtypes.""" class RangeDynamic(tf.Module): @tf.function( input_signature=[ - tf.TensorSpec(shape=(), dtype=tf.int32), - tf.TensorSpec(shape=(), dtype=tf.int32), - tf.TensorSpec(shape=(), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=dtype), + tf.TensorSpec(shape=(), dtype=dtype), + tf.TensorSpec(shape=(), dtype=dtype), ] ) def func(self, start, limit, delta): - return tf.range(start, limit, delta, dtype=tf.int32) + return tf.range(start, limit, delta) + + cf = RangeDynamic().func.get_concrete_function() + mod = _get_mod_from_cfunc(cf) + + np_dtype = np.int32 if dtype == tf.int32 else np.float32 + inputs = [ + np.array(start, np_dtype), + np.array(limit, np_dtype), + np.array(delta, np_dtype), + ] + + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main", *inputs) + vm.invoke_stateful("main") + tvm_out = vm.get_outputs("main").numpy() - with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"): - verify(RangeDynamic) + expected = np.arange(start, limit, delta, dtype=np_dtype) + np.testing.assert_allclose(tvm_out, expected, rtol=1e-5, atol=1e-5) def test_tile_ir():