Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 135 additions & 116 deletions examples/advanced_lr_scheduling.rs

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions examples/basic_usage.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#![allow(clippy::type_complexity)]
#![allow(clippy::field_reassign_with_default)]
#![allow(clippy::empty_line_after_doc_comments)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::assertions_on_constants)]
#![allow(clippy::absurd_extreme_comparisons)]
#![allow(unused_comparisons)]

use ndarray::Array2;
use rust_lstm::models::lstm_network::LSTMNetwork;

Expand Down
188 changes: 121 additions & 67 deletions examples/batch_processing_example.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
use ndarray::{Array2, arr2};
use rust_lstm::{LSTMNetwork, create_adam_batch_trainer, create_basic_trainer};
#![allow(clippy::type_complexity)]
#![allow(clippy::field_reassign_with_default)]
#![allow(clippy::empty_line_after_doc_comments)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::assertions_on_constants)]
#![allow(clippy::absurd_extreme_comparisons)]
#![allow(unused_comparisons)]

use ndarray::{arr2, Array2};
use rust_lstm::{create_adam_batch_trainer, create_basic_trainer, LSTMNetwork};
use std::time::Instant;

/// Generate synthetic sine wave sequences for batch processing demonstration
fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_size: usize) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
fn generate_batch_sine_data(
num_sequences: usize,
sequence_length: usize,
input_size: usize,
) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
let mut data = Vec::new();

for i in 0..num_sequences {
let mut inputs = Vec::new();
let mut targets = Vec::new();

let start = (i as f64) * 0.05; // Different starting points for variety
let frequency = 1.0 + (i as f64) * 0.1; // Different frequencies

for j in 0..sequence_length {
let t = start + (j as f64) * 0.1;

// Create multi-dimensional input
let mut input_vec = vec![0.0; input_size];
input_vec[0] = (t * frequency * 2.0 * std::f64::consts::PI).sin();
Expand All @@ -25,17 +37,17 @@ fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_
if input_size > 2 {
input_vec[2] = t.sin() * t.cos(); // Some nonlinear combination
}

// Target is the next value in the sine sequence
let target = ((t + 0.1) * frequency * 2.0 * std::f64::consts::PI).sin();

inputs.push(Array2::from_shape_vec((input_size, 1), input_vec).unwrap());
targets.push(arr2(&[[target]]));
}

data.push((inputs, targets));
}

data
}

Expand All @@ -48,77 +60,102 @@ fn benchmark_training_performance() {
let hidden_size = 16;
let num_layers = 2;
let learning_rate = 0.001;

// Generate training data
let train_data = generate_batch_sine_data(100, 10, input_size);
let val_data = generate_batch_sine_data(20, 10, input_size);

println!("Dataset: {} training sequences, {} validation sequences", train_data.len(), val_data.len());
println!("Network: {} -> {} hidden ({} layers)\n", input_size, hidden_size, num_layers);

println!(
"Dataset: {} training sequences, {} validation sequences",
train_data.len(),
val_data.len()
);
println!(
"Network: {} -> {} hidden ({} layers)\n",
input_size, hidden_size, num_layers
);

// Test 1: Single sequence processing (traditional)
println!("Testing Traditional Single-Sequence Processing...");
let network1 = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer1 = create_basic_trainer(network1, learning_rate);

// Configure for quick demo
trainer1.config.epochs = 5;
trainer1.config.print_every = 1;

let start_time = Instant::now();
trainer1.train(&train_data, Some(&val_data));
let single_time = start_time.elapsed();

let final_metrics1 = trainer1.get_latest_metrics().unwrap();
println!("Single-sequence - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics1.train_loss, single_time.as_secs_f64());
println!(
"Single-sequence - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics1.train_loss,
single_time.as_secs_f64()
);

// Test 2: Batch processing with small batches
println!("Testing Batch Processing (batch size 8)...");
let network2 = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer2 = create_adam_batch_trainer(network2, learning_rate);

trainer2.config.epochs = 5;
trainer2.config.print_every = 1;

let start_time = Instant::now();
trainer2.train(&train_data, Some(&val_data), 8); // Batch size 8
let batch_time = start_time.elapsed();

let final_metrics2 = trainer2.get_latest_metrics().unwrap();
println!("Batch processing - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics2.train_loss, batch_time.as_secs_f64());
println!(
"Batch processing - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics2.train_loss,
batch_time.as_secs_f64()
);

// Test 3: Larger batch size
println!("Testing Larger Batch Processing (batch size 16)...");
let network3 = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer3 = create_adam_batch_trainer(network3, learning_rate);

trainer3.config.epochs = 5;
trainer3.config.print_every = 1;

let start_time = Instant::now();
trainer3.train(&train_data, Some(&val_data), 16); // Batch size 16
let large_batch_time = start_time.elapsed();

let final_metrics3 = trainer3.get_latest_metrics().unwrap();
println!("Large batch processing - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics3.train_loss, large_batch_time.as_secs_f64());
println!(
"Large batch processing - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics3.train_loss,
large_batch_time.as_secs_f64()
);

// Performance summary
println!("PERFORMANCE SUMMARY:");
println!("======================");
println!("Single-sequence: {:.2}s (baseline)", single_time.as_secs_f64());
println!("Batch-8: {:.2}s ({:.1}x speedup)",
batch_time.as_secs_f64(),
single_time.as_secs_f64() / batch_time.as_secs_f64());
println!("Batch-16: {:.2}s ({:.1}x speedup)",
large_batch_time.as_secs_f64(),
single_time.as_secs_f64() / large_batch_time.as_secs_f64());

println!(
"Single-sequence: {:.2}s (baseline)",
single_time.as_secs_f64()
);
println!(
"Batch-8: {:.2}s ({:.1}x speedup)",
batch_time.as_secs_f64(),
single_time.as_secs_f64() / batch_time.as_secs_f64()
);
println!(
"Batch-16: {:.2}s ({:.1}x speedup)",
large_batch_time.as_secs_f64(),
single_time.as_secs_f64() / large_batch_time.as_secs_f64()
);

if batch_time < single_time {
println!("Batch processing achieved {:.1}x speedup!",
single_time.as_secs_f64() / batch_time.as_secs_f64());
println!(
"Batch processing achieved {:.1}x speedup!",
single_time.as_secs_f64() / batch_time.as_secs_f64()
);
} else {
println!("Note: For small datasets, overhead may dominate. Try larger datasets for better speedup.");
}
Expand All @@ -132,34 +169,45 @@ fn demonstrate_batch_prediction() {
let input_size = 2;
let hidden_size = 8;
let num_layers = 1;

// Create and train a simple model
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer = create_adam_batch_trainer(network, 0.01);

// Generate small training dataset
let train_data = generate_batch_sine_data(20, 5, input_size);

trainer.config.epochs = 10;
trainer.config.print_every = 5;

println!("Training a small model for prediction demo...");
trainer.train(&train_data, None, 4);

// Create test sequences for batch prediction
let test_sequences = generate_batch_sine_data(3, 3, input_size);
let test_inputs: Vec<_> = test_sequences.iter().map(|(inputs, _)| inputs.clone()).collect();
let _test_targets: Vec<_> = test_sequences.iter().map(|(_, targets)| targets.clone()).collect();

let test_inputs: Vec<_> = test_sequences
.iter()
.map(|(inputs, _)| inputs.clone())
.collect();
let _test_targets: Vec<_> = test_sequences
.iter()
.map(|(_, targets)| targets.clone())
.collect();

println!("\nPerforming batch predictions...");
let predictions = trainer.predict_batch(&test_inputs);

println!("Input sequences vs Predictions:");
for (i, (inputs, preds)) in test_inputs.iter().zip(predictions.iter()).enumerate() {
println!("Sequence {}:", i + 1);
for (j, (input, pred)) in inputs.iter().zip(preds.iter()).enumerate() {
println!(" Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}",
j + 1, input[[0, 0]], input[[1, 0]], pred[[0, 0]]);
println!(
" Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}",
j + 1,
input[[0, 0]],
input[[1, 0]],
pred[[0, 0]]
);
}
println!();
}
Expand All @@ -171,48 +219,54 @@ fn demonstrate_scalability() {
println!("=========================\n");

let test_sizes = vec![
(50, 4), // Small: 50 sequences, batch size 4
(200, 8), // Medium: 200 sequences, batch size 8
(500, 16), // Large: 500 sequences, batch size 16
(50, 4), // Small: 50 sequences, batch size 4
(200, 8), // Medium: 200 sequences, batch size 8
(500, 16), // Large: 500 sequences, batch size 16
];

for (num_sequences, batch_size) in test_sizes {
println!("Testing with {} sequences, batch size {}...", num_sequences, batch_size);

println!(
"Testing with {} sequences, batch size {}...",
num_sequences, batch_size
);

let train_data = generate_batch_sine_data(num_sequences, 8, 2);
let network = LSTMNetwork::new(2, 12, 1);
let mut trainer = create_adam_batch_trainer(network, 0.001);

trainer.config.epochs = 3;
trainer.config.print_every = 1;

let start_time = Instant::now();
trainer.train(&train_data, None, batch_size);
let training_time = start_time.elapsed();

let final_loss = trainer.get_latest_metrics().unwrap().train_loss;
println!(" Completed in {:.2}s, final loss: {:.6}\n",
training_time.as_secs_f64(), final_loss);
println!(
" Completed in {:.2}s, final loss: {:.6}\n",
training_time.as_secs_f64(),
final_loss
);
}

println!("All scalability tests completed successfully!");
println!("Batch processing handles varying dataset sizes efficiently.");
}

fn main() {
println!("RUST-LSTM BATCH PROCESSING DEMONSTRATION");
println!("=========================================\n");

println!("This example demonstrates the new batch processing capabilities:");
println!("- Simultaneous processing of multiple sequences");
println!("- Performance improvements over single-sequence training");
println!("- Batch prediction capabilities");
println!("- Scalability with different batch sizes\n");

benchmark_training_performance();
demonstrate_batch_prediction();
demonstrate_batch_prediction();
demonstrate_scalability();

println!("\nBATCH PROCESSING DEMONSTRATION COMPLETED!");
println!("==========================================");
println!("Key Benefits Demonstrated:");
Expand All @@ -221,10 +275,10 @@ fn main() {
println!("- Scalable to different dataset sizes");
println!("- Easy-to-use batch training API");
println!("- Backward compatibility with existing code");

println!("\nNext Steps:");
println!("- Try batch processing with your own datasets");
println!("- Experiment with different batch sizes");
println!("- Compare performance with single-sequence training");
println!("- Use batch processing for faster model development");
}
}
Loading
Loading