77#![ allow( unused_comparisons) ]
88
99use ndarray:: { arr2, Array2 } ;
10+ use rand:: { rngs:: StdRng , Rng , SeedableRng } ;
1011use rust_lstm:: loss:: MSELoss ;
1112use rust_lstm:: models:: lstm_network:: LSTMNetwork ;
1213use rust_lstm:: optimizers:: Adam ;
13- use rust_lstm:: training:: LSTMTrainer ;
14+ use rust_lstm:: training:: { LSTMTrainer , TrainingConfig } ;
15+
16+ pub const DEMO_WEATHER_DAYS : usize = 80 ;
17+ pub const DEMO_SEQUENCE_LENGTH : usize = 5 ;
18+ pub const DEMO_HIDDEN_SIZE : usize = 8 ;
19+ pub const DEMO_EPOCHS : usize = 4 ;
20+ pub const DEMO_PRINT_EVERY : usize = 1 ;
21+ pub const DEMO_NUM_PREDICTIONS : usize = 5 ;
22+ pub const DEMO_RECENT_DAYS : usize = DEMO_SEQUENCE_LENGTH + DEMO_NUM_PREDICTIONS ;
23+ pub const DEMO_RANDOM_SEED : u64 = 42 ;
1424
1525/// Weather data with multiple meteorological features
16- #[ derive( Debug , Clone ) ]
26+ #[ derive( Debug , Clone , PartialEq ) ]
1727#[ allow( dead_code) ]
18- struct WeatherData {
19- date : String ,
20- temperature : f64 , // °C
21- humidity : f64 , // %
22- pressure : f64 , // hPa
23- wind_speed : f64 , // km/h
24- precipitation : f64 , // mm
25- cloud_cover : f64 , // %
28+ pub struct WeatherData {
29+ pub date : String ,
30+ pub temperature : f64 , // °C
31+ pub humidity : f64 , // %
32+ pub pressure : f64 , // hPa
33+ pub wind_speed : f64 , // km/h
34+ pub precipitation : f64 , // mm
35+ pub cloud_cover : f64 , // %
2636}
2737
2838/// Multi-feature weather prediction system
@@ -154,11 +164,7 @@ impl WeatherPredictor {
154164 let optimizer = Adam :: new ( 0.001 ) ;
155165 let mut trainer = LSTMTrainer :: new ( self . network . clone ( ) , loss_function, optimizer) ;
156166
157- // Configure training for quicker demo
158- let mut config = rust_lstm:: training:: TrainingConfig :: default ( ) ;
159- config. epochs = 20 ; // Reduced from 100 for demo
160- config. print_every = 5 ; // Print more frequently
161- trainer = trainer. with_config ( config) ;
167+ trainer = trainer. with_config ( weather_training_config ( ) ) ;
162168
163169 trainer. train ( train_data, Some ( val_data) ) ;
164170
@@ -192,8 +198,17 @@ impl WeatherPredictor {
192198 }
193199}
194200
201+ pub fn weather_training_config ( ) -> TrainingConfig {
202+ TrainingConfig {
203+ epochs : DEMO_EPOCHS ,
204+ print_every : DEMO_PRINT_EVERY ,
205+ ..TrainingConfig :: default ( )
206+ }
207+ }
208+
195209/// Generate realistic weather data with seasonal patterns
196- fn generate_weather_data ( days : usize ) -> Vec < WeatherData > {
210+ pub fn generate_weather_data ( days : usize ) -> Vec < WeatherData > {
211+ let mut rng = StdRng :: seed_from_u64 ( DEMO_RANDOM_SEED ) ;
197212 let mut data = Vec :: new ( ) ;
198213
199214 for i in 0 ..days {
@@ -203,23 +218,23 @@ fn generate_weather_data(days: usize) -> Vec<WeatherData> {
203218 let seasonal_temp = 15.0 + 10.0 * ( 2.0 * std:: f64:: consts:: PI * day_of_year / 365.0 ) . sin ( ) ;
204219
205220 // Daily temperature variation with some randomness
206- let daily_variation = ( rand :: random :: < f64 > ( ) - 0.5 ) * 6.0 ;
221+ let daily_variation = ( rng . gen :: < f64 > ( ) - 0.5 ) * 6.0 ;
207222 let temperature = seasonal_temp + daily_variation;
208223
209224 // Humidity inversely correlated with temperature
210- let humidity = 70.0 - ( temperature - 15.0 ) * 2.0 + ( rand :: random :: < f64 > ( ) - 0.5 ) * 20.0 ;
225+ let humidity = 70.0 - ( temperature - 15.0 ) * 2.0 + ( rng . gen :: < f64 > ( ) - 0.5 ) * 20.0 ;
211226 let humidity = humidity. clamp ( 20.0 , 95.0 ) ;
212227
213228 // Pressure with weather patterns
214- let pressure = 1013.25 + ( rand :: random :: < f64 > ( ) - 0.5 ) * 30.0 ;
229+ let pressure = 1013.25 + ( rng . gen :: < f64 > ( ) - 0.5 ) * 30.0 ;
215230
216231 // Wind speed with some correlation to pressure changes
217- let wind_speed = 10.0 + ( rand :: random :: < f64 > ( ) * 15.0 ) ;
232+ let wind_speed = 10.0 + ( rng . gen :: < f64 > ( ) * 15.0 ) ;
218233
219234 // Precipitation probability based on humidity and pressure
220235 let precip_prob = ( humidity - 50.0 ) / 100.0 + ( 1020.0 - pressure) / 50.0 ;
221- let precipitation = if rand :: random :: < f64 > ( ) < precip_prob. max ( 0.0 ) {
222- rand :: random :: < f64 > ( ) * 15.0 // 0-15mm
236+ let precipitation = if rng . gen :: < f64 > ( ) < precip_prob. max ( 0.0 ) {
237+ rng . gen :: < f64 > ( ) * 15.0 // 0-15mm
223238 } else {
224239 0.0
225240 } ;
@@ -248,7 +263,7 @@ fn main() {
248263 println ! ( "===========================================\n " ) ;
249264
250265 // Generate synthetic weather data
251- let weather_data = generate_weather_data ( 365 ) ; // One year of data
266+ let weather_data = generate_weather_data ( DEMO_WEATHER_DAYS ) ;
252267 println ! (
253268 "🌍 Generated {} days of synthetic weather data" ,
254269 weather_data. len( )
@@ -268,16 +283,15 @@ fn main() {
268283 }
269284
270285 // Create and train predictor
271- let mut predictor = WeatherPredictor :: new ( 7 , 64 ) ; // 7-day sequences, 64 hidden units
286+ let mut predictor = WeatherPredictor :: new ( DEMO_SEQUENCE_LENGTH , DEMO_HIDDEN_SIZE ) ;
272287 predictor. train ( & weather_data, 0.2 ) ; // 80% train, 20% validation
273288
274289 // Make temperature predictions
275290 println ! ( "\n 🔮 Temperature predictions for next 5 days:" ) ;
276- let recent_data = & weather_data[ weather_data. len ( ) - 20 ..] ; // Last 20 days
291+ let recent_data = & weather_data[ weather_data. len ( ) - DEMO_RECENT_DAYS ..] ;
277292
278- for i in 7 ..12 {
279- // Predict for days 8-12 of recent data
280- let input_data = & recent_data[ i - 7 ..i] ;
293+ for i in DEMO_SEQUENCE_LENGTH ..DEMO_SEQUENCE_LENGTH + DEMO_NUM_PREDICTIONS {
294+ let input_data = & recent_data[ i - DEMO_SEQUENCE_LENGTH ..i] ;
281295 if let Some ( predicted_temp) = predictor. predict_temperature ( input_data) {
282296 let actual_temp = recent_data[ i] . temperature ;
283297 let error = ( predicted_temp - actual_temp) . abs ( ) ;
0 commit comments