Skip to content

Commit db94e2e

Browse files
authored
Update test_pipeline.py
1 parent ed1751b commit db94e2e

1 file changed

Lines changed: 101 additions & 104 deletions

File tree

Lines changed: 101 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,106 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Optional
4-
import random
5-
6-
from synapdrive_ai.bci.signal_simulator import BrainSignalSimulator
7-
from synapdrive_ai.agi.core_reasoning import AGICoreReasoner
8-
from synapdrive_ai.agi.cognitive_optimizer import CognitiveOptimizer
9-
from synapdrive_ai.agi.meta_evaluator import MetaEvaluator
10-
from synapdrive_ai.action.decision_router import DecisionRouter
11-
from synapdrive_ai.memory.episodic_memory import EpisodicMemory
12-
from synapdrive_ai.safety.safety_guard import SafetyGuard
13-
from synapdrive_ai.vision.visual_inference import VisualInferenceEngine
3+
import pytest
4+
145
from synapdrive_ai.bci.intent_generator import generate_intent
6+
from synapdrive_ai.pipeline import SynapDrivePipeline
7+
8+
9+
# ---------------------------------------------------------------------------
10+
# Text command path
11+
# ---------------------------------------------------------------------------
12+
13+
def test_text_command_known_intent_succeeds(pipeline: SynapDrivePipeline) -> None:
14+
out = pipeline.run_text_command("move left", image_label="road")
15+
assert out["status"] in {"success", "blocked"}
16+
assert "intent" in out
17+
assert "evaluation" in out
18+
19+
20+
def test_text_command_stop_succeeds(pipeline: SynapDrivePipeline) -> None:
21+
out = pipeline.run_text_command("stop")
22+
assert out["status"] in {"success", "blocked"}
23+
24+
25+
def test_text_command_unknown_input_is_blocked(pipeline: SynapDrivePipeline) -> None:
26+
# Unknown inputs resolve to low confidence and should be blocked by safety.
27+
out = pipeline.run_text_command("xyzzy nonsense")
28+
assert out["status"] == "blocked"
29+
30+
31+
def test_text_command_returns_required_keys(pipeline: SynapDrivePipeline) -> None:
32+
out = pipeline.run_text_command("move right")
33+
for key in ("status", "intent", "result", "evaluation"):
34+
assert key in out, f"Missing key: {key}"
35+
36+
37+
# ---------------------------------------------------------------------------
38+
# Signal path
39+
# ---------------------------------------------------------------------------
40+
41+
def test_signal_event_known_labels(pipeline: SynapDrivePipeline) -> None:
42+
for label in ("walk", "stop", "left_arm", "right_arm", "calculate", "recall", "explore"):
43+
out = pipeline.run_signal_event(label=label)
44+
assert out["status"] in {"success", "blocked"}
45+
46+
47+
def test_signal_event_random_label(pipeline: SynapDrivePipeline) -> None:
48+
out = pipeline.run_signal_event()
49+
assert out["status"] in {"success", "blocked"}
50+
51+
52+
def test_signal_event_unknown_label_raises(pipeline: SynapDrivePipeline) -> None:
53+
with pytest.raises(ValueError, match="Unknown signal label"):
54+
pipeline.run_signal_event(label="brain_blast")
55+
56+
57+
# ---------------------------------------------------------------------------
58+
# Safety gate
59+
# ---------------------------------------------------------------------------
60+
61+
def test_blocked_result_has_reason(pipeline: SynapDrivePipeline) -> None:
62+
out = pipeline.run_text_command("xyzzy")
63+
assert out["status"] == "blocked"
64+
assert "reason" in out
65+
66+
67+
def test_blocked_result_has_zero_evaluation_score(pipeline: SynapDrivePipeline) -> None:
68+
out = pipeline.run_text_command("xyzzy")
69+
assert out["status"] == "blocked"
70+
assert out["evaluation"]["score"] == 0.0
71+
72+
73+
# ---------------------------------------------------------------------------
74+
# Action log
75+
# ---------------------------------------------------------------------------
76+
77+
def test_action_log_grows(pipeline: SynapDrivePipeline) -> None:
78+
pipeline.run_text_command("move left", image_label="road")
79+
pipeline.run_signal_event(label="walk")
80+
log = pipeline.get_action_log()
81+
assert len(log) >= 1 # blocked intents do not reach actuation
82+
83+
84+
def test_action_log_entries_have_schema(pipeline: SynapDrivePipeline) -> None:
85+
pipeline.run_text_command("move right")
86+
for entry in pipeline.get_action_log():
87+
for key in ("intent", "confidence", "status", "duration"):
88+
assert key in entry
89+
90+
91+
# ---------------------------------------------------------------------------
92+
# run_intent_packet (public entrypoint used by replay + integrations)
93+
# ---------------------------------------------------------------------------
94+
95+
def test_run_intent_packet_roundtrip(pipeline: SynapDrivePipeline) -> None:
96+
packet = generate_intent("move left")
97+
out = pipeline.run_intent_packet(packet, image_label="road")
98+
assert out["status"] in {"success", "blocked"}
99+
assert out["result"]["intent"] is not None
15100

16101

17-
class SynapDrivePipeline:
18-
"""
19-
Canonical end-to-end simulation pipeline.
20-
21-
Determinism support:
22-
simulate_delay=False disables actuation sleep so replay/tests are fast and repeatable.
23-
"""
24-
25-
def __init__(self, simulate_delay: bool = True) -> None:
26-
self.simulator = BrainSignalSimulator()
27-
self.reasoner = AGICoreReasoner()
28-
29-
self.memory = EpisodicMemory()
30-
self.visual = VisualInferenceEngine()
31-
32-
self.optimizer = CognitiveOptimizer()
33-
self.optimizer.memory = self.memory
34-
self.optimizer.visual = self.visual
35-
36-
self.guard = SafetyGuard()
37-
self.router = DecisionRouter(simulate_delay=simulate_delay)
38-
self.evaluator = MetaEvaluator()
39-
40-
def run_text_command(self, command_text: str, image_label: Optional[str] = None) -> Dict[str, Any]:
41-
intent_packet = generate_intent(command_text)
42-
return self.run_intent_packet(intent_packet, image_label=image_label)
43-
44-
def run_signal_event(self, label: Optional[str] = None, image_label: Optional[str] = None) -> Dict[str, Any]:
45-
label = label or random.choice(["left_arm", "right_arm", "walk", "stop", "calculate", "recall", "explore"])
46-
signal = self._generate_signal_for_label(label)
47-
intent_packet = self.reasoner.receive_signal(label, signal)
48-
return self.run_intent_packet(intent_packet, image_label=image_label)
49-
50-
def run_intent_packet(self, intent_packet: Dict[str, Any], image_label: Optional[str] = None) -> Dict[str, Any]:
51-
"""
52-
Public entrypoint for integrations + replay.
53-
"""
54-
return self._run_common(intent_packet, image_label=image_label)
55-
56-
def _generate_signal_for_label(self, label: str):
57-
patterns = {
58-
"left_arm": 10,
59-
"right_arm": 12,
60-
"walk": 8,
61-
"stop": 3,
62-
"calculate": 25,
63-
"recall": 18,
64-
"explore": 30,
65-
}
66-
if label not in patterns:
67-
raise ValueError(f"Unknown signal label: {label}")
68-
return self.simulator.generate_waveform(patterns[label])
69-
70-
def _run_common(self, intent_packet: Dict[str, Any], image_label: Optional[str]) -> Dict[str, Any]:
71-
optimized = self.optimizer.optimize(intent_packet, image_label=image_label)
72-
73-
is_safe, reason = self.guard.evaluate_safety(optimized)
74-
if not is_safe:
75-
return {
76-
"status": "blocked",
77-
"reason": reason,
78-
"intent": optimized,
79-
"result": {
80-
"status": "blocked",
81-
"intent": optimized.get("intent", "unknown"),
82-
"confidence": optimized.get("confidence", 0.0),
83-
"duration": 0.0,
84-
},
85-
"evaluation": {
86-
"score": 0.0,
87-
"total_actions": 0,
88-
"avg_score": 0.0,
89-
},
90-
}
91-
92-
result = self.router.route(optimized)
93-
94-
try:
95-
self.memory.record_episode(optimized, result)
96-
except Exception:
97-
pass
98-
99-
evaluation = self.evaluator.evaluate(optimized, result)
100-
101-
return {
102-
"status": result["status"],
103-
"intent": optimized,
104-
"result": result,
105-
"evaluation": evaluation,
106-
}
107-
108-
def get_action_log(self):
109-
return self.router.get_action_log()
102+
def test_run_intent_packet_with_visual_context(pipeline: SynapDrivePipeline) -> None:
103+
for image_label in ("road", "hazard", "person", "vehicle"):
104+
packet = generate_intent("move forward")
105+
out = pipeline.run_intent_packet(packet, image_label=image_label)
106+
assert out["status"] in {"success", "blocked"}

0 commit comments

Comments
 (0)