Skip to content

Commit 416f2a2

Browse files
authored
chore: restore Rust quality baseline (#17)
1 parent 1037dee commit 416f2a2

40 files changed

Lines changed: 2978 additions & 1797 deletions

examples/advanced_lr_scheduling.rs

Lines changed: 135 additions & 116 deletions
Large diffs are not rendered by default.

examples/basic_usage.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#![allow(clippy::type_complexity)]
2+
#![allow(clippy::field_reassign_with_default)]
3+
#![allow(clippy::empty_line_after_doc_comments)]
4+
#![allow(clippy::needless_range_loop)]
5+
#![allow(clippy::assertions_on_constants)]
6+
#![allow(clippy::absurd_extreme_comparisons)]
7+
#![allow(unused_comparisons)]
8+
19
use ndarray::Array2;
210
use rust_lstm::models::lstm_network::LSTMNetwork;
311

Lines changed: 121 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
1-
use ndarray::{Array2, arr2};
2-
use rust_lstm::{LSTMNetwork, create_adam_batch_trainer, create_basic_trainer};
1+
#![allow(clippy::type_complexity)]
2+
#![allow(clippy::field_reassign_with_default)]
3+
#![allow(clippy::empty_line_after_doc_comments)]
4+
#![allow(clippy::needless_range_loop)]
5+
#![allow(clippy::assertions_on_constants)]
6+
#![allow(clippy::absurd_extreme_comparisons)]
7+
#![allow(unused_comparisons)]
8+
9+
use ndarray::{arr2, Array2};
10+
use rust_lstm::{create_adam_batch_trainer, create_basic_trainer, LSTMNetwork};
311
use std::time::Instant;
412

513
/// Generate synthetic sine wave sequences for batch processing demonstration
6-
fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_size: usize) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
14+
fn generate_batch_sine_data(
15+
num_sequences: usize,
16+
sequence_length: usize,
17+
input_size: usize,
18+
) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
719
let mut data = Vec::new();
8-
20+
921
for i in 0..num_sequences {
1022
let mut inputs = Vec::new();
1123
let mut targets = Vec::new();
12-
24+
1325
let start = (i as f64) * 0.05; // Different starting points for variety
1426
let frequency = 1.0 + (i as f64) * 0.1; // Different frequencies
15-
27+
1628
for j in 0..sequence_length {
1729
let t = start + (j as f64) * 0.1;
18-
30+
1931
// Create multi-dimensional input
2032
let mut input_vec = vec![0.0; input_size];
2133
input_vec[0] = (t * frequency * 2.0 * std::f64::consts::PI).sin();
@@ -25,17 +37,17 @@ fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_
2537
if input_size > 2 {
2638
input_vec[2] = t.sin() * t.cos(); // Some nonlinear combination
2739
}
28-
40+
2941
// Target is the next value in the sine sequence
3042
let target = ((t + 0.1) * frequency * 2.0 * std::f64::consts::PI).sin();
31-
43+
3244
inputs.push(Array2::from_shape_vec((input_size, 1), input_vec).unwrap());
3345
targets.push(arr2(&[[target]]));
3446
}
35-
47+
3648
data.push((inputs, targets));
3749
}
38-
50+
3951
data
4052
}
4153

@@ -48,77 +60,102 @@ fn benchmark_training_performance() {
4860
let hidden_size = 16;
4961
let num_layers = 2;
5062
let learning_rate = 0.001;
51-
63+
5264
// Generate training data
5365
let train_data = generate_batch_sine_data(100, 10, input_size);
5466
let val_data = generate_batch_sine_data(20, 10, input_size);
55-
56-
println!("Dataset: {} training sequences, {} validation sequences", train_data.len(), val_data.len());
57-
println!("Network: {} -> {} hidden ({} layers)\n", input_size, hidden_size, num_layers);
67+
68+
println!(
69+
"Dataset: {} training sequences, {} validation sequences",
70+
train_data.len(),
71+
val_data.len()
72+
);
73+
println!(
74+
"Network: {} -> {} hidden ({} layers)\n",
75+
input_size, hidden_size, num_layers
76+
);
5877

5978
// Test 1: Single sequence processing (traditional)
6079
println!("Testing Traditional Single-Sequence Processing...");
6180
let network1 = LSTMNetwork::new(input_size, hidden_size, num_layers);
6281
let mut trainer1 = create_basic_trainer(network1, learning_rate);
63-
82+
6483
// Configure for quick demo
6584
trainer1.config.epochs = 5;
6685
trainer1.config.print_every = 1;
67-
86+
6887
let start_time = Instant::now();
6988
trainer1.train(&train_data, Some(&val_data));
7089
let single_time = start_time.elapsed();
71-
90+
7291
let final_metrics1 = trainer1.get_latest_metrics().unwrap();
73-
println!("Single-sequence - Final loss: {:.6}, Time: {:.2}s\n",
74-
final_metrics1.train_loss, single_time.as_secs_f64());
92+
println!(
93+
"Single-sequence - Final loss: {:.6}, Time: {:.2}s\n",
94+
final_metrics1.train_loss,
95+
single_time.as_secs_f64()
96+
);
7597

7698
// Test 2: Batch processing with small batches
7799
println!("Testing Batch Processing (batch size 8)...");
78100
let network2 = LSTMNetwork::new(input_size, hidden_size, num_layers);
79101
let mut trainer2 = create_adam_batch_trainer(network2, learning_rate);
80-
102+
81103
trainer2.config.epochs = 5;
82104
trainer2.config.print_every = 1;
83-
105+
84106
let start_time = Instant::now();
85107
trainer2.train(&train_data, Some(&val_data), 8); // Batch size 8
86108
let batch_time = start_time.elapsed();
87-
109+
88110
let final_metrics2 = trainer2.get_latest_metrics().unwrap();
89-
println!("Batch processing - Final loss: {:.6}, Time: {:.2}s\n",
90-
final_metrics2.train_loss, batch_time.as_secs_f64());
111+
println!(
112+
"Batch processing - Final loss: {:.6}, Time: {:.2}s\n",
113+
final_metrics2.train_loss,
114+
batch_time.as_secs_f64()
115+
);
91116

92117
// Test 3: Larger batch size
93118
println!("Testing Larger Batch Processing (batch size 16)...");
94119
let network3 = LSTMNetwork::new(input_size, hidden_size, num_layers);
95120
let mut trainer3 = create_adam_batch_trainer(network3, learning_rate);
96-
121+
97122
trainer3.config.epochs = 5;
98123
trainer3.config.print_every = 1;
99-
124+
100125
let start_time = Instant::now();
101126
trainer3.train(&train_data, Some(&val_data), 16); // Batch size 16
102127
let large_batch_time = start_time.elapsed();
103-
128+
104129
let final_metrics3 = trainer3.get_latest_metrics().unwrap();
105-
println!("Large batch processing - Final loss: {:.6}, Time: {:.2}s\n",
106-
final_metrics3.train_loss, large_batch_time.as_secs_f64());
130+
println!(
131+
"Large batch processing - Final loss: {:.6}, Time: {:.2}s\n",
132+
final_metrics3.train_loss,
133+
large_batch_time.as_secs_f64()
134+
);
107135

108136
// Performance summary
109137
println!("PERFORMANCE SUMMARY:");
110138
println!("======================");
111-
println!("Single-sequence: {:.2}s (baseline)", single_time.as_secs_f64());
112-
println!("Batch-8: {:.2}s ({:.1}x speedup)",
113-
batch_time.as_secs_f64(),
114-
single_time.as_secs_f64() / batch_time.as_secs_f64());
115-
println!("Batch-16: {:.2}s ({:.1}x speedup)",
116-
large_batch_time.as_secs_f64(),
117-
single_time.as_secs_f64() / large_batch_time.as_secs_f64());
118-
139+
println!(
140+
"Single-sequence: {:.2}s (baseline)",
141+
single_time.as_secs_f64()
142+
);
143+
println!(
144+
"Batch-8: {:.2}s ({:.1}x speedup)",
145+
batch_time.as_secs_f64(),
146+
single_time.as_secs_f64() / batch_time.as_secs_f64()
147+
);
148+
println!(
149+
"Batch-16: {:.2}s ({:.1}x speedup)",
150+
large_batch_time.as_secs_f64(),
151+
single_time.as_secs_f64() / large_batch_time.as_secs_f64()
152+
);
153+
119154
if batch_time < single_time {
120-
println!("Batch processing achieved {:.1}x speedup!",
121-
single_time.as_secs_f64() / batch_time.as_secs_f64());
155+
println!(
156+
"Batch processing achieved {:.1}x speedup!",
157+
single_time.as_secs_f64() / batch_time.as_secs_f64()
158+
);
122159
} else {
123160
println!("Note: For small datasets, overhead may dominate. Try larger datasets for better speedup.");
124161
}
@@ -132,34 +169,45 @@ fn demonstrate_batch_prediction() {
132169
let input_size = 2;
133170
let hidden_size = 8;
134171
let num_layers = 1;
135-
172+
136173
// Create and train a simple model
137174
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
138175
let mut trainer = create_adam_batch_trainer(network, 0.01);
139-
176+
140177
// Generate small training dataset
141178
let train_data = generate_batch_sine_data(20, 5, input_size);
142-
179+
143180
trainer.config.epochs = 10;
144181
trainer.config.print_every = 5;
145-
182+
146183
println!("Training a small model for prediction demo...");
147184
trainer.train(&train_data, None, 4);
148-
185+
149186
// Create test sequences for batch prediction
150187
let test_sequences = generate_batch_sine_data(3, 3, input_size);
151-
let test_inputs: Vec<_> = test_sequences.iter().map(|(inputs, _)| inputs.clone()).collect();
152-
let _test_targets: Vec<_> = test_sequences.iter().map(|(_, targets)| targets.clone()).collect();
153-
188+
let test_inputs: Vec<_> = test_sequences
189+
.iter()
190+
.map(|(inputs, _)| inputs.clone())
191+
.collect();
192+
let _test_targets: Vec<_> = test_sequences
193+
.iter()
194+
.map(|(_, targets)| targets.clone())
195+
.collect();
196+
154197
println!("\nPerforming batch predictions...");
155198
let predictions = trainer.predict_batch(&test_inputs);
156-
199+
157200
println!("Input sequences vs Predictions:");
158201
for (i, (inputs, preds)) in test_inputs.iter().zip(predictions.iter()).enumerate() {
159202
println!("Sequence {}:", i + 1);
160203
for (j, (input, pred)) in inputs.iter().zip(preds.iter()).enumerate() {
161-
println!(" Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}",
162-
j + 1, input[[0, 0]], input[[1, 0]], pred[[0, 0]]);
204+
println!(
205+
" Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}",
206+
j + 1,
207+
input[[0, 0]],
208+
input[[1, 0]],
209+
pred[[0, 0]]
210+
);
163211
}
164212
println!();
165213
}
@@ -171,48 +219,54 @@ fn demonstrate_scalability() {
171219
println!("=========================\n");
172220

173221
let test_sizes = vec![
174-
(50, 4), // Small: 50 sequences, batch size 4
175-
(200, 8), // Medium: 200 sequences, batch size 8
176-
(500, 16), // Large: 500 sequences, batch size 16
222+
(50, 4), // Small: 50 sequences, batch size 4
223+
(200, 8), // Medium: 200 sequences, batch size 8
224+
(500, 16), // Large: 500 sequences, batch size 16
177225
];
178226

179227
for (num_sequences, batch_size) in test_sizes {
180-
println!("Testing with {} sequences, batch size {}...", num_sequences, batch_size);
181-
228+
println!(
229+
"Testing with {} sequences, batch size {}...",
230+
num_sequences, batch_size
231+
);
232+
182233
let train_data = generate_batch_sine_data(num_sequences, 8, 2);
183234
let network = LSTMNetwork::new(2, 12, 1);
184235
let mut trainer = create_adam_batch_trainer(network, 0.001);
185-
236+
186237
trainer.config.epochs = 3;
187238
trainer.config.print_every = 1;
188-
239+
189240
let start_time = Instant::now();
190241
trainer.train(&train_data, None, batch_size);
191242
let training_time = start_time.elapsed();
192-
243+
193244
let final_loss = trainer.get_latest_metrics().unwrap().train_loss;
194-
println!(" Completed in {:.2}s, final loss: {:.6}\n",
195-
training_time.as_secs_f64(), final_loss);
245+
println!(
246+
" Completed in {:.2}s, final loss: {:.6}\n",
247+
training_time.as_secs_f64(),
248+
final_loss
249+
);
196250
}
197-
251+
198252
println!("All scalability tests completed successfully!");
199253
println!("Batch processing handles varying dataset sizes efficiently.");
200254
}
201255

202256
fn main() {
203257
println!("RUST-LSTM BATCH PROCESSING DEMONSTRATION");
204258
println!("=========================================\n");
205-
259+
206260
println!("This example demonstrates the new batch processing capabilities:");
207261
println!("- Simultaneous processing of multiple sequences");
208262
println!("- Performance improvements over single-sequence training");
209263
println!("- Batch prediction capabilities");
210264
println!("- Scalability with different batch sizes\n");
211265

212266
benchmark_training_performance();
213-
demonstrate_batch_prediction();
267+
demonstrate_batch_prediction();
214268
demonstrate_scalability();
215-
269+
216270
println!("\nBATCH PROCESSING DEMONSTRATION COMPLETED!");
217271
println!("==========================================");
218272
println!("Key Benefits Demonstrated:");
@@ -221,10 +275,10 @@ fn main() {
221275
println!("- Scalable to different dataset sizes");
222276
println!("- Easy-to-use batch training API");
223277
println!("- Backward compatibility with existing code");
224-
278+
225279
println!("\nNext Steps:");
226280
println!("- Try batch processing with your own datasets");
227281
println!("- Experiment with different batch sizes");
228282
println!("- Compare performance with single-sequence training");
229283
println!("- Use batch processing for faster model development");
230-
}
284+
}

0 commit comments

Comments
 (0)