170170
171171# --- Training ---
172172EPOCHS = 100
173- BATCH_SIZE = 32
174- EVAL_BATCH = 1024 # Larger batch for eval (no gradients)
173+ BATCH_SIZE = 1024
174+ EVAL_BATCH = 16384 # Larger batch for eval (no gradients)
175175LR = 1e-4
176+ USE_SCHEDULER = True
176177LR_MIN = 1e-6 # Cosine annealing floor
177178VAL_EVERY = 3 # Validate every N epochs (validation is expensive)
178179PATIENCE = 20 # Early stopping patience (in validation checks)
185186
186187# Feature scaling multiplier applied to return channels before float32 training.
187188# Higher values increase numerical visibility of tiny moves.
188- RETURN_SCALE = 100 .0
189+ RETURN_SCALE = 1 .0
189190# Base outlier threshold in unscaled return space.
190191# Threshold in feature space is this base multiplied by RETURN_SCALE.
191192MAX_ABS_NORMALIZED_BASE = 20.0
@@ -540,7 +541,6 @@ def build_model():
540541 weight_init = 'resonant' , # Edge of Chaos initialization
541542 activation = 'tanh' , # Bounded oscillations
542543 hebb_type = 'synapse' , # Per-synapse plasticity -> workbenches
543- dropout_rate = 0.05 , # Light regularization
544544 gradient_checkpointing = False , # VRAM plentiful; skip recompute for ~2x speed
545545 device = DEVICE ,
546546 )
@@ -1068,10 +1068,12 @@ def main():
10681068 trainer = OdyssNetTrainer (model , device = DEVICE , lr = LR )
10691069 trainer .loss_fn = HeteroscedasticNLL ()
10701070
1071- # Cosine annealing: LR decays smoothly from LR to LR_MIN
1072- scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (
1073- trainer .optimizer , T_max = EPOCHS , eta_min = LR_MIN ,
1074- )
1071+ # Optional cosine annealing scheduler.
1072+ scheduler = None
1073+ if USE_SCHEDULER :
1074+ scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (
1075+ trainer .optimizer , T_max = EPOCHS , eta_min = LR_MIN ,
1076+ )
10751077
10761078 history = TrainingHistory ()
10771079 best_val_mse = float ('inf' )
@@ -1081,7 +1083,10 @@ def main():
10811083
10821084 print (f"\n { '=' * 70 } " )
10831085 print (f" TRAINING" )
1084- print (f" Epochs: { EPOCHS } | Batch: { BATCH_SIZE } | LR: { LR } -> { LR_MIN } (cosine)" )
1086+ if USE_SCHEDULER :
1087+ print (f" Epochs: { EPOCHS } | Batch: { BATCH_SIZE } | LR: { LR } -> { LR_MIN } (cosine)" )
1088+ else :
1089+ print (f" Epochs: { EPOCHS } | Batch: { BATCH_SIZE } | LR: { LR } (fixed)" )
10851090 print (f" Loss: Heteroscedastic Gaussian NLL (Kendall & Gal, 2017)" )
10861091 print (f" Confident signal threshold: P(up) > { CONF_THRESH :.2f} or < { 1.0 - CONF_THRESH :.2f} " )
10871092 print (f" Validation every { VAL_EVERY } epochs | Early stop patience: { PATIENCE } " )
@@ -1092,8 +1097,11 @@ def main():
10921097 for epoch in range (1 , EPOCHS + 1 ):
10931098 # Train
10941099 avg_loss = run_epoch (trainer , train_x , train_y )
1095- scheduler .step ()
1096- current_lr = scheduler .get_last_lr ()[0 ]
1100+ if scheduler is not None :
1101+ scheduler .step ()
1102+ current_lr = scheduler .get_last_lr ()[0 ]
1103+ else :
1104+ current_lr = trainer .optimizer .param_groups [0 ]['lr' ]
10971105
10981106 # Validate (every VAL_EVERY epochs + first + last)
10991107 do_val = (epoch == 1 or epoch % VAL_EVERY == 0 or epoch == EPOCHS )
0 commit comments