55import json
66import sqlite3
77import time
8+ import uuid
9+ from datetime import datetime
810from pathlib import Path
911from typing import Any
1012
1113import typer
1214
1315from ...config import get_config
1416from ...core .llm import simple_call
15- from ...storage import open_study_db
1617from ..app import app , console , get_json_mode
1718
1819chat_app = typer .Typer (help = "Chat with simulated agents using DB-backed history" )
1920app .add_typer (chat_app , name = "chat" )
2021
2122
23+ def _now_iso () -> str :
24+ return datetime .now ().isoformat ()
25+
26+
27+ def _ensure_chat_tables (conn : sqlite3 .Connection ) -> None :
28+ cur = conn .cursor ()
29+ cur .executescript (
30+ """
31+ CREATE TABLE IF NOT EXISTS chat_sessions (
32+ session_id TEXT PRIMARY KEY,
33+ run_id TEXT NOT NULL,
34+ agent_id TEXT NOT NULL,
35+ mode TEXT NOT NULL,
36+ created_at TEXT NOT NULL,
37+ closed_at TEXT,
38+ meta_json TEXT
39+ );
40+
41+ CREATE TABLE IF NOT EXISTS chat_messages (
42+ session_id TEXT NOT NULL,
43+ turn_index INTEGER NOT NULL,
44+ role TEXT NOT NULL,
45+ content TEXT NOT NULL,
46+ citations_json TEXT,
47+ token_usage_json TEXT,
48+ created_at TEXT NOT NULL,
49+ PRIMARY KEY (session_id, turn_index)
50+ );
51+ """
52+ )
53+ conn .commit ()
54+
55+
56+ def _create_chat_session (
57+ conn : sqlite3 .Connection ,
58+ run_id : str ,
59+ agent_id : str ,
60+ mode : str ,
61+ meta : dict [str , Any ] | None = None ,
62+ session_id : str | None = None ,
63+ ) -> str :
64+ _ensure_chat_tables (conn )
65+ sid = session_id or str (uuid .uuid4 ())
66+ cur = conn .cursor ()
67+ cur .execute (
68+ """
69+ INSERT OR REPLACE INTO chat_sessions
70+ (session_id, run_id, agent_id, mode, created_at, meta_json)
71+ VALUES (?, ?, ?, ?, ?, ?)
72+ """ ,
73+ (sid , run_id , agent_id , mode , _now_iso (), json .dumps (meta or {})),
74+ )
75+ conn .commit ()
76+ return sid
77+
78+
79+ def _append_chat_message (
80+ conn : sqlite3 .Connection ,
81+ session_id : str ,
82+ role : str ,
83+ content : str ,
84+ citations : dict [str , Any ] | None = None ,
85+ token_usage : dict [str , Any ] | None = None ,
86+ ) -> int :
87+ _ensure_chat_tables (conn )
88+ cur = conn .cursor ()
89+ cur .execute (
90+ "SELECT COALESCE(MAX(turn_index), -1) AS max_turn FROM chat_messages WHERE session_id = ?" ,
91+ (session_id ,),
92+ )
93+ turn = int (cur .fetchone ()["max_turn" ]) + 1
94+ cur .execute (
95+ """
96+ INSERT INTO chat_messages
97+ (session_id, turn_index, role, content, citations_json, token_usage_json, created_at)
98+ VALUES (?, ?, ?, ?, ?, ?, ?)
99+ """ ,
100+ (
101+ session_id ,
102+ turn ,
103+ role ,
104+ content ,
105+ json .dumps (citations or {}),
106+ json .dumps (token_usage or {}),
107+ _now_iso (),
108+ ),
109+ )
110+ conn .commit ()
111+ return turn
112+
113+
114+ def _get_chat_messages (conn : sqlite3 .Connection , session_id : str ) -> list [dict [str , Any ]]:
115+ _ensure_chat_tables (conn )
116+ cur = conn .cursor ()
117+ cur .execute (
118+ """
119+ SELECT turn_index, role, content, citations_json, token_usage_json, created_at
120+ FROM chat_messages
121+ WHERE session_id = ?
122+ ORDER BY turn_index
123+ """ ,
124+ (session_id ,),
125+ )
126+ rows = []
127+ for row in cur .fetchall ():
128+ rows .append (
129+ {
130+ "turn_index" : int (row ["turn_index" ]),
131+ "role" : str (row ["role" ]),
132+ "content" : str (row ["content" ]),
133+ "citations" : json .loads (row ["citations_json" ] or "{}" ),
134+ "token_usage" : json .loads (row ["token_usage_json" ] or "{}" ),
135+ "created_at" : str (row ["created_at" ]),
136+ }
137+ )
138+ return rows
139+
140+
22141def _load_agent_chat_context (
23142 conn : sqlite3 .Connection ,
24143 run_id : str ,
@@ -368,13 +487,14 @@ def chat_interactive(
368487 console .print (f"[red]✗[/red] { e } " )
369488 raise typer .Exit (1 )
370489
371- with open_study_db (study_db ) as db :
372- sid = session_id or db .create_chat_session (
373- run_id = resolved_run_id ,
374- agent_id = resolved_agent_id ,
375- mode = "interactive" ,
376- meta = {"entrypoint" : "repl" },
377- )
490+ sid = _create_chat_session (
491+ conn = conn ,
492+ run_id = resolved_run_id ,
493+ agent_id = resolved_agent_id ,
494+ mode = "interactive" ,
495+ meta = {"entrypoint" : "repl" },
496+ session_id = session_id ,
497+ )
378498
379499 console .print (f"[bold]Chat session[/bold] { sid } " )
380500 console .print (
@@ -394,8 +514,7 @@ def chat_interactive(
394514 if prompt == "/exit" :
395515 break
396516 if prompt == "/history" :
397- with open_study_db (study_db ) as db :
398- messages = db .get_chat_messages (sid )
517+ messages = _get_chat_messages (conn , sid )
399518 for m in messages :
400519 console .print (f"[{ m ['role' ]} ] { m ['content' ]} " )
401520 continue
@@ -424,8 +543,7 @@ def chat_interactive(
424543 context , citations = _load_agent_chat_context (
425544 conn , resolved_run_id , resolved_agent_id , timeline_n = 12
426545 )
427- with open_study_db (study_db ) as db :
428- history = db .get_chat_messages (sid )
546+ history = _get_chat_messages (conn , sid )
429547 try :
430548 answer , model_used = _generate_agent_chat_reply (
431549 context = context ,
@@ -437,19 +555,19 @@ def chat_interactive(
437555 continue
438556 latency_ms = int ((time .time () - started ) * 1000 )
439557
440- with open_study_db ( study_db ) as db :
441- db . append_chat_message ( sid , "user" , prompt )
442- db . append_chat_message (
443- sid ,
444- "assistant" ,
445- answer ,
446- citations = {"sources" : citations , "model" : model_used },
447- token_usage = {
448- "input_tokens" : 0 ,
449- "output_tokens" : 0 ,
450- "latency_ms" : latency_ms ,
451- },
452- )
558+ _append_chat_message ( conn , sid , "user" , prompt )
559+ _append_chat_message (
560+ conn ,
561+ sid ,
562+ "assistant" ,
563+ answer ,
564+ citations = {"sources" : citations , "model" : model_used },
565+ token_usage = {
566+ "input_tokens" : 0 ,
567+ "output_tokens" : 0 ,
568+ "latency_ms" : latency_ms ,
569+ },
570+ )
453571
454572 console .print (answer )
455573
@@ -482,24 +600,15 @@ def chat_ask(
482600 resolved_run_id , resolved_agent_id = _resolve_run_and_agent (
483601 conn , run_id , agent_id
484602 )
485- except ValueError as e :
486- console .print (f"[red]✗[/red] { e } " )
487- raise typer .Exit (1 )
488- finally :
489- conn .close ()
490-
491- with open_study_db (study_db ) as db :
492- sid = session_id or db .create_chat_session (
603+ sid = _create_chat_session (
604+ conn = conn ,
493605 run_id = resolved_run_id ,
494606 agent_id = resolved_agent_id ,
495607 mode = "machine" ,
496608 meta = {"entrypoint" : "ask" },
609+ session_id = session_id ,
497610 )
498- history = db .get_chat_messages (sid )
499-
500- conn = sqlite3 .connect (str (study_db ))
501- conn .row_factory = sqlite3 .Row
502- try :
611+ history = _get_chat_messages (conn , sid )
503612 context , citations = _load_agent_chat_context (
504613 conn , resolved_run_id , resolved_agent_id , timeline_n = 12
505614 )
@@ -508,14 +617,10 @@ def chat_ask(
508617 user_prompt = prompt ,
509618 history = history ,
510619 )
511- finally :
512- conn .close ()
513-
514- latency_ms = int ((time .time () - started ) * 1000 )
515-
516- with open_study_db (study_db ) as db :
517- user_turn = db .append_chat_message (sid , "user" , prompt )
518- assistant_turn = db .append_chat_message (
620+ latency_ms = int ((time .time () - started ) * 1000 )
621+ user_turn = _append_chat_message (conn , sid , "user" , prompt )
622+ assistant_turn = _append_chat_message (
623+ conn ,
519624 sid ,
520625 "assistant" ,
521626 answer ,
@@ -526,6 +631,11 @@ def chat_ask(
526631 "latency_ms" : latency_ms ,
527632 },
528633 )
634+ except ValueError as e :
635+ console .print (f"[red]✗[/red] { e } " )
636+ raise typer .Exit (1 )
637+ finally :
638+ conn .close ()
529639
530640 payload = {
531641 "session_id" : sid ,
0 commit comments