Skip to content

Commit fb5fdeb

Browse files
authored
Merge pull request #25 from theomgdev/copilot/fix-bugs-performance-issues
Harden trainer edge cases, vectorize input mapping, and unify pytest configuration
2 parents 65196b4 + 8f80f6c commit fb5fdeb

4 files changed

Lines changed: 56 additions & 7 deletions

File tree

odyssnet/training/trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.optim as optim
44
import time
55
import math
6+
import numbers
67
from typing import Callable
78
from ..utils.data import prepare_input, to_tensor
89

@@ -194,6 +195,12 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac
194195
"""
195196
Runs a single training step on a batch.
196197
"""
198+
if isinstance(gradient_accumulation_steps, bool) or not isinstance(gradient_accumulation_steps, numbers.Integral):
199+
raise ValueError("gradient_accumulation_steps must be an integer >= 1")
200+
gradient_accumulation_steps = int(gradient_accumulation_steps)
201+
if gradient_accumulation_steps < 1:
202+
raise ValueError("gradient_accumulation_steps must be an integer >= 1")
203+
197204
self.model.train()
198205

199206
self._ensure_scaler()
@@ -386,6 +393,15 @@ def fit(self, input_features, target_values, epochs, batch_size=32, thinking_ste
386393
input_features = to_tensor(input_features, self.device)
387394
target_values = to_tensor(target_values, self.device)
388395

396+
if isinstance(batch_size, bool) or not isinstance(batch_size, int):
397+
raise TypeError("batch_size must be an integer")
398+
if batch_size < 1:
399+
raise ValueError("batch_size must be >= 1")
400+
if len(input_features) != len(target_values):
401+
raise ValueError("input_features and target_values must have the same length")
402+
if len(input_features) == 0:
403+
raise ValueError("input_features and target_values must be non-empty")
404+
389405
history = []
390406

391407
# Prepare Data

odyssnet/utils/data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ def prepare_input(input_features, model_input_ids, num_neurons, device):
4848
x_input = torch.zeros(batch_size, steps, num_neurons, device=device)
4949

5050
num_assigned = min(num_features, len(model_input_ids))
51-
for k in range(num_assigned):
52-
x_input[:, :, model_input_ids[k]] = input_features[:, :, k]
51+
if num_assigned > 0:
52+
target_ids = model_input_ids[:num_assigned]
53+
x_input[:, :, target_ids] = input_features[:, :, :num_assigned]
5354

5455
return x_input, batch_size
5556

@@ -61,8 +62,9 @@ def prepare_input(input_features, model_input_ids, num_neurons, device):
6162
num_assigned = min(num_features, len(model_input_ids))
6263

6364
# Assign features to neurons
64-
for k in range(num_assigned):
65-
x_input[:, model_input_ids[k]] = input_features[:, k]
65+
if num_assigned > 0:
66+
target_ids = model_input_ids[:num_assigned]
67+
x_input[:, target_ids] = input_features[:, :num_assigned]
6668

6769
return x_input, batch_size
6870

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,3 @@ all = [
5050

5151
[tool.setuptools.packages.find]
5252
include = ["odyssnet*"]
53-
54-
[tool.pytest.ini_options]
55-
testpaths = ["tests"]

tests/training/test_trainer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,16 @@ def test_non_pulse_2d_input_full_sequence(self):
242242
loss = t.train_batch(x, y, thinking_steps=10, full_sequence=True)
243243
assert isinstance(loss, float)
244244

245+
def test_gradient_accumulation_steps_must_be_positive(self):
246+
model = _model()
247+
t = _trainer(model)
248+
x = _batch()
249+
y = _targets()
250+
with pytest.raises(ValueError):
251+
t.train_batch(x, y, thinking_steps=2, gradient_accumulation_steps=0)
252+
with pytest.raises(ValueError):
253+
t.train_batch(x, y, thinking_steps=2, gradient_accumulation_steps=-1)
254+
245255

246256
# ===========================================================================
247257
# predict
@@ -341,6 +351,30 @@ def test_fit_loss_trend_downward_on_simple_data(self):
341351
history = t.fit(x, y, epochs=20, batch_size=n, thinking_steps=5, verbose=False)
342352
assert history[-1] < history[0], "Loss should decrease over training"
343353

354+
def test_fit_empty_dataset_raises(self):
355+
model = _model()
356+
t = _trainer(model)
357+
x = torch.empty(0, 5)
358+
y = torch.empty(0, 2)
359+
with pytest.raises(ValueError):
360+
t.fit(x, y, epochs=1, batch_size=4, thinking_steps=2, verbose=False)
361+
362+
def test_fit_length_mismatch_raises(self):
363+
model = _model()
364+
t = _trainer(model)
365+
x = torch.randn(3, 5)
366+
y = torch.randn(2, 2)
367+
with pytest.raises(ValueError):
368+
t.fit(x, y, epochs=1, batch_size=2, thinking_steps=2, verbose=False)
369+
370+
def test_fit_invalid_batch_size_raises(self):
371+
model = _model()
372+
t = _trainer(model)
373+
x = torch.randn(4, 5)
374+
y = torch.randn(4, 2)
375+
with pytest.raises(ValueError):
376+
t.fit(x, y, epochs=1, batch_size=0, thinking_steps=2, verbose=False)
377+
344378

345379
# ===========================================================================
346380
# regenerate_synapses

0 commit comments

Comments
 (0)