| title | LSTMs: Long Short-Term Memory | |||||
|---|---|---|---|---|---|---|
| sidebar_label | LSTM | |||||
| description | A deep dive into the LSTM architecture, cell states, and the gating mechanisms that prevent vanishing gradients. | |||||
| tags |
|
Standard RNNs have a major weakness: they have a very short memory. Because of the Vanishing Gradient problem, they struggle to connect information that is far apart in a sequence.
LSTMs, introduced by Hochreiter & Schmidhuber, were specifically designed to overcome this. They introduce a "Cell State" (a long-term memory track) and a series of "Gates" that control what information is kept and what is discarded.
The "Secret Sauce" of the LSTM is the Cell State (
An LSTM uses three specialized gates to protect and control the cell state. Each gate is composed of a Sigmoid neural net layer and a point-wise multiplication operation.
This gate decides what information we are going to throw away from the cell state.
-
Input:
$h_{t-1}$ (previous hidden state) and$x_t$ (current input). - Output: A number between 0 (completely forget) and 1 (completely keep).
This gate decides which new information we’re going to store in the cell state. It works in tandem with a tanh layer that creates a vector of new candidate values (
This gate decides what our next hidden state (
The flow within a single LSTM cell is highly structured. The "Cell State" acts as the horizontal spine, while gates regulate the vertical flow of information.
graph LR
subgraph LSTM_Cell [LSTM Cell at Time $$\ t$$]
direction LR
X(($$x_t$$)) --> ForgetGate{Forget Gate}
X --> InputGate{Input Gate}
X --> OutputGate{Output Gate}
H_prev(($$h_t-1$$)) --> ForgetGate
H_prev --> InputGate
H_prev --> OutputGate
C_prev(($$C_t-1$$)) --> Forget_Mult(($$X$$))
ForgetGate -- "$$f_t$$" --> Forget_Mult
InputGate -- "$$i_t$$" --> Input_Mult(($$X$$))
X --> Candidate[$$\tan h$$]
Candidate --> Input_Mult
Forget_Mult --> State_Add((+))
Input_Mult --> State_Add
State_Add --> C_out(($$C_t$$))
C_out --> Tanh_Final[$$\tan h$$]
OutputGate -- "$$o_t$$" --> Output_Mult(($$X$$))
Tanh_Final --> Output_Mult
Output_Mult --> H_out(($$h_t$$))
end
| Feature | Standard RNN | LSTM |
|---|---|---|
| Architecture | Simple (Single Tanh layer) | Complex (4 interacting layers) |
| Memory | Short-term only | Long and Short-term |
| Gradient Flow | Suffers from Vanishing Gradient | Resists Vanishing Gradient via the Cell State |
| Complexity | Low | High (More parameters to train) |
In PyTorch, the nn.LSTM module automatically handles the complex gating logic and cell state management.
import torch
import torch.nn as nn
# input_size=10, hidden_size=20, num_layers=1
lstm = nn.LSTM(10, 20, batch_first=True)
# Input shape: (batch_size, seq_len, input_size)
input_seq = torch.randn(1, 5, 10)
# Initial Hidden State (h0) and Cell State (c0)
h0 = torch.zeros(1, 1, 20)
c0 = torch.zeros(1, 1, 20)
# Forward pass returns output and a tuple (hn, cn)
output, (hn, cn) = lstm(input_seq, (h0, c0))
print(f"Output shape: {output.shape}") # [1, 5, 20]
print(f"Final Cell State shape: {cn.shape}") # [1, 1, 20]- Colah's Blog: Understanding LSTM Networks (Essential Reading)
- Stanford CS224N: RNNs and LSTMs
LSTMs are powerful but computationally expensive because of their three gates. Is there a way to simplify this without losing the memory benefits?