diff --git a/docs/commands.md b/docs/commands.md index 5188d04..6805d66 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -473,32 +473,47 @@ Interactive chat with simulated agents. ### Interactive REPL ```bash -extropy chat --run-id --agent-id agent_042 +extropy chat --study-db austin/study.db ``` | Flag | Type | Default | Description | |------|------|---------|-------------| -| `--study-db` | path | `./study.db` | Study database path | -| `--run-id` | string | required | Simulation run ID | -| `--agent-id` | string | required | Agent ID | +| `--study-db` | path | required | Study database path | +| `--run-id` | string | latest | Simulation run ID | +| `--agent-id` | string | first agent in run | Agent ID | | `--session-id` | string | auto | Chat session ID | REPL commands: `/context`, `/timeline `, `/history`, `/exit` +### extropy chat list + +Show recent runs and sample agents so users can pick chat targets quickly. + +```bash +extropy chat list --study-db austin/study.db +``` + +| Flag | Type | Default | Description | +|------|------|---------|-------------| +| `--study-db` | path | required | Study database path | +| `--limit-runs` | int | `10` | Number of recent runs to list | +| `--agents-per-run` | int | `5` | Number of sample agent IDs per run | +| `--json` | flag | false | Output JSON response | + ### extropy chat ask Non-interactive API for automation. ```bash -extropy chat ask --run-id --agent-id agent_042 \ +extropy chat ask --study-db austin/study.db \ --prompt "What changed your mind?" --json ``` | Flag | Type | Default | Description | |------|------|---------|-------------| -| `--study-db` | path | `./study.db` | Study database path | -| `--run-id` | string | required | Simulation run ID | -| `--agent-id` | string | required | Agent ID | +| `--study-db` | path | required | Study database path | +| `--run-id` | string | latest | Simulation run ID | +| `--agent-id` | string | first agent in run | Agent ID | | `--prompt` | string | required | Question to ask | | `--session-id` | string | auto | Chat session ID | | `--json` | flag | false | Output JSON response | diff --git a/extropy/cli/commands/chat.py b/extropy/cli/commands/chat.py index 2a15796..85109bd 100644 --- a/extropy/cli/commands/chat.py +++ b/extropy/cli/commands/chat.py @@ -5,18 +5,141 @@ import json import sqlite3 import time +import uuid +from datetime import datetime from pathlib import Path from typing import Any import typer -from ...storage import open_study_db +from ...config import get_config +from ...core.llm import simple_call from ..app import app, console, get_json_mode chat_app = typer.Typer(help="Chat with simulated agents using DB-backed history") app.add_typer(chat_app, name="chat") +def _now_iso() -> str: + return datetime.now().isoformat() + + +def _ensure_chat_tables(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.executescript( + """ + CREATE TABLE IF NOT EXISTS chat_sessions ( + session_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + mode TEXT NOT NULL, + created_at TEXT NOT NULL, + closed_at TEXT, + meta_json TEXT + ); + + CREATE TABLE IF NOT EXISTS chat_messages ( + session_id TEXT NOT NULL, + turn_index INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + citations_json TEXT, + token_usage_json TEXT, + created_at TEXT NOT NULL, + PRIMARY KEY (session_id, turn_index) + ); + """ + ) + conn.commit() + + +def _create_chat_session( + conn: sqlite3.Connection, + run_id: str, + agent_id: str, + mode: str, + meta: dict[str, Any] | None = None, + session_id: str | None = None, +) -> str: + _ensure_chat_tables(conn) + sid = session_id or str(uuid.uuid4()) + cur = conn.cursor() + cur.execute( + """ + INSERT OR REPLACE INTO chat_sessions + (session_id, run_id, agent_id, mode, created_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?) + """, + (sid, run_id, agent_id, mode, _now_iso(), json.dumps(meta or {})), + ) + conn.commit() + return sid + + +def _append_chat_message( + conn: sqlite3.Connection, + session_id: str, + role: str, + content: str, + citations: dict[str, Any] | None = None, + token_usage: dict[str, Any] | None = None, +) -> int: + _ensure_chat_tables(conn) + cur = conn.cursor() + cur.execute( + "SELECT COALESCE(MAX(turn_index), -1) AS max_turn FROM chat_messages WHERE session_id = ?", + (session_id,), + ) + turn = int(cur.fetchone()["max_turn"]) + 1 + cur.execute( + """ + INSERT INTO chat_messages + (session_id, turn_index, role, content, citations_json, token_usage_json, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + session_id, + turn, + role, + content, + json.dumps(citations or {}), + json.dumps(token_usage or {}), + _now_iso(), + ), + ) + conn.commit() + return turn + + +def _get_chat_messages( + conn: sqlite3.Connection, session_id: str +) -> list[dict[str, Any]]: + _ensure_chat_tables(conn) + cur = conn.cursor() + cur.execute( + """ + SELECT turn_index, role, content, citations_json, token_usage_json, created_at + FROM chat_messages + WHERE session_id = ? + ORDER BY turn_index + """, + (session_id,), + ) + rows = [] + for row in cur.fetchall(): + rows.append( + { + "turn_index": int(row["turn_index"]), + "role": str(row["role"]), + "content": str(row["content"]), + "citations": json.loads(row["citations_json"] or "{}"), + "token_usage": json.loads(row["token_usage_json"] or "{}"), + "created_at": str(row["created_at"]), + } + ) + return rows + + def _load_agent_chat_context( conn: sqlite3.Connection, run_id: str, @@ -92,60 +215,245 @@ def _load_agent_chat_context( return context, citations -def _summarize_context(context: dict[str, Any], prompt: str) -> str: +def _render_chat_history(history: list[dict[str, Any]], max_turns: int = 12) -> str: + if not history: + return "(no prior conversation)" + + rendered: list[str] = [] + for msg in history[-max_turns:]: + role = "User" if msg.get("role") == "user" else "Agent" + content = str(msg.get("content") or "").strip().replace("\n", " ") + if len(content) > 400: + content = content[:400].rstrip() + "..." + rendered.append(f"{role}: {content}") + return "\n".join(rendered) + + +def _build_agent_chat_prompt( + context: dict[str, Any], + user_prompt: str, + history: list[dict[str, Any]], +) -> str: state = context.get("state", {}) attrs = context.get("attributes", {}) timeline = context.get("timeline", []) - agent_id = context.get("agent_id") - - private_position = state.get("private_position") or state.get("position") - private_sentiment = state.get("private_sentiment") - if private_sentiment is None: - private_sentiment = state.get("sentiment") - private_conviction = state.get("private_conviction") - if private_conviction is None: - private_conviction = state.get("conviction") - - lines = [f"Agent `{agent_id}` context snapshot:"] - if private_position is not None: - lines.append(f"- Position: {private_position}") - if private_sentiment is not None: - lines.append(f"- Sentiment: {private_sentiment:.3f}") - if private_conviction is not None: - lines.append(f"- Conviction: {private_conviction:.3f}") - - if state.get("public_statement"): - lines.append(f"- Public statement: {state['public_statement']}") - if state.get("raw_reasoning"): - lines.append(f"- Latest raw reasoning: {state['raw_reasoning']}") - - if attrs: - top_attrs = [(k, v) for k, v in attrs.items() if not str(k).startswith("_")] - top_attrs = sorted(top_attrs)[:8] - if top_attrs: - lines.append( - "- Key attributes: " + ", ".join(f"{k}={v}" for k, v in top_attrs) - ) - if timeline: - lines.append("- Recent timeline events:") - for item in timeline[-5:]: - details = item.get("details_json") or "{}" - lines.append( - f" - t={item.get('timestep')} {item.get('event_type')} details={details}" + # Keep this compact so chat stays cheap/fast while still grounded. + context_payload = { + "run_id": context.get("run_id"), + "population_id": context.get("population_id"), + "agent_id": context.get("agent_id"), + "attributes": attrs, + "state": { + k: state.get(k) + for k in ( + "aware", + "position", + "private_position", + "public_position", + "sentiment", + "private_sentiment", + "public_sentiment", + "conviction", + "private_conviction", + "public_conviction", + "action_intent", + "public_statement", + "raw_reasoning", ) + if k in state + }, + "recent_timeline": timeline[-8:], + } - lines.append(f"- Your prompt: {prompt}") - lines.append( - "This answer is grounded in persisted DB state and does not mutate simulation state." + return ( + "You are answering as this simulated person from a completed simulation run.\n" + "Stay in first person and in character.\n" + "Use only the provided simulation context and chat history.\n" + "Do not claim facts outside the run data.\n" + "Do not mention being an AI model.\n" + "If the data is missing, say you're unsure based on what you experienced in this run.\n" + "Keep responses conversational and concise (2-6 sentences unless asked for more).\n\n" + "SIMULATION CONTEXT (JSON):\n" + f"{json.dumps(context_payload, indent=2, default=str)}\n\n" + "CHAT HISTORY:\n" + f"{_render_chat_history(history)}\n\n" + "NEW USER QUESTION:\n" + f"{user_prompt}" ) - return "\n".join(lines) + + +def _generate_agent_chat_reply( + context: dict[str, Any], + user_prompt: str, + history: list[dict[str, Any]], +) -> tuple[str, str]: + model = get_config().resolve_sim_strong() + prompt = _build_agent_chat_prompt(context, user_prompt, history) + schema = { + "type": "object", + "properties": { + "assistant_text": { + "type": "string", + "description": "In-character reply from the simulated agent", + } + }, + "required": ["assistant_text"], + "additionalProperties": False, + } + response = simple_call( + prompt=prompt, + response_schema=schema, + schema_name="agent_chat_reply", + model=model, + log=False, + max_tokens=500, + ) + assistant_text = str(response.get("assistant_text", "")).strip() + if not assistant_text: + raise ValueError("LLM returned empty assistant_text for chat reply") + return assistant_text, model def _print_repl_help() -> None: console.print("[dim]Commands: /context, /timeline , /history, /exit[/dim]") +def _resolve_run_and_agent( + conn: sqlite3.Connection, + run_id: str | None, + agent_id: str | None, +) -> tuple[str, str]: + cur = conn.cursor() + if run_id: + cur.execute( + """ + SELECT run_id, population_id + FROM simulation_runs + WHERE run_id = ? + LIMIT 1 + """, + (run_id,), + ) + else: + cur.execute( + """ + SELECT run_id, population_id + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ + ) + run_row = cur.fetchone() + if not run_row: + raise ValueError("No simulation runs found in study DB") + resolved_run_id = str(run_row["run_id"]) + population_id = str(run_row["population_id"]) + + if agent_id: + cur.execute( + "SELECT 1 FROM agent_states WHERE run_id = ? AND agent_id = ? LIMIT 1", + (resolved_run_id, agent_id), + ) + if cur.fetchone(): + return resolved_run_id, agent_id + cur.execute( + "SELECT 1 FROM agents WHERE population_id = ? AND agent_id = ? LIMIT 1", + (population_id, agent_id), + ) + if cur.fetchone(): + return resolved_run_id, agent_id + raise ValueError(f"agent_id not found for run/population: {agent_id}") + + cur.execute( + "SELECT agent_id FROM agent_states WHERE run_id = ? ORDER BY agent_id LIMIT 1", + (resolved_run_id,), + ) + agent_row = cur.fetchone() + if not agent_row: + cur.execute( + "SELECT agent_id FROM agents WHERE population_id = ? ORDER BY agent_id LIMIT 1", + (population_id,), + ) + agent_row = cur.fetchone() + if not agent_row: + raise ValueError("No agents found for resolved run") + return resolved_run_id, str(agent_row["agent_id"]) + + +@chat_app.command("list") +def chat_list( + study_db: Path = typer.Option(..., "--study-db"), + limit_runs: int = typer.Option(10, "--limit-runs", min=1, max=100), + agents_per_run: int = typer.Option(5, "--agents-per-run", min=1, max=25), + json_output: bool = typer.Option(False, "--json"), +): + """List recent runs with sample agents for quick chat selection.""" + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + """ + SELECT run_id, status, started_at, completed_at, population_id + FROM simulation_runs + ORDER BY started_at DESC + LIMIT ? + """, + (limit_runs,), + ) + runs = [dict(r) for r in cur.fetchall()] + for run in runs: + run_id = str(run["run_id"]) + population_id = str(run["population_id"]) + cur.execute( + """ + SELECT agent_id + FROM agent_states + WHERE run_id = ? + ORDER BY agent_id + LIMIT ? + """, + (run_id, agents_per_run), + ) + sample_agents = [str(r["agent_id"]) for r in cur.fetchall()] + if not sample_agents: + cur.execute( + """ + SELECT agent_id + FROM agents + WHERE population_id = ? + ORDER BY agent_id + LIMIT ? + """, + (population_id, agents_per_run), + ) + sample_agents = [str(r["agent_id"]) for r in cur.fetchall()] + run["sample_agents"] = sample_agents + finally: + conn.close() + + payload = {"study_db": str(study_db), "runs": runs} + if json_output or get_json_mode(): + console.print_json(data=payload) + return + + if not runs: + console.print("[yellow]No simulation runs found.[/yellow]") + return + console.print(f"[bold]Recent Runs[/bold] ({len(runs)})") + for run in runs: + agents = ", ".join(run["sample_agents"]) if run["sample_agents"] else "-" + console.print( + f"- {run['run_id']} status={run['status']} started={run['started_at']} " + f"population={run['population_id']} sample_agents=[{agents}]" + ) + + @chat_app.callback(invoke_without_command=True) def chat_interactive( ctx: typer.Context, @@ -162,10 +470,8 @@ def chat_interactive( if ctx.invoked_subcommand is not None: return - if not study_db or not run_id or not agent_id: - console.print( - "[red]✗[/red] interactive chat requires --study-db, --run-id, and --agent-id" - ) + if not study_db: + console.print("[red]✗[/red] interactive chat requires --study-db") raise typer.Exit(1) if not study_db.exists(): @@ -174,22 +480,28 @@ def chat_interactive( conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row - cur = conn.cursor() - cur.execute("SELECT 1 FROM simulation_runs WHERE run_id = ? LIMIT 1", (run_id,)) - if not cur.fetchone(): + try: + resolved_run_id, resolved_agent_id = _resolve_run_and_agent( + conn, run_id, agent_id + ) + except ValueError as e: conn.close() - console.print(f"[red]✗[/red] run_id not found: {run_id}") + console.print(f"[red]✗[/red] {e}") raise typer.Exit(1) - with open_study_db(study_db) as db: - sid = session_id or db.create_chat_session( - run_id=run_id, - agent_id=agent_id, - mode="interactive", - meta={"entrypoint": "repl"}, - ) + sid = _create_chat_session( + conn=conn, + run_id=resolved_run_id, + agent_id=resolved_agent_id, + mode="interactive", + meta={"entrypoint": "repl"}, + session_id=session_id, + ) console.print(f"[bold]Chat session[/bold] {sid}") + console.print( + f"[dim]Using run_id={resolved_run_id} agent_id={resolved_agent_id}[/dim]" + ) _print_repl_help() try: @@ -204,8 +516,7 @@ def chat_interactive( if prompt == "/exit": break if prompt == "/history": - with open_study_db(study_db) as db: - messages = db.get_chat_messages(sid) + messages = _get_chat_messages(conn, sid) for m in messages: console.print(f"[{m['role']}] {m['content']}") continue @@ -216,7 +527,7 @@ def chat_interactive( except ValueError: n = 10 context, _ = _load_agent_chat_context( - conn, run_id, agent_id, timeline_n=max(1, n) + conn, resolved_run_id, resolved_agent_id, timeline_n=max(1, n) ) for item in context.get("timeline", []): console.print( @@ -225,31 +536,40 @@ def chat_interactive( continue if prompt == "/context": context, _ = _load_agent_chat_context( - conn, run_id, agent_id, timeline_n=10 + conn, resolved_run_id, resolved_agent_id, timeline_n=10 ) console.print_json(data=context) continue started = time.time() context, citations = _load_agent_chat_context( - conn, run_id, agent_id, timeline_n=12 + conn, resolved_run_id, resolved_agent_id, timeline_n=12 ) - answer = _summarize_context(context, prompt) + history = _get_chat_messages(conn, sid) + try: + answer, model_used = _generate_agent_chat_reply( + context=context, + user_prompt=prompt, + history=history, + ) + except Exception as e: + console.print(f"[red]✗[/red] LLM chat failed: {e}") + continue latency_ms = int((time.time() - started) * 1000) - with open_study_db(study_db) as db: - db.append_chat_message(sid, "user", prompt) - db.append_chat_message( - sid, - "assistant", - answer, - citations={"sources": citations}, - token_usage={ - "input_tokens": 0, - "output_tokens": 0, - "latency_ms": latency_ms, - }, - ) + _append_chat_message(conn, sid, "user", prompt) + _append_chat_message( + conn, + sid, + "assistant", + answer, + citations={"sources": citations, "model": model_used}, + token_usage={ + "input_tokens": 0, + "output_tokens": 0, + "latency_ms": latency_ms, + }, + ) console.print(answer) @@ -260,8 +580,8 @@ def chat_interactive( @chat_app.command("ask") def chat_ask( study_db: Path = typer.Option(..., "--study-db"), - run_id: str = typer.Option(..., "--run-id"), - agent_id: str = typer.Option(..., "--agent-id"), + run_id: str | None = typer.Option(None, "--run-id"), + agent_id: str | None = typer.Option(None, "--agent-id"), prompt: str = typer.Option(..., "--prompt"), session_id: str | None = typer.Option(None, "--session-id"), json_output: bool = typer.Option(False, "--json"), @@ -279,50 +599,50 @@ def chat_ask( conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row try: - cur = conn.cursor() - cur.execute("SELECT 1 FROM simulation_runs WHERE run_id = ? LIMIT 1", (run_id,)) - if not cur.fetchone(): - console.print(f"[red]✗[/red] run_id not found: {run_id}") - raise typer.Exit(1) - finally: - conn.close() - - with open_study_db(study_db) as db: - sid = session_id or db.create_chat_session( - run_id=run_id, - agent_id=agent_id, + resolved_run_id, resolved_agent_id = _resolve_run_and_agent( + conn, run_id, agent_id + ) + sid = _create_chat_session( + conn=conn, + run_id=resolved_run_id, + agent_id=resolved_agent_id, mode="machine", meta={"entrypoint": "ask"}, + session_id=session_id, ) - - conn = sqlite3.connect(str(study_db)) - conn.row_factory = sqlite3.Row - try: + history = _get_chat_messages(conn, sid) context, citations = _load_agent_chat_context( - conn, run_id, agent_id, timeline_n=12 + conn, resolved_run_id, resolved_agent_id, timeline_n=12 ) - answer = _summarize_context(context, prompt) - finally: - conn.close() - - latency_ms = int((time.time() - started) * 1000) - - with open_study_db(study_db) as db: - user_turn = db.append_chat_message(sid, "user", prompt) - assistant_turn = db.append_chat_message( + answer, model_used = _generate_agent_chat_reply( + context=context, + user_prompt=prompt, + history=history, + ) + latency_ms = int((time.time() - started) * 1000) + user_turn = _append_chat_message(conn, sid, "user", prompt) + assistant_turn = _append_chat_message( + conn, sid, "assistant", answer, - citations={"sources": citations}, + citations={"sources": citations, "model": model_used}, token_usage={ "input_tokens": 0, "output_tokens": 0, "latency_ms": latency_ms, }, ) + except ValueError as e: + console.print(f"[red]✗[/red] {e}") + raise typer.Exit(1) + finally: + conn.close() payload = { "session_id": sid, + "run_id": resolved_run_id, + "agent_id": resolved_agent_id, "user_turn_index": user_turn, "turn_index": assistant_turn, "assistant_text": answer, diff --git a/extropy/core/providers/openai.py b/extropy/core/providers/openai.py index 45d5599..ab33350 100644 --- a/extropy/core/providers/openai.py +++ b/extropy/core/providers/openai.py @@ -86,6 +86,12 @@ def _extract_output_text(response) -> str | None: Returns: Raw text string, or None if not found. """ + # Newer SDKs expose a convenience property that already concatenates + # message text segments (empty string when no assistant text exists). + direct = getattr(response, "output_text", None) + if isinstance(direct, str) and direct: + return direct + for item in response.output: if hasattr(item, "type") and item.type == "message": for content_item in item.content: @@ -97,6 +103,25 @@ def _extract_output_text(response) -> str | None: return content_item.text return None + @staticmethod + def _is_max_tokens_incomplete(response) -> bool: + """Whether a Responses API call ended incomplete due to token cap.""" + if getattr(response, "status", None) != "incomplete": + return False + details = getattr(response, "incomplete_details", None) + if details is None: + return False + if isinstance(details, dict): + return details.get("reason") == "max_output_tokens" + return getattr(details, "reason", None) == "max_output_tokens" + + @staticmethod + def _bump_max_tokens(max_tokens: int | None) -> int: + """Increase max output budget for an incomplete retry.""" + if max_tokens is None: + return 2048 + return min(8192, max(256, int(max_tokens * 2))) + @staticmethod def _extract_chat_completions_text(response) -> str | None: """Extract text content from a Chat Completions API response. @@ -255,44 +280,70 @@ def simple_call( self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) use_chat = self._api_format == "chat_completions" - - if use_chat: - request_params = self._build_chat_completions_params( - model, prompt, response_schema, schema_name, max_tokens - ) - else: - request_params = self._build_responses_params( - model, prompt, response_schema, schema_name, max_tokens - ) + request_params: dict = {} + response = None + structured_data = None + max_tokens_eff = max_tokens logger.info(f"[LLM] simple_call starting - model={model}, schema={schema_name}") logger.info(f"[LLM] prompt length: {len(prompt)} chars") - api_start = time.time() - if use_chat: - response = self._with_retry( - lambda: client.chat.completions.create(**request_params) - ) - else: - response = self._with_retry( - lambda: client.responses.create(**request_params) - ) - api_elapsed = time.time() - api_start + for attempt in range(2): + if use_chat: + request_params = self._build_chat_completions_params( + model, prompt, response_schema, schema_name, max_tokens_eff + ) + else: + request_params = self._build_responses_params( + model, prompt, response_schema, schema_name, max_tokens_eff + ) - logger.info(f"[LLM] API response received in {api_elapsed:.2f}s") + api_start = time.time() + if use_chat: + response = self._with_retry( + lambda: client.chat.completions.create(**request_params) + ) + else: + response = self._with_retry( + lambda: client.responses.create(**request_params) + ) + api_elapsed = time.time() - api_start + logger.info(f"[LLM] API response received in {api_elapsed:.2f}s") - # Extract structured data - if use_chat: - raw_text = self._extract_chat_completions_text(response) - else: - raw_text = self._extract_output_text(response) - structured_data = json.loads(raw_text) if raw_text else None + usage = self._extract_usage(response, use_chat=use_chat) + self._record_usage(model, usage, call_type="simple") - # Extract and record token usage - usage = self._extract_usage(response, use_chat=use_chat) - self._record_usage(model, usage, call_type="simple") + if use_chat: + raw_text = self._extract_chat_completions_text(response) + else: + raw_text = self._extract_output_text(response) - if log: + if raw_text: + try: + structured_data = json.loads(raw_text) + except json.JSONDecodeError: + if ( + not use_chat + and self._is_max_tokens_incomplete(response) + and attempt == 0 + ): + max_tokens_eff = self._bump_max_tokens(max_tokens_eff) + continue + raise + else: + structured_data = None + + if ( + not use_chat + and structured_data is None + and self._is_max_tokens_incomplete(response) + and attempt == 0 + ): + max_tokens_eff = self._bump_max_tokens(max_tokens_eff) + continue + break + + if log and response is not None: log_request_response( function_name="simple_call", request=request_params, @@ -314,35 +365,63 @@ async def simple_call_async( client = self._get_async_client() use_chat = self._api_format == "chat_completions" + request_params: dict = {} + response = None + structured_data = None + usage = TokenUsage() + max_tokens_eff = max_tokens + + for attempt in range(2): + if use_chat: + request_params = self._build_chat_completions_params( + model, prompt, response_schema, schema_name, max_tokens_eff + ) + else: + request_params = self._build_responses_params( + model, prompt, response_schema, schema_name, max_tokens_eff + ) - if use_chat: - request_params = self._build_chat_completions_params( - model, prompt, response_schema, schema_name, max_tokens - ) - else: - request_params = self._build_responses_params( - model, prompt, response_schema, schema_name, max_tokens - ) + if use_chat: + response = await self._with_retry_async( + lambda: client.chat.completions.create(**request_params) + ) + else: + response = await self._with_retry_async( + lambda: client.responses.create(**request_params) + ) - if use_chat: - response = await self._with_retry_async( - lambda: client.chat.completions.create(**request_params) - ) - else: - response = await self._with_retry_async( - lambda: client.responses.create(**request_params) - ) + usage = self._extract_usage(response, use_chat=use_chat) + self._record_usage(model, usage, call_type="async") - # Extract structured data - if use_chat: - raw_text = self._extract_chat_completions_text(response) - else: - raw_text = self._extract_output_text(response) - structured_data = json.loads(raw_text) if raw_text else None + if use_chat: + raw_text = self._extract_chat_completions_text(response) + else: + raw_text = self._extract_output_text(response) - # Extract and record token usage - usage = self._extract_usage(response, use_chat=use_chat) - self._record_usage(model, usage, call_type="async") + if raw_text: + try: + structured_data = json.loads(raw_text) + except json.JSONDecodeError: + if ( + not use_chat + and self._is_max_tokens_incomplete(response) + and attempt == 0 + ): + max_tokens_eff = self._bump_max_tokens(max_tokens_eff) + continue + raise + else: + structured_data = None + + if ( + not use_chat + and structured_data is None + and self._is_max_tokens_incomplete(response) + and attempt == 0 + ): + max_tokens_eff = self._bump_max_tokens(max_tokens_eff) + continue + break return structured_data or {}, usage diff --git a/tests/test_cli.py b/tests/test_cli.py index 0b6fab1..1b6c889 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -462,9 +462,27 @@ def test_export_states_defaults_to_latest_run(self, tmp_path): assert rows[0]["run_id"] == "run_new" assert rows[0]["private_position"] == "new_pos" - def test_chat_ask_reads_state_for_requested_run(self, tmp_path): + def test_chat_ask_reads_state_for_requested_run(self, tmp_path, monkeypatch): study_db = tmp_path / "study.db" _seed_run_scoped_state(study_db) + import extropy.cli.commands.chat as chat_cmd + + captured: dict[str, str] = {} + + def fake_simple_call( + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ): + del response_schema, schema_name, log, max_tokens + captured["prompt"] = prompt + captured["model"] = model or "" + return {"assistant_text": "I am still at old_pos."} + + monkeypatch.setattr(chat_cmd, "simple_call", fake_simple_call) result = runner.invoke( app, @@ -485,8 +503,129 @@ def test_chat_ask_reads_state_for_requested_run(self, tmp_path): assert result.exit_code == 0 payload = json.loads(result.stdout.strip()) assert payload["session_id"] - assert "old_pos" in payload["assistant_text"] - assert "new_pos" not in payload["assistant_text"] + assert payload["assistant_text"] == "I am still at old_pos." + assert "old_pos" in captured["prompt"] + assert "new_pos" not in captured["prompt"] + + def test_chat_ask_includes_session_history(self, tmp_path, monkeypatch): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + import extropy.cli.commands.chat as chat_cmd + + with open_study_db(study_db) as db: + sid = db.create_chat_session( + run_id="run_old", + agent_id="a0", + mode="machine", + meta={"entrypoint": "test"}, + ) + db.append_chat_message(sid, "user", "first question") + db.append_chat_message(sid, "assistant", "first answer") + + captured: dict[str, str] = {} + + def fake_simple_call( + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ): + del response_schema, schema_name, model, log, max_tokens + captured["prompt"] = prompt + return {"assistant_text": "second answer"} + + monkeypatch.setattr(chat_cmd, "simple_call", fake_simple_call) + + result = runner.invoke( + app, + [ + "chat", + "ask", + "--study-db", + str(study_db), + "--run-id", + "run_old", + "--agent-id", + "a0", + "--session-id", + sid, + "--prompt", + "second question", + "--json", + ], + ) + assert result.exit_code == 0 + assert "first question" in captured["prompt"] + assert "first answer" in captured["prompt"] + assert "second question" in captured["prompt"] + + def test_chat_ask_defaults_to_latest_run_and_first_agent( + self, tmp_path, monkeypatch + ): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + import extropy.cli.commands.chat as chat_cmd + + captured: dict[str, str] = {} + + def fake_simple_call( + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ): + del response_schema, schema_name, model, log, max_tokens + captured["prompt"] = prompt + return {"assistant_text": "latest run default works"} + + monkeypatch.setattr(chat_cmd, "simple_call", fake_simple_call) + + result = runner.invoke( + app, + [ + "chat", + "ask", + "--study-db", + str(study_db), + "--prompt", + "default target?", + "--json", + ], + ) + assert result.exit_code == 0 + payload = json.loads(result.stdout.strip()) + assert payload["run_id"] == "run_new" + assert payload["agent_id"] == "a0" + assert "new_pos" in captured["prompt"] + assert "old_pos" not in captured["prompt"] + + def test_chat_list_outputs_recent_runs_and_sample_agents(self, tmp_path): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + + result = runner.invoke( + app, + [ + "chat", + "list", + "--study-db", + str(study_db), + "--limit-runs", + "2", + "--agents-per-run", + "2", + "--json", + ], + ) + assert result.exit_code == 0 + payload = json.loads(result.stdout.strip()) + assert payload["runs"] + assert payload["runs"][0]["run_id"] == "run_new" + assert "a0" in payload["runs"][0]["sample_agents"] class TestPersonaCommand: diff --git a/tests/test_providers.py b/tests/test_providers.py index 1aad81c..971be7d 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -119,9 +119,18 @@ def test_extracts_text(self): text = provider._extract_output_text(response) assert text == '{"hello": "world"}' + def test_extracts_from_output_text_property(self): + provider = _make_openai_provider() + response = MagicMock() + response.output_text = '{"assistant_text":"ok"}' + response.output = [] + text = provider._extract_output_text(response) + assert text == '{"assistant_text":"ok"}' + def test_returns_none_on_empty(self): provider = _make_openai_provider() response = MagicMock() + response.output_text = "" response.output = [] text = provider._extract_output_text(response) assert text is None @@ -165,6 +174,46 @@ def test_returns_empty_dict_on_no_output(self, mock_get_client): ) assert result == {} + @patch.object(OpenAIProvider, "_get_client") + def test_retries_when_incomplete_due_to_max_output_tokens(self, mock_get_client): + provider = _make_openai_provider() + + incomplete_details = MagicMock() + incomplete_details.reason = "max_output_tokens" + + reasoning_item = MagicMock() + reasoning_item.type = "reasoning" + + incomplete_response = MagicMock() + incomplete_response.status = "incomplete" + incomplete_response.incomplete_details = incomplete_details + incomplete_response.output = [reasoning_item] + incomplete_response.output_text = "" + incomplete_response.usage = MagicMock() + incomplete_response.usage.input_tokens = 10 + incomplete_response.usage.output_tokens = 20 + + complete_response = _make_openai_response('{"assistant_text":"ok"}') + complete_response.status = "completed" + complete_response.incomplete_details = None + + mock_client = MagicMock() + mock_client.responses.create.side_effect = [ + incomplete_response, + complete_response, + ] + mock_get_client.return_value = mock_client + + result = provider.simple_call( + prompt="test", + response_schema={"type": "object", "properties": {}}, + max_tokens=120, + log=False, + ) + + assert result == {"assistant_text": "ok"} + assert mock_client.responses.create.call_count == 2 + class TestOpenAIRetry: """Test OpenAI transient error retry."""