Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3349,7 +3349,13 @@ def _const_1d(values, dtype="int64"):
return self.bb.normalize(relax.op.dynamic_strided_slice(operand, begin, end, strides))

def _convert_stablehlo_dynamic_update_slice(self, op):
"""Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax for static starts."""
"""Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax.

Lowers to ``relax.op.scatter_nd``. Constant start indices build the index
grid at compile time; runtime (dynamic) start indices build it in-graph
with ``arange`` + broadcast, clamping each start to
``[0, operand_dim - update_dim]`` per StableHLO semantics.
"""
input_tensors = self.get_input_tensors(op)
# operand + update + N start-index scalars
assert len(input_tensors) >= 3, "input tensors length should be >= 3"
Expand All @@ -3368,11 +3374,21 @@ def _convert_stablehlo_dynamic_update_slice(self, op):
"STABLEHLO_DYNAMIC_UPDATE_SLICE requires operand, update, "
"and start-index ranks to match"
)
for dim, size in zip(operand_shape, update_shape):
if size > dim:
raise tvm.error.OpNotImplemented(
"STABLEHLO_DYNAMIC_UPDATE_SLICE update shape must be smaller than "
"or equal to operand shape for all dimensions"
)

operand = self.get_tensor_expr(operand_tensor)
update = self.get_tensor_expr(update_tensor)
Comment thread
Aharrypotter marked this conversation as resolved.

if any(self.has_expr(t.tensor_idx) for t in start_tensors):
raise tvm.error.OpNotImplemented(
"STABLEHLO_DYNAMIC_UPDATE_SLICE with dynamic start indices is not supported"
indices = self._build_dynamic_update_slice_indices(
start_tensors, operand_shape, update_shape, rank
)
return self.bb.normalize(relax.op.scatter_nd(operand, indices, update, "update"))

start_vals = [int(np.asarray(self.get_tensor_value(t)).item()) for t in start_tensors]
for start, size, dim in zip(start_vals, update_shape, operand_shape):
Expand All @@ -3387,11 +3403,37 @@ def _convert_stablehlo_dynamic_update_slice(self, op):
update_indices[axis] += start
update_indices = np.moveaxis(update_indices, 0, -1)

operand = self.get_tensor_expr(operand_tensor)
update = self.get_tensor_expr(update_tensor)
indices = self.bb.normalize(relax.const(update_indices, dtype="int64"))
return self.bb.normalize(relax.op.scatter_nd(operand, indices, update, "update"))

def _build_dynamic_update_slice_indices(self, start_tensors, operand_shape, update_shape, rank):
"""Build the scatter_nd index grid for runtime DYNAMIC_UPDATE_SLICE starts.

Returns an int64 tensor of shape ``(*update_shape, rank)`` where axis ``a``
holds ``arange(update_shape[a]) + clamp(start[a], 0, operand_dim - update_dim)``,
broadcast over the other axes (StableHLO clamps out-of-range starts).
"""
axis_indices = []
for axis in range(rank):
start_expr = self.bb.normalize(
relax.op.astype(self.get_tensor_expr(start_tensors[axis]), "int64")
)
max_start = operand_shape[axis] - update_shape[axis]
start_expr = relax.op.maximum(start_expr, relax.const(0, "int64"))
start_expr = relax.op.minimum(start_expr, relax.const(max_start, "int64"))

base = relax.op.arange(0, update_shape[axis], 1, "int64")
idx = relax.op.add(base, start_expr)

broadcast_shape = [1] * rank
broadcast_shape[axis] = update_shape[axis]
idx = self.bb.normalize(relax.op.reshape(idx, broadcast_shape))
idx = self.bb.normalize(relax.op.broadcast_to(idx, update_shape))
idx = self.bb.normalize(relax.op.expand_dims(idx, axis=-1))
axis_indices.append(idx)

return self.bb.normalize(relax.op.concat(axis_indices, axis=-1))

def _convert_stablehlo_dot_general(self, op):
"""Convert the canonical 2D STABLEHLO_DOT_GENERAL subset to Relax matmul."""
from tflite.StablehloDotGeneralOptions import StablehloDotGeneralOptions
Expand Down
49 changes: 40 additions & 9 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9851,16 +9851,47 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported():
"""TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is unsupported."""
buf = _build_stablehlo_dynamic_update_slice_model([0, 0], dynamic_starts=True)
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(buf, 0)
def test_stablehlo_dynamic_update_slice_dynamic_starts():
"""TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts lowers structurally."""
mod = _load_model_from_buffer(
_build_stablehlo_dynamic_update_slice_model([0, 0], dynamic_starts=True)
)

with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"):
from_tflite(tflite_model)
@I.ir_module
class Expected:
@R.function
def main(
operand: R.Tensor((3, 4), dtype="float32"),
update: R.Tensor((2, 2), dtype="float32"),
s0: R.Tensor((), dtype="int32"),
s1: R.Tensor((), dtype="int32"),
) -> R.Tensor((3, 4), dtype="float32"):
R.func_attr({"num_input": 4})
with R.dataflow():
lv: R.Tensor((2,), dtype="int64") = R.arange(0, 2, 1, dtype="int64")
lv1: R.Tensor((), dtype="int64") = R.astype(s0, dtype="int64")
lv2: R.Tensor((), dtype="int64") = R.maximum(lv1, R.const(0, "int64"))
lv3: R.Tensor((), dtype="int64") = R.minimum(lv2, R.const(1, "int64"))
lv4: R.Tensor((2,), dtype="int64") = R.add(lv, lv3)
lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, (2, 1))
lv6: R.Tensor((2, 2), dtype="int64") = R.broadcast_to(lv5, (2, 2))
lv7: R.Tensor((2,), dtype="int64") = R.arange(0, 2, 1, dtype="int64")
lv8: R.Tensor((), dtype="int64") = R.astype(s1, dtype="int64")
lv9: R.Tensor((), dtype="int64") = R.maximum(lv8, R.const(0, "int64"))
lv10: R.Tensor((), dtype="int64") = R.minimum(lv9, R.const(2, "int64"))
lv11: R.Tensor((2,), dtype="int64") = R.add(lv7, lv10)
lv12: R.Tensor((1, 2), dtype="int64") = R.reshape(lv11, (1, 2))
lv13: R.Tensor((2, 2), dtype="int64") = R.broadcast_to(lv12, (2, 2))
lv14: R.Tensor((2, 2, 1), dtype="int64") = R.expand_dims(lv6, axis=[-1])
lv15: R.Tensor((2, 2, 1), dtype="int64") = R.expand_dims(lv13, axis=[-1])
lv16: R.Tensor((2, 2, 2), dtype="int64") = R.concat((lv14, lv15), axis=-1)
gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd(
operand, lv16, update, reduction="update"
)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported():
Expand Down
Loading