Skip to content

Commit 7181f63

Browse files
RandomOscillationsDeveshParagiri
authored andcommitted
fix(chat): support legacy study.db without schema migration
1 parent aabb07d commit 7181f63

1 file changed

Lines changed: 156 additions & 46 deletions

File tree

extropy/cli/commands/chat.py

Lines changed: 156 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,139 @@
55
import json
66
import sqlite3
77
import time
8+
import uuid
9+
from datetime import datetime
810
from pathlib import Path
911
from typing import Any
1012

1113
import typer
1214

1315
from ...config import get_config
1416
from ...core.llm import simple_call
15-
from ...storage import open_study_db
1617
from ..app import app, console, get_json_mode
1718

1819
chat_app = typer.Typer(help="Chat with simulated agents using DB-backed history")
1920
app.add_typer(chat_app, name="chat")
2021

2122

23+
def _now_iso() -> str:
24+
return datetime.now().isoformat()
25+
26+
27+
def _ensure_chat_tables(conn: sqlite3.Connection) -> None:
28+
cur = conn.cursor()
29+
cur.executescript(
30+
"""
31+
CREATE TABLE IF NOT EXISTS chat_sessions (
32+
session_id TEXT PRIMARY KEY,
33+
run_id TEXT NOT NULL,
34+
agent_id TEXT NOT NULL,
35+
mode TEXT NOT NULL,
36+
created_at TEXT NOT NULL,
37+
closed_at TEXT,
38+
meta_json TEXT
39+
);
40+
41+
CREATE TABLE IF NOT EXISTS chat_messages (
42+
session_id TEXT NOT NULL,
43+
turn_index INTEGER NOT NULL,
44+
role TEXT NOT NULL,
45+
content TEXT NOT NULL,
46+
citations_json TEXT,
47+
token_usage_json TEXT,
48+
created_at TEXT NOT NULL,
49+
PRIMARY KEY (session_id, turn_index)
50+
);
51+
"""
52+
)
53+
conn.commit()
54+
55+
56+
def _create_chat_session(
57+
conn: sqlite3.Connection,
58+
run_id: str,
59+
agent_id: str,
60+
mode: str,
61+
meta: dict[str, Any] | None = None,
62+
session_id: str | None = None,
63+
) -> str:
64+
_ensure_chat_tables(conn)
65+
sid = session_id or str(uuid.uuid4())
66+
cur = conn.cursor()
67+
cur.execute(
68+
"""
69+
INSERT OR REPLACE INTO chat_sessions
70+
(session_id, run_id, agent_id, mode, created_at, meta_json)
71+
VALUES (?, ?, ?, ?, ?, ?)
72+
""",
73+
(sid, run_id, agent_id, mode, _now_iso(), json.dumps(meta or {})),
74+
)
75+
conn.commit()
76+
return sid
77+
78+
79+
def _append_chat_message(
80+
conn: sqlite3.Connection,
81+
session_id: str,
82+
role: str,
83+
content: str,
84+
citations: dict[str, Any] | None = None,
85+
token_usage: dict[str, Any] | None = None,
86+
) -> int:
87+
_ensure_chat_tables(conn)
88+
cur = conn.cursor()
89+
cur.execute(
90+
"SELECT COALESCE(MAX(turn_index), -1) AS max_turn FROM chat_messages WHERE session_id = ?",
91+
(session_id,),
92+
)
93+
turn = int(cur.fetchone()["max_turn"]) + 1
94+
cur.execute(
95+
"""
96+
INSERT INTO chat_messages
97+
(session_id, turn_index, role, content, citations_json, token_usage_json, created_at)
98+
VALUES (?, ?, ?, ?, ?, ?, ?)
99+
""",
100+
(
101+
session_id,
102+
turn,
103+
role,
104+
content,
105+
json.dumps(citations or {}),
106+
json.dumps(token_usage or {}),
107+
_now_iso(),
108+
),
109+
)
110+
conn.commit()
111+
return turn
112+
113+
114+
def _get_chat_messages(conn: sqlite3.Connection, session_id: str) -> list[dict[str, Any]]:
115+
_ensure_chat_tables(conn)
116+
cur = conn.cursor()
117+
cur.execute(
118+
"""
119+
SELECT turn_index, role, content, citations_json, token_usage_json, created_at
120+
FROM chat_messages
121+
WHERE session_id = ?
122+
ORDER BY turn_index
123+
""",
124+
(session_id,),
125+
)
126+
rows = []
127+
for row in cur.fetchall():
128+
rows.append(
129+
{
130+
"turn_index": int(row["turn_index"]),
131+
"role": str(row["role"]),
132+
"content": str(row["content"]),
133+
"citations": json.loads(row["citations_json"] or "{}"),
134+
"token_usage": json.loads(row["token_usage_json"] or "{}"),
135+
"created_at": str(row["created_at"]),
136+
}
137+
)
138+
return rows
139+
140+
22141
def _load_agent_chat_context(
23142
conn: sqlite3.Connection,
24143
run_id: str,
@@ -368,13 +487,14 @@ def chat_interactive(
368487
console.print(f"[red]✗[/red] {e}")
369488
raise typer.Exit(1)
370489

371-
with open_study_db(study_db) as db:
372-
sid = session_id or db.create_chat_session(
373-
run_id=resolved_run_id,
374-
agent_id=resolved_agent_id,
375-
mode="interactive",
376-
meta={"entrypoint": "repl"},
377-
)
490+
sid = _create_chat_session(
491+
conn=conn,
492+
run_id=resolved_run_id,
493+
agent_id=resolved_agent_id,
494+
mode="interactive",
495+
meta={"entrypoint": "repl"},
496+
session_id=session_id,
497+
)
378498

379499
console.print(f"[bold]Chat session[/bold] {sid}")
380500
console.print(
@@ -394,8 +514,7 @@ def chat_interactive(
394514
if prompt == "/exit":
395515
break
396516
if prompt == "/history":
397-
with open_study_db(study_db) as db:
398-
messages = db.get_chat_messages(sid)
517+
messages = _get_chat_messages(conn, sid)
399518
for m in messages:
400519
console.print(f"[{m['role']}] {m['content']}")
401520
continue
@@ -424,8 +543,7 @@ def chat_interactive(
424543
context, citations = _load_agent_chat_context(
425544
conn, resolved_run_id, resolved_agent_id, timeline_n=12
426545
)
427-
with open_study_db(study_db) as db:
428-
history = db.get_chat_messages(sid)
546+
history = _get_chat_messages(conn, sid)
429547
try:
430548
answer, model_used = _generate_agent_chat_reply(
431549
context=context,
@@ -437,19 +555,19 @@ def chat_interactive(
437555
continue
438556
latency_ms = int((time.time() - started) * 1000)
439557

440-
with open_study_db(study_db) as db:
441-
db.append_chat_message(sid, "user", prompt)
442-
db.append_chat_message(
443-
sid,
444-
"assistant",
445-
answer,
446-
citations={"sources": citations, "model": model_used},
447-
token_usage={
448-
"input_tokens": 0,
449-
"output_tokens": 0,
450-
"latency_ms": latency_ms,
451-
},
452-
)
558+
_append_chat_message(conn, sid, "user", prompt)
559+
_append_chat_message(
560+
conn,
561+
sid,
562+
"assistant",
563+
answer,
564+
citations={"sources": citations, "model": model_used},
565+
token_usage={
566+
"input_tokens": 0,
567+
"output_tokens": 0,
568+
"latency_ms": latency_ms,
569+
},
570+
)
453571

454572
console.print(answer)
455573

@@ -482,24 +600,15 @@ def chat_ask(
482600
resolved_run_id, resolved_agent_id = _resolve_run_and_agent(
483601
conn, run_id, agent_id
484602
)
485-
except ValueError as e:
486-
console.print(f"[red]✗[/red] {e}")
487-
raise typer.Exit(1)
488-
finally:
489-
conn.close()
490-
491-
with open_study_db(study_db) as db:
492-
sid = session_id or db.create_chat_session(
603+
sid = _create_chat_session(
604+
conn=conn,
493605
run_id=resolved_run_id,
494606
agent_id=resolved_agent_id,
495607
mode="machine",
496608
meta={"entrypoint": "ask"},
609+
session_id=session_id,
497610
)
498-
history = db.get_chat_messages(sid)
499-
500-
conn = sqlite3.connect(str(study_db))
501-
conn.row_factory = sqlite3.Row
502-
try:
611+
history = _get_chat_messages(conn, sid)
503612
context, citations = _load_agent_chat_context(
504613
conn, resolved_run_id, resolved_agent_id, timeline_n=12
505614
)
@@ -508,14 +617,10 @@ def chat_ask(
508617
user_prompt=prompt,
509618
history=history,
510619
)
511-
finally:
512-
conn.close()
513-
514-
latency_ms = int((time.time() - started) * 1000)
515-
516-
with open_study_db(study_db) as db:
517-
user_turn = db.append_chat_message(sid, "user", prompt)
518-
assistant_turn = db.append_chat_message(
620+
latency_ms = int((time.time() - started) * 1000)
621+
user_turn = _append_chat_message(conn, sid, "user", prompt)
622+
assistant_turn = _append_chat_message(
623+
conn,
519624
sid,
520625
"assistant",
521626
answer,
@@ -526,6 +631,11 @@ def chat_ask(
526631
"latency_ms": latency_ms,
527632
},
528633
)
634+
except ValueError as e:
635+
console.print(f"[red]✗[/red] {e}")
636+
raise typer.Exit(1)
637+
finally:
638+
conn.close()
529639

530640
payload = {
531641
"session_id": sid,

0 commit comments

Comments
 (0)