1- use ndarray:: { Array2 , arr2} ;
2- use rust_lstm:: { LSTMNetwork , create_adam_batch_trainer, create_basic_trainer} ;
1+ #![ allow( clippy:: type_complexity) ]
2+ #![ allow( clippy:: field_reassign_with_default) ]
3+ #![ allow( clippy:: empty_line_after_doc_comments) ]
4+ #![ allow( clippy:: needless_range_loop) ]
5+ #![ allow( clippy:: assertions_on_constants) ]
6+ #![ allow( clippy:: absurd_extreme_comparisons) ]
7+ #![ allow( unused_comparisons) ]
8+
9+ use ndarray:: { arr2, Array2 } ;
10+ use rust_lstm:: { create_adam_batch_trainer, create_basic_trainer, LSTMNetwork } ;
311use std:: time:: Instant ;
412
513/// Generate synthetic sine wave sequences for batch processing demonstration
6- fn generate_batch_sine_data ( num_sequences : usize , sequence_length : usize , input_size : usize ) -> Vec < ( Vec < Array2 < f64 > > , Vec < Array2 < f64 > > ) > {
14+ fn generate_batch_sine_data (
15+ num_sequences : usize ,
16+ sequence_length : usize ,
17+ input_size : usize ,
18+ ) -> Vec < ( Vec < Array2 < f64 > > , Vec < Array2 < f64 > > ) > {
719 let mut data = Vec :: new ( ) ;
8-
20+
921 for i in 0 ..num_sequences {
1022 let mut inputs = Vec :: new ( ) ;
1123 let mut targets = Vec :: new ( ) ;
12-
24+
1325 let start = ( i as f64 ) * 0.05 ; // Different starting points for variety
1426 let frequency = 1.0 + ( i as f64 ) * 0.1 ; // Different frequencies
15-
27+
1628 for j in 0 ..sequence_length {
1729 let t = start + ( j as f64 ) * 0.1 ;
18-
30+
1931 // Create multi-dimensional input
2032 let mut input_vec = vec ! [ 0.0 ; input_size] ;
2133 input_vec[ 0 ] = ( t * frequency * 2.0 * std:: f64:: consts:: PI ) . sin ( ) ;
@@ -25,17 +37,17 @@ fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_
2537 if input_size > 2 {
2638 input_vec[ 2 ] = t. sin ( ) * t. cos ( ) ; // Some nonlinear combination
2739 }
28-
40+
2941 // Target is the next value in the sine sequence
3042 let target = ( ( t + 0.1 ) * frequency * 2.0 * std:: f64:: consts:: PI ) . sin ( ) ;
31-
43+
3244 inputs. push ( Array2 :: from_shape_vec ( ( input_size, 1 ) , input_vec) . unwrap ( ) ) ;
3345 targets. push ( arr2 ( & [ [ target] ] ) ) ;
3446 }
35-
47+
3648 data. push ( ( inputs, targets) ) ;
3749 }
38-
50+
3951 data
4052}
4153
@@ -48,77 +60,102 @@ fn benchmark_training_performance() {
4860 let hidden_size = 16 ;
4961 let num_layers = 2 ;
5062 let learning_rate = 0.001 ;
51-
63+
5264 // Generate training data
5365 let train_data = generate_batch_sine_data ( 100 , 10 , input_size) ;
5466 let val_data = generate_batch_sine_data ( 20 , 10 , input_size) ;
55-
56- println ! ( "Dataset: {} training sequences, {} validation sequences" , train_data. len( ) , val_data. len( ) ) ;
57- println ! ( "Network: {} -> {} hidden ({} layers)\n " , input_size, hidden_size, num_layers) ;
67+
68+ println ! (
69+ "Dataset: {} training sequences, {} validation sequences" ,
70+ train_data. len( ) ,
71+ val_data. len( )
72+ ) ;
73+ println ! (
74+ "Network: {} -> {} hidden ({} layers)\n " ,
75+ input_size, hidden_size, num_layers
76+ ) ;
5877
5978 // Test 1: Single sequence processing (traditional)
6079 println ! ( "Testing Traditional Single-Sequence Processing..." ) ;
6180 let network1 = LSTMNetwork :: new ( input_size, hidden_size, num_layers) ;
6281 let mut trainer1 = create_basic_trainer ( network1, learning_rate) ;
63-
82+
6483 // Configure for quick demo
6584 trainer1. config . epochs = 5 ;
6685 trainer1. config . print_every = 1 ;
67-
86+
6887 let start_time = Instant :: now ( ) ;
6988 trainer1. train ( & train_data, Some ( & val_data) ) ;
7089 let single_time = start_time. elapsed ( ) ;
71-
90+
7291 let final_metrics1 = trainer1. get_latest_metrics ( ) . unwrap ( ) ;
73- println ! ( "Single-sequence - Final loss: {:.6}, Time: {:.2}s\n " ,
74- final_metrics1. train_loss, single_time. as_secs_f64( ) ) ;
92+ println ! (
93+ "Single-sequence - Final loss: {:.6}, Time: {:.2}s\n " ,
94+ final_metrics1. train_loss,
95+ single_time. as_secs_f64( )
96+ ) ;
7597
7698 // Test 2: Batch processing with small batches
7799 println ! ( "Testing Batch Processing (batch size 8)..." ) ;
78100 let network2 = LSTMNetwork :: new ( input_size, hidden_size, num_layers) ;
79101 let mut trainer2 = create_adam_batch_trainer ( network2, learning_rate) ;
80-
102+
81103 trainer2. config . epochs = 5 ;
82104 trainer2. config . print_every = 1 ;
83-
105+
84106 let start_time = Instant :: now ( ) ;
85107 trainer2. train ( & train_data, Some ( & val_data) , 8 ) ; // Batch size 8
86108 let batch_time = start_time. elapsed ( ) ;
87-
109+
88110 let final_metrics2 = trainer2. get_latest_metrics ( ) . unwrap ( ) ;
89- println ! ( "Batch processing - Final loss: {:.6}, Time: {:.2}s\n " ,
90- final_metrics2. train_loss, batch_time. as_secs_f64( ) ) ;
111+ println ! (
112+ "Batch processing - Final loss: {:.6}, Time: {:.2}s\n " ,
113+ final_metrics2. train_loss,
114+ batch_time. as_secs_f64( )
115+ ) ;
91116
92117 // Test 3: Larger batch size
93118 println ! ( "Testing Larger Batch Processing (batch size 16)..." ) ;
94119 let network3 = LSTMNetwork :: new ( input_size, hidden_size, num_layers) ;
95120 let mut trainer3 = create_adam_batch_trainer ( network3, learning_rate) ;
96-
121+
97122 trainer3. config . epochs = 5 ;
98123 trainer3. config . print_every = 1 ;
99-
124+
100125 let start_time = Instant :: now ( ) ;
101126 trainer3. train ( & train_data, Some ( & val_data) , 16 ) ; // Batch size 16
102127 let large_batch_time = start_time. elapsed ( ) ;
103-
128+
104129 let final_metrics3 = trainer3. get_latest_metrics ( ) . unwrap ( ) ;
105- println ! ( "Large batch processing - Final loss: {:.6}, Time: {:.2}s\n " ,
106- final_metrics3. train_loss, large_batch_time. as_secs_f64( ) ) ;
130+ println ! (
131+ "Large batch processing - Final loss: {:.6}, Time: {:.2}s\n " ,
132+ final_metrics3. train_loss,
133+ large_batch_time. as_secs_f64( )
134+ ) ;
107135
108136 // Performance summary
109137 println ! ( "PERFORMANCE SUMMARY:" ) ;
110138 println ! ( "======================" ) ;
111- println ! ( "Single-sequence: {:.2}s (baseline)" , single_time. as_secs_f64( ) ) ;
112- println ! ( "Batch-8: {:.2}s ({:.1}x speedup)" ,
113- batch_time. as_secs_f64( ) ,
114- single_time. as_secs_f64( ) / batch_time. as_secs_f64( ) ) ;
115- println ! ( "Batch-16: {:.2}s ({:.1}x speedup)" ,
116- large_batch_time. as_secs_f64( ) ,
117- single_time. as_secs_f64( ) / large_batch_time. as_secs_f64( ) ) ;
118-
139+ println ! (
140+ "Single-sequence: {:.2}s (baseline)" ,
141+ single_time. as_secs_f64( )
142+ ) ;
143+ println ! (
144+ "Batch-8: {:.2}s ({:.1}x speedup)" ,
145+ batch_time. as_secs_f64( ) ,
146+ single_time. as_secs_f64( ) / batch_time. as_secs_f64( )
147+ ) ;
148+ println ! (
149+ "Batch-16: {:.2}s ({:.1}x speedup)" ,
150+ large_batch_time. as_secs_f64( ) ,
151+ single_time. as_secs_f64( ) / large_batch_time. as_secs_f64( )
152+ ) ;
153+
119154 if batch_time < single_time {
120- println ! ( "Batch processing achieved {:.1}x speedup!" ,
121- single_time. as_secs_f64( ) / batch_time. as_secs_f64( ) ) ;
155+ println ! (
156+ "Batch processing achieved {:.1}x speedup!" ,
157+ single_time. as_secs_f64( ) / batch_time. as_secs_f64( )
158+ ) ;
122159 } else {
123160 println ! ( "Note: For small datasets, overhead may dominate. Try larger datasets for better speedup." ) ;
124161 }
@@ -132,34 +169,45 @@ fn demonstrate_batch_prediction() {
132169 let input_size = 2 ;
133170 let hidden_size = 8 ;
134171 let num_layers = 1 ;
135-
172+
136173 // Create and train a simple model
137174 let network = LSTMNetwork :: new ( input_size, hidden_size, num_layers) ;
138175 let mut trainer = create_adam_batch_trainer ( network, 0.01 ) ;
139-
176+
140177 // Generate small training dataset
141178 let train_data = generate_batch_sine_data ( 20 , 5 , input_size) ;
142-
179+
143180 trainer. config . epochs = 10 ;
144181 trainer. config . print_every = 5 ;
145-
182+
146183 println ! ( "Training a small model for prediction demo..." ) ;
147184 trainer. train ( & train_data, None , 4 ) ;
148-
185+
149186 // Create test sequences for batch prediction
150187 let test_sequences = generate_batch_sine_data ( 3 , 3 , input_size) ;
151- let test_inputs: Vec < _ > = test_sequences. iter ( ) . map ( |( inputs, _) | inputs. clone ( ) ) . collect ( ) ;
152- let _test_targets: Vec < _ > = test_sequences. iter ( ) . map ( |( _, targets) | targets. clone ( ) ) . collect ( ) ;
153-
188+ let test_inputs: Vec < _ > = test_sequences
189+ . iter ( )
190+ . map ( |( inputs, _) | inputs. clone ( ) )
191+ . collect ( ) ;
192+ let _test_targets: Vec < _ > = test_sequences
193+ . iter ( )
194+ . map ( |( _, targets) | targets. clone ( ) )
195+ . collect ( ) ;
196+
154197 println ! ( "\n Performing batch predictions..." ) ;
155198 let predictions = trainer. predict_batch ( & test_inputs) ;
156-
199+
157200 println ! ( "Input sequences vs Predictions:" ) ;
158201 for ( i, ( inputs, preds) ) in test_inputs. iter ( ) . zip ( predictions. iter ( ) ) . enumerate ( ) {
159202 println ! ( "Sequence {}:" , i + 1 ) ;
160203 for ( j, ( input, pred) ) in inputs. iter ( ) . zip ( preds. iter ( ) ) . enumerate ( ) {
161- println ! ( " Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}" ,
162- j + 1 , input[ [ 0 , 0 ] ] , input[ [ 1 , 0 ] ] , pred[ [ 0 , 0 ] ] ) ;
204+ println ! (
205+ " Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}" ,
206+ j + 1 ,
207+ input[ [ 0 , 0 ] ] ,
208+ input[ [ 1 , 0 ] ] ,
209+ pred[ [ 0 , 0 ] ]
210+ ) ;
163211 }
164212 println ! ( ) ;
165213 }
@@ -171,48 +219,54 @@ fn demonstrate_scalability() {
171219 println ! ( "=========================\n " ) ;
172220
173221 let test_sizes = vec ! [
174- ( 50 , 4 ) , // Small: 50 sequences, batch size 4
175- ( 200 , 8 ) , // Medium: 200 sequences, batch size 8
176- ( 500 , 16 ) , // Large: 500 sequences, batch size 16
222+ ( 50 , 4 ) , // Small: 50 sequences, batch size 4
223+ ( 200 , 8 ) , // Medium: 200 sequences, batch size 8
224+ ( 500 , 16 ) , // Large: 500 sequences, batch size 16
177225 ] ;
178226
179227 for ( num_sequences, batch_size) in test_sizes {
180- println ! ( "Testing with {} sequences, batch size {}..." , num_sequences, batch_size) ;
181-
228+ println ! (
229+ "Testing with {} sequences, batch size {}..." ,
230+ num_sequences, batch_size
231+ ) ;
232+
182233 let train_data = generate_batch_sine_data ( num_sequences, 8 , 2 ) ;
183234 let network = LSTMNetwork :: new ( 2 , 12 , 1 ) ;
184235 let mut trainer = create_adam_batch_trainer ( network, 0.001 ) ;
185-
236+
186237 trainer. config . epochs = 3 ;
187238 trainer. config . print_every = 1 ;
188-
239+
189240 let start_time = Instant :: now ( ) ;
190241 trainer. train ( & train_data, None , batch_size) ;
191242 let training_time = start_time. elapsed ( ) ;
192-
243+
193244 let final_loss = trainer. get_latest_metrics ( ) . unwrap ( ) . train_loss ;
194- println ! ( " Completed in {:.2}s, final loss: {:.6}\n " ,
195- training_time. as_secs_f64( ) , final_loss) ;
245+ println ! (
246+ " Completed in {:.2}s, final loss: {:.6}\n " ,
247+ training_time. as_secs_f64( ) ,
248+ final_loss
249+ ) ;
196250 }
197-
251+
198252 println ! ( "All scalability tests completed successfully!" ) ;
199253 println ! ( "Batch processing handles varying dataset sizes efficiently." ) ;
200254}
201255
202256fn main ( ) {
203257 println ! ( "RUST-LSTM BATCH PROCESSING DEMONSTRATION" ) ;
204258 println ! ( "=========================================\n " ) ;
205-
259+
206260 println ! ( "This example demonstrates the new batch processing capabilities:" ) ;
207261 println ! ( "- Simultaneous processing of multiple sequences" ) ;
208262 println ! ( "- Performance improvements over single-sequence training" ) ;
209263 println ! ( "- Batch prediction capabilities" ) ;
210264 println ! ( "- Scalability with different batch sizes\n " ) ;
211265
212266 benchmark_training_performance ( ) ;
213- demonstrate_batch_prediction ( ) ;
267+ demonstrate_batch_prediction ( ) ;
214268 demonstrate_scalability ( ) ;
215-
269+
216270 println ! ( "\n BATCH PROCESSING DEMONSTRATION COMPLETED!" ) ;
217271 println ! ( "==========================================" ) ;
218272 println ! ( "Key Benefits Demonstrated:" ) ;
@@ -221,10 +275,10 @@ fn main() {
221275 println ! ( "- Scalable to different dataset sizes" ) ;
222276 println ! ( "- Easy-to-use batch training API" ) ;
223277 println ! ( "- Backward compatibility with existing code" ) ;
224-
278+
225279 println ! ( "\n Next Steps:" ) ;
226280 println ! ( "- Try batch processing with your own datasets" ) ;
227281 println ! ( "- Experiment with different batch sizes" ) ;
228282 println ! ( "- Compare performance with single-sequence training" ) ;
229283 println ! ( "- Use batch processing for faster model development" ) ;
230- }
284+ }
0 commit comments