Skip to content

Commit c730dce

Browse files
authored
test: bound weather prediction example (#26)
1 parent 5bc2c8f commit c730dce

2 files changed

Lines changed: 117 additions & 28 deletions

File tree

examples/weather_prediction.rs

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,32 @@
77
#![allow(unused_comparisons)]
88

99
use ndarray::{arr2, Array2};
10+
use rand::{rngs::StdRng, Rng, SeedableRng};
1011
use rust_lstm::loss::MSELoss;
1112
use rust_lstm::models::lstm_network::LSTMNetwork;
1213
use rust_lstm::optimizers::Adam;
13-
use rust_lstm::training::LSTMTrainer;
14+
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
15+
16+
pub const DEMO_WEATHER_DAYS: usize = 80;
17+
pub const DEMO_SEQUENCE_LENGTH: usize = 5;
18+
pub const DEMO_HIDDEN_SIZE: usize = 8;
19+
pub const DEMO_EPOCHS: usize = 4;
20+
pub const DEMO_PRINT_EVERY: usize = 1;
21+
pub const DEMO_NUM_PREDICTIONS: usize = 5;
22+
pub const DEMO_RECENT_DAYS: usize = DEMO_SEQUENCE_LENGTH + DEMO_NUM_PREDICTIONS;
23+
pub const DEMO_RANDOM_SEED: u64 = 42;
1424

1525
/// Weather data with multiple meteorological features
16-
#[derive(Debug, Clone)]
26+
#[derive(Debug, Clone, PartialEq)]
1727
#[allow(dead_code)]
18-
struct WeatherData {
19-
date: String,
20-
temperature: f64, // °C
21-
humidity: f64, // %
22-
pressure: f64, // hPa
23-
wind_speed: f64, // km/h
24-
precipitation: f64, // mm
25-
cloud_cover: f64, // %
28+
pub struct WeatherData {
29+
pub date: String,
30+
pub temperature: f64, // °C
31+
pub humidity: f64, // %
32+
pub pressure: f64, // hPa
33+
pub wind_speed: f64, // km/h
34+
pub precipitation: f64, // mm
35+
pub cloud_cover: f64, // %
2636
}
2737

2838
/// Multi-feature weather prediction system
@@ -154,11 +164,7 @@ impl WeatherPredictor {
154164
let optimizer = Adam::new(0.001);
155165
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer);
156166

157-
// Configure training for quicker demo
158-
let mut config = rust_lstm::training::TrainingConfig::default();
159-
config.epochs = 20; // Reduced from 100 for demo
160-
config.print_every = 5; // Print more frequently
161-
trainer = trainer.with_config(config);
167+
trainer = trainer.with_config(weather_training_config());
162168

163169
trainer.train(train_data, Some(val_data));
164170

@@ -192,8 +198,17 @@ impl WeatherPredictor {
192198
}
193199
}
194200

201+
pub fn weather_training_config() -> TrainingConfig {
202+
TrainingConfig {
203+
epochs: DEMO_EPOCHS,
204+
print_every: DEMO_PRINT_EVERY,
205+
..TrainingConfig::default()
206+
}
207+
}
208+
195209
/// Generate realistic weather data with seasonal patterns
196-
fn generate_weather_data(days: usize) -> Vec<WeatherData> {
210+
pub fn generate_weather_data(days: usize) -> Vec<WeatherData> {
211+
let mut rng = StdRng::seed_from_u64(DEMO_RANDOM_SEED);
197212
let mut data = Vec::new();
198213

199214
for i in 0..days {
@@ -203,23 +218,23 @@ fn generate_weather_data(days: usize) -> Vec<WeatherData> {
203218
let seasonal_temp = 15.0 + 10.0 * (2.0 * std::f64::consts::PI * day_of_year / 365.0).sin();
204219

205220
// Daily temperature variation with some randomness
206-
let daily_variation = (rand::random::<f64>() - 0.5) * 6.0;
221+
let daily_variation = (rng.gen::<f64>() - 0.5) * 6.0;
207222
let temperature = seasonal_temp + daily_variation;
208223

209224
// Humidity inversely correlated with temperature
210-
let humidity = 70.0 - (temperature - 15.0) * 2.0 + (rand::random::<f64>() - 0.5) * 20.0;
225+
let humidity = 70.0 - (temperature - 15.0) * 2.0 + (rng.gen::<f64>() - 0.5) * 20.0;
211226
let humidity = humidity.clamp(20.0, 95.0);
212227

213228
// Pressure with weather patterns
214-
let pressure = 1013.25 + (rand::random::<f64>() - 0.5) * 30.0;
229+
let pressure = 1013.25 + (rng.gen::<f64>() - 0.5) * 30.0;
215230

216231
// Wind speed with some correlation to pressure changes
217-
let wind_speed = 10.0 + (rand::random::<f64>() * 15.0);
232+
let wind_speed = 10.0 + (rng.gen::<f64>() * 15.0);
218233

219234
// Precipitation probability based on humidity and pressure
220235
let precip_prob = (humidity - 50.0) / 100.0 + (1020.0 - pressure) / 50.0;
221-
let precipitation = if rand::random::<f64>() < precip_prob.max(0.0) {
222-
rand::random::<f64>() * 15.0 // 0-15mm
236+
let precipitation = if rng.gen::<f64>() < precip_prob.max(0.0) {
237+
rng.gen::<f64>() * 15.0 // 0-15mm
223238
} else {
224239
0.0
225240
};
@@ -248,7 +263,7 @@ fn main() {
248263
println!("===========================================\n");
249264

250265
// Generate synthetic weather data
251-
let weather_data = generate_weather_data(365); // One year of data
266+
let weather_data = generate_weather_data(DEMO_WEATHER_DAYS);
252267
println!(
253268
"🌍 Generated {} days of synthetic weather data",
254269
weather_data.len()
@@ -268,16 +283,15 @@ fn main() {
268283
}
269284

270285
// Create and train predictor
271-
let mut predictor = WeatherPredictor::new(7, 64); // 7-day sequences, 64 hidden units
286+
let mut predictor = WeatherPredictor::new(DEMO_SEQUENCE_LENGTH, DEMO_HIDDEN_SIZE);
272287
predictor.train(&weather_data, 0.2); // 80% train, 20% validation
273288

274289
// Make temperature predictions
275290
println!("\n🔮 Temperature predictions for next 5 days:");
276-
let recent_data = &weather_data[weather_data.len() - 20..]; // Last 20 days
291+
let recent_data = &weather_data[weather_data.len() - DEMO_RECENT_DAYS..];
277292

278-
for i in 7..12 {
279-
// Predict for days 8-12 of recent data
280-
let input_data = &recent_data[i - 7..i];
293+
for i in DEMO_SEQUENCE_LENGTH..DEMO_SEQUENCE_LENGTH + DEMO_NUM_PREDICTIONS {
294+
let input_data = &recent_data[i - DEMO_SEQUENCE_LENGTH..i];
281295
if let Some(predicted_temp) = predictor.predict_temperature(input_data) {
282296
let actual_temp = recent_data[i].temperature;
283297
let error = (predicted_temp - actual_temp).abs();

tests/example_training_bounds_test.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ mod advanced_lr_scheduling;
2222
#[path = "../examples/dropout_example.rs"]
2323
mod dropout_example;
2424

25+
#[allow(dead_code)]
26+
#[path = "../examples/weather_prediction.rs"]
27+
mod weather_prediction;
28+
2529
use std::hint::black_box;
2630

2731
#[test]
@@ -387,3 +391,74 @@ fn advanced_lr_scheduling_uses_small_deterministic_fixture() {
387391
"generated fixtures should use the public demo sequence-length bound"
388392
);
389393
}
394+
395+
#[test]
396+
fn weather_prediction_applies_bounded_training_config() {
397+
let config = weather_prediction::weather_training_config();
398+
399+
assert!(
400+
black_box(config.epochs) <= 4,
401+
"weather_prediction should avoid long default training runs"
402+
);
403+
assert!(
404+
black_box(config.print_every) > 0,
405+
"weather_prediction progress logging should stay enabled"
406+
);
407+
assert!(
408+
black_box(config.print_every) <= black_box(config.epochs),
409+
"weather_prediction progress logging should not exceed the epoch budget"
410+
);
411+
assert!(
412+
config.early_stopping.is_none(),
413+
"weather_prediction should avoid hidden early-stopping work"
414+
);
415+
assert!(
416+
black_box(weather_prediction::DEMO_WEATHER_DAYS) <= 80,
417+
"weather_prediction should keep its synthetic dataset bounded"
418+
);
419+
assert!(
420+
black_box(weather_prediction::DEMO_SEQUENCE_LENGTH) <= 5,
421+
"weather_prediction should keep each sequence bounded"
422+
);
423+
assert!(
424+
black_box(weather_prediction::DEMO_HIDDEN_SIZE) <= 8,
425+
"weather_prediction should keep hidden size bounded"
426+
);
427+
assert!(
428+
black_box(weather_prediction::DEMO_RECENT_DAYS)
429+
>= black_box(
430+
weather_prediction::DEMO_SEQUENCE_LENGTH + weather_prediction::DEMO_NUM_PREDICTIONS,
431+
),
432+
"weather_prediction should keep enough recent days for every preview prediction"
433+
);
434+
assert!(
435+
black_box(weather_prediction::DEMO_WEATHER_DAYS)
436+
>= black_box(weather_prediction::DEMO_RECENT_DAYS),
437+
"weather_prediction dataset should cover the preview prediction window"
438+
);
439+
}
440+
441+
#[test]
442+
fn weather_prediction_demo_data_is_reproducible() {
443+
let first = weather_prediction::generate_weather_data(black_box(12));
444+
let second = weather_prediction::generate_weather_data(black_box(12));
445+
446+
assert_eq!(
447+
first, second,
448+
"weather demo fixture should be deterministic"
449+
);
450+
assert!(
451+
first.windows(2).any(|window| window[0] != window[1]),
452+
"weather demo fixture should vary across generated days"
453+
);
454+
assert!(
455+
first.iter().all(|weather| {
456+
weather.humidity >= 20.0
457+
&& weather.humidity <= 95.0
458+
&& weather.cloud_cover >= 0.0
459+
&& weather.cloud_cover <= 100.0
460+
&& weather.precipitation >= 0.0
461+
}),
462+
"generated weather fixture should keep bounded meteorological values"
463+
);
464+
}

0 commit comments

Comments
 (0)