Skip to content

Commit dcbebe7

Browse files
LudovicoYINCopilot
andauthored
[Relax][Frontend][TFLite] Add UNIDIRECTIONAL_SEQUENCE_RNN converter (#19601)
## Summary This PR adds Relax TFLite frontend support for `UNIDIRECTIONAL_SEQUENCE_RNN` (BuiltinOperator 35), claimed in [#19519](#19519) Group A. The op executes a simple RNN cell over a time sequence. The converter unrolls the time steps at graph-construction time using Relax primitives. Cell equation: ``` h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b) ``` ## Changes - **Handler**: `convert_unidirectional_sequence_rnn` registered in `convert_map` (alphabetical, U-region after `UNPACK`) - **Inputs** (5): `input [batch, time, input_size]`, `input_weights [num_units, input_size]`, `recurrent_weights [num_units, num_units]`, `bias [num_units]`, `hidden_state [batch, num_units]` (variable, zero-initialised) - **Output**: `[batch, time, num_units]` (always batch-major) - **time_major=True**: input is transposed to batch-major before unrolling - **Activations**: NONE, RELU, RELU6, TANH, SIGMOID (via `convert_fused_activation_function`) - **Quantized**: raises `OpNotImplemented` (not yet supported) ## Testing Modern TF/Keras (2.x, Keras 3) no longer emits `UNIDIRECTIONAL_SEQUENCE_RNN`; `SimpleRNN` with `unroll=False` lowers to `WHILE`+TensorList ops, and `unroll=True` expands to elementwise ops. Tests therefore follow the same flatbuffer-construction pattern used by the StableHLO op PRs (#19536, #19587). Three tests added to `tests/python/relax/test_frontend_tflite.py`: - `test_unidirectional_sequence_rnn_none_activation` — `tvm.ir.assert_structural_equal` with identity weights / zero bias, NONE activation, time=1 - `test_unidirectional_sequence_rnn_relu_activation` — shape check, random weights, RELU activation, time=3 - `test_unidirectional_sequence_rnn_time_major` — shape check, `time_major=True` input layout ```bash python -m pytest tests/python/relax/test_frontend_tflite.py -k unidirectional_sequence_rnn -v ``` All 3 tests pass. pre-commit (ASF header, ruff check, ruff format) all pass. ## References - Issue [#19519](#19519) Group A: Sequence / recurrent model operators Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent fa66213 commit dcbebe7

2 files changed

Lines changed: 308 additions & 0 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
381381
"TRANSPOSE_CONV": self.convert_transpose_conv,
382382
"TRANSPOSE": self.convert_transpose,
383383
"UNPACK": self.convert_unpack,
384+
"UNIDIRECTIONAL_SEQUENCE_RNN": self.convert_unidirectional_sequence_rnn,
384385
"UNSORTED_SEGMENT_MIN": functools.partial(
385386
self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min"
386387
),
@@ -4877,6 +4878,106 @@ def convert_unpack(self, op):
48774878

48784879
return squeezed
48794880

4881+
def convert_unidirectional_sequence_rnn(self, op):
4882+
"""Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
4883+
4884+
Inputs (5 tensors):
4885+
[0] input [batch, time, input_size] (or [time, batch, input_size] if time_major)
4886+
[1] input_weights [num_units, input_size]
4887+
[2] recurrent_weights [num_units, num_units]
4888+
[3] bias [num_units]
4889+
[4] hidden_state [batch, num_units] (variable, zero-initialised)
4890+
4891+
Output:
4892+
[0] output [batch, time, num_units]
4893+
4894+
Cell equation:
4895+
h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
4896+
"""
4897+
from tflite.BuiltinOptions import BuiltinOptions
4898+
from tflite.SequenceRNNOptions import SequenceRNNOptions
4899+
4900+
if self.is_quantized(op):
4901+
raise tvm.error.OpNotImplemented(
4902+
"TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported yet."
4903+
)
4904+
4905+
input_tensors = self.get_input_tensors(op)
4906+
assert len(input_tensors) == 5, "input tensors length should be 5"
4907+
4908+
input_tensor = input_tensors[0]
4909+
weights_tensor = input_tensors[1]
4910+
recurrent_tensor = input_tensors[2]
4911+
bias_tensor = input_tensors[3]
4912+
hidden_state_tensor = input_tensors[4]
4913+
4914+
output_tensors = self.get_output_tensors(op)
4915+
assert len(output_tensors) >= 1, "output tensors length should be at least 1"
4916+
4917+
assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
4918+
op_options = op.BuiltinOptions()
4919+
seq_rnn_options = SequenceRNNOptions()
4920+
seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
4921+
time_major = seq_rnn_options.TimeMajor()
4922+
fused_activation_fn = seq_rnn_options.FusedActivationFunction()
4923+
4924+
# Constant weight/bias expressions.
4925+
weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size]
4926+
recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units]
4927+
4928+
# bias is optional (tensor_idx == -1 when absent); default to zeros.
4929+
if bias_tensor.tensor_idx != -1:
4930+
bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
4931+
else:
4932+
num_units = int(self.get_tensor_shape(weights_tensor)[0])
4933+
bias_dtype = self.get_tensor_type_str(weights_tensor.tensor.Type())
4934+
bias_expr = relax.op.zeros((num_units,), dtype=bias_dtype)
4935+
4936+
# Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T.
4937+
w_t = relax.op.permute_dims(weights_expr)
4938+
wr_t = relax.op.permute_dims(recurrent_expr)
4939+
4940+
# Resolve the input expression; normalise to batch-major [batch, time, input_size].
4941+
# Only the time dimension must be static (needed for unrolling); batch may be dynamic.
4942+
in_expr = self.get_tensor_expr(input_tensor)
4943+
in_shape = self.get_tensor_shape(input_tensor)
4944+
if time_major:
4945+
in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
4946+
num_steps = int(in_shape[0])
4947+
else:
4948+
num_steps = int(in_shape[1])
4949+
4950+
# Initial hidden state: use the model's tensor value when available (non-zero init or
4951+
# graph input), otherwise fall back to zeros for the common variable-tensor case.
4952+
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
4953+
if self.has_expr(hidden_state_tensor.tensor_idx) or (
4954+
hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0
4955+
):
4956+
h = self.get_tensor_expr(hidden_state_tensor)
4957+
else:
4958+
h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
4959+
h = relax.op.zeros(h_shape, dtype=h_dtype)
4960+
4961+
# Unroll over the time axis.
4962+
# relax.op.split with 1 section returns the tensor directly; handle uniformly.
4963+
if num_steps == 1:
4964+
steps = [relax.op.squeeze(in_expr, axis=[1])]
4965+
else:
4966+
splits = relax.op.split(in_expr, num_steps, axis=1)
4967+
steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)]
4968+
4969+
outputs = []
4970+
for x_t in steps: # x_t: [batch, input_size]
4971+
gates = relax.op.add(
4972+
relax.op.add(relax.op.matmul(x_t, w_t), relax.op.matmul(h, wr_t)),
4973+
bias_expr,
4974+
)
4975+
h = self.convert_fused_activation_function(gates, fused_activation_fn)
4976+
outputs.append(h)
4977+
4978+
# Stack timestep outputs: [batch, time, num_units].
4979+
return relax.op.stack(outputs, axis=1)
4980+
48804981
"""
48814982
def convert_unidirectional_sequence_lstm(self, op):
48824983
### Long Short Term Memory for TFLite implementation. ###

tests/python/relax/test_frontend_tflite.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3720,6 +3720,8 @@ def _get_tflite_schema_enum(enum_name):
37203720
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
37213721
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")
37223722

3723+
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
3724+
37233725
_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
37243726
_DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
37253727
_DENSIFY_ROW_PTRS = [0, 1, 2]
@@ -9719,5 +9721,210 @@ def main(
97199721
tvm.ir.assert_structural_equal(mod, Expected)
97209722

97219723

9724+
# ── UNIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────
9725+
9726+
9727+
def _build_unidirectional_sequence_rnn_model(
9728+
batch,
9729+
time,
9730+
input_size,
9731+
num_units,
9732+
weights,
9733+
recurrent_weights,
9734+
bias,
9735+
activation,
9736+
*,
9737+
time_major=False,
9738+
):
9739+
"""Build a minimal TFLite flatbuffer model containing one UNIDIRECTIONAL_SEQUENCE_RNN op.
9740+
9741+
Tensor layout (indices 0-5):
9742+
0 - input [batch, time, input_size] (or [time, batch, input_size] if time_major)
9743+
1 - input_weights [num_units, input_size] (constant)
9744+
2 - recurrent_wts [num_units, num_units] (constant)
9745+
3 - bias [num_units] (constant)
9746+
4 - hidden_state [batch, num_units] (variable, zero-initialised)
9747+
5 - output [batch, time, num_units]
9748+
"""
9749+
builder = flatbuffers.Builder(4096)
9750+
9751+
_tfl_sequence_rnn_options.SequenceRNNOptionsStart(builder)
9752+
_tfl_sequence_rnn_options.SequenceRNNOptionsAddTimeMajor(builder, time_major)
9753+
_tfl_sequence_rnn_options.SequenceRNNOptionsAddFusedActivationFunction(builder, activation)
9754+
rnn_opts = _tfl_sequence_rnn_options.SequenceRNNOptionsEnd(builder)
9755+
9756+
rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_RNN)
9757+
9758+
input_shape = [time, batch, input_size] if time_major else [batch, time, input_size]
9759+
9760+
def _t(buf_idx, shape, is_variable=False):
9761+
shape_vec = _tflite_shape(builder, shape)
9762+
_tfl_tensor.TensorStart(builder)
9763+
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
9764+
_tfl_tensor.TensorAddHasRank(builder, True)
9765+
_tfl_tensor.TensorAddIsVariable(builder, is_variable)
9766+
_tfl_tensor.TensorAddShape(builder, shape_vec)
9767+
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
9768+
return _tfl_tensor.TensorEnd(builder)
9769+
9770+
tensors = [
9771+
_t(0, input_shape),
9772+
_t(1, [num_units, input_size]),
9773+
_t(2, [num_units, num_units]),
9774+
_t(3, [num_units]),
9775+
_t(4, [batch, num_units], is_variable=True),
9776+
_t(5, [batch, time, num_units]),
9777+
]
9778+
9779+
rnn_op = _build_operator(
9780+
builder,
9781+
0,
9782+
[0, 1, 2, 3, 4],
9783+
[5],
9784+
builtin_options_type=_tfl_builtin_options.SequenceRNNOptions,
9785+
builtin_options=rnn_opts,
9786+
)
9787+
9788+
subgraph = _build_subgraph(
9789+
builder,
9790+
tensors=tensors,
9791+
operators=[rnn_op],
9792+
inputs=[0],
9793+
outputs=[5],
9794+
)
9795+
9796+
buffers = [
9797+
_build_buffer(builder),
9798+
_build_buffer(builder, weights.tobytes()),
9799+
_build_buffer(builder, recurrent_weights.tobytes()),
9800+
_build_buffer(builder, bias.tobytes()),
9801+
_build_buffer(builder),
9802+
_build_buffer(builder),
9803+
]
9804+
9805+
return _finish_tflite_model(
9806+
builder,
9807+
subgraph=subgraph,
9808+
operator_codes=[rnn_op_code],
9809+
buffers=buffers,
9810+
)
9811+
9812+
9813+
def test_unidirectional_sequence_rnn_none_activation():
9814+
"""UNIDIRECTIONAL_SEQUENCE_RNN with NONE activation, time=1, lowers to matmul/add/stack.
9815+
9816+
Cell equation: h_t = x_t @ W.T + h_{t-1} @ Wr.T + b (no activation for NONE)
9817+
"""
9818+
from tflite.ActivationFunctionType import ActivationFunctionType
9819+
9820+
batch, time, input_size, num_units = 2, 1, 2, 2
9821+
weights = np.eye(num_units, input_size, dtype=np.float32)
9822+
recurrent_weights = np.eye(num_units, dtype=np.float32)
9823+
bias = np.zeros(num_units, dtype=np.float32)
9824+
9825+
mod = _load_model_from_buffer(
9826+
_build_unidirectional_sequence_rnn_model(
9827+
batch,
9828+
time,
9829+
input_size,
9830+
num_units,
9831+
weights,
9832+
recurrent_weights,
9833+
bias,
9834+
ActivationFunctionType.NONE,
9835+
)
9836+
)
9837+
9838+
@I.ir_module
9839+
class Expected:
9840+
@R.function
9841+
def main(x: R.Tensor((2, 1, 2), dtype="float32")) -> R.Tensor((2, 1, 2), dtype="float32"):
9842+
R.func_attr({"num_input": 1})
9843+
with R.dataflow():
9844+
lv: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1])
9845+
lv1: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
9846+
R.const(np.eye(2, dtype=np.float32)), axes=None
9847+
)
9848+
lv2: R.Tensor((2, 2), dtype="float32") = R.matmul(lv, lv1, out_dtype="void")
9849+
lv3: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32")
9850+
lv4: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
9851+
R.const(np.eye(2, dtype=np.float32)), axes=None
9852+
)
9853+
lv5: R.Tensor((2, 2), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
9854+
lv6: R.Tensor((2, 2), dtype="float32") = R.add(lv2, lv5)
9855+
lv7: R.Tensor((2, 2), dtype="float32") = R.add(
9856+
lv6, R.const(np.zeros(2, dtype=np.float32))
9857+
)
9858+
gv: R.Tensor((2, 1, 2), dtype="float32") = R.stack((lv7,), axis=1)
9859+
R.output(gv)
9860+
return gv
9861+
9862+
tvm.ir.assert_structural_equal(mod, Expected)
9863+
9864+
9865+
def test_unidirectional_sequence_rnn_relu_activation():
9866+
"""UNIDIRECTIONAL_SEQUENCE_RNN with RELU activation and multiple time steps."""
9867+
from tflite.ActivationFunctionType import ActivationFunctionType
9868+
9869+
batch, time, input_size, num_units = 2, 3, 4, 8
9870+
np.random.seed(42)
9871+
weights = np.random.randn(num_units, input_size).astype(np.float32)
9872+
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
9873+
bias = np.random.randn(num_units).astype(np.float32)
9874+
9875+
mod = _load_model_from_buffer(
9876+
_build_unidirectional_sequence_rnn_model(
9877+
batch,
9878+
time,
9879+
input_size,
9880+
num_units,
9881+
weights,
9882+
recurrent_weights,
9883+
bias,
9884+
ActivationFunctionType.RELU,
9885+
)
9886+
)
9887+
9888+
fn = mod["main"]
9889+
assert len(fn.params) == 1, "only the sequence input should be a graph input"
9890+
in_shape = fn.params[0].struct_info.shape
9891+
assert tuple(int(d) for d in in_shape) == (batch, time, input_size)
9892+
out_shape = fn.ret_struct_info.shape
9893+
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
9894+
9895+
9896+
def test_unidirectional_sequence_rnn_time_major():
9897+
"""UNIDIRECTIONAL_SEQUENCE_RNN with time_major=True transposes before unrolling."""
9898+
from tflite.ActivationFunctionType import ActivationFunctionType
9899+
9900+
batch, time, input_size, num_units = 3, 4, 2, 5
9901+
np.random.seed(7)
9902+
weights = np.random.randn(num_units, input_size).astype(np.float32)
9903+
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
9904+
bias = np.zeros(num_units, dtype=np.float32)
9905+
9906+
mod = _load_model_from_buffer(
9907+
_build_unidirectional_sequence_rnn_model(
9908+
batch,
9909+
time,
9910+
input_size,
9911+
num_units,
9912+
weights,
9913+
recurrent_weights,
9914+
bias,
9915+
ActivationFunctionType.NONE,
9916+
time_major=True,
9917+
)
9918+
)
9919+
9920+
fn = mod["main"]
9921+
# Input to the graph is the raw time-major tensor [time, batch, input_size].
9922+
in_shape = fn.params[0].struct_info.shape
9923+
assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
9924+
# Output is always batch-major [batch, time, num_units].
9925+
out_shape = fn.ret_struct_info.shape
9926+
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
9927+
9928+
97229929
if __name__ == "__main__":
97239930
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)