Skip to content

Commit fe07106

Browse files
committed
Toggleable scheduler for financial oracle, and tweak batch
1 parent 0f831c8 commit fe07106

1 file changed

Lines changed: 19 additions & 11 deletions

File tree

examples/advanced/experiment_financial_oracle.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,10 @@
170170

171171
# --- Training ---
172172
EPOCHS = 100
173-
BATCH_SIZE = 32
174-
EVAL_BATCH = 1024 # Larger batch for eval (no gradients)
173+
BATCH_SIZE = 1024
174+
EVAL_BATCH = 16384 # Larger batch for eval (no gradients)
175175
LR = 1e-4
176+
USE_SCHEDULER = True
176177
LR_MIN = 1e-6 # Cosine annealing floor
177178
VAL_EVERY = 3 # Validate every N epochs (validation is expensive)
178179
PATIENCE = 20 # Early stopping patience (in validation checks)
@@ -185,7 +186,7 @@
185186

186187
# Feature scaling multiplier applied to return channels before float32 training.
187188
# Higher values increase numerical visibility of tiny moves.
188-
RETURN_SCALE = 100.0
189+
RETURN_SCALE = 1.0
189190
# Base outlier threshold in unscaled return space.
190191
# Threshold in feature space is this base multiplied by RETURN_SCALE.
191192
MAX_ABS_NORMALIZED_BASE = 20.0
@@ -540,7 +541,6 @@ def build_model():
540541
weight_init='resonant', # Edge of Chaos initialization
541542
activation='tanh', # Bounded oscillations
542543
hebb_type='synapse', # Per-synapse plasticity -> workbenches
543-
dropout_rate=0.05, # Light regularization
544544
gradient_checkpointing=False, # VRAM plentiful; skip recompute for ~2x speed
545545
device=DEVICE,
546546
)
@@ -1068,10 +1068,12 @@ def main():
10681068
trainer = OdyssNetTrainer(model, device=DEVICE, lr=LR)
10691069
trainer.loss_fn = HeteroscedasticNLL()
10701070

1071-
# Cosine annealing: LR decays smoothly from LR to LR_MIN
1072-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1073-
trainer.optimizer, T_max=EPOCHS, eta_min=LR_MIN,
1074-
)
1071+
# Optional cosine annealing scheduler.
1072+
scheduler = None
1073+
if USE_SCHEDULER:
1074+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1075+
trainer.optimizer, T_max=EPOCHS, eta_min=LR_MIN,
1076+
)
10751077

10761078
history = TrainingHistory()
10771079
best_val_mse = float('inf')
@@ -1081,7 +1083,10 @@ def main():
10811083

10821084
print(f"\n{'='*70}")
10831085
print(f" TRAINING")
1084-
print(f" Epochs: {EPOCHS} | Batch: {BATCH_SIZE} | LR: {LR} -> {LR_MIN} (cosine)")
1086+
if USE_SCHEDULER:
1087+
print(f" Epochs: {EPOCHS} | Batch: {BATCH_SIZE} | LR: {LR} -> {LR_MIN} (cosine)")
1088+
else:
1089+
print(f" Epochs: {EPOCHS} | Batch: {BATCH_SIZE} | LR: {LR} (fixed)")
10851090
print(f" Loss: Heteroscedastic Gaussian NLL (Kendall & Gal, 2017)")
10861091
print(f" Confident signal threshold: P(up) > {CONF_THRESH:.2f} or < {1.0 - CONF_THRESH:.2f}")
10871092
print(f" Validation every {VAL_EVERY} epochs | Early stop patience: {PATIENCE}")
@@ -1092,8 +1097,11 @@ def main():
10921097
for epoch in range(1, EPOCHS + 1):
10931098
# Train
10941099
avg_loss = run_epoch(trainer, train_x, train_y)
1095-
scheduler.step()
1096-
current_lr = scheduler.get_last_lr()[0]
1100+
if scheduler is not None:
1101+
scheduler.step()
1102+
current_lr = scheduler.get_last_lr()[0]
1103+
else:
1104+
current_lr = trainer.optimizer.param_groups[0]['lr']
10971105

10981106
# Validate (every VAL_EVERY epochs + first + last)
10991107
do_val = (epoch == 1 or epoch % VAL_EVERY == 0 or epoch == EPOCHS)

0 commit comments

Comments
 (0)