Skip to content

Commit 4ab83cc

Browse files
authored
Fixed LSTM (#219)
This PR fixes two critical issues in `arm_lstm_unidirectional_s8` and `s16` that prevent state persistence in streaming models and cause out-of-bounds reads during non-time-major inference. These issues are closely related to in tensorflow/tflite-micro#3564. Problem: - State Wiping: By default, `arm_lstm_unidirectional_*` unconditionally sets `hidden_in` to `NULL` and memsets `cell_state` to 0. This discards the `HiddenStateTensor` and `CellStateTensor` that TFLM relies on to persist state across `Invoke()` calls for streaming models. - Striding Bug: In the `time_major` = `false` block of `arm_lstm_unidirectional_*`, CMSIS-NN attempts to jump between batches by passing `batch_offset` = `params->time_steps` to `arm_nn_lstm_step_*`. However, `arm_nn_lstm_step_*` forwards this `batch_offset` to `arm_nn_vec_mat_mul_result_acc_s8_s16` for both the `data_in` and `hidden_in` pointers. Since the `hidden_state` buffer is contiguous (stride 1) and not strided like `data_in`, passing `batch_offset` = `params->time_steps` causes out-of-bounds reads on the hidden_in buffer at `timestep` t=0. Solution: - Adding a `hidden_state` pointer to `cmsis_nn_lstm_context`. - Forwarding this `hidden_state` as `hidden_in` when present, skipping the `cell_state` wiping if so. - Explicitly iterating over the `batch_size` in the `time_major` = `false` case when computing step sizes, which forces `batch_offset` = 1 and avoids the buggy out-of-bounds stride entirely while writing to the final memory buffer sequentially.
1 parent 91f84c8 commit 4ab83cc

41 files changed

Lines changed: 764 additions & 32 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Include/arm_nn_types.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
* Description: Public header file to contain the CMSIS-NN structs for the
2323
* TensorFlowLite micro compliant functions
2424
*
25-
* $Date: 27 March 2026
26-
* $Revision: V.3.6.0
25+
* $Date: 21 May 2026
26+
* $Revision: V.3.6.1
2727
*
2828
* Target : Arm(R) M-Profile Architecture
2929
* -------------------------------------------------------------------- */
@@ -274,6 +274,7 @@ typedef struct
274274
void *temp1;
275275
void *temp2;
276276
void *cell_state;
277+
void *hidden_state;
277278
} cmsis_nn_lstm_context;
278279

279280
/**

Source/LSTMFunctions/arm_lstm_unidirectional_s16.c

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
* Title: arm_lstm_unidirectional_s16.c
2222
* Description: S16 LSTM function with S16 gate output
2323
*
24-
* $Date: 26 March 2024
25-
* $Revision: V.1.0.0
24+
* $Date: 21 May 2026
25+
* $Revision: V.1.0.1
2626
*
2727
* Target Processor: Cortex-M processors
2828
*
@@ -52,8 +52,13 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input,
5252
cmsis_nn_lstm_context *buffers)
5353
{
5454

55-
int16_t *hidden_in = NULL;
56-
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
55+
int16_t *hidden_in = (int16_t *)buffers->hidden_state;
56+
57+
if (buffers->hidden_state == NULL)
58+
{
59+
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
60+
}
61+
5762
if (params->time_major)
5863
{
5964
// First dimension is time, input/output for each time step is stored continously in memory
@@ -69,22 +74,50 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input,
6974
// Output is used as recurrent input/hidden state for the next timestep.
7075
hidden_in = &hidden_out[0];
7176
}
77+
78+
if (buffers->hidden_state != NULL && params->time_steps > 0)
79+
{
80+
memcpy(buffers->hidden_state, hidden_in, params->batch_size * params->hidden_size * sizeof(int16_t));
81+
}
7282
}
7383
else
7484
{
75-
// First dimension is time, add batch_offset to jump in memory for each batch
76-
for (int t = 0; t < params->time_steps; t++)
85+
// Batch major: [batch, time, size]
86+
// arm_nn_lstm_step_s16 expects data_in and hidden_in to have the same batch_offset.
87+
// Since the initial hidden_state is contiguous, we must process one batch at a time.
88+
cmsis_nn_lstm_params step_params = *params;
89+
step_params.batch_size = 1;
90+
91+
for (int b = 0; b < params->batch_size; b++)
7792
{
78-
const int16_t *data_in = input + (t * params->input_size);
79-
int16_t *hidden_out = output + (t * params->hidden_size);
80-
arm_cmsis_nn_status status =
81-
arm_nn_lstm_step_s16(data_in, hidden_in, hidden_out, params, buffers, params->time_steps);
82-
if (status != ARM_CMSIS_NN_SUCCESS)
93+
int16_t *step_hidden_in =
94+
(buffers->hidden_state != NULL) ? ((int16_t *)buffers->hidden_state + b * params->hidden_size) : NULL;
95+
96+
cmsis_nn_lstm_context step_buffers = *buffers;
97+
step_buffers.cell_state = (int16_t *)buffers->cell_state + b * params->hidden_size;
98+
99+
for (int t = 0; t < params->time_steps; t++)
83100
{
84-
return status;
101+
const int16_t *data_in = input + (b * params->time_steps + t) * params->input_size;
102+
int16_t *hidden_out = output + (b * params->time_steps + t) * params->hidden_size;
103+
104+
arm_cmsis_nn_status status =
105+
arm_nn_lstm_step_s16(data_in, step_hidden_in, hidden_out, &step_params, &step_buffers, 1);
106+
107+
if (status != ARM_CMSIS_NN_SUCCESS)
108+
{
109+
return status;
110+
}
111+
112+
step_hidden_in = hidden_out;
113+
}
114+
115+
if (buffers->hidden_state != NULL && params->time_steps > 0)
116+
{
117+
memcpy((int16_t *)buffers->hidden_state + b * params->hidden_size,
118+
step_hidden_in,
119+
params->hidden_size * sizeof(int16_t));
85120
}
86-
// Output is used as recurrent input/hidden state for the next timestep.
87-
hidden_in = &hidden_out[0];
88121
}
89122
}
90123
return ARM_CMSIS_NN_SUCCESS;

Source/LSTMFunctions/arm_lstm_unidirectional_s8.c

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
* Title: arm_lstm_unidirectional_s8.c
2222
* Description: S8 LSTM function with S16 gate output
2323
*
24-
* $Date: 08 February 2024
25-
* $Revision: V.1.1.0
24+
* $Date: 21 May 2026
25+
* $Revision: V.1.1.1
2626
*
2727
* Target Processor: Cortex-M processors
2828
*
@@ -52,8 +52,13 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
5252
cmsis_nn_lstm_context *buffers)
5353
{
5454

55-
int8_t *hidden_in = NULL;
56-
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
55+
int8_t *hidden_in = (int8_t *)buffers->hidden_state;
56+
57+
if (buffers->hidden_state == NULL)
58+
{
59+
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
60+
}
61+
5762
if (params->time_major)
5863
{
5964
// First dimension is time, input/output for each time step is stored continously in memory
@@ -69,22 +74,50 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
6974
// Output is used as recurrent input/hidden state for the next timestep.
7075
hidden_in = &hidden_out[0];
7176
}
77+
78+
if (buffers->hidden_state != NULL && params->time_steps > 0)
79+
{
80+
memcpy(buffers->hidden_state, hidden_in, params->batch_size * params->hidden_size * sizeof(int8_t));
81+
}
7282
}
7383
else
7484
{
75-
// First dimension is time, add batch_offset to jump in memory for each batch
76-
for (int t = 0; t < params->time_steps; t++)
85+
// Batch major: [batch, time, size]
86+
// arm_nn_lstm_step_s8 expects data_in and hidden_in to have the same batch_offset.
87+
// Since the initial hidden_state is contiguous, we must process one batch at a time.
88+
cmsis_nn_lstm_params step_params = *params;
89+
step_params.batch_size = 1;
90+
91+
for (int b = 0; b < params->batch_size; b++)
7792
{
78-
const int8_t *data_in = input + (t * params->input_size);
79-
int8_t *hidden_out = output + (t * params->hidden_size);
80-
arm_cmsis_nn_status status =
81-
arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, params->time_steps);
82-
if (status != ARM_CMSIS_NN_SUCCESS)
93+
int8_t *step_hidden_in =
94+
(buffers->hidden_state != NULL) ? ((int8_t *)buffers->hidden_state + b * params->hidden_size) : NULL;
95+
96+
cmsis_nn_lstm_context step_buffers = *buffers;
97+
step_buffers.cell_state = (int16_t *)buffers->cell_state + b * params->hidden_size;
98+
99+
for (int t = 0; t < params->time_steps; t++)
83100
{
84-
return status;
101+
const int8_t *data_in = input + (b * params->time_steps + t) * params->input_size;
102+
int8_t *hidden_out = output + (b * params->time_steps + t) * params->hidden_size;
103+
104+
arm_cmsis_nn_status status =
105+
arm_nn_lstm_step_s8(data_in, step_hidden_in, hidden_out, &step_params, &step_buffers, 1);
106+
107+
if (status != ARM_CMSIS_NN_SUCCESS)
108+
{
109+
return status;
110+
}
111+
112+
step_hidden_in = hidden_out;
113+
}
114+
115+
if (buffers->hidden_state != NULL && params->time_steps > 0)
116+
{
117+
memcpy((int8_t *)buffers->hidden_state + b * params->hidden_size,
118+
step_hidden_in,
119+
params->hidden_size * sizeof(int8_t));
85120
}
86-
// Output is used as recurrent input/hidden state for the next timestep.
87-
hidden_in = &hidden_out[0];
88121
}
89122
}
90123
return ARM_CMSIS_NN_SUCCESS;

Tests/UnitTest/RefactoredTestGen/test_plan.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,14 @@
542542
"input_size" : 22,
543543
"hidden_size" : 3,
544544
"json_template": "lstm_s16.json"
545+
},
546+
{"name" : "lstm_stateful_batch_major_multibatch_s16",
547+
"time_major" : false,
548+
"batch_size" : 2,
549+
"time_steps" : 2,
550+
"input_size" : 6,
551+
"hidden_size" : 7,
552+
"json_template": "lstm_s16.json"
545553
}
546554
]
547555
},
@@ -574,6 +582,13 @@
574582
"time_steps" : 1,
575583
"input_size" : 22,
576584
"hidden_size" : 3
585+
},
586+
{"name" : "lstm_stateful_batch_major_multibatch",
587+
"time_major" : false,
588+
"batch_size" : 2,
589+
"time_steps" : 2,
590+
"input_size" : 6,
591+
"hidden_size" : 7
577592
}
578593
]
579594
},
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
2+
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
3+
#pragma once
4+
#include <stdint.h>
5+
6+
const int32_t lstm_stateful_batch_major_multibatch_cell_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0};
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
2+
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
3+
#pragma once
4+
#include <stdint.h>
5+
6+
const int8_t lstm_stateful_batch_major_multibatch_cell_gate_hidden_weights[49] = {
7+
-20, -88, 87, 87, 109, -54, -12, 21, 2, -112, 44, -79, -97, -15, 123, 105, -122,
8+
-29, -83, 36, 58, 33, 59, -115, 127, -106, 101, -57, -97, -64, -39, 71, 4, -114,
9+
-94, 74, 34, -12, -118, -64, 104, 102, -36, -114, 117, 95, -1, 67, 81};
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
2+
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
3+
#pragma once
4+
#include <stdint.h>
5+
6+
const int8_t lstm_stateful_batch_major_multibatch_cell_gate_input_weights[42] = {
7+
92, 100, -29, 127, 76, -74, 115, -81, 63, 4, 69, -81, 8, -25, 42, 99, 44, 101, 12, -25, 99,
8+
-70, -88, -41, 107, 65, 67, -31, 87, -54, -104, 95, 35, 21, 125, -87, 27, 78, -113, -114, 61, -101};
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
2+
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
3+
#pragma once
4+
5+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_MAJOR false
6+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE 2
7+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_STEPS 2
8+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE 6
9+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE 7
10+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_SCALE_POWER -15
11+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT 128
12+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT 4
13+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_CLIP 32767
14+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_MULTIPLIER 1073741824
15+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_SHIFT -14
16+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_MULTIPLIER 1073741824
17+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_SHIFT -14
18+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_MULTIPLIER 1993694592
19+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_SHIFT -21
20+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_MULTIPLIER 1143723136
21+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_SHIFT -3
22+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_MULTIPLIER 1164696448
23+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_SHIFT -3
24+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_MULTIPLIER 1134438656
25+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_SHIFT -3
26+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_MULTIPLIER 2044599040
27+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_SHIFT -4
28+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_MULTIPLIER 2130456576
29+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_SHIFT -3
30+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_MULTIPLIER 2096726656
31+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_SHIFT -3
32+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_MULTIPLIER 2025295488
33+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_SHIFT -3
34+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_MULTIPLIER 2146124928
35+
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_SHIFT -3
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
2+
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
3+
#pragma once
4+
#include <stdint.h>
5+
6+
const int32_t lstm_stateful_batch_major_multibatch_forget_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0};
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
2+
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
3+
#pragma once
4+
#include <stdint.h>
5+
6+
const int8_t lstm_stateful_batch_major_multibatch_forget_gate_hidden_weights[49] = {
7+
8, -96, 106, 66, -101, 22, -70, 86, -37, -1, -127, 52, 9, 79, 111, -94, 126,
8+
-21, -25, -79, 42, -57, -42, -3, 126, 51, -49, -28, -10, 50, -104, -48, -11, -78,
9+
36, 121, 6, -56, -24, -75, -104, -103, -119, -63, -69, -51, 11, -43, 19};

0 commit comments

Comments
 (0)