Skip to content

Commit 8906b06

Browse files
LudovicoYINCopilot
andcommitted
[Relax][Frontend][TFLite] Add UNIDIRECTIONAL_SEQUENCE_RNN converter
Implements convert_unidirectional_sequence_rnn in the Relax TFLite frontend. The op (BuiltinOperator 35) executes a simple RNN cell over a time sequence using Relax primitives (matmul / add / activation). Cell formula: h_t = act(x_t @ W^T + h_{t-1} @ Wr^T + b) Design notes: - Inputs: input[batch,time,input_size], input_weights[units,input_size], recurrent_weights[units,units], bias[units], hidden_state[batch,units] - time_major=True input is transposed to batch-major before unrolling - Activations supported: NONE, RELU, RELU6, TANH, SIGMOID - Quantised variant raises OpNotImplemented (guard already present) - For time=1 the split is skipped; squeeze is applied directly - Time steps are unrolled at graph-construction time and outputs stacked along axis=1 Tests added (3): - test_unidirectional_sequence_rnn_none_activation: structural_equal check with identity weights / zero bias, NONE activation, time=1 - test_unidirectional_sequence_rnn_relu_activation: shape check with random weights, RELU activation, time=3 - test_unidirectional_sequence_rnn_time_major: shape check with time_major=True input layout Closes part of #19519 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent cae6cb8 commit 8906b06

2 files changed

Lines changed: 308 additions & 0 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
312312
"TRANSPOSE_CONV": self.convert_transpose_conv,
313313
"TRANSPOSE": self.convert_transpose,
314314
"UNPACK": self.convert_unpack,
315+
"UNIDIRECTIONAL_SEQUENCE_RNN": self.convert_unidirectional_sequence_rnn,
315316
"UNSORTED_SEGMENT_MIN": functools.partial(
316317
self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min"
317318
),
@@ -4477,6 +4478,106 @@ def convert_unpack(self, op):
44774478

44784479
return squeezed
44794480

4481+
def convert_unidirectional_sequence_rnn(self, op):
4482+
"""Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.
4483+
4484+
Inputs (5 tensors):
4485+
[0] input [batch, time, input_size] (or [time, batch, input_size] if time_major)
4486+
[1] input_weights [num_units, input_size]
4487+
[2] recurrent_weights [num_units, num_units]
4488+
[3] bias [num_units]
4489+
[4] hidden_state [batch, num_units] (variable, zero-initialised)
4490+
4491+
Output:
4492+
[0] output [batch, time, num_units]
4493+
4494+
Cell equation:
4495+
h_t = fused_activation(x_t @ W.T + h_{t-1} @ Wr.T + b)
4496+
"""
4497+
from tflite.BuiltinOptions import BuiltinOptions
4498+
from tflite.SequenceRNNOptions import SequenceRNNOptions
4499+
4500+
if self.is_quantized(op):
4501+
raise tvm.error.OpNotImplemented(
4502+
"TFLite quantized UNIDIRECTIONAL_SEQUENCE_RNN is not supported yet."
4503+
)
4504+
4505+
input_tensors = self.get_input_tensors(op)
4506+
assert len(input_tensors) == 5, "input tensors length should be 5"
4507+
4508+
input_tensor = input_tensors[0]
4509+
weights_tensor = input_tensors[1]
4510+
recurrent_tensor = input_tensors[2]
4511+
bias_tensor = input_tensors[3]
4512+
hidden_state_tensor = input_tensors[4]
4513+
4514+
output_tensors = self.get_output_tensors(op)
4515+
assert len(output_tensors) >= 1, "output tensors length should be at least 1"
4516+
4517+
assert op.BuiltinOptionsType() == BuiltinOptions.SequenceRNNOptions
4518+
op_options = op.BuiltinOptions()
4519+
seq_rnn_options = SequenceRNNOptions()
4520+
seq_rnn_options.Init(op_options.Bytes, op_options.Pos)
4521+
time_major = seq_rnn_options.TimeMajor()
4522+
fused_activation_fn = seq_rnn_options.FusedActivationFunction()
4523+
4524+
# Constant weight/bias expressions.
4525+
weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size]
4526+
recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units]
4527+
4528+
# bias is optional (tensor_idx == -1 when absent); default to zeros.
4529+
if bias_tensor.tensor_idx != -1:
4530+
bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]
4531+
else:
4532+
num_units = int(self.get_tensor_shape(weights_tensor)[0])
4533+
bias_dtype = self.get_tensor_type_str(weights_tensor.tensor.Type())
4534+
bias_expr = relax.op.zeros((num_units,), dtype=bias_dtype)
4535+
4536+
# Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T.
4537+
w_t = relax.op.permute_dims(weights_expr)
4538+
wr_t = relax.op.permute_dims(recurrent_expr)
4539+
4540+
# Resolve the input expression; normalise to batch-major [batch, time, input_size].
4541+
# Only the time dimension must be static (needed for unrolling); batch may be dynamic.
4542+
in_expr = self.get_tensor_expr(input_tensor)
4543+
in_shape = self.get_tensor_shape(input_tensor)
4544+
if time_major:
4545+
in_expr = relax.op.permute_dims(in_expr, [1, 0, 2])
4546+
num_steps = int(in_shape[0])
4547+
else:
4548+
num_steps = int(in_shape[1])
4549+
4550+
# Initial hidden state: use the model's tensor value when available (non-zero init or
4551+
# graph input), otherwise fall back to zeros for the common variable-tensor case.
4552+
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
4553+
if self.has_expr(hidden_state_tensor.tensor_idx) or (
4554+
hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0
4555+
):
4556+
h = self.get_tensor_expr(hidden_state_tensor)
4557+
else:
4558+
h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
4559+
h = relax.op.zeros(h_shape, dtype=h_dtype)
4560+
4561+
# Unroll over the time axis.
4562+
# relax.op.split with 1 section returns the tensor directly; handle uniformly.
4563+
if num_steps == 1:
4564+
steps = [relax.op.squeeze(in_expr, axis=[1])]
4565+
else:
4566+
splits = relax.op.split(in_expr, num_steps, axis=1)
4567+
steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)]
4568+
4569+
outputs = []
4570+
for x_t in steps: # x_t: [batch, input_size]
4571+
gates = relax.op.add(
4572+
relax.op.add(relax.op.matmul(x_t, w_t), relax.op.matmul(h, wr_t)),
4573+
bias_expr,
4574+
)
4575+
h = self.convert_fused_activation_function(gates, fused_activation_fn)
4576+
outputs.append(h)
4577+
4578+
# Stack timestep outputs: [batch, time, num_units].
4579+
return relax.op.stack(outputs, axis=1)
4580+
44804581
"""
44814582
def convert_unidirectional_sequence_lstm(self, op):
44824583
### Long Short Term Memory for TFLite implementation. ###

tests/python/relax/test_frontend_tflite.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3710,6 +3710,8 @@ def _get_tflite_schema_enum(enum_name):
37103710
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
37113711
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")
37123712

3713+
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")
3714+
37133715
_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
37143716
_DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32)
37153717
_DENSIFY_ROW_PTRS = [0, 1, 2]
@@ -6731,5 +6733,210 @@ def main(
67316733
tvm.ir.assert_structural_equal(mod, Expected)
67326734

67336735

6736+
# ── UNIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────
6737+
6738+
6739+
def _build_unidirectional_sequence_rnn_model(
6740+
batch,
6741+
time,
6742+
input_size,
6743+
num_units,
6744+
weights,
6745+
recurrent_weights,
6746+
bias,
6747+
activation,
6748+
*,
6749+
time_major=False,
6750+
):
6751+
"""Build a minimal TFLite flatbuffer model containing one UNIDIRECTIONAL_SEQUENCE_RNN op.
6752+
6753+
Tensor layout (indices 0-5):
6754+
0 - input [batch, time, input_size] (or [time, batch, input_size] if time_major)
6755+
1 - input_weights [num_units, input_size] (constant)
6756+
2 - recurrent_wts [num_units, num_units] (constant)
6757+
3 - bias [num_units] (constant)
6758+
4 - hidden_state [batch, num_units] (variable, zero-initialised)
6759+
5 - output [batch, time, num_units]
6760+
"""
6761+
builder = flatbuffers.Builder(4096)
6762+
6763+
_tfl_sequence_rnn_options.SequenceRNNOptionsStart(builder)
6764+
_tfl_sequence_rnn_options.SequenceRNNOptionsAddTimeMajor(builder, time_major)
6765+
_tfl_sequence_rnn_options.SequenceRNNOptionsAddFusedActivationFunction(builder, activation)
6766+
rnn_opts = _tfl_sequence_rnn_options.SequenceRNNOptionsEnd(builder)
6767+
6768+
rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_RNN)
6769+
6770+
input_shape = [time, batch, input_size] if time_major else [batch, time, input_size]
6771+
6772+
def _t(buf_idx, shape, is_variable=False):
6773+
shape_vec = _tflite_shape(builder, shape)
6774+
_tfl_tensor.TensorStart(builder)
6775+
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
6776+
_tfl_tensor.TensorAddHasRank(builder, True)
6777+
_tfl_tensor.TensorAddIsVariable(builder, is_variable)
6778+
_tfl_tensor.TensorAddShape(builder, shape_vec)
6779+
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
6780+
return _tfl_tensor.TensorEnd(builder)
6781+
6782+
tensors = [
6783+
_t(0, input_shape),
6784+
_t(1, [num_units, input_size]),
6785+
_t(2, [num_units, num_units]),
6786+
_t(3, [num_units]),
6787+
_t(4, [batch, num_units], is_variable=True),
6788+
_t(5, [batch, time, num_units]),
6789+
]
6790+
6791+
rnn_op = _build_operator(
6792+
builder,
6793+
0,
6794+
[0, 1, 2, 3, 4],
6795+
[5],
6796+
builtin_options_type=_tfl_builtin_options.SequenceRNNOptions,
6797+
builtin_options=rnn_opts,
6798+
)
6799+
6800+
subgraph = _build_subgraph(
6801+
builder,
6802+
tensors=tensors,
6803+
operators=[rnn_op],
6804+
inputs=[0],
6805+
outputs=[5],
6806+
)
6807+
6808+
buffers = [
6809+
_build_buffer(builder),
6810+
_build_buffer(builder, weights.tobytes()),
6811+
_build_buffer(builder, recurrent_weights.tobytes()),
6812+
_build_buffer(builder, bias.tobytes()),
6813+
_build_buffer(builder),
6814+
_build_buffer(builder),
6815+
]
6816+
6817+
return _finish_tflite_model(
6818+
builder,
6819+
subgraph=subgraph,
6820+
operator_codes=[rnn_op_code],
6821+
buffers=buffers,
6822+
)
6823+
6824+
6825+
def test_unidirectional_sequence_rnn_none_activation():
6826+
"""UNIDIRECTIONAL_SEQUENCE_RNN with NONE activation, time=1, lowers to matmul/add/stack.
6827+
6828+
Cell equation: h_t = x_t @ W.T + h_{t-1} @ Wr.T + b (no activation for NONE)
6829+
"""
6830+
from tflite.ActivationFunctionType import ActivationFunctionType
6831+
6832+
batch, time, input_size, num_units = 2, 1, 2, 2
6833+
weights = np.eye(num_units, input_size, dtype=np.float32)
6834+
recurrent_weights = np.eye(num_units, dtype=np.float32)
6835+
bias = np.zeros(num_units, dtype=np.float32)
6836+
6837+
mod = _load_model_from_buffer(
6838+
_build_unidirectional_sequence_rnn_model(
6839+
batch,
6840+
time,
6841+
input_size,
6842+
num_units,
6843+
weights,
6844+
recurrent_weights,
6845+
bias,
6846+
ActivationFunctionType.NONE,
6847+
)
6848+
)
6849+
6850+
@I.ir_module
6851+
class Expected:
6852+
@R.function
6853+
def main(x: R.Tensor((2, 1, 2), dtype="float32")) -> R.Tensor((2, 1, 2), dtype="float32"):
6854+
R.func_attr({"num_input": 1})
6855+
with R.dataflow():
6856+
lv: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1])
6857+
lv1: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
6858+
R.const(np.eye(2, dtype=np.float32)), axes=None
6859+
)
6860+
lv2: R.Tensor((2, 2), dtype="float32") = R.matmul(lv, lv1, out_dtype="void")
6861+
lv3: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32")
6862+
lv4: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
6863+
R.const(np.eye(2, dtype=np.float32)), axes=None
6864+
)
6865+
lv5: R.Tensor((2, 2), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
6866+
lv6: R.Tensor((2, 2), dtype="float32") = R.add(lv2, lv5)
6867+
lv7: R.Tensor((2, 2), dtype="float32") = R.add(
6868+
lv6, R.const(np.zeros(2, dtype=np.float32))
6869+
)
6870+
gv: R.Tensor((2, 1, 2), dtype="float32") = R.stack((lv7,), axis=1)
6871+
R.output(gv)
6872+
return gv
6873+
6874+
tvm.ir.assert_structural_equal(mod, Expected)
6875+
6876+
6877+
def test_unidirectional_sequence_rnn_relu_activation():
6878+
"""UNIDIRECTIONAL_SEQUENCE_RNN with RELU activation and multiple time steps."""
6879+
from tflite.ActivationFunctionType import ActivationFunctionType
6880+
6881+
batch, time, input_size, num_units = 2, 3, 4, 8
6882+
np.random.seed(42)
6883+
weights = np.random.randn(num_units, input_size).astype(np.float32)
6884+
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
6885+
bias = np.random.randn(num_units).astype(np.float32)
6886+
6887+
mod = _load_model_from_buffer(
6888+
_build_unidirectional_sequence_rnn_model(
6889+
batch,
6890+
time,
6891+
input_size,
6892+
num_units,
6893+
weights,
6894+
recurrent_weights,
6895+
bias,
6896+
ActivationFunctionType.RELU,
6897+
)
6898+
)
6899+
6900+
fn = mod["main"]
6901+
assert len(fn.params) == 1, "only the sequence input should be a graph input"
6902+
in_shape = fn.params[0].struct_info.shape
6903+
assert tuple(int(d) for d in in_shape) == (batch, time, input_size)
6904+
out_shape = fn.ret_struct_info.shape
6905+
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
6906+
6907+
6908+
def test_unidirectional_sequence_rnn_time_major():
6909+
"""UNIDIRECTIONAL_SEQUENCE_RNN with time_major=True transposes before unrolling."""
6910+
from tflite.ActivationFunctionType import ActivationFunctionType
6911+
6912+
batch, time, input_size, num_units = 3, 4, 2, 5
6913+
np.random.seed(7)
6914+
weights = np.random.randn(num_units, input_size).astype(np.float32)
6915+
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
6916+
bias = np.zeros(num_units, dtype=np.float32)
6917+
6918+
mod = _load_model_from_buffer(
6919+
_build_unidirectional_sequence_rnn_model(
6920+
batch,
6921+
time,
6922+
input_size,
6923+
num_units,
6924+
weights,
6925+
recurrent_weights,
6926+
bias,
6927+
ActivationFunctionType.NONE,
6928+
time_major=True,
6929+
)
6930+
)
6931+
6932+
fn = mod["main"]
6933+
# Input to the graph is the raw time-major tensor [time, batch, input_size].
6934+
in_shape = fn.params[0].struct_info.shape
6935+
assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
6936+
# Output is always batch-major [batch, time, num_units].
6937+
out_shape = fn.ret_struct_info.shape
6938+
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)
6939+
6940+
67346941
if __name__ == "__main__":
67356942
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)