77#![ allow( unused_comparisons) ]
88
99use ndarray:: { arr2, Array2 } ;
10+ use rand:: { rngs:: StdRng , Rng , SeedableRng } ;
1011use rust_lstm:: loss:: MSELoss ;
1112use rust_lstm:: models:: lstm_network:: LSTMNetwork ;
1213use rust_lstm:: optimizers:: Adam ;
13- use rust_lstm:: training:: LSTMTrainer ;
14+ use rust_lstm:: training:: { LSTMTrainer , TrainingConfig } ;
15+
16+ pub ( crate ) const DEMO_EPOCHS : usize = 2 ;
17+ pub ( crate ) const DEMO_PRINT_EVERY : usize = 1 ;
18+ pub ( crate ) const DEMO_STOCK_DAYS : usize = 80 ;
19+ pub ( crate ) const DEMO_SEQUENCE_LENGTH : usize = 5 ;
20+ pub ( crate ) const DEMO_HIDDEN_SIZE : usize = 8 ;
21+ pub ( crate ) const DEMO_RNG_SEED : u64 = 42 ;
22+
23+ pub ( crate ) fn demo_training_config ( ) -> TrainingConfig {
24+ TrainingConfig {
25+ epochs : DEMO_EPOCHS ,
26+ print_every : DEMO_PRINT_EVERY ,
27+ log_lr_changes : false ,
28+ ..TrainingConfig :: default ( )
29+ }
30+ }
31+
32+ pub ( crate ) fn demo_stock_trainer ( network : LSTMNetwork ) -> LSTMTrainer < MSELoss , Adam > {
33+ LSTMTrainer :: new ( network, MSELoss , Adam :: new ( 0.001 ) ) . with_config ( demo_training_config ( ) )
34+ }
1435
1536/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
1637#[ derive( Debug , Clone ) ]
@@ -137,10 +158,7 @@ impl StockPredictor {
137158 val_data. len( )
138159 ) ;
139160
140- // Create trainer with Adam optimizer
141- let loss_function = MSELoss ;
142- let optimizer = Adam :: new ( 0.001 ) ;
143- let mut trainer = LSTMTrainer :: new ( self . network . clone ( ) , loss_function, optimizer) ;
161+ let mut trainer = demo_stock_trainer ( self . network . clone ( ) ) ;
144162
145163 // Train the model
146164 trainer. train ( train_data, Some ( val_data) ) ;
@@ -178,6 +196,7 @@ impl StockPredictor {
178196
179197/// Generate synthetic stock data for demonstration
180198fn generate_stock_data ( days : usize ) -> Vec < StockData > {
199+ let mut rng = StdRng :: seed_from_u64 ( DEMO_RNG_SEED ) ;
181200 let mut data = Vec :: new ( ) ;
182201 let mut price = 100.0 ;
183202 let volume_base = 1_000_000.0 ;
@@ -186,19 +205,19 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
186205 // Random walk with trend and volatility
187206 let trend = 0.001 ; // Slight upward trend
188207 let volatility = 0.02 ;
189- let random_change = ( rand :: random :: < f64 > ( ) - 0.5 ) * volatility;
208+ let random_change = ( rng . gen :: < f64 > ( ) - 0.5 ) * volatility;
190209
191210 price *= 1.0 + trend + random_change;
192211 price = price. max ( 1.0 ) ; // Prevent negative prices
193212
194213 // Generate OHLC based on closing price
195214 let daily_volatility = 0.005 ;
196- let high = price * ( 1.0 + rand :: random :: < f64 > ( ) * daily_volatility) ;
197- let low = price * ( 1.0 - rand :: random :: < f64 > ( ) * daily_volatility) ;
198- let open = low + ( high - low) * rand :: random :: < f64 > ( ) ;
215+ let high = price * ( 1.0 + rng . gen :: < f64 > ( ) * daily_volatility) ;
216+ let low = price * ( 1.0 - rng . gen :: < f64 > ( ) * daily_volatility) ;
217+ let open = low + ( high - low) * rng . gen :: < f64 > ( ) ;
199218
200219 // Volume with some correlation to price movement
201- let volume_factor = 0.8 + 0.4 * rand :: random :: < f64 > ( ) ;
220+ let volume_factor = 0.8 + 0.4 * rng . gen :: < f64 > ( ) ;
202221 let volume = volume_base * volume_factor;
203222
204223 data. push ( StockData {
@@ -214,12 +233,21 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
214233 data
215234}
216235
236+ #[ allow( dead_code) ]
237+ pub ( crate ) fn demo_stock_closes ( days : usize ) -> Vec < f64 > {
238+ generate_stock_data ( days)
239+ . into_iter ( )
240+ . map ( |stock| stock. close )
241+ . collect ( )
242+ }
243+
217244fn main ( ) {
218245 println ! ( "🏦 Stock Price Prediction with LSTM" ) ;
219246 println ! ( "=====================================\n " ) ;
247+ println ! ( "This bounded demo favors quick execution over prediction quality.\n " ) ;
220248
221249 // Generate synthetic stock data (in practice, you'd load real data)
222- let stock_data = generate_stock_data ( 500 ) ; // 500 days of data
250+ let stock_data = generate_stock_data ( DEMO_STOCK_DAYS ) ; // Bounded synthetic dataset for a quick demo
223251 println ! (
224252 "📈 Generated {} days of synthetic stock data" ,
225253 stock_data. len( )
@@ -237,16 +265,16 @@ fn main() {
237265 }
238266
239267 // Create and train predictor
240- let mut predictor = StockPredictor :: new ( 20 , 50 ) ; // 20-day sequences, 50 hidden units
268+ let mut predictor = StockPredictor :: new ( DEMO_SEQUENCE_LENGTH , DEMO_HIDDEN_SIZE ) ;
241269 predictor. train ( & stock_data, 0.2 ) ; // 80% train, 20% validation
242270
243271 // Make predictions on recent data
244272 println ! ( "\n 🔮 Making predictions..." ) ;
245- let recent_data = & stock_data[ stock_data. len ( ) - 30 ..] ; // Last 30 days
273+ let recent_window = DEMO_SEQUENCE_LENGTH * 2 ;
274+ let recent_data = & stock_data[ stock_data. len ( ) - recent_window..] ;
246275
247- for i in 20 ..25 {
248- // Predict for days 21-25 of recent data
249- let input_data = & recent_data[ i - 20 ..i] ;
276+ for i in DEMO_SEQUENCE_LENGTH ..recent_window {
277+ let input_data = & recent_data[ i - DEMO_SEQUENCE_LENGTH ..i] ;
250278 if let Some ( predicted_price) = predictor. predict_next_price ( input_data) {
251279 let actual_price = recent_data[ i] . close ;
252280 let error = ( predicted_price - actual_price) . abs ( ) ;
0 commit comments