Skip to content

Commit aabb07d

Browse files
RandomOscillationsDeveshParagiri
authored andcommitted
feat(chat): default targets, list command, and robust OpenAI structured replies
1 parent 5b2d294 commit aabb07d

2 files changed

Lines changed: 47 additions & 162 deletions

File tree

extropy/cli/commands/chat.py

Lines changed: 46 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -5,141 +5,20 @@
55
import json
66
import sqlite3
77
import time
8-
import uuid
9-
from datetime import datetime
108
from pathlib import Path
119
from typing import Any
1210

1311
import typer
1412

1513
from ...config import get_config
1614
from ...core.llm import simple_call
15+
from ...storage import open_study_db
1716
from ..app import app, console, get_json_mode
1817

1918
chat_app = typer.Typer(help="Chat with simulated agents using DB-backed history")
2019
app.add_typer(chat_app, name="chat")
2120

2221

23-
def _now_iso() -> str:
24-
return datetime.now().isoformat()
25-
26-
27-
def _ensure_chat_tables(conn: sqlite3.Connection) -> None:
28-
cur = conn.cursor()
29-
cur.executescript(
30-
"""
31-
CREATE TABLE IF NOT EXISTS chat_sessions (
32-
session_id TEXT PRIMARY KEY,
33-
run_id TEXT NOT NULL,
34-
agent_id TEXT NOT NULL,
35-
mode TEXT NOT NULL,
36-
created_at TEXT NOT NULL,
37-
closed_at TEXT,
38-
meta_json TEXT
39-
);
40-
41-
CREATE TABLE IF NOT EXISTS chat_messages (
42-
session_id TEXT NOT NULL,
43-
turn_index INTEGER NOT NULL,
44-
role TEXT NOT NULL,
45-
content TEXT NOT NULL,
46-
citations_json TEXT,
47-
token_usage_json TEXT,
48-
created_at TEXT NOT NULL,
49-
PRIMARY KEY (session_id, turn_index)
50-
);
51-
"""
52-
)
53-
conn.commit()
54-
55-
56-
def _create_chat_session(
57-
conn: sqlite3.Connection,
58-
run_id: str,
59-
agent_id: str,
60-
mode: str,
61-
meta: dict[str, Any] | None = None,
62-
session_id: str | None = None,
63-
) -> str:
64-
_ensure_chat_tables(conn)
65-
sid = session_id or str(uuid.uuid4())
66-
cur = conn.cursor()
67-
cur.execute(
68-
"""
69-
INSERT OR REPLACE INTO chat_sessions
70-
(session_id, run_id, agent_id, mode, created_at, meta_json)
71-
VALUES (?, ?, ?, ?, ?, ?)
72-
""",
73-
(sid, run_id, agent_id, mode, _now_iso(), json.dumps(meta or {})),
74-
)
75-
conn.commit()
76-
return sid
77-
78-
79-
def _append_chat_message(
80-
conn: sqlite3.Connection,
81-
session_id: str,
82-
role: str,
83-
content: str,
84-
citations: dict[str, Any] | None = None,
85-
token_usage: dict[str, Any] | None = None,
86-
) -> int:
87-
_ensure_chat_tables(conn)
88-
cur = conn.cursor()
89-
cur.execute(
90-
"SELECT COALESCE(MAX(turn_index), -1) AS max_turn FROM chat_messages WHERE session_id = ?",
91-
(session_id,),
92-
)
93-
turn = int(cur.fetchone()["max_turn"]) + 1
94-
cur.execute(
95-
"""
96-
INSERT INTO chat_messages
97-
(session_id, turn_index, role, content, citations_json, token_usage_json, created_at)
98-
VALUES (?, ?, ?, ?, ?, ?, ?)
99-
""",
100-
(
101-
session_id,
102-
turn,
103-
role,
104-
content,
105-
json.dumps(citations or {}),
106-
json.dumps(token_usage or {}),
107-
_now_iso(),
108-
),
109-
)
110-
conn.commit()
111-
return turn
112-
113-
114-
def _get_chat_messages(
115-
conn: sqlite3.Connection, session_id: str
116-
) -> list[dict[str, Any]]:
117-
_ensure_chat_tables(conn)
118-
cur = conn.cursor()
119-
cur.execute(
120-
"""
121-
SELECT turn_index, role, content, citations_json, token_usage_json, created_at
122-
FROM chat_messages
123-
WHERE session_id = ?
124-
ORDER BY turn_index
125-
""",
126-
(session_id,),
127-
)
128-
rows = []
129-
for row in cur.fetchall():
130-
rows.append(
131-
{
132-
"turn_index": int(row["turn_index"]),
133-
"role": str(row["role"]),
134-
"content": str(row["content"]),
135-
"citations": json.loads(row["citations_json"] or "{}"),
136-
"token_usage": json.loads(row["token_usage_json"] or "{}"),
137-
"created_at": str(row["created_at"]),
138-
}
139-
)
140-
return rows
141-
142-
14322
def _load_agent_chat_context(
14423
conn: sqlite3.Connection,
14524
run_id: str,
@@ -489,14 +368,13 @@ def chat_interactive(
489368
console.print(f"[red]✗[/red] {e}")
490369
raise typer.Exit(1)
491370

492-
sid = _create_chat_session(
493-
conn=conn,
494-
run_id=resolved_run_id,
495-
agent_id=resolved_agent_id,
496-
mode="interactive",
497-
meta={"entrypoint": "repl"},
498-
session_id=session_id,
499-
)
371+
with open_study_db(study_db) as db:
372+
sid = session_id or db.create_chat_session(
373+
run_id=resolved_run_id,
374+
agent_id=resolved_agent_id,
375+
mode="interactive",
376+
meta={"entrypoint": "repl"},
377+
)
500378

501379
console.print(f"[bold]Chat session[/bold] {sid}")
502380
console.print(
@@ -516,7 +394,8 @@ def chat_interactive(
516394
if prompt == "/exit":
517395
break
518396
if prompt == "/history":
519-
messages = _get_chat_messages(conn, sid)
397+
with open_study_db(study_db) as db:
398+
messages = db.get_chat_messages(sid)
520399
for m in messages:
521400
console.print(f"[{m['role']}] {m['content']}")
522401
continue
@@ -545,7 +424,8 @@ def chat_interactive(
545424
context, citations = _load_agent_chat_context(
546425
conn, resolved_run_id, resolved_agent_id, timeline_n=12
547426
)
548-
history = _get_chat_messages(conn, sid)
427+
with open_study_db(study_db) as db:
428+
history = db.get_chat_messages(sid)
549429
try:
550430
answer, model_used = _generate_agent_chat_reply(
551431
context=context,
@@ -557,19 +437,19 @@ def chat_interactive(
557437
continue
558438
latency_ms = int((time.time() - started) * 1000)
559439

560-
_append_chat_message(conn, sid, "user", prompt)
561-
_append_chat_message(
562-
conn,
563-
sid,
564-
"assistant",
565-
answer,
566-
citations={"sources": citations, "model": model_used},
567-
token_usage={
568-
"input_tokens": 0,
569-
"output_tokens": 0,
570-
"latency_ms": latency_ms,
571-
},
572-
)
440+
with open_study_db(study_db) as db:
441+
db.append_chat_message(sid, "user", prompt)
442+
db.append_chat_message(
443+
sid,
444+
"assistant",
445+
answer,
446+
citations={"sources": citations, "model": model_used},
447+
token_usage={
448+
"input_tokens": 0,
449+
"output_tokens": 0,
450+
"latency_ms": latency_ms,
451+
},
452+
)
573453

574454
console.print(answer)
575455

@@ -602,15 +482,24 @@ def chat_ask(
602482
resolved_run_id, resolved_agent_id = _resolve_run_and_agent(
603483
conn, run_id, agent_id
604484
)
605-
sid = _create_chat_session(
606-
conn=conn,
485+
except ValueError as e:
486+
console.print(f"[red]✗[/red] {e}")
487+
raise typer.Exit(1)
488+
finally:
489+
conn.close()
490+
491+
with open_study_db(study_db) as db:
492+
sid = session_id or db.create_chat_session(
607493
run_id=resolved_run_id,
608494
agent_id=resolved_agent_id,
609495
mode="machine",
610496
meta={"entrypoint": "ask"},
611-
session_id=session_id,
612497
)
613-
history = _get_chat_messages(conn, sid)
498+
history = db.get_chat_messages(sid)
499+
500+
conn = sqlite3.connect(str(study_db))
501+
conn.row_factory = sqlite3.Row
502+
try:
614503
context, citations = _load_agent_chat_context(
615504
conn, resolved_run_id, resolved_agent_id, timeline_n=12
616505
)
@@ -619,10 +508,14 @@ def chat_ask(
619508
user_prompt=prompt,
620509
history=history,
621510
)
622-
latency_ms = int((time.time() - started) * 1000)
623-
user_turn = _append_chat_message(conn, sid, "user", prompt)
624-
assistant_turn = _append_chat_message(
625-
conn,
511+
finally:
512+
conn.close()
513+
514+
latency_ms = int((time.time() - started) * 1000)
515+
516+
with open_study_db(study_db) as db:
517+
user_turn = db.append_chat_message(sid, "user", prompt)
518+
assistant_turn = db.append_chat_message(
626519
sid,
627520
"assistant",
628521
answer,
@@ -633,11 +526,6 @@ def chat_ask(
633526
"latency_ms": latency_ms,
634527
},
635528
)
636-
except ValueError as e:
637-
console.print(f"[red]✗[/red] {e}")
638-
raise typer.Exit(1)
639-
finally:
640-
conn.close()
641529

642530
payload = {
643531
"session_id": sid,

tests/test_providers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,7 @@ def test_retries_when_incomplete_due_to_max_output_tokens(self, mock_get_client)
198198
complete_response.incomplete_details = None
199199

200200
mock_client = MagicMock()
201-
mock_client.responses.create.side_effect = [
202-
incomplete_response,
203-
complete_response,
204-
]
201+
mock_client.responses.create.side_effect = [incomplete_response, complete_response]
205202
mock_get_client.return_value = mock_client
206203

207204
result = provider.simple_call(

0 commit comments

Comments
 (0)