77#![ allow( unused_comparisons) ]
88
99use ndarray:: { arr2, Array2 } ;
10+ use rand:: { rngs:: StdRng , Rng , SeedableRng } ;
1011use rust_lstm:: loss:: MSELoss ;
1112use rust_lstm:: models:: lstm_network:: LSTMNetwork ;
1213use rust_lstm:: optimizers:: Adam ;
13- use rust_lstm:: training:: LSTMTrainer ;
14+ use rust_lstm:: training:: { LSTMTrainer , TrainingConfig } ;
15+
16+ fn demo_training_config ( ) -> TrainingConfig {
17+ TrainingConfig {
18+ epochs : 2 ,
19+ print_every : 1 ,
20+ log_lr_changes : false ,
21+ ..TrainingConfig :: default ( )
22+ }
23+ }
1424
1525/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
1626#[ derive( Debug , Clone ) ]
@@ -140,7 +150,8 @@ impl StockPredictor {
140150 // Create trainer with Adam optimizer
141151 let loss_function = MSELoss ;
142152 let optimizer = Adam :: new ( 0.001 ) ;
143- let mut trainer = LSTMTrainer :: new ( self . network . clone ( ) , loss_function, optimizer) ;
153+ let mut trainer = LSTMTrainer :: new ( self . network . clone ( ) , loss_function, optimizer)
154+ . with_config ( demo_training_config ( ) ) ;
144155
145156 // Train the model
146157 trainer. train ( train_data, Some ( val_data) ) ;
@@ -178,6 +189,7 @@ impl StockPredictor {
178189
179190/// Generate synthetic stock data for demonstration
180191fn generate_stock_data ( days : usize ) -> Vec < StockData > {
192+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
181193 let mut data = Vec :: new ( ) ;
182194 let mut price = 100.0 ;
183195 let volume_base = 1_000_000.0 ;
@@ -186,19 +198,19 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
186198 // Random walk with trend and volatility
187199 let trend = 0.001 ; // Slight upward trend
188200 let volatility = 0.02 ;
189- let random_change = ( rand :: random :: < f64 > ( ) - 0.5 ) * volatility;
201+ let random_change = ( rng . gen :: < f64 > ( ) - 0.5 ) * volatility;
190202
191203 price *= 1.0 + trend + random_change;
192204 price = price. max ( 1.0 ) ; // Prevent negative prices
193205
194206 // Generate OHLC based on closing price
195207 let daily_volatility = 0.005 ;
196- let high = price * ( 1.0 + rand :: random :: < f64 > ( ) * daily_volatility) ;
197- let low = price * ( 1.0 - rand :: random :: < f64 > ( ) * daily_volatility) ;
198- let open = low + ( high - low) * rand :: random :: < f64 > ( ) ;
208+ let high = price * ( 1.0 + rng . gen :: < f64 > ( ) * daily_volatility) ;
209+ let low = price * ( 1.0 - rng . gen :: < f64 > ( ) * daily_volatility) ;
210+ let open = low + ( high - low) * rng . gen :: < f64 > ( ) ;
199211
200212 // Volume with some correlation to price movement
201- let volume_factor = 0.8 + 0.4 * rand :: random :: < f64 > ( ) ;
213+ let volume_factor = 0.8 + 0.4 * rng . gen :: < f64 > ( ) ;
202214 let volume = volume_base * volume_factor;
203215
204216 data. push ( StockData {
@@ -217,9 +229,10 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
217229fn main ( ) {
218230 println ! ( "🏦 Stock Price Prediction with LSTM" ) ;
219231 println ! ( "=====================================\n " ) ;
232+ println ! ( "This bounded demo favors quick execution over prediction quality.\n " ) ;
220233
221234 // Generate synthetic stock data (in practice, you'd load real data)
222- let stock_data = generate_stock_data ( 500 ) ; // 500 days of data
235+ let stock_data = generate_stock_data ( 80 ) ; // Bounded synthetic dataset for a quick demo
223236 println ! (
224237 "📈 Generated {} days of synthetic stock data" ,
225238 stock_data. len( )
@@ -237,16 +250,16 @@ fn main() {
237250 }
238251
239252 // Create and train predictor
240- let mut predictor = StockPredictor :: new ( 20 , 50 ) ; // 20 -day sequences, 50 hidden units
253+ let mut predictor = StockPredictor :: new ( 5 , 8 ) ; // 5 -day sequences, 8 hidden units
241254 predictor. train ( & stock_data, 0.2 ) ; // 80% train, 20% validation
242255
243256 // Make predictions on recent data
244257 println ! ( "\n 🔮 Making predictions..." ) ;
245- let recent_data = & stock_data[ stock_data. len ( ) - 30 ..] ; // Last 30 days
258+ let recent_data = & stock_data[ stock_data. len ( ) - 10 ..] ; // Last 10 days
246259
247- for i in 20 .. 25 {
248- // Predict for days 21-25 of recent data
249- let input_data = & recent_data[ i - 20 ..i] ;
260+ for i in 5 .. 10 {
261+ // Predict for days 6-10 of recent data
262+ let input_data = & recent_data[ i - 5 ..i] ;
250263 if let Some ( predicted_price) = predictor. predict_next_price ( input_data) {
251264 let actual_price = recent_data[ i] . close ;
252265 let error = ( predicted_price - actual_price) . abs ( ) ;
0 commit comments