Skip to content

Commit 80dc694

Browse files
authored
docs: verify README examples (#18)
1 parent 416f2a2 commit 80dc694

2 files changed

Lines changed: 183 additions & 10 deletions

File tree

README.md

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Add to your `Cargo.toml`:
5151

5252
```toml
5353
[dependencies]
54-
rust-lstm = "0.6"
54+
rust-lstm = "0.8"
5555
```
5656

5757
### Basic Usage
@@ -78,6 +78,7 @@ fn main() {
7878
### Training Example
7979

8080
```rust
81+
use ndarray::Array2;
8182
use rust_lstm::{LSTMNetwork, create_basic_trainer, TrainingConfig};
8283

8384
fn main() {
@@ -94,15 +95,23 @@ fn main() {
9495
..Default::default()
9596
});
9697

97-
// Train (train_data is slice of (input_sequence, target_sequence) tuples)
98-
// Each input_sequence and target_sequence is Vec<Array2<f64>>
98+
// Train data is a slice of (input_sequence, target_sequence) tuples.
99+
// Each input_sequence and target_sequence is Vec<Array2<f64>>.
100+
let train_data = vec![(
101+
vec![Array2::from_shape_vec((1, 1), vec![0.0]).unwrap()],
102+
vec![Array2::from_shape_vec((10, 1), vec![0.0; 10]).unwrap()],
103+
)];
104+
// Keep validation data separate from training data in real applications.
105+
let validation_data = train_data.clone();
106+
99107
trainer.train(&train_data, Some(&validation_data));
100108
}
101109
```
102110

103111
### Early Stopping
104112

105113
```rust
114+
use ndarray::Array2;
106115
use rust_lstm::{
107116
LSTMNetwork, create_basic_trainer, TrainingConfig,
108117
EarlyStoppingConfig, EarlyStoppingMetric
@@ -128,6 +137,13 @@ fn main() {
128137
let mut trainer = create_basic_trainer(network, 0.001)
129138
.with_config(config);
130139

140+
let train_data = vec![(
141+
vec![Array2::from_shape_vec((1, 1), vec![0.0]).unwrap()],
142+
vec![Array2::from_shape_vec((10, 1), vec![0.0; 10]).unwrap()],
143+
)];
144+
// Keep validation data separate from training data in real applications.
145+
let validation_data = train_data.clone();
146+
131147
// Training will stop early if validation loss stops improving
132148
trainer.train(&train_data, Some(&validation_data));
133149
}
@@ -136,7 +152,16 @@ fn main() {
136152
### Bidirectional LSTM
137153

138154
```rust
139-
use rust_lstm::layers::bilstm_network::{BiLSTMNetwork, CombineMode};
155+
use ndarray::Array2;
156+
use rust_lstm::layers::bilstm_network::BiLSTMNetwork;
157+
158+
let input_size = 3;
159+
let hidden_size = 5;
160+
let num_layers = 1;
161+
let sequence = vec![
162+
Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(),
163+
Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(),
164+
];
140165

141166
// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
142167
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);
@@ -170,23 +195,38 @@ graph TD
170195
### GRU Networks
171196

172197
```rust
198+
use ndarray::Array2;
173199
use rust_lstm::models::gru_network::GRUNetwork;
174200

201+
let input_size = 3;
202+
let hidden_size = 5;
203+
let num_layers = 2;
204+
175205
// Create GRU network (alternative to LSTM)
176206
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
177207
.with_input_dropout(0.2, true)
178208
.with_recurrent_dropout(0.3, true);
179209

180-
// Forward pass
181-
let (output, _) = gru.forward(&input, &hidden_state);
210+
let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap();
211+
let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers];
212+
213+
// Forward pass returns one hidden state per layer
214+
let outputs = gru.forward(&input, &hidden_states);
215+
let output = outputs.last().unwrap();
182216
```
183217

184218
### Linear Layer
185219

186220
```rust
221+
use ndarray::Array2;
187222
use rust_lstm::layers::linear::LinearLayer;
188223
use rust_lstm::optimizers::Adam;
189224

225+
let hidden_size = 4;
226+
let num_classes = 3;
227+
let lstm_output = Array2::ones((hidden_size, 1));
228+
let grad_output = Array2::ones((num_classes, 1));
229+
190230
// Create linear layer for classification: hidden_size -> num_classes
191231
let mut classifier = LinearLayer::new(hidden_size, num_classes);
192232
let mut optimizer = Adam::new(0.001);
@@ -251,7 +291,7 @@ use rust_lstm::{
251291
let network = LSTMNetwork::new(1, 10, 2);
252292

253293
// Step decay: reduce LR by 50% every 10 epochs
254-
let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5);
294+
let mut trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5);
255295

256296
// OneCycle policy for modern deep learning
257297
let mut trainer = create_one_cycle_trainer(network.clone(), 0.1, 100);

tests/readme_examples_test.rs

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@
88

99
use ndarray::Array2;
1010
use rust_lstm::{
11+
layers::bilstm_network::BiLSTMNetwork,
1112
layers::dropout::{Dropout, Zoneout},
13+
layers::linear::LinearLayer,
1214
layers::peephole_lstm_cell::PeepholeLSTMCell,
1315
loss::{CrossEntropyLoss, MAELoss, MSELoss},
14-
optimizers::{Adam, RMSprop, SGD},
15-
training::create_basic_trainer,
16-
LSTMNetwork, LSTMTrainer, LayerDropoutConfig, TrainingConfig,
16+
models::gru_network::GRUNetwork,
17+
optimizers::{Adam, RMSprop, ScheduledOptimizer, SGD},
18+
schedulers::{CyclicalLR, LRScheduleVisualizer, PolynomialLR, WarmupScheduler},
19+
training::{
20+
create_basic_trainer, create_cosine_annealing_trainer, create_one_cycle_trainer,
21+
create_step_lr_trainer,
22+
},
23+
EarlyStoppingConfig, EarlyStoppingMetric, LSTMNetwork, LSTMTrainer, LayerDropoutConfig,
24+
TrainingConfig,
1725
};
1826

1927
#[test]
@@ -119,6 +127,131 @@ fn test_training_example() {
119127
assert_eq!(predictions[0].shape(), &[4, 1]);
120128
}
121129

130+
#[test]
131+
fn test_readme_early_stopping_example() {
132+
let network = LSTMNetwork::new(1, 4, 1);
133+
134+
// Configure early stopping
135+
let early_stopping = EarlyStoppingConfig {
136+
patience: 2,
137+
min_delta: 1e-4,
138+
restore_best_weights: true,
139+
monitor: EarlyStoppingMetric::ValidationLoss,
140+
};
141+
142+
let config = TrainingConfig {
143+
epochs: 2,
144+
early_stopping: Some(early_stopping),
145+
..Default::default()
146+
};
147+
148+
let mut trainer = create_basic_trainer(network, 0.001).with_config(config);
149+
let train_data = generate_test_data();
150+
let validation_data = generate_test_data();
151+
152+
trainer.train(&train_data, Some(&validation_data));
153+
154+
assert_eq!(trainer.config.early_stopping.as_ref().unwrap().patience, 2);
155+
}
156+
157+
#[test]
158+
fn test_readme_bilstm_example() {
159+
let input_size = 3;
160+
let hidden_size = 5;
161+
let num_layers = 1;
162+
163+
// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
164+
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);
165+
166+
// Process sequence with both past and future context
167+
let sequence = vec![
168+
Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(),
169+
Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(),
170+
];
171+
let outputs = bilstm.forward_sequence(&sequence);
172+
173+
assert_eq!(outputs.len(), sequence.len());
174+
for output in outputs {
175+
assert_eq!(output.shape(), &[2 * hidden_size, 1]);
176+
}
177+
}
178+
179+
#[test]
180+
fn test_readme_gru_example() {
181+
let input_size = 3;
182+
let hidden_size = 5;
183+
let num_layers = 2;
184+
185+
// Create GRU network (alternative to LSTM)
186+
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
187+
.with_input_dropout(0.2, true)
188+
.with_recurrent_dropout(0.3, true);
189+
190+
let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap();
191+
let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers];
192+
193+
// Forward pass returns one hidden state per layer
194+
let outputs = gru.forward(&input, &hidden_states);
195+
let output = outputs.last().unwrap();
196+
197+
assert_eq!(outputs.len(), num_layers);
198+
assert_eq!(output.shape(), &[hidden_size, 1]);
199+
}
200+
201+
#[test]
202+
fn test_readme_linear_layer_example() {
203+
let hidden_size = 4;
204+
let num_classes = 3;
205+
206+
// Create linear layer for classification: hidden_size -> num_classes
207+
let mut classifier = LinearLayer::new(hidden_size, num_classes);
208+
let mut optimizer = Adam::new(0.001);
209+
210+
// Forward pass
211+
let lstm_output = Array2::ones((hidden_size, 1));
212+
let logits = classifier.forward(&lstm_output);
213+
214+
// Backward pass
215+
let grad_output = Array2::ones((num_classes, 1));
216+
let (gradients, input_grad) = classifier.backward(&grad_output);
217+
classifier.update_parameters(&gradients, &mut optimizer, "classifier");
218+
219+
assert_eq!(logits.shape(), &[num_classes, 1]);
220+
assert_eq!(input_grad.shape(), &[hidden_size, 1]);
221+
}
222+
223+
#[test]
224+
fn test_readme_advanced_learning_rate_scheduling_example() {
225+
// Create a network
226+
let network = LSTMNetwork::new(1, 4, 1);
227+
228+
// Step decay: reduce LR by 50% every 10 epochs
229+
let mut step_trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5);
230+
231+
// OneCycle policy for modern deep learning
232+
let mut one_cycle_trainer = create_one_cycle_trainer(network.clone(), 0.1, 100);
233+
234+
// Cosine annealing with warm restarts
235+
let mut cosine_trainer = create_cosine_annealing_trainer(network.clone(), 0.01, 20, 1e-6);
236+
237+
// Advanced combinations - Warmup + Cyclical scheduling
238+
let base_scheduler = CyclicalLR::new(0.001, 0.01, 10);
239+
let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001);
240+
let mut optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);
241+
242+
// Polynomial decay with visualization
243+
let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001);
244+
let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.01, 100);
245+
246+
step_trainer.optimizer.step();
247+
one_cycle_trainer.optimizer.step();
248+
cosine_trainer.optimizer.step();
249+
optimizer.step();
250+
251+
assert_eq!(schedule.len(), 100);
252+
assert!(optimizer.get_current_lr() > 0.0);
253+
}
254+
122255
#[test]
123256
fn test_dropout_types_example() {
124257
// Standard dropout

0 commit comments

Comments
 (0)