@@ -13,15 +13,26 @@ use rust_lstm::models::lstm_network::LSTMNetwork;
1313use rust_lstm:: optimizers:: Adam ;
1414use rust_lstm:: training:: { LSTMTrainer , TrainingConfig } ;
1515
16- fn demo_training_config ( ) -> TrainingConfig {
16+ pub ( crate ) const DEMO_EPOCHS : usize = 2 ;
17+ pub ( crate ) const DEMO_PRINT_EVERY : usize = 1 ;
18+ pub ( crate ) const DEMO_STOCK_DAYS : usize = 80 ;
19+ pub ( crate ) const DEMO_SEQUENCE_LENGTH : usize = 5 ;
20+ pub ( crate ) const DEMO_HIDDEN_SIZE : usize = 8 ;
21+ pub ( crate ) const DEMO_RNG_SEED : u64 = 42 ;
22+
23+ pub ( crate ) fn demo_training_config ( ) -> TrainingConfig {
1724 TrainingConfig {
18- epochs : 2 ,
19- print_every : 1 ,
25+ epochs : DEMO_EPOCHS ,
26+ print_every : DEMO_PRINT_EVERY ,
2027 log_lr_changes : false ,
2128 ..TrainingConfig :: default ( )
2229 }
2330}
2431
32+ pub ( crate ) fn demo_stock_trainer ( network : LSTMNetwork ) -> LSTMTrainer < MSELoss , Adam > {
33+ LSTMTrainer :: new ( network, MSELoss , Adam :: new ( 0.001 ) ) . with_config ( demo_training_config ( ) )
34+ }
35+
2536/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
2637#[ derive( Debug , Clone ) ]
2738#[ allow( dead_code) ]
@@ -147,11 +158,7 @@ impl StockPredictor {
147158 val_data. len( )
148159 ) ;
149160
150- // Create trainer with Adam optimizer
151- let loss_function = MSELoss ;
152- let optimizer = Adam :: new ( 0.001 ) ;
153- let mut trainer = LSTMTrainer :: new ( self . network . clone ( ) , loss_function, optimizer)
154- . with_config ( demo_training_config ( ) ) ;
161+ let mut trainer = demo_stock_trainer ( self . network . clone ( ) ) ;
155162
156163 // Train the model
157164 trainer. train ( train_data, Some ( val_data) ) ;
@@ -189,7 +196,7 @@ impl StockPredictor {
189196
190197/// Generate synthetic stock data for demonstration
191198fn generate_stock_data ( days : usize ) -> Vec < StockData > {
192- let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
199+ let mut rng = StdRng :: seed_from_u64 ( DEMO_RNG_SEED ) ;
193200 let mut data = Vec :: new ( ) ;
194201 let mut price = 100.0 ;
195202 let volume_base = 1_000_000.0 ;
@@ -226,13 +233,21 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
226233 data
227234}
228235
236+ #[ allow( dead_code) ]
237+ pub ( crate ) fn demo_stock_closes ( days : usize ) -> Vec < f64 > {
238+ generate_stock_data ( days)
239+ . into_iter ( )
240+ . map ( |stock| stock. close )
241+ . collect ( )
242+ }
243+
229244fn main ( ) {
230245 println ! ( "🏦 Stock Price Prediction with LSTM" ) ;
231246 println ! ( "=====================================\n " ) ;
232247 println ! ( "This bounded demo favors quick execution over prediction quality.\n " ) ;
233248
234249 // Generate synthetic stock data (in practice, you'd load real data)
235- let stock_data = generate_stock_data ( 80 ) ; // Bounded synthetic dataset for a quick demo
250+ let stock_data = generate_stock_data ( DEMO_STOCK_DAYS ) ; // Bounded synthetic dataset for a quick demo
236251 println ! (
237252 "📈 Generated {} days of synthetic stock data" ,
238253 stock_data. len( )
@@ -250,16 +265,16 @@ fn main() {
250265 }
251266
252267 // Create and train predictor
253- let mut predictor = StockPredictor :: new ( 5 , 8 ) ; // 5-day sequences, 8 hidden units
268+ let mut predictor = StockPredictor :: new ( DEMO_SEQUENCE_LENGTH , DEMO_HIDDEN_SIZE ) ;
254269 predictor. train ( & stock_data, 0.2 ) ; // 80% train, 20% validation
255270
256271 // Make predictions on recent data
257272 println ! ( "\n 🔮 Making predictions..." ) ;
258- let recent_data = & stock_data[ stock_data. len ( ) - 10 ..] ; // Last 10 days
273+ let recent_window = DEMO_SEQUENCE_LENGTH * 2 ;
274+ let recent_data = & stock_data[ stock_data. len ( ) - recent_window..] ;
259275
260- for i in 5 ..10 {
261- // Predict for days 6-10 of recent data
262- let input_data = & recent_data[ i - 5 ..i] ;
276+ for i in DEMO_SEQUENCE_LENGTH ..recent_window {
277+ let input_data = & recent_data[ i - DEMO_SEQUENCE_LENGTH ..i] ;
263278 if let Some ( predicted_price) = predictor. predict_next_price ( input_data) {
264279 let actual_price = recent_data[ i] . close ;
265280 let error = ( predicted_price - actual_price) . abs ( ) ;
0 commit comments