-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvergence_identity.py
More file actions
48 lines (38 loc) · 1.33 KB
/
convergence_identity.py
File metadata and controls
48 lines (38 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from odyssnet import OdyssNet, OdyssNetTrainer, TrainingHistory, set_seed
def main():
print("OdyssNet: The Atomic Identity...")
set_seed(42)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# ATOMIC UNIT OF CHAOS
# 1 Input, 1 Output.
# Minimum possible configuration for OdyssNet.
NUM_NEURONS = 2
INPUT_ID = 0
OUTPUT_ID = 1
# CRITICAL CONFIG FOR TINY NETWORKS:
# (Every neuron is vital)
model = OdyssNet(
num_neurons=NUM_NEURONS,
input_ids=[INPUT_ID],
output_ids=[OUTPUT_ID],
pulse_mode=True,
device=DEVICE
)
trainer = OdyssNetTrainer(model, device=DEVICE, lr=1e-4)
# Data
inputs_val = torch.randint(0, 2, (100, 1)).float() * 2 - 1
targets_val = inputs_val
print("Training...")
history = TrainingHistory()
loss_list = trainer.fit(inputs_val, targets_val, epochs=50, batch_size=32, thinking_steps=50)
for loss in loss_list:
history.record(loss=loss)
print("\nTest Result:")
test_inputs = torch.tensor([[1.0], [-1.0]], device=DEVICE)
preds = trainer.predict(test_inputs, thinking_steps=50)
for i in range(len(test_inputs)):
print(f"In: {test_inputs[i].item()} -> Out: {preds[i].item():.4f}")
history.plot(title="Identity Convergence")
if __name__ == "__main__":
main()