-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpytorch_monitor.py
More file actions
58 lines (47 loc) · 1.63 KB
/
pytorch_monitor.py
File metadata and controls
58 lines (47 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python3
# Copyright (c) 2025 ForgottenForge.xyz
# Licensed under AGPL-3.0-or-later. See LICENSE for details.
# Commercial license available: nfo@forgottenforge.xyz
"""Example: Monitor susceptibility DURING training.
The callback computes K_c periodically and logs it.
If K_c changes significantly, your training dynamics are shifting.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from batch_susceptibility.pytorch import SusceptibilityCallback
def main():
# Synthetic data
X = torch.randn(10000, 20)
y = (X[:, 0] + X[:, 1] > 0).long()
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
model = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, 2),
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# Setup monitor
monitor = SusceptibilityCallback(check_every=200, window_size=500)
# Training loop
model.train()
for epoch in range(5):
for batch_x, batch_y in loader:
optimizer.zero_grad()
out = model(batch_x)
loss = loss_fn(out, batch_y)
loss.backward()
optimizer.step()
# Feed loss to monitor
result = monitor.on_step(loss.item())
if result:
print(f" -> K_c={result.K_c:.0f}, regime={result.regime}")
# Final result
final = monitor.result()
if final:
print(f"\nFinal: K_c={final.K_c:.0f}, kappa={final.kappa:.2f}")
print(final.summary())
if __name__ == "__main__":
main()