Skip to content

Commit 65196b4

Browse files
theomgdevclaude
andcommitted
Refactor state management: internalize recurrent state into model (v2.4.1)
Previously, callers were responsible for carrying `final_state` between train_batch calls and passing it back as `initial_state`. This created fragile boilerplate in experiment_llm.py and leaked an implementation detail into the public API. Changes: - OdyssNet now always persists `self.state` after every forward pass, not only when Hebbian learning is active - `train_batch` drops `initial_state` / `return_state` in favour of a single `keep_state` flag; callers no longer hold state tensors - experiment_llm.py TBPTT loop updated to use `keep_state=(t_start > 0)` - generate() uses `model.reset_state()` before warm-up instead of threading a state variable through the function - Tests updated to assert `model.state` directly and cover the new API - Add `.claude/` to .gitignore Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 5e0490f commit 65196b4

6 files changed

Lines changed: 34 additions & 46 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Thumbs.db
4242
# AI
4343
CLAUDE.md
4444
GEMINI.md
45+
.claude/
4546

4647
# Migrations & Plans
4748
*[mM]igration*.md

examples/advanced/experiment_llm.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
# OPTIMIZER CONFIG
6363
RESET_OPTIMIZER_ON_LOAD = False
6464
OVERWRITE_LR_OF_CKPT = True
65-
LEARNING_RATE = 1e-4
65+
LEARNING_RATE = 1e-5
6666

6767
# TIE EMBEDDINGS (VRAM Saving & Parameter Sharing)
6868
TIE_EMBEDDINGS = False
@@ -213,14 +213,13 @@ def generate(model, tokenizer, start_str="The", length=None, temperature=0.8, to
213213
encoded = tokenizer.encode(start_str)
214214
input_seq = encoded.ids
215215

216-
current_state = None
217-
218216
# Warm up state (Native Thinking)
217+
model.reset_state(batch_size=1)
219218
x_in = torch.tensor(input_seq, dtype=torch.long, device=model.device).unsqueeze(0)
220219
steps_total = x_in.shape[1] * (THINK_GAP + 1)
221-
220+
222221
with torch.no_grad():
223-
_, current_state = model(x_in, steps=steps_total)
222+
model(x_in, steps=steps_total)
224223

225224
last_token_idx = input_seq[-1]
226225

@@ -230,7 +229,7 @@ def generate(model, tokenizer, start_str="The", length=None, temperature=0.8, to
230229

231230
x_next = torch.tensor([[last_token_idx]], dtype=torch.long, device=model.device)
232231

233-
preds, current_state = model(x_next, steps=total_step_single, current_state=current_state)
232+
preds, _ = model(x_next, steps=total_step_single)
234233

235234
logits = preds[0, 0, :]
236235

@@ -293,7 +292,8 @@ def initialize_system(vocab_size, num_neurons, device, input_count=-1, output_co
293292
vocab_size=vocab_size,
294293
vocab_mode='discrete',
295294
tie_embeddings=TIE_EMBEDDINGS,
296-
debug=debug
295+
debug=debug,
296+
hebb_type='synapse'
297297
)
298298

299299
trainer = OdyssNetTrainer(
@@ -608,34 +608,29 @@ def flatten_logits(out):
608608
total_thinking_steps = seq_len * (THINK_GAP + 1)
609609

610610
if TRUNCATED_BPTT_SEQ_LEN != -1 and TRUNCATED_BPTT_SEQ_LEN > 0:
611-
current_state = None
612611
batch_loss = 0
613612
steps_count = 0
614-
613+
615614
chunk_len = TRUNCATED_BPTT_SEQ_LEN
616-
615+
617616
for t_start in range(0, seq_len, chunk_len):
618617
t_end = min(t_start + chunk_len, seq_len)
619-
620-
# Extract sequence chunk
618+
621619
x_chunk = x[:, t_start:t_end]
622620
y_chunk_flat = y[:, t_start:t_end].reshape(-1)
623-
624-
# Thinking steps for the current chunk
621+
625622
actual_tokens = t_end - t_start
626623
chunk_thinking_steps = actual_tokens * (THINK_GAP + 1)
627624

628-
loss, current_state = trainer.train_batch(
625+
loss = trainer.train_batch(
629626
x_chunk,
630627
y_chunk_flat,
631628
thinking_steps=chunk_thinking_steps,
632629
full_sequence=True,
633630
output_transform=flatten_logits,
634-
initial_state=current_state,
635-
return_state=True
631+
keep_state=(t_start > 0),
636632
)
637633

638-
current_state = current_state.detach()
639634
batch_loss += loss
640635
steps_count += 1
641636

odyssnet/core/network.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,10 @@ def _single_step(h_t_in, t_idx, x_input_info, hebb_W_contrib, hebb_mem_contrib):
696696
if return_sequence and (t + 1) % ratio == 0 and len(outputs) < max_outputs:
697697
outputs.append(h_t)
698698

699-
# Persist accumulated Hebbian state for the next forward call.
700-
if self.hebb_type is not None:
701-
with torch.no_grad():
699+
# Persist the recurrent state and Hebbian correlations for the next forward call.
700+
with torch.no_grad():
701+
self.state = h_t.detach()
702+
if self.hebb_type is not None:
702703
self.hebb_state_W.copy_(local_hebb_W.detach())
703704
self.hebb_state_mem.copy_(local_hebb_mem.detach())
704705

odyssnet/training/trainer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def load_state_dict(self, state):
190190
continue
191191
self._persistent_grads[id(param)] = persisted.to(device=param.device, dtype=param.dtype)
192192

193-
def train_batch(self, input_features, target_values, thinking_steps, gradient_accumulation_steps=1, full_sequence=False, mask=None, output_transform=None, initial_state=None, return_state=False):
193+
def train_batch(self, input_features, target_values, thinking_steps, gradient_accumulation_steps=1, full_sequence=False, mask=None, output_transform=None, keep_state=False):
194194
"""
195195
Runs a single training step on a batch.
196196
"""
@@ -227,16 +227,12 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac
227227

228228
# Forward Pass (with AMP)
229229
with self._get_autocast_ctx():
230-
# Use initial_state if provided, otherwise reset
231-
if initial_state is not None:
232-
current_state_in = initial_state
233-
else:
230+
if not keep_state:
234231
self.model.reset_state(batch_size)
235-
current_state_in = None
236232

237-
all_states, final_state = self.model(x_input, steps=thinking_steps, current_state=current_state_in, return_sequence=full_sequence)
233+
all_states, h_t = self.model(x_input, steps=thinking_steps, return_sequence=full_sequence)
238234

239-
predicted_outputs = self._extract_outputs(all_states, final_state, full_sequence)
235+
predicted_outputs = self._extract_outputs(all_states, h_t, full_sequence)
240236

241237
# Optional Transform
242238
if output_transform:
@@ -345,8 +341,6 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac
345341
else:
346342
self._plateau_hook_triggered = False
347343

348-
if return_state:
349-
return loss_val, final_state
350344
return loss_val
351345

352346
def predict(self, input_features, thinking_steps, full_sequence=False):

tests/core/test_network.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def test_dropout_disabled_in_eval(self):
254254
model = OdyssNet(num_neurons=4, input_ids=[0], output_ids=[3], device="cpu", dropout_rate=0.5)
255255
model.eval()
256256
x = torch.randn(4, 4)
257+
model.reset_state(batch_size=4)
257258
out1, _ = model(x, steps=3)
259+
model.reset_state(batch_size=4)
258260
out2, _ = model(x, steps=3)
259261
assert torch.allclose(out1, out2), "Eval mode must produce deterministic outputs"
260262

tests/training/test_trainer.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,34 +175,29 @@ def test_synaptic_noise_applies_without_error(self):
175175
loss = t.train_batch(x, y, thinking_steps=2)
176176
assert isinstance(loss, float)
177177

178-
def test_return_state_flag(self):
178+
def test_state_persisted_after_train_batch(self):
179179
model = _model()
180180
t = _trainer(model)
181181
x = _batch()
182182
y = _targets()
183-
result = t.train_batch(x, y, thinking_steps=2, return_state=True)
184-
assert isinstance(result, tuple)
185-
loss, state = result
183+
loss = t.train_batch(x, y, thinking_steps=2)
186184
assert isinstance(loss, float)
187-
assert state.shape == (4, model.num_neurons)
185+
assert model.state.shape == (4, model.num_neurons)
188186

189-
def test_tbptt_chained_initial_state(self):
190-
# experiment_llm.py: return_state=True feeds the final state back as
191-
# initial_state for the next chunk (Truncated BPTT).
187+
def test_tbptt_keep_state(self):
188+
# keep_state=True carries model.state across chunks without reset (Truncated BPTT).
192189
model = _model()
193190
t = _trainer(model)
194191
x = _batch()
195192
y = _targets()
196193

197-
loss1, state1 = t.train_batch(x, y, thinking_steps=2, return_state=True)
198-
state1 = state1.detach()
194+
loss1 = t.train_batch(x, y, thinking_steps=2)
195+
assert isinstance(loss1, float)
199196

200-
# Second chunk starts from where the first chunk ended
201-
loss2, state2 = t.train_batch(
202-
x, y, thinking_steps=2, initial_state=state1, return_state=True
203-
)
197+
# Second chunk continues from where the first chunk ended
198+
loss2 = t.train_batch(x, y, thinking_steps=2, keep_state=True)
204199
assert isinstance(loss2, float)
205-
assert state2.shape == (4, model.num_neurons)
200+
assert model.state.shape == (4, model.num_neurons)
206201

207202
def test_output_transform_applied(self):
208203
# convergence_mnist_reverse_record.py uses output_transform to slice warmup

0 commit comments

Comments
 (0)