Skip to content

Commit d3d5128

Browse files
authored
Create task_planner.py
1 parent 74ba6eb commit d3d5128

1 file changed

Lines changed: 171 additions & 0 deletions

File tree

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from __future__ import annotations
2+
3+
import time
4+
from dataclasses import dataclass, field
5+
from typing import List, Literal, Optional
6+
7+
from synapdrive_ai.bci.intent_generator import generate_intent
8+
from synapdrive_ai.pipeline import SynapDrivePipeline
9+
10+
11+
FallbackPolicy = Literal["freeze", "abort", "complete"]
12+
13+
14+
@dataclass
15+
class TaskStep:
16+
intent_text: str
17+
min_confidence: float = 0.55
18+
fallback: FallbackPolicy = "freeze"
19+
image_label: Optional[str] = None
20+
label: Optional[str] = None
21+
22+
def __post_init__(self) -> None:
23+
if self.label is None:
24+
self.label = self.intent_text
25+
26+
27+
@dataclass
28+
class TaskPlan:
29+
name: str
30+
steps: List[TaskStep]
31+
32+
def __len__(self) -> int:
33+
return len(self.steps)
34+
35+
36+
@dataclass
37+
class StepTrace:
38+
step_index: int
39+
label: str
40+
intent_text: str
41+
pipeline_status: str
42+
pipeline_confidence: float
43+
min_confidence: float
44+
fallback_applied: Optional[FallbackPolicy]
45+
block_reason: Optional[str]
46+
evaluation_score: float
47+
elapsed_s: float
48+
49+
50+
@dataclass
51+
class PlanTrace:
52+
plan_name: str
53+
outcome: str
54+
n_steps: int
55+
n_completed: int
56+
n_deferred: int
57+
n_aborted: int
58+
steps: List[StepTrace]
59+
total_elapsed_s: float
60+
created_utc: float = field(default_factory=time.time)
61+
62+
def summary(self) -> str:
63+
lines = [
64+
f"Plan: {self.plan_name}{self.outcome.upper()}",
65+
f" Steps: {self.n_steps} completed: {self.n_completed} deferred: {self.n_deferred} aborted: {self.n_aborted}",
66+
f" Total time: {self.total_elapsed_s:.3f}s",
67+
]
68+
for s in self.steps:
69+
tag = {
70+
"success": "✓",
71+
"blocked": "✗",
72+
"deferred": "⏸",
73+
"aborted": "⊘",
74+
}.get(s.pipeline_status, "?")
75+
lines.append(
76+
f" [{tag}] step {s.step_index}: {s.label!r} conf={s.pipeline_confidence:.2f} score={s.evaluation_score:.2f}"
77+
+ (f" → {s.block_reason}" if s.block_reason else "")
78+
)
79+
return "\n".join(lines)
80+
81+
82+
class ExecutorBridge:
83+
def __init__(
84+
self,
85+
simulate_delay: bool = False,
86+
pipeline: Optional[SynapDrivePipeline] = None,
87+
) -> None:
88+
self._pipe = pipeline or SynapDrivePipeline(simulate_delay=simulate_delay)
89+
90+
def execute(self, plan: TaskPlan) -> PlanTrace:
91+
step_traces: List[StepTrace] = []
92+
plan_start = time.time()
93+
overall_outcome = "completed"
94+
n_completed = n_deferred = n_aborted = 0
95+
96+
for idx, step in enumerate(plan.steps):
97+
step_start = time.time()
98+
trace = self._execute_step(idx, step)
99+
trace.elapsed_s = round(time.time() - step_start, 4)
100+
step_traces.append(trace)
101+
102+
if trace.pipeline_status == "success":
103+
n_completed += 1
104+
elif trace.pipeline_status == "deferred":
105+
n_deferred += 1
106+
if overall_outcome == "completed":
107+
overall_outcome = "frozen"
108+
elif trace.pipeline_status == "aborted":
109+
n_aborted += 1
110+
overall_outcome = "aborted"
111+
break
112+
else:
113+
if step.fallback == "abort":
114+
n_aborted += 1
115+
overall_outcome = "aborted"
116+
break
117+
n_deferred += 1
118+
if overall_outcome == "completed":
119+
overall_outcome = "frozen"
120+
121+
if overall_outcome == "completed" and (n_deferred > 0 or n_aborted > 0):
122+
overall_outcome = "partial"
123+
124+
return PlanTrace(
125+
plan_name=plan.name,
126+
outcome=overall_outcome,
127+
n_steps=len(plan.steps),
128+
n_completed=n_completed,
129+
n_deferred=n_deferred,
130+
n_aborted=n_aborted,
131+
steps=step_traces,
132+
total_elapsed_s=round(time.time() - plan_start, 4),
133+
)
134+
135+
def _execute_step(self, idx: int, step: TaskStep) -> StepTrace:
136+
base_packet = generate_intent(step.intent_text)
137+
out = self._pipe.run_intent_packet(base_packet, image_label=step.image_label)
138+
139+
intent_out = out.get("intent", {}) or {}
140+
eval_out = out.get("evaluation", {}) or {}
141+
pipeline_status = out.get("status", "blocked")
142+
confidence = float(intent_out.get("confidence", 0.0))
143+
block_reason = out.get("reason")
144+
eval_score = float(eval_out.get("score", 0.0))
145+
fallback_applied: Optional[FallbackPolicy] = None
146+
147+
if pipeline_status == "success" and confidence < step.min_confidence:
148+
fallback_applied = step.fallback
149+
if step.fallback == "freeze":
150+
pipeline_status = "deferred"
151+
block_reason = (
152+
f"Step confidence {confidence:.2f} < required {step.min_confidence:.2f} → freeze"
153+
)
154+
elif step.fallback == "abort":
155+
pipeline_status = "aborted"
156+
block_reason = (
157+
f"Step confidence {confidence:.2f} < required {step.min_confidence:.2f} → abort"
158+
)
159+
160+
return StepTrace(
161+
step_index=idx,
162+
label=step.label or step.intent_text,
163+
intent_text=step.intent_text,
164+
pipeline_status=pipeline_status,
165+
pipeline_confidence=round(confidence, 4),
166+
min_confidence=step.min_confidence,
167+
fallback_applied=fallback_applied,
168+
block_reason=block_reason,
169+
evaluation_score=eval_score,
170+
elapsed_s=0.0,
171+
)

0 commit comments

Comments
 (0)