Skip to content

Commit 8ec4f45

Browse files
authored
Create cli.py
1 parent d3d5128 commit 8ec4f45

1 file changed

Lines changed: 238 additions & 0 deletions

File tree

synapdrive_ai/neuro/cli.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import sys
5+
from pathlib import Path
6+
7+
import numpy as np
8+
9+
10+
def cmd_demo(args) -> int:
11+
from synapdrive_ai.neuro.band_analyzer import BandPowerAnalyzer
12+
from synapdrive_ai.neuro.eeg_loader import EEGLoader
13+
from synapdrive_ai.neuro.session_analyzer import SessionAnalyzer
14+
15+
print("SynapDrive-AI · Neuroscience demo\n")
16+
17+
sr = 256.0
18+
duration = 10.0
19+
n_samples = int(sr * duration)
20+
t = np.linspace(0, duration, n_samples, endpoint=False)
21+
22+
signal = (
23+
0.5 * np.sin(2 * np.pi * 6 * t)
24+
+ 1.0 * np.sin(2 * np.pi * 10 * t)
25+
+ 0.3 * np.sin(2 * np.pi * 20 * t)
26+
+ 0.1 * np.sin(2 * np.pi * 40 * t)
27+
)
28+
29+
burst = (t >= 4.0) & (t <= 7.0)
30+
signal[burst] = (
31+
0.2 * np.sin(2 * np.pi * 6 * t[burst])
32+
+ 0.2 * np.sin(2 * np.pi * 10 * t[burst])
33+
+ 1.2 * np.sin(2 * np.pi * 20 * t[burst])
34+
+ 0.8 * np.sin(2 * np.pi * 40 * t[burst])
35+
)
36+
signal += np.random.normal(0, 0.05, n_samples)
37+
38+
loader = EEGLoader(sampling_rate=sr)
39+
recording = loader.load_array(signal, sampling_rate=sr, channel_names=["C3"], source_label="demo_synthetic")
40+
41+
print(f"Synthetic recording: {recording.summary()}")
42+
print("Motor intent burst injected at t=4s–7s\n")
43+
44+
analyzer = BandPowerAnalyzer(sampling_rate=sr)
45+
result = analyzer.analyze(signal)
46+
print("Full-signal band power:")
47+
for band, power in result.relative.items():
48+
bar = "█" * int(power * 40)
49+
print(f" {band:6s} {power:.3f} {bar}")
50+
print(f" engagement ratio: {result.engagement_ratio:.3f}")
51+
print(f" intent class: {result.intent_class}")
52+
print(f" confidence: {result.confidence:.3f}\n")
53+
54+
session_analyzer = SessionAnalyzer(channel="C3", window_s=1.0, step_s=0.5)
55+
report = session_analyzer.run(recording)
56+
57+
print(report.summary())
58+
print("\nEpoch detail:")
59+
print(f" {'t_start':>7} {'class':>10} {'conf':>6} {'status':>8}")
60+
for ep in report.epochs:
61+
print(
62+
f" {ep.time_start_s:>7.1f}s {ep.intent_class:>10} "
63+
f"{ep.signal_confidence:>6.3f} {ep.pipeline_status:>8}"
64+
)
65+
66+
return 0
67+
68+
69+
def cmd_analyze(args) -> int:
70+
from synapdrive_ai.neuro.eeg_loader import EEGLoader
71+
from synapdrive_ai.neuro.session_analyzer import SessionAnalyzer
72+
73+
path = Path(args.file)
74+
if not path.exists():
75+
print(f"Error: file not found: {path}", file=sys.stderr)
76+
return 1
77+
78+
print(f"Loading {path}...")
79+
loader = EEGLoader(sampling_rate=args.sr)
80+
recording = loader.load(path)
81+
print(recording.summary())
82+
83+
if args.channels:
84+
print(f"Available channels: {recording.channels}")
85+
return 0
86+
87+
channel = args.channel or recording.channels[0]
88+
print(f"Analyzing channel: {channel} window={args.window}s step={args.step}s\n")
89+
90+
analyzer = SessionAnalyzer(
91+
channel=channel,
92+
window_s=args.window,
93+
step_s=args.step,
94+
image_label=args.image,
95+
)
96+
report = analyzer.run(recording)
97+
print(report.summary())
98+
99+
if args.out:
100+
out_dir = Path(args.out)
101+
out_dir.mkdir(parents=True, exist_ok=True)
102+
stem = path.stem
103+
jsonl_path = out_dir / f"{stem}_analysis.jsonl"
104+
csv_path = out_dir / f"{stem}_epochs.csv"
105+
report.save_jsonl(jsonl_path)
106+
report.save_csv(csv_path)
107+
print(f"\nSaved: {jsonl_path}")
108+
print(f"Saved: {csv_path}")
109+
110+
return 0
111+
112+
113+
def cmd_threshold(args) -> int:
114+
from synapdrive_ai.neuro.eeg_loader import EEGLoader
115+
from synapdrive_ai.neuro.session_analyzer import SessionAnalyzer
116+
117+
path = Path(args.file)
118+
if not path.exists():
119+
print(f"Error: file not found: {path}", file=sys.stderr)
120+
return 1
121+
122+
loader = EEGLoader(sampling_rate=args.sr)
123+
recording = loader.load(path)
124+
channel = args.channel or recording.channels[0]
125+
126+
thresholds = [float(t) for t in args.min_conf]
127+
print(f"Threshold sweep — {path.name} channel={channel}\n")
128+
print(f" {'threshold':>10} {'blocked%':>10} {'mean_conf':>10} {'n_success':>10}")
129+
130+
for thresh in thresholds:
131+
analyzer = SessionAnalyzer(channel=channel, window_s=args.window, step_s=args.step)
132+
analyzer._pipe.guard.min_confidence_threshold = thresh
133+
report = analyzer.run(recording)
134+
print(
135+
f" {thresh:>10.3f} {report.block_rate * 100:>9.1f}% "
136+
f"{report.mean_confidence:>10.3f} {report.n_success:>10}"
137+
)
138+
139+
return 0
140+
141+
142+
def cmd_plan(args) -> int:
143+
from synapdrive_ai.neuro.task_planner import ExecutorBridge, TaskPlan, TaskStep
144+
145+
plans = {
146+
"reach_grasp": TaskPlan(
147+
name="reach and grasp",
148+
steps=[
149+
TaskStep("move forward", min_confidence=0.55, label="approach"),
150+
TaskStep("move left", min_confidence=0.55, label="align"),
151+
TaskStep("pick up", min_confidence=0.70, fallback="freeze", label="grasp"),
152+
],
153+
),
154+
"navigate": TaskPlan(
155+
name="navigate to target",
156+
steps=[
157+
TaskStep("move forward", min_confidence=0.50, label="forward"),
158+
TaskStep("turn left", min_confidence=0.55, label="turn"),
159+
TaskStep("move forward", min_confidence=0.50, label="approach"),
160+
TaskStep("stop", min_confidence=0.45, label="halt"),
161+
],
162+
),
163+
"cognitive_sequence": TaskPlan(
164+
name="cognitive task sequence",
165+
steps=[
166+
TaskStep("calculate", min_confidence=0.55, label="compute"),
167+
TaskStep("recall", min_confidence=0.50, label="retrieve"),
168+
TaskStep("stop", min_confidence=0.45, label="confirm"),
169+
],
170+
),
171+
}
172+
173+
if args.list:
174+
print("Available plans:")
175+
for name, plan in plans.items():
176+
print(f" {name}: {len(plan)} steps — {plan.name}")
177+
return 0
178+
179+
task_name = args.task or "reach_grasp"
180+
if task_name not in plans:
181+
print(f"Unknown plan: {task_name!r}. Use --list to see options.", file=sys.stderr)
182+
return 1
183+
184+
plan = plans[task_name]
185+
bridge = ExecutorBridge(simulate_delay=False)
186+
print(f"Executing plan: {plan.name}\n")
187+
trace = bridge.execute(plan)
188+
print(trace.summary())
189+
return 0
190+
191+
192+
def build_parser() -> argparse.ArgumentParser:
193+
p = argparse.ArgumentParser(
194+
prog="python -m synapdrive_ai.neuro.cli",
195+
description="SynapDrive-AI neuroscience tools",
196+
)
197+
sub = p.add_subparsers(dest="command", required=True)
198+
199+
a = sub.add_parser("analyze", help="Analyze an EEG file session")
200+
a.add_argument("file", help="EEG file (.edf, .bdf, .csv, .npy)")
201+
a.add_argument("--channel", default=None, help="Channel to analyze (default: first)")
202+
a.add_argument("--channels", action="store_true", help="List available channels and exit")
203+
a.add_argument("--window", type=float, default=1.0, help="Epoch window in seconds")
204+
a.add_argument("--step", type=float, default=0.5, help="Sliding step in seconds")
205+
a.add_argument("--image", default=None, help="Optional visual context label")
206+
a.add_argument("--sr", type=float, default=256.0, help="Sampling rate Hz for CSV/NPY")
207+
a.add_argument("--out", default=None, help="Output directory for JSONL + CSV results")
208+
209+
t = sub.add_parser("threshold", help="Test multiple confidence thresholds")
210+
t.add_argument("file")
211+
t.add_argument("--min-conf", nargs="+", default=["0.4", "0.5", "0.6", "0.7"])
212+
t.add_argument("--channel", default=None)
213+
t.add_argument("--window", type=float, default=1.0)
214+
t.add_argument("--step", type=float, default=0.5)
215+
t.add_argument("--sr", type=float, default=256.0)
216+
217+
pl = sub.add_parser("plan", help="Execute a sequential task plan")
218+
pl.add_argument("--task", default=None, help="Plan name")
219+
pl.add_argument("--list", action="store_true", help="List available plans")
220+
221+
sub.add_parser("demo", help="Run a full demo with synthetic EEG data")
222+
223+
return p
224+
225+
226+
def main(argv=None) -> int:
227+
args = build_parser().parse_args(argv)
228+
dispatch = {
229+
"analyze": cmd_analyze,
230+
"threshold": cmd_threshold,
231+
"plan": cmd_plan,
232+
"demo": cmd_demo,
233+
}
234+
return dispatch[args.command](args)
235+
236+
237+
if __name__ == "__main__":
238+
raise SystemExit(main())

0 commit comments

Comments
 (0)