Skip to content

Commit cb7511b

Browse files
authored
Create test_neuro.py
1 parent 8ec4f45 commit cb7511b

1 file changed

Lines changed: 264 additions & 0 deletions

File tree

synapdrive_ai/tests/test_neuro.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pytest
5+
6+
from synapdrive_ai.neuro.band_analyzer import BANDS, BandPowerAnalyzer
7+
from synapdrive_ai.neuro.eeg_loader import EEGLoader, EEGRecording
8+
from synapdrive_ai.neuro.session_analyzer import SessionAnalyzer
9+
from synapdrive_ai.neuro.task_planner import ExecutorBridge, TaskPlan, TaskStep
10+
11+
12+
SR = 256.0
13+
DURATION = 2.0
14+
N = int(SR * DURATION)
15+
T = np.linspace(0, DURATION, N, endpoint=False)
16+
17+
18+
def _make_signal(freq_hz: float, noise: float = 0.02) -> np.ndarray:
19+
return np.sin(2 * np.pi * freq_hz * T) + np.random.default_rng(42).normal(0, noise, N)
20+
21+
22+
def _motor_signal() -> np.ndarray:
23+
return (
24+
0.1 * np.sin(2 * np.pi * 10 * T)
25+
+ 1.5 * np.sin(2 * np.pi * 20 * T)
26+
+ 0.8 * np.sin(2 * np.pi * 40 * T)
27+
+ np.random.default_rng(0).normal(0, 0.02, N)
28+
)
29+
30+
31+
def _alpha_signal() -> np.ndarray:
32+
return (
33+
2.0 * np.sin(2 * np.pi * 10 * T)
34+
+ 0.1 * np.sin(2 * np.pi * 20 * T)
35+
+ np.random.default_rng(1).normal(0, 0.02, N)
36+
)
37+
38+
39+
class TestBandPowerAnalyzer:
40+
def test_returns_all_bands(self):
41+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
42+
result = analyzer.analyze(_make_signal(10.0))
43+
assert set(result.absolute.keys()) == set(BANDS.keys())
44+
assert set(result.relative.keys()) == set(BANDS.keys())
45+
46+
def test_relative_power_sums_to_one(self):
47+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
48+
result = analyzer.analyze(_make_signal(10.0))
49+
assert abs(sum(result.relative.values()) - 1.0) < 1e-6
50+
51+
def test_confidence_in_range(self):
52+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
53+
for freq in (6, 10, 20, 40):
54+
result = analyzer.analyze(_make_signal(freq))
55+
assert 0.0 <= result.confidence <= 1.0
56+
57+
def test_motor_signal_classified_motor(self):
58+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
59+
result = analyzer.analyze(_motor_signal())
60+
assert result.intent_class == "motor"
61+
62+
def test_alpha_signal_classified_unclear(self):
63+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
64+
result = analyzer.analyze(_alpha_signal())
65+
assert result.intent_class == "unclear"
66+
67+
def test_motor_signal_higher_confidence_than_alpha(self):
68+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
69+
motor_conf = analyzer.analyze(_motor_signal()).confidence
70+
alpha_conf = analyzer.analyze(_alpha_signal()).confidence
71+
assert motor_conf > alpha_conf
72+
73+
def test_short_signal_returns_zero_result(self):
74+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
75+
result = analyzer.analyze(np.array([0.1, 0.2]))
76+
assert result.confidence == 0.0
77+
assert result.intent_class == "unclear"
78+
79+
def test_engagement_ratio_positive(self):
80+
analyzer = BandPowerAnalyzer(sampling_rate=SR)
81+
result = analyzer.analyze(_motor_signal())
82+
assert result.engagement_ratio > 0.0
83+
84+
85+
class TestEEGLoader:
86+
def test_load_1d_array(self):
87+
loader = EEGLoader(sampling_rate=SR)
88+
recording = loader.load_array(_make_signal(10.0))
89+
assert recording.n_channels == 1
90+
assert recording.n_samples == N
91+
assert recording.sampling_rate == SR
92+
93+
def test_load_2d_array(self):
94+
loader = EEGLoader(sampling_rate=SR)
95+
data = np.stack([_make_signal(10.0), _make_signal(20.0)])
96+
recording = loader.load_array(data)
97+
assert recording.n_channels == 2
98+
assert recording.n_samples == N
99+
100+
def test_channel_lookup_by_name(self):
101+
loader = EEGLoader(sampling_rate=SR)
102+
data = np.stack([_make_signal(10.0), _make_signal(20.0)])
103+
recording = loader.load_array(data, channel_names=["C3", "C4"])
104+
ch = recording.channel("C3")
105+
assert len(ch) == N
106+
107+
def test_channel_lookup_case_insensitive(self):
108+
loader = EEGLoader(sampling_rate=SR)
109+
recording = loader.load_array(_make_signal(10.0), channel_names=["Cz"])
110+
recording.channel("cz")
111+
112+
def test_channel_not_found_raises(self):
113+
loader = EEGLoader(sampling_rate=SR)
114+
recording = loader.load_array(_make_signal(10.0), channel_names=["C3"])
115+
with pytest.raises(KeyError):
116+
recording.channel("Fz")
117+
118+
def test_duration_correct(self):
119+
loader = EEGLoader(sampling_rate=SR)
120+
recording = loader.load_array(_make_signal(10.0))
121+
assert abs(recording.duration_s - DURATION) < 0.01
122+
123+
def test_summary_returns_string(self):
124+
loader = EEGLoader(sampling_rate=SR)
125+
recording = loader.load_array(_make_signal(10.0))
126+
assert isinstance(recording.summary(), str)
127+
128+
129+
class TestSessionAnalyzer:
130+
def _recording(self, signal: np.ndarray) -> EEGRecording:
131+
return EEGLoader(sampling_rate=SR).load_array(signal, channel_names=["C3"])
132+
133+
def test_produces_epochs(self):
134+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.25)
135+
report = analyzer.run(self._recording(_motor_signal()))
136+
assert report.n_epochs > 0
137+
138+
def test_epoch_count_reasonable(self):
139+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
140+
report = analyzer.run(self._recording(_motor_signal()))
141+
expected = int(DURATION / 0.5)
142+
assert abs(report.n_epochs - expected) <= 1
143+
144+
def test_success_plus_blocked_equals_total(self):
145+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
146+
report = analyzer.run(self._recording(_motor_signal()))
147+
assert report.n_success + report.n_blocked == report.n_epochs
148+
149+
def test_block_rate_in_range(self):
150+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
151+
report = analyzer.run(self._recording(_motor_signal()))
152+
assert 0.0 <= report.block_rate <= 1.0
153+
154+
def test_mean_confidence_in_range(self):
155+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
156+
report = analyzer.run(self._recording(_motor_signal()))
157+
assert 0.0 <= report.mean_confidence <= 1.0
158+
159+
def test_intent_distribution_sums_to_n_epochs(self):
160+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
161+
report = analyzer.run(self._recording(_motor_signal()))
162+
assert sum(report.intent_distribution.values()) == report.n_epochs
163+
164+
def test_alpha_signal_has_higher_block_rate_than_motor(self):
165+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
166+
motor_report = analyzer.run(self._recording(_motor_signal()))
167+
analyzer2 = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
168+
alpha_report = analyzer2.run(self._recording(_alpha_signal()))
169+
assert alpha_report.block_rate >= motor_report.block_rate
170+
171+
def test_save_jsonl(self, tmp_path):
172+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
173+
report = analyzer.run(self._recording(_motor_signal()))
174+
out = tmp_path / "report.jsonl"
175+
report.save_jsonl(out)
176+
assert out.exists()
177+
lines = out.read_text().strip().splitlines()
178+
assert len(lines) == report.n_epochs + 1
179+
180+
def test_save_csv(self, tmp_path):
181+
analyzer = SessionAnalyzer(channel="C3", window_s=0.5, step_s=0.5)
182+
report = analyzer.run(self._recording(_motor_signal()))
183+
out = tmp_path / "report.csv"
184+
report.save_csv(out)
185+
assert out.exists()
186+
187+
def test_window_too_short_raises(self):
188+
analyzer = SessionAnalyzer(channel="C3", window_s=0.001, step_s=0.001)
189+
with pytest.raises(ValueError, match="Window too short"):
190+
analyzer.run(self._recording(_motor_signal()))
191+
192+
193+
class TestTaskPlanner:
194+
def _simple_plan(self) -> TaskPlan:
195+
return TaskPlan(
196+
name="test plan",
197+
steps=[
198+
TaskStep("move left", min_confidence=0.0, label="step1"),
199+
TaskStep("stop", min_confidence=0.0, label="step2"),
200+
],
201+
)
202+
203+
def test_basic_plan_executes(self):
204+
bridge = ExecutorBridge(simulate_delay=False)
205+
trace = bridge.execute(self._simple_plan())
206+
assert trace.n_steps == 2
207+
assert trace.outcome in ("completed", "frozen", "partial", "aborted")
208+
209+
def test_trace_has_all_steps(self):
210+
bridge = ExecutorBridge(simulate_delay=False)
211+
trace = bridge.execute(self._simple_plan())
212+
assert len(trace.steps) == 2
213+
214+
def test_step_trace_fields_present(self):
215+
bridge = ExecutorBridge(simulate_delay=False)
216+
trace = bridge.execute(self._simple_plan())
217+
for step in trace.steps:
218+
assert step.pipeline_status in ("success", "blocked", "deferred", "aborted")
219+
assert 0.0 <= step.pipeline_confidence <= 1.0
220+
assert step.elapsed_s >= 0.0
221+
222+
def test_impossible_confidence_defers(self):
223+
plan = TaskPlan(
224+
name="impossible",
225+
steps=[TaskStep("move left", min_confidence=1.1, fallback="freeze")],
226+
)
227+
bridge = ExecutorBridge(simulate_delay=False)
228+
trace = bridge.execute(plan)
229+
assert any(s.pipeline_status == "deferred" for s in trace.steps)
230+
231+
def test_abort_fallback_stops_plan(self):
232+
plan = TaskPlan(
233+
name="abort test",
234+
steps=[
235+
TaskStep("move left", min_confidence=1.1, fallback="abort"),
236+
TaskStep("stop", min_confidence=0.0),
237+
],
238+
)
239+
bridge = ExecutorBridge(simulate_delay=False)
240+
trace = bridge.execute(plan)
241+
assert trace.outcome == "aborted"
242+
assert len(trace.steps) == 1
243+
244+
def test_complete_fallback_proceeds(self):
245+
plan = TaskPlan(
246+
name="complete test",
247+
steps=[TaskStep("move left", min_confidence=1.1, fallback="complete")],
248+
)
249+
bridge = ExecutorBridge(simulate_delay=False)
250+
trace = bridge.execute(plan)
251+
assert trace.steps[0].pipeline_status in ("success", "blocked")
252+
253+
def test_plan_summary_returns_string(self):
254+
bridge = ExecutorBridge(simulate_delay=False)
255+
trace = bridge.execute(self._simple_plan())
256+
assert isinstance(trace.summary(), str)
257+
assert "test plan" in trace.summary()
258+
259+
def test_empty_plan(self):
260+
plan = TaskPlan(name="empty", steps=[])
261+
bridge = ExecutorBridge(simulate_delay=False)
262+
trace = bridge.execute(plan)
263+
assert trace.n_steps == 0
264+
assert trace.outcome == "completed"

0 commit comments

Comments
 (0)