From 4b618cec8eaa8319ac271834349692a664a1ba91 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 00:11:05 -0500 Subject: [PATCH 01/15] feat(db-first): add study.db storage, strict cli contracts, and new db tooling --- README.md | 3 +- docs/commands.md | 21 +- extropy/cli/app.py | 6 + extropy/cli/commands/__init__.py | 12 + extropy/cli/commands/chat.py | 280 ++++++++ extropy/cli/commands/estimate.py | 40 +- extropy/cli/commands/export.py | 93 +++ extropy/cli/commands/inspect.py | 138 ++++ extropy/cli/commands/migrate.py | 148 ++++ extropy/cli/commands/network.py | 199 +++++- extropy/cli/commands/query.py | 75 ++ extropy/cli/commands/report.py | 89 +++ extropy/cli/commands/results.py | 221 +++++- extropy/cli/commands/sample.py | 73 +- extropy/cli/commands/scenario.py | 46 +- extropy/cli/commands/simulate.py | 102 ++- extropy/cli/commands/validate.py | 16 +- extropy/core/models/scenario.py | 34 +- extropy/population/network/config.py | 18 + extropy/population/network/generator.py | 597 +++++++++++++++- extropy/scenario/compiler.py | 113 ++- extropy/scenario/validator.py | 68 +- extropy/simulation/__init__.py | 3 +- extropy/simulation/engine.py | 186 ++++- extropy/simulation/reasoning.py | 104 +-- extropy/storage/__init__.py | 18 + extropy/storage/schemas.py | 46 ++ extropy/storage/study_db.py | 898 ++++++++++++++++++++++++ extropy/utils/__init__.py | 7 + extropy/utils/resource_governor.py | 102 +++ tests/test_cli.py | 71 +- tests/test_engine.py | 5 +- tests/test_estimator.py | 5 +- tests/test_integration_timestep.py | 5 +- tests/test_network.py | 72 ++ tests/test_propagation.py | 5 +- tests/test_reasoning_prompts.py | 5 +- tests/test_scenario.py | 15 +- tests/test_scenario_validator.py | 57 +- 39 files changed, 3640 insertions(+), 356 deletions(-) create mode 100644 extropy/cli/commands/chat.py create mode 100644 extropy/cli/commands/export.py create mode 100644 extropy/cli/commands/inspect.py create mode 100644 extropy/cli/commands/migrate.py create mode 100644 extropy/cli/commands/query.py create mode 100644 extropy/cli/commands/report.py create mode 100644 extropy/storage/__init__.py create mode 100644 extropy/storage/schemas.py create mode 100644 extropy/storage/study_db.py create mode 100644 extropy/utils/resource_governor.py diff --git a/README.md b/README.md index 9ceb30f..33ddbb7 100644 --- a/README.md +++ b/README.md @@ -125,8 +125,7 @@ $50-100k: drive_and_pay 40% | switch_to_transit 28% | shift_schedule 21% Each agent reasoned individually. A low-income commuter with no transit access reacts differently than a tech worker near a rail stop — not because we scripted it, but because their attributes, persona, and social context led them there. Simulation output directory (`austin/results/`) contains: -- `simulation.db` (checkpointable state store) -- `timeline.jsonl` (streaming event log) +- `study.db` (canonical state + checkpoint store) - `agent_states.json` (final per-agent states) - `by_timestep.json` (time-series aggregates) - `outcome_distributions.json` (final distributions) diff --git a/docs/commands.md b/docs/commands.md index c4b11d3..ff156e6 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -242,8 +242,9 @@ Auto-detection: if `{population_stem}.network-config.yaml` exists alongside the Save a generated config for inspection/editing with `--save-config`: ```bash -extropy network austin/agents.json -o austin/network.json \ - -p austin/population.yaml --save-config austin/network-config.yaml +extropy network --study-db austin/study.db --network-id baseline \ + -p austin/population.yaml --save-config austin/network-config.yaml \ + -o austin/network.json ``` ### How connections form @@ -260,7 +261,8 @@ The network uses a **Watts-Strogatz small-world model** with attribute-based sim Add `-v` to print network quality metrics: ```bash -extropy network austin/agents.json -o austin/network.json -p austin/population.yaml --validate +extropy network --study-db austin/study.db --network-id baseline \ + -p austin/population.yaml --validate ``` This shows clustering coefficient, average path length, modularity, and flags anything outside expected ranges for a realistic social network. @@ -269,8 +271,10 @@ This shows clustering coefficient, average path length, modularity, and flags an | | Name | Description | |---|---|---| -| **Arg** | `agents_file` | Agents JSON file | -| **Opt** | `--output` / `-o` | Output network JSON file **(required)** | +| **Opt** | `--study-db` | Canonical study DB path **(required)** | +| **Opt** | `--population-id` | Population ID in study DB (default: `default`) | +| **Opt** | `--network-id` | Network ID to write/read (default: `default`) | +| **Opt** | `--output` / `-o` | Optional network JSON export path | | **Opt** | `--population` / `-p` | Population spec YAML — generates network config via LLM | | **Opt** | `--network-config` / `-c` | Custom network config YAML file | | **Opt** | `--save-config` | Save the generated/loaded config to YAML | @@ -282,7 +286,7 @@ This shows clustering coefficient, average path length, modularity, and flags an ### Output -A JSON file (`network.json`) containing nodes (agent IDs) and weighted, typed edges. +Canonical output is `study.db` (`network_edges`, `network_runs`, `network_metrics`). Optional JSON export can be written with `--output`. --- @@ -461,8 +465,7 @@ These aren't scripted responses. They emerge from each agent's unique combinatio ### Output A results directory containing: -- `simulation.db` — SQLite database with full simulation state -- `timeline.jsonl` — Event-by-event timeline +- `study.db` — canonical SQLite state and checkpoint store - `agent_states.json` — Final state of every agent - `by_timestep.json` — Per-timestep metrics (exposure, sentiment, conviction, position distributions) - `outcome_distributions.json` — Aggregate outcome distributions @@ -473,7 +476,7 @@ A results directory containing: ## Viewing Results ```bash -extropy results austin/results/ +extropy results --study-db austin/study.db ``` Display a summary of simulation outcomes — exposure rates, outcome distributions, and convergence information. diff --git a/extropy/cli/app.py b/extropy/cli/app.py index aa320c3..561f0c4 100644 --- a/extropy/cli/app.py +++ b/extropy/cli/app.py @@ -71,4 +71,10 @@ def main_callback( estimate, results, config_cmd, + inspect, + query, + report, + export, + chat, + migrate, ) diff --git a/extropy/cli/commands/__init__.py b/extropy/cli/commands/__init__.py index 8c36014..df2e0cc 100644 --- a/extropy/cli/commands/__init__.py +++ b/extropy/cli/commands/__init__.py @@ -12,6 +12,12 @@ estimate, results, config_cmd, + inspect, + query, + report, + export, + chat, + migrate, ) __all__ = [ @@ -26,4 +32,10 @@ "estimate", "results", "config_cmd", + "inspect", + "query", + "report", + "export", + "chat", + "migrate", ] diff --git a/extropy/cli/commands/chat.py b/extropy/cli/commands/chat.py new file mode 100644 index 0000000..6ea4cdf --- /dev/null +++ b/extropy/cli/commands/chat.py @@ -0,0 +1,280 @@ +"""Agent chat commands backed by study DB history.""" + +from __future__ import annotations + +import json +import sqlite3 +import time +from pathlib import Path +from typing import Any + +import typer + +from ...storage import open_study_db +from ..app import app, console, get_json_mode + +chat_app = typer.Typer(help="Chat with simulated agents using DB-backed history") +app.add_typer(chat_app, name="chat") + + +def _load_agent_chat_context( + conn: sqlite3.Connection, + run_id: str, + agent_id: str, + timeline_n: int = 10, +) -> tuple[dict[str, Any], list[dict[str, Any]]]: + cur = conn.cursor() + + cur.execute( + "SELECT attrs_json FROM agents WHERE agent_id = ? ORDER BY rowid DESC LIMIT 1", + (agent_id,), + ) + attrs_row = cur.fetchone() + attrs = {} + if attrs_row and attrs_row["attrs_json"]: + try: + attrs = json.loads(attrs_row["attrs_json"]) + except json.JSONDecodeError: + attrs = {} + + cur.execute("SELECT * FROM agent_states WHERE agent_id = ? LIMIT 1", (agent_id,)) + state_row = cur.fetchone() + state = dict(state_row) if state_row else {} + + cur.execute( + """ + SELECT timestep, event_type, details_json + FROM timeline + WHERE agent_id = ? + ORDER BY id DESC + LIMIT ? + """, + (agent_id, timeline_n), + ) + timeline_rows = [dict(r) for r in cur.fetchall()] + + context = { + "run_id": run_id, + "agent_id": agent_id, + "attributes": attrs, + "state": state, + "timeline": list(reversed(timeline_rows)), + } + + citations = [ + {"table": "agents", "agent_id": agent_id}, + {"table": "agent_states", "agent_id": agent_id}, + {"table": "timeline", "agent_id": agent_id, "limit": timeline_n}, + ] + return context, citations + + +def _summarize_context(context: dict[str, Any], prompt: str) -> str: + state = context.get("state", {}) + attrs = context.get("attributes", {}) + timeline = context.get("timeline", []) + agent_id = context.get("agent_id") + + private_position = state.get("private_position") or state.get("position") + private_sentiment = state.get("private_sentiment") + if private_sentiment is None: + private_sentiment = state.get("sentiment") + private_conviction = state.get("private_conviction") + if private_conviction is None: + private_conviction = state.get("conviction") + + lines = [f"Agent `{agent_id}` context snapshot:"] + if private_position is not None: + lines.append(f"- Position: {private_position}") + if private_sentiment is not None: + lines.append(f"- Sentiment: {private_sentiment:.3f}") + if private_conviction is not None: + lines.append(f"- Conviction: {private_conviction:.3f}") + + if state.get("public_statement"): + lines.append(f"- Public statement: {state['public_statement']}") + if state.get("raw_reasoning"): + lines.append(f"- Latest raw reasoning: {state['raw_reasoning']}") + + if attrs: + top_attrs = [(k, v) for k, v in attrs.items() if not str(k).startswith("_")] + top_attrs = sorted(top_attrs)[:8] + if top_attrs: + lines.append("- Key attributes: " + ", ".join(f"{k}={v}" for k, v in top_attrs)) + + if timeline: + lines.append("- Recent timeline events:") + for item in timeline[-5:]: + details = item.get("details_json") or "{}" + lines.append( + f" - t={item.get('timestep')} {item.get('event_type')} details={details}" + ) + + lines.append(f"- Your prompt: {prompt}") + lines.append( + "This answer is grounded in persisted DB state and does not mutate simulation state." + ) + return "\n".join(lines) + + +def _print_repl_help() -> None: + console.print("[dim]Commands: /context, /timeline , /history, /exit[/dim]") + + +@chat_app.callback(invoke_without_command=True) +def chat_interactive( + ctx: typer.Context, + study_db: Path | None = typer.Option(None, "--study-db"), + run_id: str | None = typer.Option(None, "--run-id"), + agent_id: str | None = typer.Option(None, "--agent-id"), + session_id: str | None = typer.Option(None, "--session-id"), +): + """Interactive chat REPL. + + Example: + extropy chat --study-db study.db --run-id run_123 --agent-id a_42 + """ + if ctx.invoked_subcommand is not None: + return + + if not study_db or not run_id or not agent_id: + console.print( + "[red]✗[/red] interactive chat requires --study-db, --run-id, and --agent-id" + ) + raise typer.Exit(1) + + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + with open_study_db(study_db) as db: + sid = session_id or db.create_chat_session( + run_id=run_id, + agent_id=agent_id, + mode="interactive", + meta={"entrypoint": "repl"}, + ) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + + console.print(f"[bold]Chat session[/bold] {sid}") + _print_repl_help() + + try: + while True: + try: + prompt = input("chat> ").strip() + except EOFError: + break + + if not prompt: + continue + if prompt == "/exit": + break + if prompt == "/history": + with open_study_db(study_db) as db: + messages = db.get_chat_messages(sid) + for m in messages: + console.print(f"[{m['role']}] {m['content']}") + continue + if prompt.startswith("/timeline"): + parts = prompt.split() + try: + n = int(parts[1]) if len(parts) > 1 else 10 + except ValueError: + n = 10 + context, _ = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=max(1, n)) + for item in context.get("timeline", []): + console.print( + f"t={item.get('timestep')} {item.get('event_type')} {item.get('details_json') or '{}'}" + ) + continue + if prompt == "/context": + context, _ = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=10) + console.print_json(data=context) + continue + + started = time.time() + context, citations = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=12) + answer = _summarize_context(context, prompt) + latency_ms = int((time.time() - started) * 1000) + + with open_study_db(study_db) as db: + db.append_chat_message(sid, "user", prompt) + db.append_chat_message( + sid, + "assistant", + answer, + citations={"sources": citations}, + token_usage={"input_tokens": 0, "output_tokens": 0, "latency_ms": latency_ms}, + ) + + console.print(answer) + + finally: + conn.close() + + +@chat_app.command("ask") +def chat_ask( + study_db: Path = typer.Option(..., "--study-db"), + run_id: str = typer.Option(..., "--run-id"), + agent_id: str = typer.Option(..., "--agent-id"), + prompt: str = typer.Option(..., "--prompt"), + session_id: str | None = typer.Option(None, "--session-id"), + json_output: bool = typer.Option(False, "--json"), +): + """Non-interactive chat API for automation. + + Example: + extropy chat ask --study-db study.db --run-id r1 --agent-id a1 --prompt "What changed?" --json + """ + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + started = time.time() + + with open_study_db(study_db) as db: + sid = session_id or db.create_chat_session( + run_id=run_id, + agent_id=agent_id, + mode="machine", + meta={"entrypoint": "ask"}, + ) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + context, citations = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=12) + answer = _summarize_context(context, prompt) + finally: + conn.close() + + latency_ms = int((time.time() - started) * 1000) + + with open_study_db(study_db) as db: + user_turn = db.append_chat_message(sid, "user", prompt) + assistant_turn = db.append_chat_message( + sid, + "assistant", + answer, + citations={"sources": citations}, + token_usage={"input_tokens": 0, "output_tokens": 0, "latency_ms": latency_ms}, + ) + + payload = { + "session_id": sid, + "user_turn_index": user_turn, + "turn_index": assistant_turn, + "assistant_text": answer, + "citations": {"sources": citations}, + "token_usage": {"input_tokens": 0, "output_tokens": 0}, + "latency_ms": latency_ms, + } + + if json_output or get_json_mode(): + console.print_json(data=payload) + else: + console.print(answer) diff --git a/extropy/cli/commands/estimate.py b/extropy/cli/commands/estimate.py index ec734e5..90f26a0 100644 --- a/extropy/cli/commands/estimate.py +++ b/extropy/cli/commands/estimate.py @@ -1,6 +1,5 @@ """Estimate command for predicting simulation costs before running.""" -import json from pathlib import Path import typer @@ -11,6 +10,7 @@ @app.command("estimate") def estimate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), model: str = typer.Option( "", "--model", @@ -41,19 +41,22 @@ def estimate_command( model, and predicts LLM calls, tokens, and USD cost. No API keys required. Example: - extropy estimate scenario.yaml - extropy estimate scenario.yaml --model gpt-5-mini - extropy estimate scenario.yaml --pivotal-model gpt-5 --routine-model gpt-5-mini -v + extropy estimate scenario.yaml --study-db study.db + extropy estimate scenario.yaml --study-db study.db --model gpt-5-mini + extropy estimate scenario.yaml --study-db study.db --pivotal-model gpt-5 --routine-model gpt-5-mini -v """ from ...config import get_config from ...core.models import ScenarioSpec, PopulationSpec - from ...population.network import load_agents_json from ...simulation.estimator import estimate_simulation_cost + from ...storage import open_study_db # Validate input file if not scenario_file.exists(): console.print(f"[red]x[/red] Scenario file not found: {scenario_file}") raise typer.Exit(1) + if not study_db.exists(): + console.print(f"[red]x[/red] Study DB not found: {study_db}") + raise typer.Exit(1) # Load scenario try: @@ -71,24 +74,19 @@ def estimate_command( raise typer.Exit(1) population_spec = PopulationSpec.from_yaml(pop_path) - # Load agents - agents_path = Path(scenario.meta.agents_file) - if not agents_path.is_absolute(): - agents_path = scenario_file.parent / agents_path - if not agents_path.exists(): - console.print(f"[red]x[/red] Agents file not found: {agents_path}") + with open_study_db(study_db) as db: + agents = db.get_agents(scenario.meta.population_id) + network = db.get_network(scenario.meta.network_id) + if not agents: + console.print( + f"[red]x[/red] Population ID not found in study DB: {scenario.meta.population_id}" + ) raise typer.Exit(1) - agents = load_agents_json(agents_path) - - # Load network - network_path = Path(scenario.meta.network_file) - if not network_path.is_absolute(): - network_path = scenario_file.parent / network_path - if not network_path.exists(): - console.print(f"[red]x[/red] Network file not found: {network_path}") + if not network.get("edges"): + console.print( + f"[red]x[/red] Network ID not found in study DB: {scenario.meta.network_id}" + ) raise typer.Exit(1) - with open(network_path) as f: - network = json.load(f) # Resolve config config = get_config() diff --git a/extropy/cli/commands/export.py b/extropy/cli/commands/export.py new file mode 100644 index 0000000..337ddfd --- /dev/null +++ b/extropy/cli/commands/export.py @@ -0,0 +1,93 @@ +"""Explicit exports from study DB.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import typer + +from ..app import app, console + +export_app = typer.Typer(help="Export datasets from study DB") +app.add_typer(export_app, name="export") + + +def _write_jsonl(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, default=str) + "\n") + + +@export_app.command("agents") +def export_agents( + study_db: Path = typer.Option(..., "--study-db"), + population_id: str = typer.Option("default", "--population-id"), + output: Path = typer.Option(..., "--to"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + "SELECT agent_id, attrs_json FROM agents WHERE population_id = ? ORDER BY agent_id", + (population_id,), + ) + rows = [] + for row in cur.fetchall(): + try: + rows.append(json.loads(row["attrs_json"])) + except json.JSONDecodeError: + rows.append({"_id": row["agent_id"]}) + finally: + conn.close() + + _write_jsonl(output, rows) + console.print(f"[green]✓[/green] Exported {len(rows)} agents -> {output}") + + +@export_app.command("edges") +def export_edges( + study_db: Path = typer.Option(..., "--study-db"), + network_id: str = typer.Option("default", "--network-id"), + output: Path = typer.Option(..., "--to"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + """ + SELECT source_id, target_id, weight, edge_type, influence_st, influence_ts + FROM network_edges + WHERE network_id = ? + ORDER BY source_id, target_id + """, + (network_id,), + ) + rows = [dict(row) for row in cur.fetchall()] + finally: + conn.close() + + _write_jsonl(output, rows) + console.print(f"[green]✓[/green] Exported {len(rows)} edges -> {output}") + + +@export_app.command("states") +def export_states( + study_db: Path = typer.Option(..., "--study-db"), + output: Path = typer.Option(..., "--to"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute("SELECT * FROM agent_states ORDER BY agent_id") + rows = [dict(row) for row in cur.fetchall()] + finally: + conn.close() + + _write_jsonl(output, rows) + console.print(f"[green]✓[/green] Exported {len(rows)} agent states -> {output}") diff --git a/extropy/cli/commands/inspect.py b/extropy/cli/commands/inspect.py new file mode 100644 index 0000000..3f83101 --- /dev/null +++ b/extropy/cli/commands/inspect.py @@ -0,0 +1,138 @@ +"""Inspect commands for DB-backed artifacts.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import typer + +from ...storage import open_study_db +from ..app import app, console + +inspect_app = typer.Typer(help="Inspect study DB entities") +app.add_typer(inspect_app, name="inspect") + + +@inspect_app.command("summary") +def inspect_summary( + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + population_id: str = typer.Option("default", "--population-id"), + network_id: str = typer.Option("default", "--network-id"), +): + with open_study_db(study_db) as db: + agent_count = db.get_agent_count(population_id) + edge_count = db.get_network_edge_count(network_id) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states") + sim_agents = int(cur.fetchone()["cnt"]) + cur.execute("SELECT COUNT(*) AS cnt FROM timestep_summaries") + timesteps = int(cur.fetchone()["cnt"]) + cur.execute("SELECT COUNT(*) AS cnt FROM timeline") + events = int(cur.fetchone()["cnt"]) + finally: + conn.close() + + console.print("[bold]Study Summary[/bold]") + console.print(f"study_db: {study_db}") + console.print(f"population_id={population_id} agents={agent_count}") + console.print(f"network_id={network_id} edges={edge_count}") + console.print(f"simulation.agent_states={sim_agents}") + console.print(f"simulation.timesteps={timesteps}") + console.print(f"simulation.events={events}") + + +@inspect_app.command("agent") +def inspect_agent( + study_db: Path = typer.Option(..., "--study-db"), + agent_id: str = typer.Option(..., "--agent-id"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute("SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,)) + attrs_row = cur.fetchone() + attrs = json.loads(attrs_row["attrs_json"]) if attrs_row else {} + + cur.execute("SELECT * FROM agent_states WHERE agent_id = ? LIMIT 1", (agent_id,)) + state = cur.fetchone() + + cur.execute( + "SELECT timestep, event_type, details_json FROM timeline WHERE agent_id = ? ORDER BY id DESC LIMIT 10", + (agent_id,), + ) + events = cur.fetchall() + finally: + conn.close() + + console.print(f"[bold]Agent {agent_id}[/bold]") + if attrs: + console.print("[bold]Attributes[/bold]") + for key in sorted(attrs.keys()): + if key.startswith("_"): + continue + console.print(f" - {key}: {attrs[key]}") + + if state: + console.print("[bold]State[/bold]") + console.print(f" aware={bool(state['aware'])} will_share={bool(state['will_share'])}") + console.print( + f" position={state['private_position'] or state['position']} " + f"sentiment={state['private_sentiment'] if state['private_sentiment'] is not None else state['sentiment']}" + ) + if state["raw_reasoning"]: + console.print("[bold]Raw reasoning[/bold]") + console.print(str(state["raw_reasoning"])) + + if events: + console.print("[bold]Recent events[/bold]") + for row in events: + details = row["details_json"] or "{}" + console.print(f" t={row['timestep']} {row['event_type']} {details}") + + +@inspect_app.command("network") +def inspect_network( + study_db: Path = typer.Option(..., "--study-db"), + network_id: str = typer.Option("default", "--network-id"), + top: int = typer.Option(10, "--top", min=1), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + "SELECT COUNT(*) AS cnt, AVG(weight) AS avg_w FROM network_edges WHERE network_id = ?", + (network_id,), + ) + row = cur.fetchone() + edge_count = int(row["cnt"]) if row else 0 + avg_w = float(row["avg_w"]) if row and row["avg_w"] is not None else 0.0 + + cur.execute( + """ + SELECT source_id, COUNT(*) AS degree + FROM network_edges + WHERE network_id = ? + GROUP BY source_id + ORDER BY degree DESC + LIMIT ? + """, + (network_id, top), + ) + top_rows = cur.fetchall() + finally: + conn.close() + + console.print(f"[bold]Network {network_id}[/bold]") + console.print(f"edges={edge_count} avg_weight={avg_w:.4f}") + if top_rows: + console.print("top source degrees:") + for r in top_rows: + console.print(f" - {r['source_id']}: {r['degree']}") diff --git a/extropy/cli/commands/migrate.py b/extropy/cli/commands/migrate.py new file mode 100644 index 0000000..e26bb91 --- /dev/null +++ b/extropy/cli/commands/migrate.py @@ -0,0 +1,148 @@ +"""Migration commands for DB-first runtime artifacts.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import typer +import yaml + +from ...storage import open_study_db +from ..app import app, console + +migrate_app = typer.Typer(help="Migrate legacy artifacts to DB-first schema") +app.add_typer(migrate_app, name="migrate") + + +def _load_json(path: Path) -> Any: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +@migrate_app.command("legacy") +def migrate_legacy_artifacts( + study_db: Path = typer.Option(..., "--study-db", help="Target canonical study DB"), + agents_file: Path | None = typer.Option( + None, "--agents-file", help="Legacy agents JSON" + ), + network_file: Path | None = typer.Option( + None, "--network-file", help="Legacy network JSON" + ), + population_spec: Path | None = typer.Option( + None, + "--population-spec", + help="Optional population spec YAML source used for sample provenance", + ), + population_id: str = typer.Option("default", "--population-id"), + network_id: str = typer.Option("default", "--network-id"), +): + """Ingest legacy `agents.json`/`network.json` into `study.db`.""" + if agents_file is None and network_file is None: + console.print("[red]✗[/red] Provide at least one of --agents-file or --network-file") + raise typer.Exit(1) + + with open_study_db(study_db) as db: + if population_spec is not None: + if not population_spec.exists(): + console.print(f"[red]✗[/red] population spec not found: {population_spec}") + raise typer.Exit(1) + db.save_population_spec( + population_id=population_id, + spec_yaml=population_spec.read_text(encoding="utf-8"), + source_path=str(population_spec), + ) + + if agents_file is not None: + if not agents_file.exists(): + console.print(f"[red]✗[/red] agents file not found: {agents_file}") + raise typer.Exit(1) + agents_data = _load_json(agents_file) + if not isinstance(agents_data, list): + console.print("[red]✗[/red] agents JSON must be a list") + raise typer.Exit(1) + sample_run_id = db.save_sample_result( + population_id=population_id, + agents=agents_data, + meta={ + "source": "legacy_migration", + "source_file": str(agents_file), + }, + seed=None, + ) + console.print( + f"[green]✓[/green] Imported {len(agents_data)} agents " + f"(population_id={population_id}, sample_run_id={sample_run_id})" + ) + + if network_file is not None: + if not network_file.exists(): + console.print(f"[red]✗[/red] network file not found: {network_file}") + raise typer.Exit(1) + network_data = _load_json(network_file) + if not isinstance(network_data, dict): + console.print("[red]✗[/red] network JSON must be an object") + raise typer.Exit(1) + + raw_edges = network_data.get("edges", []) + if not isinstance(raw_edges, list): + console.print("[red]✗[/red] network.edges must be a list") + raise typer.Exit(1) + + network_run_id = db.save_network_result( + population_id=population_id, + network_id=network_id, + config=network_data.get("config", {}), + result_meta=network_data.get("meta", {}), + edges=raw_edges, + seed=None, + candidate_mode="legacy", + network_metrics=network_data.get("metrics"), + ) + console.print( + f"[green]✓[/green] Imported {len(raw_edges)} edges " + f"(network_id={network_id}, network_run_id={network_run_id})" + ) + + console.print(f"[green]✓[/green] Migration complete: {study_db}") + + +@migrate_app.command("scenario") +def migrate_scenario_yaml( + input_path: Path = typer.Option(..., "--input", help="Legacy scenario YAML"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB path"), + population_id: str = typer.Option("default", "--population-id"), + network_id: str = typer.Option("default", "--network-id"), + output: Path | None = typer.Option(None, "--output", "-o"), +): + """Rewrite a legacy scenario YAML to DB-first metadata fields.""" + if not input_path.exists(): + console.print(f"[red]✗[/red] Scenario file not found: {input_path}") + raise typer.Exit(1) + + with open(input_path, encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + + meta = data.get("meta") + if not isinstance(meta, dict): + console.print("[red]✗[/red] Invalid scenario YAML: missing meta object") + raise typer.Exit(1) + + had_legacy = "agents_file" in meta or "network_file" in meta + meta.pop("agents_file", None) + meta.pop("network_file", None) + meta["study_db"] = str(study_db) + meta["population_id"] = population_id + meta["network_id"] = network_id + data["meta"] = meta + + out = output or input_path.with_suffix(".db-first.yaml") + out.parent.mkdir(parents=True, exist_ok=True) + with open(out, "w", encoding="utf-8") as f: + yaml.safe_dump(data, f, sort_keys=False) + + if had_legacy: + console.print(f"[green]✓[/green] Migrated legacy scenario -> {out}") + else: + console.print(f"[green]✓[/green] Rewrote scenario metadata -> {out}") diff --git a/extropy/cli/commands/network.py b/extropy/cli/commands/network.py index 3444d7d..d1b8eaa 100644 --- a/extropy/cli/commands/network.py +++ b/extropy/cli/commands/network.py @@ -14,10 +14,18 @@ @app.command("network") def network_command( - agents_file: Path = typer.Argument( - ..., help="Agents JSON file to generate network from" + study_db: Path = typer.Option( + ..., "--study-db", help="Canonical study DB file" + ), + population_id: str = typer.Option( + "default", "--population-id", help="Population ID in study DB" + ), + network_id: str = typer.Option( + "default", "--network-id", help="Network ID to write in study DB" + ), + output: Path | None = typer.Option( + None, "--output", "-o", help="Optional JSON export path (non-canonical)" ), - output: Path = typer.Option(..., "--output", "-o", help="Output network JSON file"), population: Path | None = typer.Option( None, "--population", @@ -50,6 +58,65 @@ def network_command( no_metrics: bool = typer.Option( False, "--no-metrics", help="Skip computing node metrics (faster)" ), + candidate_mode: str = typer.Option( + "exact", + "--candidate-mode", + help="Similarity candidate mode: exact | blocked", + ), + candidate_pool_multiplier: float = typer.Option( + 12.0, + "--candidate-pool-multiplier", + help="Blocked mode candidate pool size as multiple of avg_degree", + ), + block_attr: list[str] | None = typer.Option( + None, + "--block-attr", + help="Blocking attribute (repeatable). If omitted, auto-selects top attributes", + ), + similarity_workers: int = typer.Option( + 1, + "--similarity-workers", + min=1, + help="Worker processes for similarity computation", + ), + similarity_chunk_size: int = typer.Option( + 64, + "--similarity-chunk-size", + min=8, + help="Row chunk size for similarity worker tasks", + ), + checkpoint: Path | None = typer.Option( + None, + "--checkpoint", + help="Path to similarity checkpoint file (.pkl) or study DB (.db)", + ), + resume_checkpoint: bool = typer.Option( + False, + "--resume-checkpoint", + help="Resume similarity stage from --checkpoint file", + ), + checkpoint_every: int = typer.Option( + 250, + "--checkpoint-every", + min=1, + help="Write checkpoint every N processed similarity rows", + ), + resource_mode: str = typer.Option( + "auto", + "--resource-mode", + help="Resource tuning mode: auto | manual", + ), + safe_auto_workers: bool = typer.Option( + True, + "--safe-auto-workers/--unsafe-auto-workers", + help="When auto mode is enabled, keep worker count conservative for laptops/VMs", + ), + max_memory_gb: float | None = typer.Option( + None, + "--max-memory-gb", + min=0.5, + help="Optional memory budget cap for auto resource tuning", + ), ): """ Generate a social network from sampled agents. @@ -65,37 +132,47 @@ def network_command( 4. None of the above → empty config (flat network, no similarity structure) Example: - extropy network agents.json -o network.json - extropy network agents.json -o network.json -p population.yaml - extropy network agents.json -o network.json -c network-config.yaml - extropy network agents.json -o network.json -p population.yaml --save-config my-config.yaml + extropy network --study-db study.db + extropy network --study-db study.db --population-id main --network-id main + extropy network --study-db study.db -p population.yaml -c network-config.yaml """ from ...population.network import ( generate_network, generate_network_with_metrics, - load_agents_json, NetworkConfig, generate_network_config, ) from ...core.models import PopulationSpec + from ...storage import open_study_db + from ...utils import ResourceGovernor start_time = time.time() console.print() + if resume_checkpoint and checkpoint is None: + checkpoint = study_db + # Load Agents - if not agents_file.exists(): - console.print(f"[red]✗[/red] Agents file not found: {agents_file}") + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") raise typer.Exit(1) with console.status("[cyan]Loading agents...[/cyan]"): try: - agents = load_agents_json(agents_file) + with open_study_db(study_db) as db: + agents = db.get_agents(population_id) except Exception as e: console.print(f"[red]✗[/red] Failed to load agents: {e}") raise typer.Exit(1) + if not agents: + console.print( + f"[red]✗[/red] No agents found for population_id '{population_id}' in {study_db}" + ) + raise typer.Exit(1) console.print( - f"[green]✓[/green] Loaded {len(agents)} agents from [bold]{agents_file}[/bold]" + f"[green]✓[/green] Loaded {len(agents)} agents from [bold]{study_db}[/bold] " + f"(population_id={population_id})" ) # ========================================================================= @@ -165,9 +242,48 @@ def network_command( "avg_degree": avg_degree, "rewire_prob": rewire_prob, "seed": seed if seed is not None else config.seed, + "candidate_mode": candidate_mode, + "candidate_pool_multiplier": candidate_pool_multiplier, + "blocking_attributes": block_attr or config.blocking_attributes, + "similarity_workers": similarity_workers, + "similarity_chunk_size": similarity_chunk_size, + "checkpoint_every_rows": checkpoint_every, + } + ) + + if resource_mode not in {"auto", "manual"}: + console.print("[red]✗[/red] --resource-mode must be 'auto' or 'manual'") + raise typer.Exit(1) + + governor = ResourceGovernor( + resource_mode=resource_mode, + safe_auto_workers=safe_auto_workers, + max_memory_gb=max_memory_gb, + ) + tuned_workers = governor.recommend_workers( + requested_workers=config.similarity_workers, + memory_per_worker_gb=0.75, + ) + tuned_chunk = governor.recommend_chunk_size( + requested_chunk_size=config.similarity_chunk_size, + min_chunk_size=8, + max_chunk_size=2048, + ) + + config = config.model_copy( + update={ + "similarity_workers": tuned_workers, + "similarity_chunk_size": tuned_chunk, } ) + if config.candidate_mode not in {"exact", "blocked"}: + console.print( + f"[red]✗[/red] Invalid --candidate-mode '{config.candidate_mode}' " + "(expected: exact | blocked)" + ) + raise typer.Exit(1) + # Save config if requested if save_config: config.to_yaml(save_config) @@ -175,6 +291,17 @@ def network_command( f"[green]✓[/green] Saved network config to [bold]{save_config}[/bold]" ) + console.print( + f"[dim]Mode: {config.candidate_mode} | workers={config.similarity_workers} " + f"| checkpoint={'on' if checkpoint else 'off'}[/dim]" + ) + if resource_mode == "auto": + snap = governor.snapshot() + console.print( + f"[dim]Auto resources: cpu={snap.cpu_count}, " + f"total_mem={snap.total_memory_gb:.1f}GB, budget={snap.memory_budget_gb:.1f}GB[/dim]" + ) + console.print() generation_start = time.time() current_stage = ["Initializing", 0, 0] @@ -192,9 +319,21 @@ def do_generation(): nonlocal result, generation_error try: if no_metrics: - result = generate_network(agents, config, on_progress) + result = generate_network( + agents, + config, + on_progress, + checkpoint_path=checkpoint, + resume_from_checkpoint=resume_checkpoint, + ) else: - result = generate_network_with_metrics(agents, config, on_progress) + result = generate_network_with_metrics( + agents, + config, + on_progress, + checkpoint_path=checkpoint, + resume_from_checkpoint=resume_checkpoint, + ) except Exception as e: generation_error = e finally: @@ -274,14 +413,38 @@ def do_generation(): pct = count / len(result.edges) * 100 if result.edges else 0 console.print(f" {edge_type}: {count} ({pct:.1f}%)") - # Save Output + # Save canonical output to study DB console.print() - with console.status(f"[cyan]Saving to {output}...[/cyan]"): - result.save_json(output) + with console.status(f"[cyan]Saving network to {study_db}...[/cyan]"): + with open_study_db(study_db) as db: + network_metrics = ( + result.network_metrics.model_dump(mode="json") + if result.network_metrics + else None + ) + db.save_network_result( + population_id=population_id, + network_id=network_id, + config=config.model_dump(mode="json"), + result_meta=result.meta, + edges=[e.to_dict() for e in result.edges], + seed=config.seed, + candidate_mode=config.candidate_mode, + network_metrics=network_metrics, + ) + + if output is not None: + with console.status(f"[cyan]Exporting JSON to {output}...[/cyan]"): + result.save_json(output) elapsed = time.time() - start_time console.print("═" * 60) - console.print(f"[green]✓[/green] Network saved to [bold]{output}[/bold]") + console.print( + f"[green]✓[/green] Network saved to [bold]{study_db}[/bold] " + f"(network_id={network_id})" + ) + if output is not None: + console.print(f"[dim]Exported JSON: {output}[/dim]") console.print(f"[dim]Total time: {format_elapsed(elapsed)}[/dim]") console.print("═" * 60) diff --git a/extropy/cli/commands/query.py b/extropy/cli/commands/query.py new file mode 100644 index 0000000..f15210c --- /dev/null +++ b/extropy/cli/commands/query.py @@ -0,0 +1,75 @@ +"""Ad-hoc read-only query command.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import typer + +from ...storage import open_study_db, ReadOnlySQLRequest +from ..app import app, console + +query_app = typer.Typer(help="Read-only SQL query tools") +app.add_typer(query_app, name="query") + +_ALLOWED_PREFIXES = ("select", "with", "explain") +_DENYLIST_TOKENS = ( + " insert ", + " update ", + " delete ", + " alter ", + " drop ", + " create ", + " attach ", + " vacuum ", + " pragma ", + " replace ", + " truncate ", +) + + +@query_app.command("sql") +def query_sql( + study_db: Path = typer.Option(..., "--study-db"), + sql: str = typer.Option(..., "--sql", help="Read-only SQL statement"), + limit: int = typer.Option(1000, "--limit", min=1), + format: str = typer.Option("table", "--format", help="table|json|jsonl"), +): + req = ReadOnlySQLRequest(sql=sql, limit=limit) + normalized = req.sql.strip().lower() + if not normalized.startswith(_ALLOWED_PREFIXES): + console.print("[red]✗[/red] Only read-only SELECT/WITH/EXPLAIN queries are allowed") + raise typer.Exit(1) + padded = f" {normalized} " + if ";" in req.sql.strip().rstrip(";"): + console.print("[red]✗[/red] Multi-statement SQL is not allowed") + raise typer.Exit(1) + if any(tok in padded for tok in _DENYLIST_TOKENS): + console.print("[red]✗[/red] Mutating SQL tokens are not allowed") + raise typer.Exit(1) + + with open_study_db(study_db) as db: + try: + rows = db.run_select(req.sql, limit=req.limit) + except Exception as e: + console.print(f"[red]✗[/red] Query failed: {e}") + raise typer.Exit(1) + + if format == "json": + console.print_json(data=rows) + return + if format == "jsonl": + for row in rows: + console.print(json.dumps(row, default=str)) + return + + if not rows: + console.print("[dim](no rows)[/dim]") + return + + columns = list(rows[0].keys()) + console.print(" | ".join(columns)) + console.print("-" * max(20, len(" | ".join(columns)))) + for row in rows: + console.print(" | ".join(str(row.get(c, "")) for c in columns)) diff --git a/extropy/cli/commands/report.py b/extropy/cli/commands/report.py new file mode 100644 index 0000000..6b2657f --- /dev/null +++ b/extropy/cli/commands/report.py @@ -0,0 +1,89 @@ +"""Reusable report generation commands.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import typer + +from ..app import app, console + +report_app = typer.Typer(help="Generate reusable JSON reports") +app.add_typer(report_app, name="report") + + +@report_app.command("run") +def report_run( + study_db: Path = typer.Option(..., "--study-db"), + output: Path = typer.Option(..., "--output", "-o"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states") + total = int(cur.fetchone()["cnt"]) + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states WHERE aware = 1") + aware = int(cur.fetchone()["cnt"]) + cur.execute( + """ + SELECT COALESCE(private_position, position) AS position, COUNT(*) AS cnt + FROM agent_states + WHERE COALESCE(private_position, position) IS NOT NULL + GROUP BY COALESCE(private_position, position) + """ + ) + positions = {row["position"]: int(row["cnt"]) for row in cur.fetchall()} + finally: + conn.close() + + payload = { + "agent_count": total, + "aware_count": aware, + "aware_rate": (aware / total) if total else 0.0, + "positions": positions, + } + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(payload, indent=2), encoding="utf-8") + console.print(f"[green]✓[/green] Wrote run report: {output}") + + +@report_app.command("network") +def report_network( + study_db: Path = typer.Option(..., "--study-db"), + network_id: str = typer.Option("default", "--network-id"), + output: Path = typer.Option(..., "--output", "-o"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + "SELECT COUNT(*) AS cnt, AVG(weight) AS avg_w FROM network_edges WHERE network_id = ?", + (network_id,), + ) + row = cur.fetchone() + edge_count = int(row["cnt"]) if row else 0 + avg_weight = float(row["avg_w"]) if row and row["avg_w"] is not None else 0.0 + + cur.execute( + "SELECT edge_type, COUNT(*) AS cnt FROM network_edges WHERE network_id = ? GROUP BY edge_type", + (network_id,), + ) + edge_types = {r["edge_type"]: int(r["cnt"]) for r in cur.fetchall()} + finally: + conn.close() + + payload = { + "network_id": network_id, + "edge_count": edge_count, + "avg_weight": avg_weight, + "edge_types": edge_types, + } + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(payload, indent=2), encoding="utf-8") + console.print(f"[green]✓[/green] Wrote network report: {output}") diff --git a/extropy/cli/commands/results.py b/extropy/cli/commands/results.py index 3011eee..b6ece60 100644 --- a/extropy/cli/commands/results.py +++ b/extropy/cli/commands/results.py @@ -1,5 +1,9 @@ -"""Results command for displaying simulation results.""" +"""Results command for DB-first simulation results.""" +from __future__ import annotations + +import json +import sqlite3 from pathlib import Path import typer @@ -9,7 +13,10 @@ @app.command("results") def results_command( - results_dir: Path = typer.Argument(..., help="Results directory from simulation"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + run_id: str | None = typer.Option( + None, "--run-id", help="Run ID (reserved for multi-run support)" + ), segment: str | None = typer.Option( None, "--segment", "-s", help="Attribute to segment by" ), @@ -18,43 +25,185 @@ def results_command( None, "--agent", "-a", help="Show single agent details" ), ): - """ - Display simulation results. - - Load and display results from a completed simulation run. - - Example: - extropy results results/ # Summary view - extropy results results/ --segment age # Breakdown by age - extropy results results/ --timeline # Timeline view - extropy results results/ --agent agent_001 # Single agent - """ - from ...results import ( - load_results, - display_summary, - display_segment_breakdown, - display_timeline, - display_agent, + """Display simulation results from the canonical study DB.""" + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + + try: + if run_id: + cur = conn.cursor() + cur.execute( + "SELECT status, started_at, completed_at, stopped_reason FROM simulation_runs WHERE run_id = ?", + (run_id,), + ) + run_row = cur.fetchone() + if not run_row: + console.print(f"[red]✗[/red] run_id not found: {run_id}") + raise typer.Exit(1) + console.print( + f"[dim]run_id={run_id} status={run_row['status']} " + f"started_at={run_row['started_at']} completed_at={run_row['completed_at'] or '-'}[/dim]" + ) + if agent: + _display_agent(conn, agent) + return + if segment: + _display_segment(conn, segment) + return + if timeline: + _display_timeline(conn) + return + _display_summary(conn) + finally: + conn.close() + + +def _display_summary(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states") + total = int(cur.fetchone()["cnt"]) + if total == 0: + console.print("[yellow]No simulation state found in study DB.[/yellow]") + return + + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states WHERE aware = 1") + aware = int(cur.fetchone()["cnt"]) + + cur.execute( + """ + SELECT COALESCE(private_position, position) AS position, COUNT(*) AS cnt + FROM agent_states + WHERE COALESCE(private_position, position) IS NOT NULL + GROUP BY COALESCE(private_position, position) + ORDER BY cnt DESC + """ ) + rows = cur.fetchall() console.print() + console.print("[bold]Simulation Summary[/bold]") + console.print(f"Agents: {total}") + console.print(f"Aware: {aware} ({aware / total:.1%})") + console.print("Positions:") + for row in rows: + pct = int(row["cnt"]) / total + console.print(f" - {row['position']}: {row['cnt']} ({pct:.1%})") - if not results_dir.exists(): - console.print(f"[red]✗[/red] Results directory not found: {results_dir}") - raise typer.Exit(1) - try: - reader = load_results(results_dir) - except Exception as e: - console.print(f"[red]✗[/red] Failed to load results: {e}") - raise typer.Exit(1) +def _display_timeline(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + SELECT timestep, new_exposures, agents_reasoned, shares_occurred, exposure_rate + FROM timestep_summaries + ORDER BY timestep + """ + ) + rows = cur.fetchall() + if not rows: + console.print("[yellow]No timestep summaries found.[/yellow]") + return + + console.print() + console.print("[bold]Timeline[/bold]") + for row in rows: + console.print( + f"t={row['timestep']:>3} | new_exp={row['new_exposures']:>5} | " + f"reasoned={row['agents_reasoned']:>5} | shares={row['shares_occurred']:>5} | " + f"exposure={float(row['exposure_rate']):.1%}" + ) + + +def _display_segment(conn: sqlite3.Connection, attribute: str) -> None: + cur = conn.cursor() + cur.execute("SELECT agent_id, attrs_json FROM agents") + attr_by_agent: dict[str, str] = {} + for row in cur.fetchall(): + try: + attrs = json.loads(row["attrs_json"]) + except json.JSONDecodeError: + continue + attr_by_agent[str(row["agent_id"])] = str(attrs.get(attribute, "unknown")) + + if not attr_by_agent: + console.print("[yellow]No agent attribute records found.[/yellow]") + return + + cur.execute( + """ + SELECT agent_id, aware, COALESCE(private_position, position) AS position + FROM agent_states + """ + ) + groups: dict[str, dict[str, int]] = {} + for row in cur.fetchall(): + aid = str(row["agent_id"]) + key = attr_by_agent.get(aid, "unknown") + if key not in groups: + groups[key] = {"total": 0, "aware": 0} + groups[key]["total"] += 1 + if int(row["aware"]) == 1: + groups[key]["aware"] += 1 + + console.print() + console.print(f"[bold]Segment by {attribute}[/bold]") + for key, data in sorted(groups.items(), key=lambda x: x[1]["total"], reverse=True): + total = data["total"] + aware = data["aware"] + pct = aware / total if total else 0.0 + console.print(f" - {key}: {total} agents, aware={aware} ({pct:.1%})") - # Dispatch to appropriate view - if agent: - display_agent(console, reader, agent) - elif segment: - display_segment_breakdown(console, reader, segment) - elif timeline: - display_timeline(console, reader) - else: - display_summary(console, reader) + +def _display_agent(conn: sqlite3.Connection, agent_id: str) -> None: + cur = conn.cursor() + cur.execute( + """ + SELECT * + FROM agent_states + WHERE agent_id = ? + """, + (agent_id,), + ) + row = cur.fetchone() + if not row: + console.print(f"[yellow]Agent not found in simulation state: {agent_id}[/yellow]") + return + + cur.execute("SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,)) + attrs_row = cur.fetchone() + attrs = {} + if attrs_row: + try: + attrs = json.loads(attrs_row["attrs_json"]) + except json.JSONDecodeError: + attrs = {} + + console.print() + console.print(f"[bold]Agent {agent_id}[/bold]") + console.print(f"Aware: {bool(row['aware'])}") + console.print(f"Position: {row['private_position'] or row['position']}") + console.print( + f"Sentiment: {row['private_sentiment'] if row['private_sentiment'] is not None else row['sentiment']}" + ) + console.print( + f"Conviction: {row['private_conviction'] if row['private_conviction'] is not None else row['conviction']}" + ) + if row["public_statement"]: + console.print(f"Public statement: {row['public_statement']}") + if row["action_intent"]: + console.print(f"Action intent: {row['action_intent']}") + if row["raw_reasoning"]: + console.print() + console.print("[bold]Raw Reasoning[/bold]") + console.print(str(row["raw_reasoning"])) + if attrs: + console.print() + console.print("[bold]Attributes[/bold]") + for key in sorted(attrs.keys()): + if key.startswith("_"): + continue + console.print(f" - {key}: {attrs[key]}") diff --git a/extropy/cli/commands/sample.py b/extropy/cli/commands/sample.py index 5db1f20..bbea94c 100644 --- a/extropy/cli/commands/sample.py +++ b/extropy/cli/commands/sample.py @@ -22,8 +22,9 @@ def sample_command( spec_file: Path = typer.Argument( ..., help="Population spec YAML file to sample from" ), - output: Path = typer.Option( - ..., "--output", "-o", help="Output file path (.json or .db)" + study_db: Path = typer.Option(..., "--study-db", help="Canonical study database"), + population_id: str = typer.Option( + "default", "--population-id", help="Population identifier inside study DB" ), count: int | None = typer.Option( None, "--count", "-n", help="Number of agents (default: spec.meta.size)" @@ -31,9 +32,6 @@ def sample_command( seed: int | None = typer.Option( None, "--seed", help="Random seed for reproducibility" ), - format: str = typer.Option( - "json", "--format", "-f", help="Output format: json or sqlite" - ), report: bool = typer.Option( False, "--report", "-r", help="Show distribution summaries and stats" ), @@ -54,18 +52,16 @@ def sample_command( 3 = Sampling error EXAMPLES: - extropy sample surgeons.yaml -o agents.json - extropy sample surgeons.yaml -n 500 -o agents.json --seed 42 - extropy sample surgeons.yaml -n 1000 -o agents.db --format sqlite - extropy sample surgeons.yaml -o agents.json --report - extropy --json sample surgeons.yaml -o agents.json --report + extropy sample surgeons.yaml --study-db study.db + extropy sample surgeons.yaml --study-db study.db --population-id main --seed 42 + extropy sample surgeons.yaml --study-db study.db --count 1000 --report + extropy --json sample surgeons.yaml --study-db study.db --report """ from ...population.sampler import ( sample_population, - save_json, - save_sqlite, SamplingError, ) + from ...storage import open_study_db out = Output(console, json_mode=get_json_mode()) start_time = time.time() @@ -289,35 +285,48 @@ def on_progress(current: int, total: int): ) out.blank() - # Save Output + # Save to canonical DB out.blank() - output_format = format.lower() - - if output.suffix.lower() == ".db": - output_format = "sqlite" - elif output.suffix.lower() == ".json": - output_format = "json" - if not get_json_mode(): - with console.status(f"[cyan]Saving to {output_format}...[/cyan]"): - if output_format == "sqlite": - save_sqlite(result, output) - else: - save_json(result, output) + with console.status(f"[cyan]Saving to study DB: {study_db}...[/cyan]"): + with open_study_db(study_db) as db: + db.save_population_spec( + population_id=population_id, + spec_yaml=spec_file.read_text(encoding="utf-8"), + source_path=str(spec_file), + ) + sample_run_id = db.save_sample_result( + population_id=population_id, + agents=result.agents, + meta=result.meta, + seed=result.meta.get("seed"), + ) else: - if output_format == "sqlite": - save_sqlite(result, output) - else: - save_json(result, output) + with open_study_db(study_db) as db: + db.save_population_spec( + population_id=population_id, + spec_yaml=spec_file.read_text(encoding="utf-8"), + source_path=str(spec_file), + ) + sample_run_id = db.save_sample_result( + population_id=population_id, + agents=result.agents, + meta=result.meta, + seed=result.meta.get("seed"), + ) elapsed = time.time() - start_time - out.set_data("output_file", str(output)) - out.set_data("output_format", output_format) + out.set_data("study_db", str(study_db)) + out.set_data("population_id", population_id) + out.set_data("sample_run_id", sample_run_id) out.set_data("total_time_seconds", elapsed) out.divider() - out.success(f"Saved {len(result.agents)} agents to [bold]{output}[/bold]") + out.success( + f"Saved {len(result.agents)} agents to [bold]{study_db}[/bold] " + f"(population_id={population_id}, sample_run_id={sample_run_id})" + ) out.text(f"[dim]Total time: {format_elapsed(elapsed)}[/dim]") out.divider() diff --git a/extropy/cli/commands/scenario.py b/extropy/cli/commands/scenario.py index 0a23d52..6a5d99d 100644 --- a/extropy/cli/commands/scenario.py +++ b/extropy/cli/commands/scenario.py @@ -17,8 +17,13 @@ def scenario_command( population: Path = typer.Option( ..., "--population", "-p", help="Population spec YAML file" ), - agents: Path = typer.Option(..., "--agents", "-a", help="Sampled agents JSON file"), - network: Path = typer.Option(..., "--network", "-n", help="Network JSON file"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + population_id: str = typer.Option( + "default", "--population-id", help="Population ID in study DB" + ), + network_id: str = typer.Option( + "default", "--network-id", help="Network ID in study DB" + ), description: str | None = typer.Option( None, "--description", @@ -44,12 +49,12 @@ def scenario_command( - Outcome definitions (what to measure) Example: - extropy scenario -p population.yaml -a agents.json -n network.json - extropy scenario -p pop.yaml -a agents.json -n net.json -d "Custom description" -o custom.yaml + extropy scenario -p population.yaml --study-db study.db + extropy scenario -p pop.yaml --study-db study.db --population-id main --network-id main -d "Custom description" -o custom.yaml """ from ...core.models import PopulationSpec from ...scenario import create_scenario - from ...utils import make_relative_to + from ...storage import open_study_db start_time = time.time() console.print() @@ -59,13 +64,21 @@ def scenario_command( console.print(f"[red]✗[/red] Population spec not found: {population}") raise typer.Exit(1) - if not agents.exists(): - console.print(f"[red]✗[/red] Agents file not found: {agents}") + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") raise typer.Exit(1) - if not network.exists(): - console.print(f"[red]✗[/red] Network file not found: {network}") - raise typer.Exit(1) + with open_study_db(study_db) as db: + if db.get_agent_count(population_id) == 0: + console.print( + f"[red]✗[/red] No agents found for population_id '{population_id}' in {study_db}" + ) + raise typer.Exit(1) + if db.get_network_edge_count(network_id) == 0: + console.print( + f"[red]✗[/red] No network edges found for network_id '{network_id}' in {study_db}" + ) + raise typer.Exit(1) # Load population spec to get scenario description if not provided try: @@ -106,8 +119,9 @@ def run_pipeline(): result_spec, validation_result = create_scenario( description=scenario_desc, population_spec_path=population, - agents_path=agents, - network_path=network, + study_db_path=study_db, + population_id=population_id, + network_id=network_id, output_path=None, # Don't save yet on_progress=on_progress, ) @@ -239,10 +253,10 @@ def run_pipeline(): console.print("[dim]Cancelled.[/dim]") raise typer.Exit(0) - # Convert paths to be relative to output file before saving - result_spec.meta.population_spec = make_relative_to(population, output_path) - result_spec.meta.agents_file = make_relative_to(agents, output_path) - result_spec.meta.network_file = make_relative_to(network, output_path) + result_spec.meta.population_spec = str(population) + result_spec.meta.study_db = str(study_db) + result_spec.meta.population_id = population_id + result_spec.meta.network_id = network_id # Save to YAML result_spec.to_yaml(output_path) diff --git a/extropy/cli/commands/simulate.py b/extropy/cli/commands/simulate.py index 73f162f..7d6168e 100644 --- a/extropy/cli/commands/simulate.py +++ b/extropy/cli/commands/simulate.py @@ -101,6 +101,7 @@ def setup_logging(verbose: bool = False, debug: bool = False): def simulate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), output: Path = typer.Option(..., "--output", "-o", help="Output results directory"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), model: str = typer.Option( "", "--model", @@ -132,6 +133,55 @@ def simulate_command( chunk_size: int = typer.Option( 50, "--chunk-size", help="Agents per reasoning chunk for checkpointing" ), + checkpoint_every_chunks: int = typer.Option( + 1, + "--checkpoint-every-chunks", + min=1, + help="Persist simulation chunk checkpoints every N chunks", + ), + run_id: str | None = typer.Option( + None, + "--run-id", + help="Explicit run id (required with --resume)", + ), + resume: bool = typer.Option( + False, + "--resume", + help="Resume an existing run from study DB checkpoints", + ), + writer_queue_size: int = typer.Option( + 256, + "--writer-queue-size", + min=1, + help="Reserved writer queue size (future pipeline tuning)", + ), + db_write_batch_size: int = typer.Option( + 100, + "--db-write-batch-size", + min=1, + help="Reserved DB write batch size (future pipeline tuning)", + ), + retention_lite: bool = typer.Option( + False, + "--retention-lite", + help="Reduce retained payload volume (drops full raw reasoning text)", + ), + resource_mode: str = typer.Option( + "auto", + "--resource-mode", + help="Resource tuning mode: auto | manual", + ), + safe_auto_workers: bool = typer.Option( + True, + "--safe-auto-workers/--unsafe-auto-workers", + help="Conservative auto tuning for laptop/VM environments", + ), + max_memory_gb: float | None = typer.Option( + None, + "--max-memory-gb", + min=0.5, + help="Optional memory budget cap for auto resource tuning", + ), seed: int | None = typer.Option( None, "--seed", help="Random seed for reproducibility" ), @@ -157,12 +207,13 @@ def simulate_command( used automatically for embodied first-person personas. Example: - extropy simulate scenario.yaml -o results/ - extropy simulate scenario.yaml -o results/ --model gpt-5-nano --seed 42 - extropy simulate scenario.yaml -o results/ --persona population.persona.yaml + extropy simulate scenario.yaml --study-db study.db -o results/ + extropy simulate scenario.yaml --study-db study.db -o results/ --model gpt-5-nano --seed 42 + extropy simulate scenario.yaml --study-db study.db -o results/ --persona population.persona.yaml """ from ...simulation import run_simulation from ...simulation.progress import SimulationProgress + from ...utils import ResourceGovernor # Setup logging based on verbosity setup_logging(verbose=verbose, debug=debug) @@ -174,6 +225,15 @@ def simulate_command( if not scenario_file.exists(): console.print(f"[red]✗[/red] Scenario file not found: {scenario_file}") raise typer.Exit(1) + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + if resume and not run_id: + console.print("[red]✗[/red] --resume requires --run-id") + raise typer.Exit(1) + if resource_mode not in {"auto", "manual"}: + console.print("[red]✗[/red] --resource-mode must be 'auto' or 'manual'") + raise typer.Exit(1) from ...config import get_config @@ -192,6 +252,7 @@ def simulate_command( console.print(f"Simulating: [bold]{scenario_file}[/bold]") console.print(f"Output: {output}") + console.print(f"Study DB: {study_db}") console.print( f"Provider: {display_provider} | Model: {display_model} | Threshold: {threshold}" ) @@ -213,6 +274,27 @@ def simulate_command( console.print(f"Rate overrides: {' | '.join(parts)}") if seed: console.print(f"Seed: {seed}") + governor = ResourceGovernor( + resource_mode=resource_mode, + safe_auto_workers=safe_auto_workers, + max_memory_gb=max_memory_gb, + ) + tuned_chunk_size = governor.recommend_chunk_size( + requested_chunk_size=chunk_size, + min_chunk_size=8, + max_chunk_size=2000, + ) + if resource_mode == "auto": + snap = governor.snapshot() + console.print( + f"Resources(auto): cpu={snap.cpu_count} mem={snap.total_memory_gb:.1f}GB " + f"budget={snap.memory_budget_gb:.1f}GB chunk={tuned_chunk_size}" + ) + if writer_queue_size != 256 or db_write_batch_size != 100: + console.print( + "[dim]Note: writer queue/batch flags are accepted now and will be fully enforced " + "by the upcoming async writer pipeline.[/dim]" + ) if verbose or debug: console.print(f"Logging: {'DEBUG' if debug else 'VERBOSE'}") console.print() @@ -236,6 +318,7 @@ def on_progress(timestep: int, max_timesteps: int, status: str): result = run_simulation( scenario_path=scenario_file, output_dir=output, + study_db_path=study_db, model=effective_model, pivotal_model=effective_pivotal, routine_model=effective_routine, @@ -246,8 +329,12 @@ def on_progress(timestep: int, max_timesteps: int, status: str): rate_tier=effective_tier, rpm_override=effective_rpm, tpm_override=effective_tpm, - chunk_size=chunk_size, + chunk_size=tuned_chunk_size, progress=progress_state, + run_id=run_id, + resume=resume, + checkpoint_every_chunks=checkpoint_every_chunks, + retention_lite=retention_lite, ) simulation_error = None except Exception as e: @@ -265,6 +352,7 @@ def do_simulation(): result = run_simulation( scenario_path=scenario_file, output_dir=output, + study_db_path=study_db, model=effective_model, pivotal_model=effective_pivotal, routine_model=effective_routine, @@ -275,8 +363,12 @@ def do_simulation(): rate_tier=effective_tier, rpm_override=effective_rpm, tpm_override=effective_tpm, - chunk_size=chunk_size, + chunk_size=tuned_chunk_size, progress=progress_state, + run_id=run_id, + resume=resume, + checkpoint_every_chunks=checkpoint_every_chunks, + retention_lite=retention_lite, ) except Exception as e: simulation_error = e diff --git a/extropy/cli/commands/validate.py b/extropy/cli/commands/validate.py index 0ca0fa3..d09177a 100644 --- a/extropy/cli/commands/validate.py +++ b/extropy/cli/commands/validate.py @@ -185,17 +185,13 @@ def _validate_scenario_spec(spec_file: Path, out: Output) -> int: f" [red]✗[/red] Population: {spec.meta.population_spec} (not found)" ) - agents_path = resolve_relative_to(spec.meta.agents_file, spec_file) - if agents_path.exists(): - out.text(f" [green]✓[/green] Agents: {spec.meta.agents_file}") + study_db_path = resolve_relative_to(spec.meta.study_db, spec_file) + if study_db_path.exists(): + out.text(f" [green]✓[/green] Study DB: {spec.meta.study_db}") else: - out.text(f" [red]✗[/red] Agents: {spec.meta.agents_file} (not found)") - - network_path = resolve_relative_to(spec.meta.network_file, spec_file) - if network_path.exists(): - out.text(f" [green]✓[/green] Network: {spec.meta.network_file}") - else: - out.text(f" [red]✗[/red] Network: {spec.meta.network_file} (not found)") + out.text(f" [red]✗[/red] Study DB: {spec.meta.study_db} (not found)") + out.text(f" [cyan]•[/cyan] population_id: {spec.meta.population_id}") + out.text(f" [cyan]•[/cyan] network_id: {spec.meta.network_id}") out.blank() diff --git a/extropy/core/models/scenario.py b/extropy/core/models/scenario.py index 0288667..5cabe1f 100644 --- a/extropy/core/models/scenario.py +++ b/extropy/core/models/scenario.py @@ -268,8 +268,9 @@ class ScenarioMeta(BaseModel): name: str = Field(description="Short identifier for the scenario") description: str = Field(description="Full scenario description") population_spec: str = Field(description="Path to population YAML") - agents_file: str = Field(description="Path to sampled agents JSON") - network_file: str = Field(description="Path to network JSON") + study_db: str = Field(description="Path to canonical study DB") + population_id: str = Field(default="default", description="Population ID in study DB") + network_id: str = Field(default="default", description="Network ID in study DB") created_at: datetime = Field(default_factory=datetime.now) @@ -305,7 +306,34 @@ def from_yaml(cls, path: Path | str) -> "ScenarioSpec": with open(path) as f: data = yaml.safe_load(f) - return cls.model_validate(data) + if not isinstance(data, dict): + raise ValueError("Scenario YAML must parse to an object") + + meta = data.get("meta", {}) + if isinstance(meta, dict) and ( + "agents_file" in meta or "network_file" in meta + ): + raise ValueError( + "Legacy scenario schema detected (meta.agents_file/meta.network_file). " + "Migrate with: extropy migrate scenario --input " + f"{path} --study-db study.db --population-id default --network-id default" + ) + + try: + return cls.model_validate(data) + except Exception as e: + if isinstance(meta, dict) and ( + "study_db" not in meta + or "population_id" not in meta + or "network_id" not in meta + ): + raise ValueError( + "Scenario metadata must include meta.study_db, meta.population_id, " + "and meta.network_id. If this is an older scenario, run: " + "extropy migrate scenario --input " + f"{path} --study-db study.db --population-id default --network-id default" + ) from e + raise def summary(self) -> str: """Get a text summary of the scenario spec.""" diff --git a/extropy/population/network/config.py b/extropy/population/network/config.py index f411c2d..08a4024 100644 --- a/extropy/population/network/config.py +++ b/extropy/population/network/config.py @@ -103,8 +103,18 @@ class NetworkConfig(BaseModel): Attributes: avg_degree: Target average degree (connections per agent) rewire_prob: Watts-Strogatz rewiring probability + similarity_store_threshold: Minimum similarity retained in sparse matrix similarity_threshold: Sigmoid threshold for edge probability similarity_steepness: Sigmoid steepness for edge probability + candidate_mode: Similarity candidate strategy. + - "exact": all-pairs (highest fidelity, slowest) + - "blocked": block-based candidate pruning (near-equivalent, much faster) + candidate_pool_multiplier: Candidate pool size per node as a multiple of avg_degree + min_candidate_pool: Lower bound for candidate pool size per node in blocked mode + blocking_attributes: Attributes used for blocking. Auto-selected if empty. + similarity_workers: Worker processes for similarity stage (1 = serial) + similarity_chunk_size: Row chunk size per worker task + checkpoint_every_rows: Save similarity checkpoint every N rows triadic_closure_prob: Probability of closing open triads (A-B, B-C -> A-C). Higher values create more realistic clustering. Default 0.4. target_clustering: Target clustering coefficient (0.3-0.5 is realistic). @@ -123,8 +133,16 @@ class NetworkConfig(BaseModel): avg_degree: float = 20.0 rewire_prob: float = 0.05 + similarity_store_threshold: float = 0.05 similarity_threshold: float = 0.3 similarity_steepness: float = 10.0 + candidate_mode: Literal["exact", "blocked"] = "exact" + candidate_pool_multiplier: float = 12.0 + min_candidate_pool: int = 80 + blocking_attributes: list[str] = Field(default_factory=list) + similarity_workers: int = 1 + similarity_chunk_size: int = 64 + checkpoint_every_rows: int = 250 triadic_closure_prob: float = 0.6 target_clustering: float = 0.35 target_modularity: float = 0.55 # Target modularity (0.4-0.7 range) diff --git a/extropy/population/network/generator.py b/extropy/population/network/generator.py index 09fa31c..75ad7b7 100644 --- a/extropy/population/network/generator.py +++ b/extropy/population/network/generator.py @@ -5,12 +5,17 @@ import json import logging +import hashlib +import multiprocessing as mp +import pickle import random +from concurrent.futures import ProcessPoolExecutor, as_completed from datetime import datetime from pathlib import Path from typing import Any from ...core.models import Edge, NetworkResult +from ...storage import open_study_db from ...utils.callbacks import NetworkProgressCallback from ...utils.eval_safe import ConditionError, eval_condition from .config import NetworkConfig, InfluenceFactorConfig @@ -21,6 +26,463 @@ logger = logging.getLogger(__name__) +_SIM_WORKER_AGENTS: list[dict[str, Any]] | None = None +_SIM_WORKER_ATTRIBUTE_WEIGHTS = None +_SIM_WORKER_ORDINAL_LEVELS: dict[str, dict[str, int]] | None = None +_SIM_WORKER_THRESHOLD: float = 0.05 +_SIM_WORKER_CANDIDATE_MAP: list[list[int]] | None = None + + +def _is_db_checkpoint(path: Path | None) -> bool: + return path is not None and path.suffix.lower() == ".db" + + +def _choose_blocking_attributes(config: NetworkConfig) -> list[str]: + """Choose blocking attributes for candidate pruning.""" + if config.blocking_attributes: + return list(config.blocking_attributes) + + weighted = sorted( + config.attribute_weights.items(), + key=lambda x: x[1].weight, + reverse=True, + ) + preferred = [ + attr for attr, cfg in weighted if cfg.match_type in {"exact", "within_n"} + ] + + if preferred: + return preferred[:3] + + return [attr for attr, _ in weighted[:2]] + + +def _build_blocked_candidate_map( + agents: list[dict[str, Any]], + config: NetworkConfig, + seed: int, +) -> tuple[list[list[int]] | None, list[str]]: + """Build per-agent candidate lists for blocked similarity mode.""" + attrs = _choose_blocking_attributes(config) + n = len(agents) + + if not attrs or n <= 1: + return None, attrs + + blocks: dict[str, dict[Any, list[int]]] = {attr: {} for attr in attrs} + + for idx, agent in enumerate(agents): + for attr in attrs: + val = agent.get(attr) + if val is None: + continue + blocks[attr].setdefault(val, []).append(idx) + + target_pool = max(config.min_candidate_pool, int(config.avg_degree * config.candidate_pool_multiplier)) + target_pool = max(1, min(n - 1, target_pool)) + + candidate_map: list[list[int]] = [[] for _ in range(n)] + + for i, agent in enumerate(agents): + scores: dict[int, int] = {} + + for attr in attrs: + val = agent.get(attr) + if val is None: + continue + for j in blocks[attr].get(val, []): + if j == i: + continue + scores[j] = scores.get(j, 0) + 1 + + ranked = sorted(scores.items(), key=lambda x: (-x[1], x[0])) + chosen = [j for j, _ in ranked[:target_pool]] + + if len(chosen) < target_pool: + rng = random.Random(seed + (i + 1) * 7919) + seen = set(chosen) + seen.add(i) + while len(chosen) < target_pool and len(seen) < n: + j = rng.randrange(n) + if j in seen: + continue + seen.add(j) + chosen.append(j) + + candidate_map[i] = sorted(chosen) + + return candidate_map, attrs + + +def _similarity_checkpoint_signature( + n: int, + seed: int, + config: NetworkConfig, + blocking_attrs: list[str], +) -> dict[str, Any]: + """Build a minimal signature to validate checkpoint compatibility.""" + return { + "n": n, + "seed": seed, + "candidate_mode": config.candidate_mode, + "threshold": config.similarity_store_threshold, + "candidate_pool_multiplier": config.candidate_pool_multiplier, + "min_candidate_pool": config.min_candidate_pool, + "blocking_attributes": blocking_attrs, + } + + +def _similarity_checkpoint_job_id(signature: dict[str, Any]) -> str: + raw = json.dumps(signature, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24] + + +def _save_similarity_checkpoint( + path: Path, + similarities: dict[tuple[int, int], float], + completed_rows: int, + signature: dict[str, Any], + completed_chunks: list[tuple[int, int]] | None = None, +) -> None: + """Persist sparse similarities so generation can resume after interruption.""" + payload = { + "version": 1, + "completed_rows": completed_rows, + "completed_chunks": completed_chunks or [], + "signature": signature, + "similarities": similarities, + "saved_at": datetime.now().isoformat(), + } + if _is_db_checkpoint(path): + job_id = _similarity_checkpoint_job_id(signature) + with open_study_db(path) as db: + db.init_network_similarity_job( + network_run_id=f"checkpoint:{job_id}", + signature=signature, + job_id=job_id, + ) + db.save_similarity_snapshot(job_id=job_id, payload=pickle.dumps(payload)) + return + + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with open(tmp_path, "wb") as f: + pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + tmp_path.replace(path) + + +def _load_similarity_checkpoint( + path: Path, + expected_signature: dict[str, Any], +) -> tuple[dict[tuple[int, int], float], int, set[int]]: + """Load checkpoint and validate compatibility with current run settings.""" + if _is_db_checkpoint(path): + job_id = _similarity_checkpoint_job_id(expected_signature) + with open_study_db(path) as db: + signature = db.get_network_similarity_job_signature(job_id) + if signature is not None: + if signature != expected_signature: + raise ValueError( + "Checkpoint settings do not match current run. " + "Delete checkpoint or run with matching config." + ) + + done_chunks = db.list_completed_similarity_chunks(job_id) + done_starts = {start for start, _ in done_chunks} + similarities = db.load_similarity_pairs(job_id) + + # Resume serial fallback only from contiguous completed prefix. + contiguous_rows = 0 + for start, end in done_chunks: + if start != contiguous_rows: + break + contiguous_rows = end + + return similarities, max(0, contiguous_rows), done_starts + + payload_bytes = db.get_similarity_snapshot(job_id) + if payload_bytes is None: + raise ValueError(f"Checkpoint not found in study DB: job_id={job_id}") + payload = pickle.loads(payload_bytes) + else: + with open(path, "rb") as f: + payload = pickle.load(f) + + signature = payload.get("signature", {}) + if signature != expected_signature: + raise ValueError( + "Checkpoint settings do not match current run. " + "Delete checkpoint or run with matching config." + ) + + similarities = payload.get("similarities", {}) + completed_rows = int(payload.get("completed_rows", 0)) + completed_chunk_starts: set[int] = set() + for item in payload.get("completed_chunks", []): + if ( + isinstance(item, (list, tuple)) + and len(item) == 2 + and isinstance(item[0], int) + and isinstance(item[1], int) + ): + completed_chunk_starts.add(item[0]) + + if not isinstance(similarities, dict): + raise ValueError("Invalid checkpoint similarities payload") + + return similarities, max(0, completed_rows), completed_chunk_starts + + +def _init_similarity_worker( + agents: list[dict[str, Any]], + attribute_weights, + ordinal_levels: dict[str, dict[str, int]] | None, + threshold: float, + candidate_map: list[list[int]] | None, +) -> None: + """Initialize process-local globals for similarity workers.""" + global _SIM_WORKER_AGENTS + global _SIM_WORKER_ATTRIBUTE_WEIGHTS + global _SIM_WORKER_ORDINAL_LEVELS + global _SIM_WORKER_THRESHOLD + global _SIM_WORKER_CANDIDATE_MAP + + _SIM_WORKER_AGENTS = agents + _SIM_WORKER_ATTRIBUTE_WEIGHTS = attribute_weights + _SIM_WORKER_ORDINAL_LEVELS = ordinal_levels + _SIM_WORKER_THRESHOLD = threshold + _SIM_WORKER_CANDIDATE_MAP = candidate_map + + +def _compute_similarity_chunk(task: tuple[int, int]) -> tuple[int, list[tuple[int, int, float]]]: + """Compute similarities for a chunk of row indices in a worker process.""" + start, end = task + if _SIM_WORKER_AGENTS is None: + raise RuntimeError("Similarity worker not initialized") + + n = len(_SIM_WORKER_AGENTS) + rows: list[tuple[int, int, float]] = [] + + for i in range(start, min(end, n)): + if _SIM_WORKER_CANDIDATE_MAP is None: + candidates = range(i + 1, n) + else: + candidates = _SIM_WORKER_CANDIDATE_MAP[i] + + for j in candidates: + if j <= i: + continue + sim = compute_similarity( + _SIM_WORKER_AGENTS[i], + _SIM_WORKER_AGENTS[j], + _SIM_WORKER_ATTRIBUTE_WEIGHTS, + _SIM_WORKER_ORDINAL_LEVELS, + ) + if sim >= _SIM_WORKER_THRESHOLD: + rows.append((i, j, sim)) + + return end, rows + + +def _compute_similarities_parallel( + agents: list[dict[str, Any]], + config: NetworkConfig, + candidate_map: list[list[int]] | None, + on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | None = None, + checkpoint_signature: dict[str, Any] | None = None, + initial_similarities: dict[tuple[int, int], float] | None = None, + completed_rows: int = 0, + completed_chunk_starts: set[int] | None = None, + checkpoint_job_id: str | None = None, +) -> dict[tuple[int, int], float]: + """Compute sparse similarities with process parallelism.""" + n = len(agents) + similarities: dict[tuple[int, int], float] = dict(initial_similarities or {}) + + chunk_size = max(8, config.similarity_chunk_size) + tasks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)] + task_ends = {start: end for start, end in tasks} + completed_starts: set[int] = set(completed_chunk_starts or set()) + for start, _ in tasks: + if start < completed_rows: + completed_starts.add(start) + pending_tasks = [(s, e) for s, e in tasks if s not in completed_starts] + workers = max(1, config.similarity_workers) + + completed_row_count = sum((e - s) for s, e in tasks if s in completed_starts) + if on_progress and completed_row_count > 0: + on_progress("Computing similarities", min(completed_row_count, n), n) + + try: + ctx = mp.get_context("spawn") + with ProcessPoolExecutor( + max_workers=workers, + mp_context=ctx, + initializer=_init_similarity_worker, + initargs=( + agents, + config.attribute_weights, + config.ordinal_levels, + config.similarity_store_threshold, + candidate_map, + ), + ) as ex: + futures = { + ex.submit(_compute_similarity_chunk, task): task for task in pending_tasks + } + pending_results: dict[int, list[tuple[int, int, float]]] = {} + sorted_starts = [start for start, _ in tasks] + next_commit_idx = 0 + + for fut in as_completed(futures): + task_start, _task_end = futures[fut] + _row_end, local_rows = fut.result() + pending_results[task_start] = local_rows + + # Deterministic merge: commit completed chunks in chunk_start order. + while next_commit_idx < len(sorted_starts): + current_start = sorted_starts[next_commit_idx] + current_end = task_ends[current_start] + if current_start in completed_starts: + next_commit_idx += 1 + continue + if current_start not in pending_results: + break + + chunk_rows = pending_results.pop(current_start) + for i, j, sim in chunk_rows: + similarities[(i, j)] = sim + completed_starts.add(current_start) + completed_row_count += current_end - current_start + completed_rows = max(completed_rows, current_end) + + if ( + checkpoint_path is not None + and checkpoint_signature is not None + ): + if _is_db_checkpoint(checkpoint_path) and checkpoint_job_id: + with open_study_db(checkpoint_path) as db: + db.save_similarity_chunk_rows( + job_id=checkpoint_job_id, + chunk_start=current_start, + chunk_end=current_end, + rows=chunk_rows, + ) + else: + completed_chunks = [ + (s, e) for s, e in tasks if s in completed_starts + ] + _save_similarity_checkpoint( + path=checkpoint_path, + similarities=similarities, + completed_rows=min(completed_row_count, n), + signature=checkpoint_signature, + completed_chunks=completed_chunks, + ) + + if on_progress: + on_progress( + "Computing similarities", min(completed_row_count, n), n + ) + next_commit_idx += 1 + + except Exception as e: + logger.warning( + "Parallel similarity failed (%s). Falling back to serial mode.", e + ) + return _compute_similarities_serial( + agents=agents, + config=config, + candidate_map=candidate_map, + on_progress=on_progress, + checkpoint_path=checkpoint_path, + initial_similarities=similarities, + start_row=completed_rows, + checkpoint_signature=checkpoint_signature, + completed_chunk_starts=completed_starts, + checkpoint_job_id=checkpoint_job_id, + ) + + return similarities + + +def _compute_similarities_serial( + agents: list[dict[str, Any]], + config: NetworkConfig, + candidate_map: list[list[int]] | None = None, + on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | None = None, + initial_similarities: dict[tuple[int, int], float] | None = None, + start_row: int = 0, + checkpoint_signature: dict[str, Any] | None = None, + completed_chunk_starts: set[int] | None = None, + checkpoint_job_id: str | None = None, +) -> dict[tuple[int, int], float]: + """Compute sparse similarities serially, with optional checkpointing.""" + n = len(agents) + threshold = config.similarity_store_threshold + similarities = dict(initial_similarities or {}) + checkpoint_every = max(1, config.checkpoint_every_rows) + chunk_size = max(8, config.similarity_chunk_size) + tasks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)] + completed_starts: set[int] = set(completed_chunk_starts or set()) + for start, _ in tasks: + if start < start_row: + completed_starts.add(start) + completed_row_count = sum((e - s) for s, e in tasks if s in completed_starts) + + for chunk_idx, (start, end) in enumerate(tasks): + if start in completed_starts: + continue + + local_rows: list[tuple[int, int, float]] = [] + for i in range(start, end): + if candidate_map is None: + candidates = range(i + 1, n) + else: + candidates = candidate_map[i] + + for j in candidates: + if j <= i: + continue + sim = compute_similarity( + agents[i], agents[j], config.attribute_weights, config.ordinal_levels + ) + if sim >= threshold: + similarities[(i, j)] = sim + local_rows.append((i, j, sim)) + + completed_starts.add(start) + completed_row_count += end - start + + if checkpoint_path is not None and checkpoint_signature is not None: + if _is_db_checkpoint(checkpoint_path) and checkpoint_job_id: + with open_study_db(checkpoint_path) as db: + db.save_similarity_chunk_rows( + job_id=checkpoint_job_id, + chunk_start=start, + chunk_end=end, + rows=local_rows, + ) + elif ( + completed_row_count % checkpoint_every == 0 + or chunk_idx == len(tasks) - 1 + ): + completed_chunks = [(s, e) for s, e in tasks if s in completed_starts] + _save_similarity_checkpoint( + path=checkpoint_path, + similarities=similarities, + completed_rows=min(completed_row_count, n), + signature=checkpoint_signature, + completed_chunks=completed_chunks, + ) + + if on_progress: + on_progress("Computing similarities", min(completed_row_count, n), n) + + return similarities + def _eval_edge_condition( condition: str, @@ -429,6 +891,7 @@ def _triadic_closure( edge_set: set[tuple[str, str]], config: NetworkConfig, rng: random.Random, + similarities: dict[tuple[int, int], float] | None = None, communities: list[int] | None = None, target_clustering: float = 0.35, max_edge_increase: float = 1.5, @@ -480,9 +943,17 @@ def _triadic_closure( # Score triads by similarity and community membership triad_with_score = [] for a, c, b in open_triads: - sim = compute_similarity( - agents[a], agents[c], config.attribute_weights, config.ordinal_levels - ) + pair = (min(a, c), max(a, c)) + sim = similarities.get(pair) if similarities is not None else None + if sim is None: + sim = compute_similarity( + agents[a], + agents[c], + config.attribute_weights, + config.ordinal_levels, + ) + if similarities is not None: + similarities[pair] = sim same_community = ( communities is not None and communities[a] == communities[c] ) @@ -691,6 +1162,7 @@ def _generate_network_single_pass( edge_set, config, rng, + similarities=similarities, communities=communities, target_clustering=config.target_clustering, max_edge_increase=2.5, # Allow up to 2.5x edges for better clustering @@ -718,6 +1190,8 @@ def generate_network( agents: list[dict[str, Any]], config: NetworkConfig | None = None, on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | str | None = None, + resume_from_checkpoint: bool = False, ) -> NetworkResult: """Generate a social network from sampled agents. @@ -746,27 +1220,102 @@ def generate_network( n = len(agents) agent_ids = [a.get("_id", f"agent_{i}") for i, a in enumerate(agents)] - - if on_progress: - on_progress("Computing similarities", 0, n) + checkpoint_file = Path(checkpoint_path) if checkpoint_path else None # Step 1: Compute degree factors degree_factors = [compute_degree_factor(a, config) for a in agents] - # Step 2: Compute similarity matrix (sparse) - similarities: dict[tuple[int, int], float] = {} - threshold = 0.05 + # Step 2: Build similarity candidates (exact/blocked) + candidate_map: list[list[int]] | None = None + blocking_attrs: list[str] = [] + candidate_mode = config.candidate_mode - for i in range(n): - for j in range(i + 1, n): - sim = compute_similarity( - agents[i], agents[j], config.attribute_weights, config.ordinal_levels + if config.candidate_mode == "blocked": + if on_progress: + on_progress("Preparing candidate blocks", 0, n) + candidate_map, blocking_attrs = _build_blocked_candidate_map(agents, config, seed) + if on_progress: + on_progress("Preparing candidate blocks", n, n) + if candidate_map is None: + logger.warning( + "Blocked candidate mode could not be initialized. Falling back to exact mode." ) - if sim >= threshold: - similarities[(i, j)] = sim + candidate_mode = "exact" + + if on_progress: + on_progress("Computing similarities", 0, n) + + checkpoint_signature = _similarity_checkpoint_signature( + n=n, + seed=seed, + config=config, + blocking_attrs=blocking_attrs, + ) + checkpoint_job_id: str | None = None + if _is_db_checkpoint(checkpoint_file): + checkpoint_job_id = _similarity_checkpoint_job_id(checkpoint_signature) + if not resume_from_checkpoint: + with open_study_db(checkpoint_file) as db: + db.init_network_similarity_job( + network_run_id=f"checkpoint:{checkpoint_job_id}", + signature=checkpoint_signature, + job_id=checkpoint_job_id, + ) + db.mark_similarity_job_running(checkpoint_job_id) + + similarities: dict[tuple[int, int], float] + start_row = 0 + completed_chunk_starts: set[int] = set() + + if resume_from_checkpoint and checkpoint_file is None: + raise ValueError("--resume-checkpoint requires --checkpoint path") + + if resume_from_checkpoint: + if checkpoint_file is None or not checkpoint_file.exists(): + raise ValueError(f"Checkpoint not found: {checkpoint_file}") + similarities, start_row, completed_chunk_starts = _load_similarity_checkpoint( + checkpoint_file, checkpoint_signature + ) + if checkpoint_job_id and checkpoint_file is not None: + with open_study_db(checkpoint_file) as db: + db.mark_similarity_job_running(checkpoint_job_id) + if on_progress: + on_progress("Computing similarities", min(start_row, n), n) + else: + similarities = {} + + use_parallel_similarity = config.similarity_workers > 1 + + if use_parallel_similarity: + similarities = _compute_similarities_parallel( + agents=agents, + config=config, + candidate_map=candidate_map if candidate_mode == "blocked" else None, + on_progress=on_progress, + checkpoint_path=checkpoint_file, + checkpoint_signature=checkpoint_signature, + initial_similarities=similarities, + completed_rows=start_row, + completed_chunk_starts=completed_chunk_starts, + checkpoint_job_id=checkpoint_job_id, + ) + else: + similarities = _compute_similarities_serial( + agents=agents, + config=config, + candidate_map=candidate_map if candidate_mode == "blocked" else None, + on_progress=on_progress, + checkpoint_path=checkpoint_file, + initial_similarities=similarities, + start_row=start_row, + checkpoint_signature=checkpoint_signature, + completed_chunk_starts=completed_chunk_starts, + checkpoint_job_id=checkpoint_job_id, + ) - if on_progress and i % 50 == 0: - on_progress("Computing similarities", i, n) + if checkpoint_job_id and checkpoint_file is not None: + with open_study_db(checkpoint_file) as db: + db.mark_similarity_job_complete(checkpoint_job_id) if on_progress: on_progress("Computing similarities", n, n) @@ -935,6 +1484,10 @@ def generate_network( "rewired_count": rewired_count, "algorithm": "adaptive_calibration", "seed": seed, + "candidate_mode": candidate_mode, + "similarity_pairs": len(similarities), + "blocking_attributes": blocking_attrs if candidate_mode == "blocked" else [], + "resumed_from_checkpoint": resume_from_checkpoint, "config": { "avg_degree_target": config.avg_degree, "rewire_prob": config.rewire_prob, @@ -951,6 +1504,8 @@ def generate_network_with_metrics( agents: list[dict[str, Any]], config: NetworkConfig | None = None, on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | str | None = None, + resume_from_checkpoint: bool = False, ) -> NetworkResult: """Generate network and compute all metrics. @@ -960,7 +1515,13 @@ def generate_network_with_metrics( """ from .metrics import compute_network_metrics, compute_node_metrics - result = generate_network(agents, config, on_progress) + result = generate_network( + agents, + config, + on_progress, + checkpoint_path=checkpoint_path, + resume_from_checkpoint=resume_from_checkpoint, + ) agent_ids = [a.get("_id", f"agent_{i}") for i, a in enumerate(agents)] diff --git a/extropy/scenario/compiler.py b/extropy/scenario/compiler.py index 7be6831..bb108f2 100644 --- a/extropy/scenario/compiler.py +++ b/extropy/scenario/compiler.py @@ -8,7 +8,6 @@ 5. Assemble and validate ScenarioSpec """ -import json import re from datetime import datetime from pathlib import Path @@ -26,7 +25,8 @@ from .interaction import determine_interaction_model from .outcomes import define_outcomes from ..utils.callbacks import StepProgressCallback -from .validator import validate_scenario, get_agent_count +from .validator import validate_scenario +from ..storage import open_study_db def _generate_scenario_name(description: str) -> str: @@ -57,43 +57,38 @@ def _determine_simulation_config(population_size: int) -> SimulationConfig: ) -def _load_network_summary(network_path: Path) -> dict | None: - """Load network summary for exposure generation.""" - if not network_path.exists(): - return None +def _load_network_summary(network_data: dict[str, object]) -> dict[str, object]: + """Build network summary for exposure generation from network payload.""" + edge_types = set() + node_count = 0 - try: - with open(network_path) as f: - network = json.load(f) + meta = network_data.get("meta") + if isinstance(meta, dict): + raw_count = meta.get("node_count") + if isinstance(raw_count, int): + node_count = raw_count - # Extract summary information - edge_types = set() - node_count = 0 + edges = network_data.get("edges") + if isinstance(edges, list): + for edge in edges: + if not isinstance(edge, dict): + continue + edge_type = edge.get("edge_type") or edge.get("type") + if isinstance(edge_type, str): + edge_types.add(edge_type) - if "meta" in network: - node_count = network["meta"].get("node_count", 0) - - if "edges" in network: - for edge in network["edges"]: - # Check both 'edge_type' and 'type' fields (different network formats) - if "edge_type" in edge: - edge_types.add(edge["edge_type"]) - elif "type" in edge: - edge_types.add(edge["type"]) - - return { - "node_count": node_count, - "edge_types": list(edge_types), - } - except (json.JSONDecodeError, KeyError, TypeError): - return None + return { + "node_count": node_count, + "edge_types": list(edge_types), + } def create_scenario( description: str, population_spec_path: str | Path, - agents_path: str | Path, - network_path: str | Path, + study_db_path: str | Path, + population_id: str = "default", + network_id: str = "default", output_path: str | Path | None = None, on_progress: StepProgressCallback | None = None, ) -> tuple[ScenarioSpec, ValidationResult]: @@ -114,8 +109,9 @@ def create_scenario( Args: description: Natural language scenario description population_spec_path: Path to population YAML file - agents_path: Path to agents JSON file - network_path: Path to network JSON file + study_db_path: Path to canonical study DB + population_id: Population ID in study DB + network_id: Network ID in study DB output_path: Optional path to save scenario YAML on_progress: Optional callback(step, status) for progress updates @@ -130,16 +126,16 @@ def create_scenario( >>> spec, result = create_scenario( ... "Netflix announces $3 price increase", ... "population.yaml", - ... "agents.json", - ... "network.json", + ... "study.db", + ... "default", + ... "default", ... "scenario.yaml" ... ) >>> result.valid True """ population_spec_path = Path(population_spec_path) - agents_path = Path(agents_path) - network_path = Path(network_path) + study_db_path = Path(study_db_path) def progress(step: str, status: str): if on_progress: @@ -157,7 +153,13 @@ def progress(step: str, status: str): population_spec = PopulationSpec.from_yaml(population_spec_path) # Load network summary for exposure generation - network_summary = _load_network_summary(network_path) + with open_study_db(study_db_path) as db: + network = db.get_network(network_id) + if not network.get("edges"): + raise FileNotFoundError( + f"Network '{network_id}' not found in study DB: {study_db_path}" + ) + network_summary = _load_network_summary(network) # ========================================================================= # Step 1: Parse scenario description @@ -220,8 +222,9 @@ def progress(step: str, status: str): name=scenario_name, description=description, population_spec=str(population_spec_path), - agents_file=str(agents_path), - network_file=str(network_path), + study_db=str(study_db_path), + population_id=population_id, + network_id=network_id, created_at=datetime.now(), ) @@ -240,18 +243,9 @@ def progress(step: str, status: str): # Validate # ========================================================================= - # Note: We validate agent count consistency, which requires loading the file. - # We use get_agent_count() to do this safely/robustly. - agent_count = get_agent_count(agents_path) - - # Load network for validation (needed for edge type reference validation) - network = None - if network_path.exists(): - try: - with open(network_path) as f: - network = json.load(f) - except (json.JSONDecodeError, OSError): - pass + with open_study_db(study_db_path) as db: + agent_count = db.get_agent_count(population_id) + network = db.get_network(network_id) validation_result = validate_scenario(spec, population_spec, agent_count, network) @@ -268,8 +262,9 @@ def progress(step: str, status: str): def compile_scenario_from_files( description: str, population_spec_path: str | Path, - agents_path: str | Path, - network_path: str | Path, + study_db_path: str | Path, + population_id: str = "default", + network_id: str = "default", ) -> ScenarioSpec: """ Convenience function to create a scenario spec. @@ -279,8 +274,9 @@ def compile_scenario_from_files( Args: description: Natural language scenario description population_spec_path: Path to population YAML file - agents_path: Path to agents JSON file - network_path: Path to network JSON file + study_db_path: Path to canonical study DB + population_id: Population ID in study DB + network_id: Network ID in study DB Returns: ScenarioSpec @@ -292,8 +288,9 @@ def compile_scenario_from_files( spec, result = create_scenario( description, population_spec_path, - agents_path, - network_path, + study_db_path, + population_id, + network_id, ) if not result.valid: diff --git a/extropy/scenario/validator.py b/extropy/scenario/validator.py index 13ab389..cdc8ac4 100644 --- a/extropy/scenario/validator.py +++ b/extropy/scenario/validator.py @@ -20,6 +20,7 @@ extract_names_from_expression, validate_expression_syntax, ) +from ..storage import open_study_db # Helper functions to create ValidationIssue with appropriate severity @@ -411,12 +412,10 @@ def validate_scenario( base_file = Path(spec_file) population_path = resolve_relative_to(spec.meta.population_spec, base_file) - agents_path = resolve_relative_to(spec.meta.agents_file, base_file) - network_path = resolve_relative_to(spec.meta.network_file, base_file) + study_db_path = resolve_relative_to(spec.meta.study_db, base_file) else: population_path = Path(spec.meta.population_spec) - agents_path = Path(spec.meta.agents_file) - network_path = Path(spec.meta.network_file) + study_db_path = Path(spec.meta.study_db) if not population_path.exists(): errors.append( @@ -428,22 +427,12 @@ def validate_scenario( ) ) - if not agents_path.exists(): + if not study_db_path.exists(): errors.append( ValidationError( category="file_reference", - location="meta.agents_file", - message=f"Agents file not found: {spec.meta.agents_file}", - suggestion="Check the file path", - ) - ) - - if not network_path.exists(): - errors.append( - ValidationError( - category="file_reference", - location="meta.network_file", - message=f"Network file not found: {spec.meta.network_file}", + location="meta.study_db", + message=f"Study DB not found: {spec.meta.study_db}", suggestion="Check the file path", ) ) @@ -462,6 +451,38 @@ def validate_scenario( ) ) + # Validate IDs inside study DB when available. + if study_db_path.exists(): + try: + with open_study_db(study_db_path) as db: + if db.get_agent_count(spec.meta.population_id) == 0: + errors.append( + ValidationError( + category="file_reference", + location="meta.population_id", + message=f"Population ID not found in study DB: {spec.meta.population_id}", + suggestion="Run `extropy sample ... --study-db ... --population-id ...` first", + ) + ) + if db.get_network_edge_count(spec.meta.network_id) == 0: + errors.append( + ValidationError( + category="file_reference", + location="meta.network_id", + message=f"Network ID not found in study DB: {spec.meta.network_id}", + suggestion="Run `extropy network ... --study-db ... --network-id ...` first", + ) + ) + except Exception: + errors.append( + ValidationError( + category="file_reference", + location="meta.study_db", + message=f"Failed to read study DB: {spec.meta.study_db}", + suggestion="Check that the file is a valid SQLite study DB", + ) + ) + return ValidationResult(issues=[*errors, *warnings]) @@ -541,15 +562,12 @@ def load_and_validate_scenario( except Exception: pass # Will be caught as validation error - agents_path = resolve_relative_to(spec.meta.agents_file, scenario_path) - if agents_path.exists(): - agent_count = get_agent_count(agents_path) - - network_path = resolve_relative_to(spec.meta.network_file, scenario_path) - if network_path.exists(): + study_db_path = resolve_relative_to(spec.meta.study_db, scenario_path) + if study_db_path.exists(): try: - with open(network_path) as f: - network = json.load(f) + with open_study_db(study_db_path) as db: + agent_count = db.get_agent_count(spec.meta.population_id) + network = db.get_network(spec.meta.network_id) except Exception: pass diff --git a/extropy/simulation/__init__.py b/extropy/simulation/__init__.py index 1f22a33..2d6ed67 100644 --- a/extropy/simulation/__init__.py +++ b/extropy/simulation/__init__.py @@ -25,8 +25,7 @@ Output: Results directory containing: - - simulation.db: SQLite database with all state - - timeline.jsonl: Streaming event log + - study.db: Canonical SQLite database with simulation state/checkpoints - agent_states.json: Final state per agent - by_timestep.json: Metrics over time - outcome_distributions.json: Final outcome distributions diff --git a/extropy/simulation/engine.py b/extropy/simulation/engine.py index 29f4053..5237dc6 100644 --- a/extropy/simulation/engine.py +++ b/extropy/simulation/engine.py @@ -15,7 +15,9 @@ import json import logging import random +import sqlite3 import time +import uuid from datetime import datetime from pathlib import Path from typing import Any @@ -36,15 +38,14 @@ float_to_conviction, ) from ..core.rate_limiter import DualRateLimiter -from ..population.network import load_agents_json from ..population.persona import PersonaConfig +from ..storage import open_study_db from .progress import SimulationProgress from .state import StateManager from .persona import generate_persona from .reasoning import batch_reason_agents, create_reasoning_context from .propagation import apply_seed_exposures, propagate_through_network from .stopping import evaluate_stopping_conditions -from .timeline import TimelineManager from ..utils.callbacks import TimestepProgressCallback from .aggregation import ( compute_timestep_summary, @@ -65,6 +66,22 @@ _PRIVATE_FLIP_CONVICTION = CONVICTION_MAP[ConvictionLevel.FIRM] +class _StateTimelineAdapter: + """Timeline adapter that persists events into StateManager timeline table.""" + + def __init__(self, state_manager: StateManager): + self.state_manager = state_manager + + def log_event(self, event: SimulationEvent) -> None: + self.state_manager.log_event(event) + + def flush(self) -> None: + return + + def close(self) -> None: + return + + class SimulationSummary: """Summary of a completed simulation run.""" @@ -128,6 +145,10 @@ def __init__( persona_config: PersonaConfig | None = None, rate_limiter: DualRateLimiter | None = None, chunk_size: int = 50, + state_db_path: Path | str | None = None, + run_id: str | None = None, + checkpoint_every_chunks: int = 1, + retention_lite: bool = False, ): """Initialize simulation engine. @@ -149,6 +170,9 @@ def __init__( self.persona_config = persona_config self.rate_limiter = rate_limiter self.chunk_size = chunk_size + self.run_id = run_id or f"run_{uuid.uuid4().hex[:12]}" + self.checkpoint_every_chunks = max(1, checkpoint_every_chunks) + self.retention_lite = retention_lite # Build agent map for quick lookup self.agent_map = {a.get("_id", str(i)): a for i, a in enumerate(agents)} @@ -174,13 +198,15 @@ def __init__( self.output_dir.mkdir(parents=True, exist_ok=True) # Initialize state manager + state_db_file = Path(state_db_path) if state_db_path else self.output_dir / "study.db" self.state_manager = StateManager( - self.output_dir / "simulation.db", + state_db_file, agents, ) + self.study_db = open_study_db(state_db_file) # Initialize timeline manager - self.timeline = TimelineManager(self.output_dir / "timeline.jsonl") + self.timeline = _StateTimelineAdapter(self.state_manager) # Pre-generate personas for all agents # Extract decision-relevant attributes from outcome config (trait salience) @@ -320,7 +346,7 @@ def run(self) -> SimulationSummary: """Execute the full simulation. Supports automatic resume: if the output directory contains a - simulation.db with partial progress, the engine picks up where + study.db with partial progress, the engine picks up where it left off. Returns: @@ -379,6 +405,7 @@ def run(self) -> SimulationSummary: self._export_results() finally: self.state_manager.close() + self.study_db.close() return summary @@ -552,7 +579,17 @@ def _on_agent_done(agent_id: str, result: Any) -> None: total_changes = 0 total_shares = 0 + completed_chunks = self.study_db.get_completed_simulation_chunks( + self.run_id, timestep + ) + for chunk_start in range(0, len(contexts), self.chunk_size): + chunk_index = chunk_start // self.chunk_size + if chunk_index in completed_chunks: + logger.info( + f"[TIMESTEP {timestep}] Skipping completed chunk {chunk_index}" + ) + continue chunk_contexts = contexts[chunk_start : chunk_start + self.chunk_size] reasoning_start = time.time() @@ -584,6 +621,16 @@ def _on_agent_done(agent_id: str, result: Any) -> None: reasoned, changes, shares = self._process_reasoning_chunk( timestep, chunk_results, old_states ) + if ( + ((chunk_index + 1) % self.checkpoint_every_chunks == 0) + or (chunk_start + self.chunk_size >= len(contexts)) + ): + self.study_db.save_simulation_checkpoint( + run_id=self.run_id, + timestep=timestep, + chunk_index=chunk_index, + status="done", + ) total_reasoned += reasoned total_changes += changes @@ -778,7 +825,7 @@ def _process_reasoning_chunk( private_outcomes=private_outcomes, committed=is_committed, outcomes=private_outcomes, - raw_reasoning=response.reasoning, + raw_reasoning=None if self.retention_lite else response.reasoning, updated_at=timestep, ) @@ -813,6 +860,9 @@ def _process_reasoning_chunk( "public_conviction": new_state.public_conviction, "private_conviction": new_state.private_conviction, "will_share": new_state.will_share, + "raw_reasoning": None + if self.retention_lite + else response.reasoning, }, ) ) @@ -1209,6 +1259,7 @@ def _export_results(self) -> None: def run_simulation( scenario_path: str | Path, output_dir: str | Path, + study_db_path: str | Path | None = None, model: str = "", pivotal_model: str = "", routine_model: str = "", @@ -1221,6 +1272,10 @@ def run_simulation( tpm_override: int | None = None, chunk_size: int = 50, progress: SimulationProgress | None = None, + run_id: str | None = None, + resume: bool = False, + checkpoint_every_chunks: int = 1, + retention_lite: bool = False, ) -> SimulationSummary: """Run a simulation from a scenario file. @@ -1241,12 +1296,40 @@ def run_simulation( tpm_override: Override TPM limit chunk_size: Agents per reasoning chunk for checkpointing progress: Optional SimulationProgress for live display tracking + run_id: Optional run identifier for resume and bookkeeping + resume: Resume a prior run from DB checkpoints + checkpoint_every_chunks: Mark simulation checkpoint every N chunks + retention_lite: Reduce payload volume by dropping full raw reasoning text Returns: SimulationSummary with results """ scenario_path = Path(scenario_path) output_dir = Path(output_dir) + if resume and not run_id: + raise ValueError("--resume requires --run-id") + + def _reset_runtime_tables(path: Path) -> None: + conn = sqlite3.connect(str(path)) + try: + cur = conn.cursor() + cur.executescript( + """ + DELETE FROM agent_states; + DELETE FROM exposures; + DELETE FROM memory_traces; + DELETE FROM timeline; + DELETE FROM timestep_summaries; + DELETE FROM shared_to; + DELETE FROM simulation_metadata; + """ + ) + conn.commit() + except sqlite3.OperationalError: + # First run on this DB may not have simulation tables yet. + pass + finally: + conn.close() # Load scenario scenario = ScenarioSpec.from_yaml(scenario_path) @@ -1257,18 +1340,62 @@ def run_simulation( pop_path = scenario_path.parent / pop_path population_spec = PopulationSpec.from_yaml(pop_path) - # Load agents - agents_path = Path(scenario.meta.agents_file) - if not agents_path.is_absolute(): - agents_path = scenario_path.parent / agents_path - agents = load_agents_json(agents_path) + # Resolve canonical study DB + if study_db_path is None: + if not getattr(scenario.meta, "study_db", None): + raise ValueError( + "Legacy scenario format detected. Rebuild scenario with --study-db." + ) + study_db_resolved = Path(scenario.meta.study_db) + if not study_db_resolved.is_absolute(): + study_db_resolved = scenario_path.parent / study_db_resolved + else: + study_db_resolved = Path(study_db_path) + + if not study_db_resolved.exists(): + raise FileNotFoundError(f"Study DB not found: {study_db_resolved}") - # Load network - network_path = Path(scenario.meta.network_file) - if not network_path.is_absolute(): - network_path = scenario_path.parent / network_path - with open(network_path) as f: - network = json.load(f) + resolved_run_id = ( + run_id + or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + ) + + with open_study_db(study_db_resolved) as db: + agents = db.get_agents(scenario.meta.population_id) + if not agents: + raise ValueError( + f"No agents for population_id '{scenario.meta.population_id}' in {study_db_resolved}" + ) + network = db.get_network(scenario.meta.network_id) + if not network.get("edges"): + raise ValueError( + f"No network edges for network_id '{scenario.meta.network_id}' in {study_db_resolved}" + ) + db.create_simulation_run( + run_id=resolved_run_id, + scenario_name=scenario.meta.name, + population_id=scenario.meta.population_id, + network_id=scenario.meta.network_id, + config={ + "scenario_path": str(scenario_path), + "output_dir": str(output_dir), + "model": model, + "pivotal_model": pivotal_model, + "routine_model": routine_model, + "multi_touch_threshold": multi_touch_threshold, + "chunk_size": chunk_size, + "checkpoint_every_chunks": checkpoint_every_chunks, + "retention_lite": retention_lite, + "resume": resume, + }, + seed=random_seed, + status="running", + ) + db.set_run_metadata(resolved_run_id, "output_dir", str(output_dir)) + db.set_run_metadata(resolved_run_id, "study_db", str(study_db_resolved)) + + if not resume: + _reset_runtime_tables(study_db_resolved) # Load persona config if provided persona_config = None @@ -1323,6 +1450,10 @@ def run_simulation( persona_config=persona_config, rate_limiter=rate_limiter, chunk_size=chunk_size, + state_db_path=study_db_resolved, + run_id=resolved_run_id, + checkpoint_every_chunks=checkpoint_every_chunks, + retention_lite=retention_lite, ) if on_progress: @@ -1331,4 +1462,23 @@ def run_simulation( if progress: engine.set_progress_state(progress) - return engine.run() + try: + summary = engine.run() + except Exception as e: + with open_study_db(study_db_resolved) as db: + db.update_simulation_run( + run_id=resolved_run_id, + status="failed", + stopped_reason=str(e), + ) + raise + + final_status = "stopped" if summary.stopped_reason else "completed" + with open_study_db(study_db_resolved) as db: + db.update_simulation_run( + run_id=resolved_run_id, + status=final_status, + stopped_reason=summary.stopped_reason, + ) + + return summary diff --git a/extropy/simulation/reasoning.py b/extropy/simulation/reasoning.py index bf3ad52..7c51151 100644 --- a/extropy/simulation/reasoning.py +++ b/extropy/simulation/reasoning.py @@ -776,69 +776,85 @@ def batch_reason_agents( logger.info(f"[BATCH] Starting two-pass async reasoning for {total} agents") async def run_all(): - # Always use a semaphore to cap concurrent tasks. - # When rate limiter is available, size it from max_safe_concurrent. if rate_limiter: - concurrency = rate_limiter.max_safe_concurrent - # Stagger interval: spread launches across the RPM window - # e.g. 500 RPM → 8.3 req/s → 120ms between launches + target_concurrency = max(1, rate_limiter.max_safe_concurrent) stagger_interval = 60.0 / rate_limiter.pivotal.rpm logger.info( - f"[BATCH] Concurrency cap: {concurrency}, " + f"[BATCH] Concurrency cap: {target_concurrency}, " f"stagger: {stagger_interval * 1000:.0f}ms between launches" ) else: - concurrency = max_concurrency + target_concurrency = max(1, max_concurrency) stagger_interval = 0.0 - semaphore = asyncio.Semaphore(concurrency) completed = [0] + adaptive_concurrency = target_concurrency async def reason_with_pacing( + idx: int, ctx: ReasoningContext, - ) -> tuple[str, ReasoningResponse | None]: - async with semaphore: - start = time.time() - result = await _reason_agent_two_pass_async( - ctx, scenario, config, rate_limiter - ) - elapsed = time.time() - start - completed[0] += 1 - - if result: - logger.info( - f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} done in {elapsed:.2f}s " - f"(position={result.position}, sentiment={result.sentiment}, " - f"conviction={float_to_conviction(result.conviction)})" - ) - else: - logger.warning( - f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} FAILED" - ) - - if on_agent_done: - on_agent_done(ctx.agent_id, result) + ) -> tuple[int, str, ReasoningResponse | None, float]: + start = time.time() + result = await _reason_agent_two_pass_async(ctx, scenario, config, rate_limiter) + elapsed = time.time() - start + completed[0] += 1 - return (ctx.agent_id, result) - - # Stagger task launches to avoid burst of requests hitting API at once. - # Each task is created with a small delay so they don't all enter - # the semaphore simultaneously. - tasks = [] - for i, ctx in enumerate(contexts): - tasks.append(asyncio.create_task(reason_with_pacing(ctx))) - if stagger_interval > 0 and i < concurrency - 1: - # Only stagger the first batch — after that the semaphore - # naturally gates as tasks complete and new ones enter - await asyncio.sleep(stagger_interval) + if result: + logger.info( + f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} done in {elapsed:.2f}s " + f"(position={result.position}, sentiment={result.sentiment}, " + f"conviction={float_to_conviction(result.conviction)})" + ) + else: + logger.warning(f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} FAILED") + + if on_agent_done: + on_agent_done(ctx.agent_id, result) + + return (idx, ctx.agent_id, result, elapsed) + + results: list[tuple[str, ReasoningResponse | None] | None] = [None] * total + next_idx = 0 + while next_idx < total: + batch_end = min(total, next_idx + adaptive_concurrency) + batch_contexts = contexts[next_idx:batch_end] + tasks = [] + for local_offset, ctx in enumerate(batch_contexts): + idx = next_idx + local_offset + tasks.append(asyncio.create_task(reason_with_pacing(idx, ctx))) + if stagger_interval > 0 and local_offset < len(batch_contexts) - 1: + await asyncio.sleep(stagger_interval) + + batch_results = await asyncio.gather(*tasks) + latencies: list[float] = [] + failures = 0 + for idx, agent_id, result, elapsed in batch_results: + results[idx] = (agent_id, result) + latencies.append(elapsed) + if result is None: + failures += 1 + + # Adaptive concurrency control: + # - high error rate or high latency => downshift + # - clean/fast batches => cautiously upshift + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + fail_rate = failures / len(batch_results) if batch_results else 0.0 + if fail_rate >= 0.2 or avg_latency >= 20.0: + adaptive_concurrency = max(1, int(adaptive_concurrency * 0.7)) + elif fail_rate == 0 and avg_latency <= 8.0: + adaptive_concurrency = min(target_concurrency, adaptive_concurrency + 1) - results = await asyncio.gather(*tasks) + logger.info( + f"[BATCH] Adaptive concurrency={adaptive_concurrency} " + f"(avg_latency={avg_latency:.2f}s, fail_rate={fail_rate:.0%})" + ) + next_idx = batch_end # Close the async HTTP client before the event loop shuts down. # Without this, orphaned httpx connections produce "Event loop is # closed" errors during garbage collection. await close_simulation_provider() - return results + return [r for r in results if r is not None] batch_start = time.time() results = asyncio.run(run_all()) diff --git a/extropy/storage/__init__.py b/extropy/storage/__init__.py new file mode 100644 index 0000000..a1b35ab --- /dev/null +++ b/extropy/storage/__init__.py @@ -0,0 +1,18 @@ +"""Storage layer for canonical study database.""" + +from .study_db import StudyDB, open_study_db +from .schemas import ( + AgentDBRecord, + NetworkEdgeDBRecord, + ChatMessagePayload, + ReadOnlySQLRequest, +) + +__all__ = [ + "StudyDB", + "open_study_db", + "AgentDBRecord", + "NetworkEdgeDBRecord", + "ChatMessagePayload", + "ReadOnlySQLRequest", +] diff --git a/extropy/storage/schemas.py b/extropy/storage/schemas.py new file mode 100644 index 0000000..198e122 --- /dev/null +++ b/extropy/storage/schemas.py @@ -0,0 +1,46 @@ +"""Pydantic schemas for canonical study DB payloads.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field, ConfigDict + + +class AgentDBRecord(BaseModel): + """Validated representation of an agent row in the study DB.""" + + population_id: str + agent_id: str + attrs_json: dict[str, Any] + sample_run_id: str + + +class NetworkEdgeDBRecord(BaseModel): + """Validated representation of a network edge row in the study DB.""" + + network_id: str + source_id: str + target_id: str + weight: float + edge_type: str + influence_st: float | None = None + influence_ts: float | None = None + + +class ChatMessagePayload(BaseModel): + """Validated chat message payload persisted in chat_messages.""" + + role: str = Field(min_length=1) + content: str = Field(min_length=1) + citations: dict[str, Any] = Field(default_factory=dict) + token_usage: dict[str, Any] = Field(default_factory=dict) + + +class ReadOnlySQLRequest(BaseModel): + """Read-only SQL request contract for query CLI.""" + + model_config = ConfigDict(str_strip_whitespace=True) + + sql: str = Field(min_length=1) + limit: int = Field(default=1000, ge=1) diff --git a/extropy/storage/study_db.py b/extropy/storage/study_db.py new file mode 100644 index 0000000..1005b98 --- /dev/null +++ b/extropy/storage/study_db.py @@ -0,0 +1,898 @@ +"""Canonical study database storage for Extropy. + +This module provides the schema and helper operations for ``study.db``. +""" + +from __future__ import annotations + +import json +import sqlite3 +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any + +from .schemas import AgentDBRecord, NetworkEdgeDBRecord, ChatMessagePayload + + +def _now_iso() -> str: + return datetime.now().isoformat() + + +def _dumps(data: Any) -> str: + return json.dumps(data, default=str) + + +class StudyDB: + """SQLite-backed canonical study store.""" + + def __init__(self, path: Path | str): + self.path = Path(path) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.conn = sqlite3.connect(str(self.path)) + self.conn.row_factory = sqlite3.Row + self._set_pragmas() + self.init_schema() + + def _set_pragmas(self) -> None: + cursor = self.conn.cursor() + cursor.execute("PRAGMA foreign_keys = ON") + cursor.execute("PRAGMA journal_mode = WAL") + cursor.execute("PRAGMA synchronous = NORMAL") + cursor.execute("PRAGMA temp_store = MEMORY") + self.conn.commit() + + def init_schema(self) -> None: + """Create canonical schema and indexes.""" + cursor = self.conn.cursor() + cursor.executescript( + """ + CREATE TABLE IF NOT EXISTS study_meta ( + key TEXT PRIMARY KEY, + value TEXT + ); + + CREATE TABLE IF NOT EXISTS population_specs ( + population_id TEXT PRIMARY KEY, + spec_yaml TEXT NOT NULL, + source_path TEXT, + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS sample_runs ( + sample_run_id TEXT PRIMARY KEY, + population_id TEXT NOT NULL, + seed INTEGER, + count INTEGER, + created_at TEXT NOT NULL, + meta_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS agents ( + population_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + attrs_json TEXT NOT NULL, + sample_run_id TEXT NOT NULL, + PRIMARY KEY (population_id, agent_id) + ); + + CREATE TABLE IF NOT EXISTS network_runs ( + network_run_id TEXT PRIMARY KEY, + population_id TEXT NOT NULL, + network_id TEXT NOT NULL, + config_json TEXT NOT NULL, + seed INTEGER, + candidate_mode TEXT NOT NULL, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + completed_at TEXT, + meta_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS network_edges ( + network_id TEXT NOT NULL, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + weight REAL NOT NULL, + edge_type TEXT NOT NULL, + influence_st REAL, + influence_ts REAL, + PRIMARY KEY (network_id, source_id, target_id) + ); + + CREATE TABLE IF NOT EXISTS network_metrics ( + network_id TEXT PRIMARY KEY, + metrics_json TEXT NOT NULL, + computed_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS network_similarity_jobs ( + job_id TEXT PRIMARY KEY, + network_run_id TEXT NOT NULL, + signature_json TEXT NOT NULL, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS network_similarity_chunks ( + job_id TEXT NOT NULL, + chunk_start INTEGER NOT NULL, + chunk_end INTEGER NOT NULL, + status TEXT NOT NULL, + pair_count INTEGER NOT NULL DEFAULT 0, + updated_at TEXT NOT NULL, + PRIMARY KEY (job_id, chunk_start) + ); + + CREATE TABLE IF NOT EXISTS network_similarity_pairs ( + job_id TEXT NOT NULL, + i INTEGER NOT NULL, + j INTEGER NOT NULL, + sim REAL NOT NULL, + PRIMARY KEY (job_id, i, j) + ) WITHOUT ROWID; + + CREATE TABLE IF NOT EXISTS network_similarity_snapshots ( + job_id TEXT PRIMARY KEY, + payload BLOB NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS simulation_runs ( + run_id TEXT PRIMARY KEY, + scenario_name TEXT, + population_id TEXT NOT NULL, + network_id TEXT NOT NULL, + config_json TEXT NOT NULL, + seed INTEGER, + status TEXT NOT NULL, + started_at TEXT NOT NULL, + completed_at TEXT, + stopped_reason TEXT + ); + + CREATE TABLE IF NOT EXISTS agent_states ( + agent_id TEXT PRIMARY KEY, + aware INTEGER DEFAULT 0, + exposure_count INTEGER DEFAULT 0, + last_reasoning_timestep INTEGER DEFAULT -1, + position TEXT, + sentiment REAL, + conviction REAL, + public_statement TEXT, + action_intent TEXT, + will_share INTEGER DEFAULT 0, + outcomes_json TEXT, + public_position TEXT, + public_sentiment REAL, + public_conviction REAL, + private_position TEXT, + private_sentiment REAL, + private_conviction REAL, + private_outcomes_json TEXT, + raw_reasoning TEXT, + committed INTEGER DEFAULT 0, + network_hop_depth INTEGER, + updated_at INTEGER DEFAULT 0 + ); + + CREATE TABLE IF NOT EXISTS exposures ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT, + timestep INTEGER, + channel TEXT, + source_agent_id TEXT, + content TEXT, + credibility REAL, + FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + ); + + CREATE TABLE IF NOT EXISTS memory_traces ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT, + timestep INTEGER, + sentiment REAL, + conviction REAL, + summary TEXT, + FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + ); + + CREATE TABLE IF NOT EXISTS timeline ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestep INTEGER, + event_type TEXT, + agent_id TEXT, + details_json TEXT, + wall_timestamp TEXT + ); + + CREATE TABLE IF NOT EXISTS timestep_summaries ( + timestep INTEGER PRIMARY KEY, + new_exposures INTEGER, + agents_reasoned INTEGER, + shares_occurred INTEGER, + state_changes INTEGER, + exposure_rate REAL, + position_distribution_json TEXT, + average_sentiment REAL, + average_conviction REAL, + sentiment_variance REAL + ); + + CREATE TABLE IF NOT EXISTS shared_to ( + source_agent_id TEXT, + target_agent_id TEXT, + timestep INTEGER, + position TEXT, + PRIMARY KEY (source_agent_id, target_agent_id) + ); + + CREATE TABLE IF NOT EXISTS simulation_metadata ( + key TEXT PRIMARY KEY, + value TEXT + ); + + CREATE TABLE IF NOT EXISTS run_metadata ( + run_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + PRIMARY KEY (run_id, key) + ); + + CREATE TABLE IF NOT EXISTS simulation_checkpoints ( + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, + chunk_index INTEGER NOT NULL, + status TEXT NOT NULL, + updated_at TEXT NOT NULL, + PRIMARY KEY (run_id, timestep, chunk_index) + ); + + CREATE TABLE IF NOT EXISTS chat_sessions ( + session_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + mode TEXT NOT NULL, + created_at TEXT NOT NULL, + closed_at TEXT, + meta_json TEXT + ); + + CREATE TABLE IF NOT EXISTS chat_messages ( + session_id TEXT NOT NULL, + turn_index INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + citations_json TEXT, + token_usage_json TEXT, + created_at TEXT NOT NULL, + PRIMARY KEY (session_id, turn_index) + ); + + CREATE TABLE IF NOT EXISTS chat_artifacts ( + session_id TEXT NOT NULL, + key TEXT NOT NULL, + value_json TEXT NOT NULL, + PRIMARY KEY (session_id, key) + ); + + CREATE INDEX IF NOT EXISTS idx_agents_population ON agents(population_id); + CREATE INDEX IF NOT EXISTS idx_network_edges_src ON network_edges(network_id, source_id); + CREATE INDEX IF NOT EXISTS idx_network_edges_tgt ON network_edges(network_id, target_id); + CREATE INDEX IF NOT EXISTS idx_net_sim_chunks_status ON network_similarity_chunks(job_id, status); + CREATE INDEX IF NOT EXISTS idx_sim_ckpt ON simulation_checkpoints(run_id, timestep, chunk_index); + CREATE INDEX IF NOT EXISTS idx_chat_session_agent ON chat_sessions(run_id, agent_id); + CREATE INDEX IF NOT EXISTS idx_agent_states_aware ON agent_states(aware); + CREATE INDEX IF NOT EXISTS idx_agent_states_will_share ON agent_states(will_share); + CREATE INDEX IF NOT EXISTS idx_agent_states_last_reasoning ON agent_states(last_reasoning_timestep); + CREATE INDEX IF NOT EXISTS idx_exposures_agent_timestep ON exposures(agent_id, timestep); + CREATE INDEX IF NOT EXISTS idx_timeline_timestep ON timeline(timestep); + CREATE INDEX IF NOT EXISTS idx_shared_to_source ON shared_to(source_agent_id); + """ + ) + self.conn.commit() + + def close(self) -> None: + self.conn.close() + + def __enter__(self) -> "StudyDB": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.close() + return False + + def save_population_spec( + self, + population_id: str, + spec_yaml: str, + source_path: str | None, + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO population_specs (population_id, spec_yaml, source_path, created_at) + VALUES (?, ?, ?, ?) + """, + (population_id, spec_yaml, source_path, _now_iso()), + ) + self.conn.commit() + + def get_population_spec_yaml(self, population_id: str) -> str | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT spec_yaml FROM population_specs WHERE population_id = ?", + (population_id,), + ) + row = cursor.fetchone() + return str(row["spec_yaml"]) if row else None + + def save_sample_result( + self, + population_id: str, + agents: list[dict[str, Any]], + meta: dict[str, Any], + seed: int | None = None, + sample_run_id: str | None = None, + ) -> str: + run_id = sample_run_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + + cursor.execute( + """ + INSERT OR REPLACE INTO sample_runs + (sample_run_id, population_id, seed, count, created_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?) + """, + (run_id, population_id, seed, len(agents), _now_iso(), _dumps(meta)), + ) + + cursor.execute("DELETE FROM agents WHERE population_id = ?", (population_id,)) + + rows = [] + for i, agent in enumerate(agents): + agent_id = str(agent.get("_id", f"agent_{i}")) + row_agent = dict(agent) + row_agent["_id"] = agent_id + rec = AgentDBRecord( + population_id=population_id, + agent_id=agent_id, + attrs_json=row_agent, + sample_run_id=run_id, + ) + rows.append( + ( + rec.population_id, + rec.agent_id, + _dumps(rec.attrs_json), + rec.sample_run_id, + ) + ) + + cursor.executemany( + """ + INSERT INTO agents (population_id, agent_id, attrs_json, sample_run_id) + VALUES (?, ?, ?, ?) + """, + rows, + ) + self.conn.commit() + return run_id + + def get_agents(self, population_id: str) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT attrs_json + FROM agents + WHERE population_id = ? + ORDER BY agent_id + """, + (population_id,), + ) + agents = [] + for row in cursor.fetchall(): + try: + agents.append(json.loads(row["attrs_json"])) + except json.JSONDecodeError: + continue + return agents + + def get_agent_count(self, population_id: str) -> int: + cursor = self.conn.cursor() + cursor.execute( + "SELECT COUNT(*) AS cnt FROM agents WHERE population_id = ?", + (population_id,), + ) + row = cursor.fetchone() + return int(row["cnt"]) if row else 0 + + def save_network_result( + self, + population_id: str, + network_id: str, + config: dict[str, Any], + result_meta: dict[str, Any], + edges: list[dict[str, Any]], + seed: int | None, + candidate_mode: str, + network_metrics: dict[str, Any] | None = None, + network_run_id: str | None = None, + ) -> str: + run_id = network_run_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + now = _now_iso() + + cursor.execute( + """ + INSERT OR REPLACE INTO network_runs + (network_run_id, population_id, network_id, config_json, seed, candidate_mode, + status, created_at, completed_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + population_id, + network_id, + _dumps(config), + seed, + candidate_mode, + "completed", + now, + now, + _dumps(result_meta), + ), + ) + + cursor.execute("DELETE FROM network_edges WHERE network_id = ?", (network_id,)) + + rows = [] + for edge in edges: + infl = edge.get("influence_weight") or {} + rec = NetworkEdgeDBRecord( + network_id=network_id, + source_id=str(edge.get("source", "")), + target_id=str(edge.get("target", "")), + weight=float(edge.get("weight", 0.0)), + edge_type=str(edge.get("type", edge.get("edge_type", "unknown"))), + influence_st=float(infl.get("source_to_target", edge.get("weight", 0.0))), + influence_ts=float(infl.get("target_to_source", edge.get("weight", 0.0))), + ) + rows.append( + ( + rec.network_id, + rec.source_id, + rec.target_id, + rec.weight, + rec.edge_type, + rec.influence_st, + rec.influence_ts, + ) + ) + + cursor.executemany( + """ + INSERT INTO network_edges + (network_id, source_id, target_id, weight, edge_type, influence_st, influence_ts) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + + if network_metrics is not None: + cursor.execute( + """ + INSERT OR REPLACE INTO network_metrics (network_id, metrics_json, computed_at) + VALUES (?, ?, ?) + """, + (network_id, _dumps(network_metrics), now), + ) + + self.conn.commit() + return run_id + + def get_network(self, network_id: str) -> dict[str, Any]: + cursor = self.conn.cursor() + + cursor.execute( + "SELECT meta_json FROM network_runs WHERE network_id = ? ORDER BY completed_at DESC LIMIT 1", + (network_id,), + ) + run_row = cursor.fetchone() + meta = {} + if run_row: + try: + meta = json.loads(run_row["meta_json"]) + except json.JSONDecodeError: + meta = {} + + cursor.execute( + """ + SELECT source_id, target_id, weight, edge_type, influence_st, influence_ts + FROM network_edges + WHERE network_id = ? + ORDER BY source_id, target_id + """, + (network_id,), + ) + edges = [] + for row in cursor.fetchall(): + edges.append( + { + "source": row["source_id"], + "target": row["target_id"], + "weight": row["weight"], + "type": row["edge_type"], + "bidirectional": True, + "influence_weight": { + "source_to_target": row["influence_st"], + "target_to_source": row["influence_ts"], + }, + } + ) + + return {"meta": meta, "edges": edges} + + def get_network_edge_count(self, network_id: str) -> int: + cursor = self.conn.cursor() + cursor.execute( + "SELECT COUNT(*) AS cnt FROM network_edges WHERE network_id = ?", + (network_id,), + ) + row = cursor.fetchone() + return int(row["cnt"]) if row else 0 + + def init_network_similarity_job( + self, + network_run_id: str, + signature: dict[str, Any], + job_id: str | None = None, + ) -> str: + job = job_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + now = _now_iso() + cursor.execute( + """ + INSERT OR REPLACE INTO network_similarity_jobs + (job_id, network_run_id, signature_json, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + (job, network_run_id, _dumps(signature), "running", now, now), + ) + self.conn.commit() + return job + + def get_network_similarity_job_signature(self, job_id: str) -> dict[str, Any] | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT signature_json FROM network_similarity_jobs WHERE job_id = ?", + (job_id,), + ) + row = cursor.fetchone() + if not row: + return None + try: + return json.loads(row["signature_json"]) + except json.JSONDecodeError: + return None + + def get_completed_similarity_chunks(self, job_id: str) -> set[int]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT chunk_start + FROM network_similarity_chunks + WHERE job_id = ? AND status = 'done' + """, + (job_id,), + ) + return {int(row["chunk_start"]) for row in cursor.fetchall()} + + def list_completed_similarity_chunks(self, job_id: str) -> list[tuple[int, int]]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT chunk_start, chunk_end + FROM network_similarity_chunks + WHERE job_id = ? AND status = 'done' + ORDER BY chunk_start + """, + (job_id,), + ) + return [ + (int(row["chunk_start"]), int(row["chunk_end"])) for row in cursor.fetchall() + ] + + def save_similarity_chunk_rows( + self, + job_id: str, + chunk_start: int, + chunk_end: int, + rows: list[tuple[int, int, float]], + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO network_similarity_chunks + (job_id, chunk_start, chunk_end, status, pair_count, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + (job_id, chunk_start, chunk_end, "running", len(rows), _now_iso()), + ) + if rows: + cursor.executemany( + """ + INSERT OR REPLACE INTO network_similarity_pairs (job_id, i, j, sim) + VALUES (?, ?, ?, ?) + """, + [(job_id, i, j, sim) for i, j, sim in rows], + ) + cursor.execute( + """ + UPDATE network_similarity_chunks + SET status = 'done', updated_at = ? + WHERE job_id = ? AND chunk_start = ? + """, + (_now_iso(), job_id, chunk_start), + ) + cursor.execute( + """ + UPDATE network_similarity_jobs + SET updated_at = ? + WHERE job_id = ? + """, + (_now_iso(), job_id), + ) + self.conn.commit() + + def mark_similarity_job_running(self, job_id: str) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + UPDATE network_similarity_jobs + SET status = 'running', updated_at = ? + WHERE job_id = ? + """, + (_now_iso(), job_id), + ) + self.conn.commit() + + def load_similarity_pairs(self, job_id: str) -> dict[tuple[int, int], float]: + cursor = self.conn.cursor() + cursor.execute( + "SELECT i, j, sim FROM network_similarity_pairs WHERE job_id = ?", + (job_id,), + ) + return {(int(row["i"]), int(row["j"])): float(row["sim"]) for row in cursor.fetchall()} + + def mark_similarity_job_complete(self, job_id: str, drop_pairs: bool = False) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + UPDATE network_similarity_jobs + SET status = 'completed', updated_at = ? + WHERE job_id = ? + """, + (_now_iso(), job_id), + ) + if drop_pairs: + cursor.execute("DELETE FROM network_similarity_pairs WHERE job_id = ?", (job_id,)) + self.conn.commit() + + def save_similarity_snapshot(self, job_id: str, payload: bytes) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO network_similarity_snapshots (job_id, payload, updated_at) + VALUES (?, ?, ?) + """, + (job_id, payload, _now_iso()), + ) + self.conn.commit() + + def get_similarity_snapshot(self, job_id: str) -> bytes | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT payload FROM network_similarity_snapshots WHERE job_id = ?", + (job_id,), + ) + row = cursor.fetchone() + return bytes(row["payload"]) if row else None + + def create_simulation_run( + self, + run_id: str, + scenario_name: str, + population_id: str, + network_id: str, + config: dict[str, Any], + seed: int | None, + status: str = "running", + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO simulation_runs + (run_id, scenario_name, population_id, network_id, config_json, seed, status, started_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (run_id, scenario_name, population_id, network_id, _dumps(config), seed, status, _now_iso()), + ) + self.conn.commit() + + def update_simulation_run( + self, + run_id: str, + status: str, + stopped_reason: str | None = None, + ) -> None: + cursor = self.conn.cursor() + completed_at = _now_iso() if status in {"completed", "failed", "stopped"} else None + cursor.execute( + """ + UPDATE simulation_runs + SET status = ?, stopped_reason = ?, completed_at = COALESCE(?, completed_at) + WHERE run_id = ? + """, + (status, stopped_reason, completed_at, run_id), + ) + self.conn.commit() + + def set_run_metadata(self, run_id: str, key: str, value: str) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO run_metadata (run_id, key, value) + VALUES (?, ?, ?) + """, + (run_id, key, value), + ) + self.conn.commit() + + def get_run_metadata(self, run_id: str, key: str) -> str | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT value FROM run_metadata WHERE run_id = ? AND key = ?", + (run_id, key), + ) + row = cursor.fetchone() + return str(row["value"]) if row else None + + def save_simulation_checkpoint( + self, + run_id: str, + timestep: int, + chunk_index: int, + status: str, + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO simulation_checkpoints + (run_id, timestep, chunk_index, status, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + (run_id, timestep, chunk_index, status, _now_iso()), + ) + self.conn.commit() + + def get_completed_simulation_chunks(self, run_id: str, timestep: int) -> set[int]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT chunk_index + FROM simulation_checkpoints + WHERE run_id = ? AND timestep = ? AND status = 'done' + """, + (run_id, timestep), + ) + return {int(row["chunk_index"]) for row in cursor.fetchall()} + + def create_chat_session( + self, + run_id: str, + agent_id: str, + mode: str, + meta: dict[str, Any] | None = None, + session_id: str | None = None, + ) -> str: + sid = session_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO chat_sessions + (session_id, run_id, agent_id, mode, created_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?) + """, + (sid, run_id, agent_id, mode, _now_iso(), _dumps(meta or {})), + ) + self.conn.commit() + return sid + + def append_chat_message( + self, + session_id: str, + role: str, + content: str, + citations: dict[str, Any] | None = None, + token_usage: dict[str, Any] | None = None, + ) -> int: + payload = ChatMessagePayload( + role=role, + content=content, + citations=citations or {}, + token_usage=token_usage or {}, + ) + + cursor = self.conn.cursor() + cursor.execute( + "SELECT COALESCE(MAX(turn_index), -1) AS max_turn FROM chat_messages WHERE session_id = ?", + (session_id,), + ) + turn = int(cursor.fetchone()["max_turn"]) + 1 + cursor.execute( + """ + INSERT INTO chat_messages + (session_id, turn_index, role, content, citations_json, token_usage_json, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + session_id, + turn, + payload.role, + payload.content, + _dumps(payload.citations), + _dumps(payload.token_usage), + _now_iso(), + ), + ) + self.conn.commit() + return turn + + def get_chat_messages(self, session_id: str) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT turn_index, role, content, citations_json, token_usage_json, created_at + FROM chat_messages + WHERE session_id = ? + ORDER BY turn_index + """, + (session_id,), + ) + out: list[dict[str, Any]] = [] + for row in cursor.fetchall(): + out.append( + { + "turn_index": int(row["turn_index"]), + "role": row["role"], + "content": row["content"], + "citations": json.loads(row["citations_json"] or "{}"), + "token_usage": json.loads(row["token_usage_json"] or "{}"), + "created_at": row["created_at"], + } + ) + return out + + def run_select( + self, + query: str, + params: tuple[Any, ...] = (), + limit: int | None = None, + ) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + sql = query.strip().rstrip(";") + if limit is not None and " limit " not in sql.lower(): + sql = f"{sql} LIMIT {int(limit)}" + cursor.execute(sql, params) + cols = [d[0] for d in cursor.description or []] + rows = [] + for row in cursor.fetchall(): + rows.append({k: row[idx] for idx, k in enumerate(cols)}) + return rows + + +def open_study_db(path: Path | str) -> StudyDB: + """Open ``study.db`` and ensure schema exists.""" + return StudyDB(path) diff --git a/extropy/utils/__init__.py b/extropy/utils/__init__.py index 347c833..6ddb4a0 100644 --- a/extropy/utils/__init__.py +++ b/extropy/utils/__init__.py @@ -40,6 +40,10 @@ resolve_relative_to, make_relative_to, ) +from .resource_governor import ( + ResourceGovernor, + ResourceSnapshot, +) __all__ = [ # Graphs @@ -69,4 +73,7 @@ # Paths "resolve_relative_to", "make_relative_to", + # Resource governor + "ResourceGovernor", + "ResourceSnapshot", ] diff --git a/extropy/utils/resource_governor.py b/extropy/utils/resource_governor.py new file mode 100644 index 0000000..33b81a2 --- /dev/null +++ b/extropy/utils/resource_governor.py @@ -0,0 +1,102 @@ +"""Resource auto-tuning helpers for CPU/memory constrained environments.""" + +from __future__ import annotations + +import os +import platform +import subprocess +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ResourceSnapshot: + cpu_count: int + total_memory_gb: float + memory_budget_gb: float + + +class ResourceGovernor: + """Computes safe worker/chunk recommendations from local machine resources.""" + + def __init__( + self, + resource_mode: str = "auto", + safe_auto_workers: bool = True, + max_memory_gb: float | None = None, + ): + self.resource_mode = resource_mode + self.safe_auto_workers = safe_auto_workers + self.max_memory_gb = max_memory_gb + + @staticmethod + def _detect_total_memory_gb() -> float: + # Linux and many Unix systems + try: + page_size = os.sysconf("SC_PAGE_SIZE") + phys_pages = os.sysconf("SC_PHYS_PAGES") + if page_size > 0 and phys_pages > 0: + return (page_size * phys_pages) / (1024**3) + except (ValueError, OSError, AttributeError): + pass + + # macOS fallback + if platform.system().lower() == "darwin": + try: + out = subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True) + return int(out.strip()) / (1024**3) + except Exception: + pass + + # Conservative fallback + return 8.0 + + def snapshot(self) -> ResourceSnapshot: + cpu_count = max(1, os.cpu_count() or 1) + total_mem = self._detect_total_memory_gb() + capped = min(total_mem, self.max_memory_gb) if self.max_memory_gb else total_mem + budget = max(1.0, capped * 0.80) + return ResourceSnapshot( + cpu_count=cpu_count, + total_memory_gb=round(total_mem, 2), + memory_budget_gb=round(budget, 2), + ) + + def recommend_workers( + self, + requested_workers: int, + memory_per_worker_gb: float, + ) -> int: + requested_workers = max(1, int(requested_workers)) + if self.resource_mode != "auto": + return requested_workers + + snap = self.snapshot() + cpu_cap = max(1, snap.cpu_count - 1) if self.safe_auto_workers else snap.cpu_count + mem_cap = max(1, int(snap.memory_budget_gb / max(0.1, memory_per_worker_gb))) + + if self.safe_auto_workers: + cpu_cap = min(cpu_cap, 8) + + return max(1, min(requested_workers, cpu_cap, mem_cap)) + + def recommend_chunk_size( + self, + requested_chunk_size: int, + min_chunk_size: int = 8, + max_chunk_size: int = 4096, + ) -> int: + requested_chunk_size = max(min_chunk_size, int(requested_chunk_size)) + if self.resource_mode != "auto": + return min(max_chunk_size, requested_chunk_size) + + snap = self.snapshot() + if snap.memory_budget_gb <= 4: + tuned = min(requested_chunk_size, 32) + elif snap.memory_budget_gb <= 8: + tuned = min(requested_chunk_size, 64) + elif snap.memory_budget_gb <= 16: + tuned = min(requested_chunk_size, 128) + else: + tuned = requested_chunk_size + + return max(min_chunk_size, min(max_chunk_size, tuned)) diff --git a/tests/test_cli.py b/tests/test_cli.py index d38079a..3c624af 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,11 +1,12 @@ """CLI smoke tests using typer's CliRunner.""" - from pathlib import Path from typer.testing import CliRunner from extropy.cli.app import app from extropy.cli.commands.validate import _is_scenario_file +from extropy.population.network.config import NetworkConfig +from extropy.storage import open_study_db runner = CliRunner() @@ -59,3 +60,71 @@ def test_version_output(self): result = runner.invoke(app, ["--version"]) assert result.exit_code == 0 assert "extropy" in result.output + + +class TestNetworkCommand: + """Smoke tests for the network command options.""" + + def test_network_command_supports_fast_mode_and_checkpoint(self, tmp_path): + study_db = tmp_path / "study.db" + config_path = tmp_path / "network-config.yaml" + output_path = tmp_path / "network.json" + checkpoint_path = tmp_path / "network-checkpoint.pkl" + + agents = [ + {"_id": "a0", "role": "x", "team": "alpha"}, + {"_id": "a1", "role": "x", "team": "alpha"}, + {"_id": "a2", "role": "y", "team": "beta"}, + {"_id": "a3", "role": "y", "team": "beta"}, + ] + with open_study_db(study_db) as db: + db.save_sample_result(population_id="default", agents=agents, meta={"source": "test"}) + + NetworkConfig(seed=42, avg_degree=2.0).to_yaml(config_path) + + result = runner.invoke( + app, + [ + "network", + "--study-db", + str(study_db), + "-o", + str(output_path), + "-c", + str(config_path), + "--no-metrics", + "--candidate-mode", + "blocked", + "--candidate-pool-multiplier", + "4.0", + "--block-attr", + "role", + "--similarity-workers", + "1", + "--similarity-chunk-size", + "8", + "--checkpoint", + str(checkpoint_path), + "--checkpoint-every", + "1", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + assert checkpoint_path.exists() + + def test_network_resume_requires_checkpoint(self): + result = runner.invoke( + app, + [ + "network", + "--study-db", + "study.db", + "-o", + "network.json", + "--resume-checkpoint", + ], + ) + assert result.exit_code == 1 + assert "Study DB not found" in result.output diff --git a/tests/test_engine.py b/tests/test_engine.py index 2375e42..3b427cc 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -49,8 +49,9 @@ def minimal_scenario(): name="test_scenario", description="Test scenario", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_estimator.py b/tests/test_estimator.py index e4c114e..9bfe443 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -140,8 +140,9 @@ def small_scenario() -> ScenarioSpec: name="test_scenario", description="Test scenario for estimation", population_spec="pop.yaml", - agents_file="agents.json", - network_file="network.json", + study_db="study.db", + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, diff --git a/tests/test_integration_timestep.py b/tests/test_integration_timestep.py index c88874f..fe21005 100644 --- a/tests/test_integration_timestep.py +++ b/tests/test_integration_timestep.py @@ -77,8 +77,9 @@ def _make_scenario( name="test_scenario", description="Test scenario", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_network.py b/tests/test_network.py index 7564cff..eb1effe 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -158,8 +158,14 @@ def test_default_config(self): assert config.avg_degree == 20.0 assert config.rewire_prob == 0.05 + assert config.similarity_store_threshold == 0.05 assert config.similarity_threshold == 0.3 assert config.similarity_steepness == 10.0 + assert config.candidate_mode == "exact" + assert config.candidate_pool_multiplier == 12.0 + assert config.min_candidate_pool == 80 + assert config.similarity_workers == 1 + assert config.checkpoint_every_rows == 250 assert config.seed is None def test_custom_config(self): @@ -704,6 +710,72 @@ def on_progress(stage, current, total): stages = set(call[0] for call in progress_calls) assert "Computing similarities" in stages + def test_generate_network_blocked_mode_reproducibility(self, sample_agents): + """Blocked candidate mode should remain deterministic with fixed seed.""" + config = REFERENCE_NETWORK_CONFIG.model_copy( + update={ + "seed": 42, + "candidate_mode": "blocked", + "candidate_pool_multiplier": 8.0, + "blocking_attributes": ["employer_type", "federal_state"], + } + ) + + result1 = generate_network(sample_agents, config) + result2 = generate_network(sample_agents, config) + + edges1 = {(e.source, e.target) for e in result1.edges} + edges2 = {(e.source, e.target) for e in result2.edges} + + assert result1.meta["candidate_mode"] == "blocked" + assert result2.meta["candidate_mode"] == "blocked" + assert edges1 == edges2 + + def test_generate_network_resume_from_checkpoint_matches_fresh(self, sample_agents): + """Resuming from a saved similarity checkpoint should match a fresh run.""" + import pickle + + config = REFERENCE_NETWORK_CONFIG.model_copy(update={"seed": 42}) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "network-similarity.pkl" + + # Build and persist checkpoint from a full run. + result_checkpointed = generate_network( + sample_agents, + config, + checkpoint_path=checkpoint_path, + ) + assert checkpoint_path.exists() + + # Simulate interruption by truncating completed_rows in checkpoint metadata. + with open(checkpoint_path, "rb") as f: + payload = pickle.load(f) + completed_rows = max(1, len(sample_agents) // 2) + payload["completed_rows"] = completed_rows + payload["similarities"] = { + pair: sim + for pair, sim in payload["similarities"].items() + if pair[0] < completed_rows + } + with open(checkpoint_path, "wb") as f: + pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + + resumed = generate_network( + sample_agents, + config, + checkpoint_path=checkpoint_path, + resume_from_checkpoint=True, + ) + fresh = generate_network(sample_agents, config) + + resumed_edges = {(e.source, e.target) for e in resumed.edges} + fresh_edges = {(e.source, e.target) for e in fresh.edges} + + assert resumed.meta["resumed_from_checkpoint"] is True + assert resumed_edges == fresh_edges + assert len(resumed.edges) == len(result_checkpointed.edges) + class TestGenerateNetworkWithMetrics: """Tests for network generation with metrics.""" diff --git a/tests/test_propagation.py b/tests/test_propagation.py index a52075b..b0d9051 100644 --- a/tests/test_propagation.py +++ b/tests/test_propagation.py @@ -72,8 +72,9 @@ def _make_scenario( name="test", description="Test scenario", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_reasoning_prompts.py b/tests/test_reasoning_prompts.py index f043dba..d8feb18 100644 --- a/tests/test_reasoning_prompts.py +++ b/tests/test_reasoning_prompts.py @@ -48,8 +48,9 @@ def _make_scenario(**overrides): name="test", description="Test", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_scenario.py b/tests/test_scenario.py index 15111fb..274b4af 100644 --- a/tests/test_scenario.py +++ b/tests/test_scenario.py @@ -342,8 +342,9 @@ def test_scenario_meta_creation(self): name="ai_tool_announcement", description="Hospital announces new AI diagnostic tool", population_spec="surgeons.yaml", - agents_file="agents.json", - network_file="network.json", + study_db="study.db", + population_id="default", + network_id="default", ) assert meta.name == "ai_tool_announcement" assert meta.population_spec == "surgeons.yaml" @@ -361,8 +362,9 @@ def sample_scenario_spec(self): name="test_scenario", description="Test scenario", population_spec="pop.yaml", - agents_file="agents.json", - network_file="network.json", + study_db="study.db", + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, @@ -596,8 +598,9 @@ def test_full_scenario_with_all_features(self): name="ai_tool_full_scenario", description="Hospital announces mandatory AI diagnostic tool", population_spec="german_surgeons.yaml", - agents_file="agents_500.json", - network_file="network_500.json", + study_db="study.db", + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, diff --git a/tests/test_scenario_validator.py b/tests/test_scenario_validator.py index 3aa6d57..25082cb 100644 --- a/tests/test_scenario_validator.py +++ b/tests/test_scenario_validator.py @@ -21,20 +21,21 @@ SpreadConfig, ) from extropy.scenario.validator import load_and_validate_scenario, validate_scenario +from extropy.storage import open_study_db def _make_scenario_spec( population_path: str, - agents_path: str, - network_path: str, + study_db_path: str, ) -> ScenarioSpec: return ScenarioSpec( meta=ScenarioMeta( name="test_scenario", description="Validation test scenario", population_spec=population_path, - agents_file=agents_path, - network_file=network_path, + study_db=study_db_path, + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, @@ -83,17 +84,19 @@ def _make_scenario_spec( def test_validate_scenario_surfaces_errors(tmp_path: Path): """Validation should preserve and return discovered errors.""" population_path = tmp_path / "population.yaml" - agents_path = tmp_path / "agents.json" - network_path = tmp_path / "network.json" + study_db = tmp_path / "study.db" population_path.write_text("placeholder: true\n") - agents_path.write_text("[]\n") - network_path.write_text('{"meta": {}, "edges": []}\n') + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=[], + meta={"source": "test"}, + ) spec = _make_scenario_spec( str(population_path), - str(agents_path), - str(network_path), + str(study_db), ) spec.seed_exposure.rules[0].channel = "missing_channel" @@ -110,15 +113,27 @@ def test_load_and_validate_scenario_resolves_relative_paths( ): """Relative file references should resolve against scenario file location.""" population_path = tmp_path / "population.yaml" - agents_path = tmp_path / "agents.json" - network_path = tmp_path / "network.json" + study_db = tmp_path / "study.db" scenario_path = tmp_path / "scenario.yaml" minimal_population_spec.to_yaml(population_path) - agents_path.write_text('[{"_id": "agent_0", "age": 35, "gender": "male"}]\n') - network_path.write_text(json.dumps({"meta": {"node_count": 1}, "edges": []})) - - spec = _make_scenario_spec("population.yaml", "agents.json", "network.json") + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=[{"_id": "agent_0", "age": 35, "gender": "male"}], + meta={"source": "test"}, + ) + db.save_network_result( + population_id="default", + network_id="default", + config={}, + result_meta={"node_count": 1}, + edges=[], + seed=None, + candidate_mode="test", + ) + + spec = _make_scenario_spec("population.yaml", "study.db") spec.to_yaml(scenario_path) _, result = load_and_validate_scenario(scenario_path) @@ -133,17 +148,15 @@ def test_load_and_validate_scenario_resolves_relative_paths( def test_validate_scenario_allows_edge_weight_in_spread_modifier(tmp_path: Path): """edge_weight should be treated as a valid spread modifier reference.""" population_path = tmp_path / "population.yaml" - agents_path = tmp_path / "agents.json" - network_path = tmp_path / "network.json" + study_db = tmp_path / "study.db" population_path.write_text("placeholder: true\n") - agents_path.write_text("[]\n") - network_path.write_text('{"meta": {}, "edges": []}\n') + with open_study_db(study_db) as db: + db.save_sample_result(population_id="default", agents=[], meta={"source": "test"}) spec = _make_scenario_spec( str(population_path), - str(agents_path), - str(network_path), + str(study_db), ) spec.spread.share_modifiers = [ SpreadModifier(when="edge_weight > 0.7", multiply=1.1, add=0.0) From cb45139b40335b9a694734bf889a223810cfab61 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 00:21:25 -0500 Subject: [PATCH 02/15] fix(tests+checkpoint): handle partial chunk resume and update DB-first compiler fixtures --- extropy/population/network/generator.py | 14 ++++-- tests/test_compiler.py | 65 ++++++++++++++----------- tests/test_scenario_validator.py | 13 ++++- 3 files changed, 57 insertions(+), 35 deletions(-) diff --git a/extropy/population/network/generator.py b/extropy/population/network/generator.py index 75ad7b7..4e412df 100644 --- a/extropy/population/network/generator.py +++ b/extropy/population/network/generator.py @@ -218,6 +218,7 @@ def _load_similarity_checkpoint( similarities = payload.get("similarities", {}) completed_rows = int(payload.get("completed_rows", 0)) completed_chunk_starts: set[int] = set() + allowed_completed_rows = max(0, completed_rows) for item in payload.get("completed_chunks", []): if ( isinstance(item, (list, tuple)) @@ -225,7 +226,10 @@ def _load_similarity_checkpoint( and isinstance(item[0], int) and isinstance(item[1], int) ): - completed_chunk_starts.add(item[0]) + # Guard against stale/inconsistent payloads where completed_chunks + # were not truncated with completed_rows. + if item[0] < allowed_completed_rows and item[1] <= allowed_completed_rows: + completed_chunk_starts.add(item[0]) if not isinstance(similarities, dict): raise ValueError("Invalid checkpoint similarities payload") @@ -304,8 +308,8 @@ def _compute_similarities_parallel( tasks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)] task_ends = {start: end for start, end in tasks} completed_starts: set[int] = set(completed_chunk_starts or set()) - for start, _ in tasks: - if start < completed_rows: + for start, end in tasks: + if end <= completed_rows: completed_starts.add(start) pending_tasks = [(s, e) for s, e in tasks if s not in completed_starts] workers = max(1, config.similarity_workers) @@ -427,8 +431,8 @@ def _compute_similarities_serial( chunk_size = max(8, config.similarity_chunk_size) tasks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)] completed_starts: set[int] = set(completed_chunk_starts or set()) - for start, _ in tasks: - if start < start_row: + for start, end in tasks: + if end <= start_row: completed_starts.add(start) completed_row_count = sum((e - s) for s, e in tasks if s in completed_starts) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index c6413ea..32c0728 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -3,7 +3,6 @@ Tests the 5-step compilation pipeline and auto-configuration logic. """ -import json from unittest.mock import patch import pytest @@ -17,6 +16,7 @@ _determine_simulation_config, create_scenario, ) +from extropy.storage import open_study_db class TestGenerateScenarioName: @@ -85,31 +85,36 @@ def mock_files(self, minimal_population_spec, tmp_path): pop_path = tmp_path / "population.yaml" minimal_population_spec.to_yaml(pop_path) - # Create agents JSON - agents = [ - {"_id": f"agent_{i:03d}", "age": 30 + i, "gender": "male"} + agents = [{"_id": f"agent_{i:03d}", "age": 30 + i, "gender": "male"} for i in range(10)] + edges = [ + { + "source": f"agent_{i:03d}", + "target": f"agent_{(i + 1) % 10:03d}", + "weight": 1.0, + "type": "colleague", + "influence_weight": {"source_to_target": 1.0, "target_to_source": 1.0}, + } for i in range(10) ] - agents_path = tmp_path / "agents.json" - agents_path.write_text(json.dumps(agents)) - - # Create network JSON - network = { - "meta": {"node_count": 10}, - "nodes": [{"id": f"agent_{i:03d}"} for i in range(10)], - "edges": [ - { - "source": f"agent_{i:03d}", - "target": f"agent_{(i + 1) % 10:03d}", - "type": "colleague", - } - for i in range(10) - ], - } - network_path = tmp_path / "network.json" - network_path.write_text(json.dumps(network)) - return pop_path, agents_path, network_path + study_db = tmp_path / "study.db" + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=agents, + meta={"source": "test_fixture"}, + ) + db.save_network_result( + population_id="default", + network_id="default", + config={}, + result_meta={"node_count": 10}, + edges=edges, + seed=42, + candidate_mode="test", + ) + + return pop_path, study_db @patch("extropy.scenario.compiler.parse_scenario") @patch("extropy.scenario.compiler.generate_seed_exposure") @@ -137,7 +142,7 @@ def test_creates_valid_scenario( OutcomeType, ) - pop_path, agents_path, network_path = mock_files + pop_path, study_db = mock_files # Configure mocks mock_parse.return_value = Event( @@ -188,8 +193,9 @@ def test_creates_valid_scenario( spec, validation_result = create_scenario( description="Test product launch scenario", population_spec_path=pop_path, - agents_path=agents_path, - network_path=network_path, + study_db_path=study_db, + population_id="default", + network_id="default", ) assert spec.meta.name is not None @@ -223,7 +229,7 @@ def test_progress_callback_called( OutcomeType, ) - pop_path, agents_path, network_path = mock_files + pop_path, study_db = mock_files mock_parse.return_value = Event( type=EventType.PRODUCT_LAUNCH, @@ -270,8 +276,9 @@ def on_progress(step, status): create_scenario( description="Test", population_spec_path=pop_path, - agents_path=agents_path, - network_path=network_path, + study_db_path=study_db, + population_id="default", + network_id="default", on_progress=on_progress, ) diff --git a/tests/test_scenario_validator.py b/tests/test_scenario_validator.py index 25082cb..5464228 100644 --- a/tests/test_scenario_validator.py +++ b/tests/test_scenario_validator.py @@ -128,7 +128,18 @@ def test_load_and_validate_scenario_resolves_relative_paths( network_id="default", config={}, result_meta={"node_count": 1}, - edges=[], + edges=[ + { + "source": "agent_0", + "target": "agent_0", + "weight": 1.0, + "type": "self", + "influence_weight": { + "source_to_target": 1.0, + "target_to_source": 1.0, + }, + } + ], seed=None, candidate_mode="test", ) From bcf0810a077bcd9408dd0eeabf1006ca073a4b83 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 00:47:31 -0500 Subject: [PATCH 03/15] feat(sim): run-scoped state tables and writer-queue checkpoints --- extropy/cli/commands/simulate.py | 13 +- extropy/simulation/engine.py | 174 ++++++++++++++----- extropy/simulation/state.py | 285 ++++++++++++++++++++++--------- extropy/storage/study_db.py | 41 +++-- tests/test_engine.py | 51 ++++++ tests/test_simulation.py | 28 +++ 6 files changed, 450 insertions(+), 142 deletions(-) diff --git a/extropy/cli/commands/simulate.py b/extropy/cli/commands/simulate.py index 7d6168e..bd592eb 100644 --- a/extropy/cli/commands/simulate.py +++ b/extropy/cli/commands/simulate.py @@ -153,13 +153,13 @@ def simulate_command( 256, "--writer-queue-size", min=1, - help="Reserved writer queue size (future pipeline tuning)", + help="Max reasoning chunks buffered before DB writer backpressure", ), db_write_batch_size: int = typer.Option( 100, "--db-write-batch-size", min=1, - help="Reserved DB write batch size (future pipeline tuning)", + help="Number of chunks applied per DB writer transaction", ), retention_lite: bool = typer.Option( False, @@ -290,11 +290,6 @@ def simulate_command( f"Resources(auto): cpu={snap.cpu_count} mem={snap.total_memory_gb:.1f}GB " f"budget={snap.memory_budget_gb:.1f}GB chunk={tuned_chunk_size}" ) - if writer_queue_size != 256 or db_write_batch_size != 100: - console.print( - "[dim]Note: writer queue/batch flags are accepted now and will be fully enforced " - "by the upcoming async writer pipeline.[/dim]" - ) if verbose or debug: console.print(f"Logging: {'DEBUG' if debug else 'VERBOSE'}") console.print() @@ -335,6 +330,8 @@ def on_progress(timestep: int, max_timesteps: int, status: str): resume=resume, checkpoint_every_chunks=checkpoint_every_chunks, retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, ) simulation_error = None except Exception as e: @@ -369,6 +366,8 @@ def do_simulation(): resume=resume, checkpoint_every_chunks=checkpoint_every_chunks, retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, ) except Exception as e: simulation_error = e diff --git a/extropy/simulation/engine.py b/extropy/simulation/engine.py index 5237dc6..fe85c0e 100644 --- a/extropy/simulation/engine.py +++ b/extropy/simulation/engine.py @@ -14,8 +14,10 @@ import json import logging +import queue import random import sqlite3 +import threading import time import uuid from datetime import datetime @@ -88,6 +90,7 @@ class SimulationSummary: def __init__( self, scenario_name: str, + run_id: str | None, population_size: int, total_timesteps: int, stopped_reason: str | None, @@ -100,6 +103,7 @@ def __init__( completed_at: datetime, ): self.scenario_name = scenario_name + self.run_id = run_id self.population_size = population_size self.total_timesteps = total_timesteps self.stopped_reason = stopped_reason @@ -115,6 +119,7 @@ def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "scenario_name": self.scenario_name, + "run_id": self.run_id, "population_size": self.population_size, "total_timesteps": self.total_timesteps, "stopped_reason": self.stopped_reason, @@ -149,6 +154,8 @@ def __init__( run_id: str | None = None, checkpoint_every_chunks: int = 1, retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, ): """Initialize simulation engine. @@ -173,6 +180,8 @@ def __init__( self.run_id = run_id or f"run_{uuid.uuid4().hex[:12]}" self.checkpoint_every_chunks = max(1, checkpoint_every_chunks) self.retention_lite = retention_lite + self.writer_queue_size = max(1, writer_queue_size) + self.db_write_batch_size = max(1, db_write_batch_size) # Build agent map for quick lookup self.agent_map = {a.get("_id", str(i)): a for i, a in enumerate(agents)} @@ -202,6 +211,7 @@ def __init__( self.state_manager = StateManager( state_db_file, agents, + run_id=self.run_id, ) self.study_db = open_study_db(state_db_file) @@ -574,16 +584,87 @@ def _on_agent_done(agent_id: str, result: Any) -> None: context = self._build_reasoning_context(agent_id, old_state) contexts.append(context) - # Split into chunks - total_reasoned = 0 - total_changes = 0 - total_shares = 0 - completed_chunks = self.study_db.get_completed_simulation_chunks( self.run_id, timestep ) + totals = {"reasoned": 0, "changes": 0, "shares": 0} + work_queue: queue.Queue[tuple[int, list[tuple[str, Any]], bool] | object] = ( + queue.Queue(maxsize=self.writer_queue_size) + ) + sentinel = object() + writer_error: list[Exception] = [] + + def _writer_loop() -> None: + chunks_since_checkpoint = 0 + pending_chunks: list[tuple[int, list[tuple[str, Any]], bool]] = [] + + def _flush_pending() -> None: + nonlocal chunks_since_checkpoint + if not pending_chunks: + return + + with self.state_manager.transaction(): + for chunk_index, chunk_results, _is_last_chunk in pending_chunks: + reasoned, changes, shares = self._process_reasoning_chunk( + timestep, chunk_results, old_states + ) + totals["reasoned"] += reasoned + totals["changes"] += changes + totals["shares"] += shares + + for chunk_index, _chunk_results, is_last_chunk in pending_chunks: + self.study_db.save_simulation_checkpoint( + run_id=self.run_id, + timestep=timestep, + chunk_index=chunk_index, + status="done", + ) + chunks_since_checkpoint += 1 + if ( + chunks_since_checkpoint >= self.checkpoint_every_chunks + or is_last_chunk + ): + self.study_db.set_run_metadata( + self.run_id, + "last_checkpoint", + f"{timestep}:{chunk_index}", + ) + chunks_since_checkpoint = 0 + + pending_chunks.clear() + + try: + while True: + item = work_queue.get() + try: + if item is sentinel: + _flush_pending() + break + + chunk_index, chunk_results, is_last_chunk = item + if chunk_index in completed_chunks: + continue + pending_chunks.append((chunk_index, chunk_results, is_last_chunk)) + if ( + len(pending_chunks) >= self.db_write_batch_size + or is_last_chunk + ): + _flush_pending() + finally: + work_queue.task_done() + except Exception as e: # pragma: no cover - surfaced to caller + writer_error.append(e) + + writer_thread = threading.Thread( + target=_writer_loop, + name=f"sim-writer-{self.run_id}-{timestep}", + daemon=True, + ) + writer_thread.start() for chunk_start in range(0, len(contexts), self.chunk_size): + if writer_error: + break chunk_index = chunk_start // self.chunk_size if chunk_index in completed_chunks: logger.info( @@ -615,28 +696,27 @@ def _on_agent_done(agent_id: str, result: Any) -> None: if chunk_results else f"[TIMESTEP {timestep}] Chunk empty" ) - - # Process and commit this chunk - with self.state_manager.transaction(): - reasoned, changes, shares = self._process_reasoning_chunk( - timestep, chunk_results, old_states - ) - if ( - ((chunk_index + 1) % self.checkpoint_every_chunks == 0) - or (chunk_start + self.chunk_size >= len(contexts)) - ): - self.study_db.save_simulation_checkpoint( - run_id=self.run_id, - timestep=timestep, - chunk_index=chunk_index, - status="done", - ) - - total_reasoned += reasoned - total_changes += changes - total_shares += shares - - return total_reasoned, total_changes, total_shares + is_last_chunk = chunk_start + self.chunk_size >= len(contexts) + work_queue.put((chunk_index, chunk_results, is_last_chunk)) + + work_queue.put(sentinel) + while work_queue.unfinished_tasks > 0: + if writer_error: + while True: + try: + work_queue.get_nowait() + work_queue.task_done() + except queue.Empty: + break + break + time.sleep(0.01) + + work_queue.join() + writer_thread.join(timeout=1) + if writer_error: + raise writer_error[0] + + return totals["reasoned"], totals["changes"], totals["shares"] def _process_reasoning_chunk( self, @@ -1104,6 +1184,7 @@ def _finalize( return SimulationSummary( scenario_name=self.scenario.meta.name, + run_id=self.run_id, population_size=len(self.agents), total_timesteps=final_timestep + 1, stopped_reason=stopped_reason, @@ -1276,6 +1357,8 @@ def run_simulation( resume: bool = False, checkpoint_every_chunks: int = 1, retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, ) -> SimulationSummary: """Run a simulation from a scenario file. @@ -1300,6 +1383,8 @@ def run_simulation( resume: Resume a prior run from DB checkpoints checkpoint_every_chunks: Mark simulation checkpoint every N chunks retention_lite: Reduce payload volume by dropping full raw reasoning text + writer_queue_size: Maximum buffered chunks waiting for DB writer + db_write_batch_size: Number of chunks applied per DB writer transaction Returns: SimulationSummary with results @@ -1309,21 +1394,26 @@ def run_simulation( if resume and not run_id: raise ValueError("--resume requires --run-id") - def _reset_runtime_tables(path: Path) -> None: + def _reset_runtime_tables(path: Path, run_key: str) -> None: conn = sqlite3.connect(str(path)) try: cur = conn.cursor() - cur.executescript( - """ - DELETE FROM agent_states; - DELETE FROM exposures; - DELETE FROM memory_traces; - DELETE FROM timeline; - DELETE FROM timestep_summaries; - DELETE FROM shared_to; - DELETE FROM simulation_metadata; - """ - ) + statements = [ + "DELETE FROM agent_states WHERE run_id = ?", + "DELETE FROM exposures WHERE run_id = ?", + "DELETE FROM memory_traces WHERE run_id = ?", + "DELETE FROM timeline WHERE run_id = ?", + "DELETE FROM timestep_summaries WHERE run_id = ?", + "DELETE FROM shared_to WHERE run_id = ?", + "DELETE FROM simulation_metadata WHERE run_id = ?", + ] + for sql in statements: + try: + cur.execute(sql, (run_key,)) + except sqlite3.OperationalError: + # Legacy table shape fallback. + table = sql.split()[2] + cur.execute(f"DELETE FROM {table}") conn.commit() except sqlite3.OperationalError: # First run on this DB may not have simulation tables yet. @@ -1386,6 +1476,8 @@ def _reset_runtime_tables(path: Path) -> None: "chunk_size": chunk_size, "checkpoint_every_chunks": checkpoint_every_chunks, "retention_lite": retention_lite, + "writer_queue_size": writer_queue_size, + "db_write_batch_size": db_write_batch_size, "resume": resume, }, seed=random_seed, @@ -1395,7 +1487,7 @@ def _reset_runtime_tables(path: Path) -> None: db.set_run_metadata(resolved_run_id, "study_db", str(study_db_resolved)) if not resume: - _reset_runtime_tables(study_db_resolved) + _reset_runtime_tables(study_db_resolved, resolved_run_id) # Load persona config if provided persona_config = None @@ -1454,6 +1546,8 @@ def _reset_runtime_tables(path: Path) -> None: run_id=resolved_run_id, checkpoint_every_chunks=checkpoint_every_chunks, retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, ) if on_progress: diff --git a/extropy/simulation/state.py b/extropy/simulation/state.py index 94a4350..0b331a3 100644 --- a/extropy/simulation/state.py +++ b/extropy/simulation/state.py @@ -27,16 +27,23 @@ class StateManager: for frequently accessed data. """ - def __init__(self, db_path: Path | str, agents: list[dict[str, Any]] | None = None): + def __init__( + self, + db_path: Path | str, + agents: list[dict[str, Any]] | None = None, + run_id: str = "default", + ): """Initialize state manager with database path. Args: db_path: Path to SQLite database file agents: Optional list of agents to initialize + run_id: Simulation run scope key """ self.db_path = Path(db_path) + self.run_id = run_id self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path)) + self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) self.conn.row_factory = sqlite3.Row self.conn.execute("PRAGMA foreign_keys = ON") @@ -54,7 +61,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS agent_states ( - agent_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, aware INTEGER DEFAULT 0, exposure_count INTEGER DEFAULT 0, last_reasoning_timestep INTEGER DEFAULT -1, @@ -73,7 +81,8 @@ def _create_schema(self) -> None: private_conviction REAL, private_outcomes_json TEXT, raw_reasoning TEXT, - updated_at INTEGER DEFAULT 0 + updated_at INTEGER DEFAULT 0, + PRIMARY KEY (run_id, agent_id) ) """ ) @@ -82,6 +91,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS exposures ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, @@ -89,7 +99,7 @@ def _create_schema(self) -> None: source_agent_id TEXT, content TEXT, credibility REAL, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -98,13 +108,14 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS memory_traces ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, sentiment REAL, conviction REAL, summary TEXT, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -113,6 +124,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timeline ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, timestep INTEGER, event_type TEXT, @@ -127,7 +139,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timestep_summaries ( - timestep INTEGER PRIMARY KEY, + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, new_exposures INTEGER, agents_reasoned INTEGER, shares_occurred INTEGER, @@ -136,7 +149,8 @@ def _create_schema(self) -> None: position_distribution_json TEXT, average_sentiment REAL, average_conviction REAL, - sentiment_variance REAL + sentiment_variance REAL, + PRIMARY KEY (run_id, timestep) ) """ ) @@ -145,37 +159,37 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_agent - ON exposures(agent_id) + ON exposures(run_id, agent_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_timestep - ON exposures(timestep) + ON exposures(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timeline_timestep - ON timeline(timestep) + ON timeline(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_aware - ON agent_states(aware) + ON agent_states(run_id, aware) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_will_share - ON agent_states(will_share) + ON agent_states(run_id, will_share) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_memory_traces_agent - ON memory_traces(agent_id) + ON memory_traces(run_id, agent_id) """ ) @@ -183,18 +197,19 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS shared_to ( + run_id TEXT NOT NULL, source_agent_id TEXT, target_agent_id TEXT, timestep INTEGER, position TEXT, - PRIMARY KEY (source_agent_id, target_agent_id) + PRIMARY KEY (run_id, source_agent_id, target_agent_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_shared_to_source - ON shared_to(source_agent_id) + ON shared_to(run_id, source_agent_id) """ ) @@ -202,8 +217,11 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS simulation_metadata ( - key TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + key TEXT NOT NULL, value TEXT + , + PRIMARY KEY (run_id, key) ) """ ) @@ -215,6 +233,13 @@ def _upgrade_schema(self) -> None: cursor = self.conn.cursor() migrations = [ + ("agent_states", "run_id", "TEXT DEFAULT 'default'"), + ("exposures", "run_id", "TEXT DEFAULT 'default'"), + ("memory_traces", "run_id", "TEXT DEFAULT 'default'"), + ("timeline", "run_id", "TEXT DEFAULT 'default'"), + ("timestep_summaries", "run_id", "TEXT DEFAULT 'default'"), + ("shared_to", "run_id", "TEXT DEFAULT 'default'"), + ("simulation_metadata", "run_id", "TEXT DEFAULT 'default'"), ("agent_states", "conviction", "REAL"), ("agent_states", "public_statement", "TEXT"), ("timestep_summaries", "average_conviction", "REAL"), @@ -237,6 +262,28 @@ def _upgrade_schema(self) -> None: # Column already exists pass + cursor.execute( + "UPDATE agent_states SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE exposures SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE memory_traces SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE timeline SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE timestep_summaries SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE shared_to SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE simulation_metadata SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + self.conn.commit() @contextmanager @@ -264,10 +311,10 @@ def initialize_agents(self, agents: list[dict[str, Any]]) -> None: agent_id = agent.get("_id", str(agent.get("id", ""))) cursor.execute( """ - INSERT OR IGNORE INTO agent_states (agent_id) - VALUES (?) + INSERT OR IGNORE INTO agent_states (run_id, agent_id) + VALUES (?, ?) """, - (agent_id,), + (self.run_id, agent_id), ) self.conn.commit() @@ -286,9 +333,9 @@ def get_agent_state(self, agent_id: str) -> AgentState: # Get main state cursor.execute( """ - SELECT * FROM agent_states WHERE agent_id = ? + SELECT * FROM agent_states WHERE run_id = ? AND agent_id = ? """, - (agent_id,), + (self.run_id, agent_id), ) row = cursor.fetchone() @@ -298,9 +345,11 @@ def get_agent_state(self, agent_id: str) -> AgentState: # Get exposure history cursor.execute( """ - SELECT * FROM exposures WHERE agent_id = ? ORDER BY timestep + SELECT * FROM exposures + WHERE run_id = ? AND agent_id = ? + ORDER BY timestep """, - (agent_id,), + (self.run_id, agent_id), ) exposure_rows = cursor.fetchall() @@ -386,35 +435,42 @@ def get_agent_state(self, agent_id: str) -> AgentState: def get_unaware_agents(self) -> list[str]: """Get IDs of agents who haven't been exposed yet.""" cursor = self.conn.cursor() - cursor.execute("SELECT agent_id FROM agent_states WHERE aware = 0") + cursor.execute( + "SELECT agent_id FROM agent_states WHERE run_id = ? AND aware = 0", + (self.run_id,), + ) return [row["agent_id"] for row in cursor.fetchall()] def get_aware_agents(self) -> list[str]: """Get IDs of agents who are aware of the event.""" cursor = self.conn.cursor() - cursor.execute("SELECT agent_id FROM agent_states WHERE aware = 1") + cursor.execute( + "SELECT agent_id FROM agent_states WHERE run_id = ? AND aware = 1", + (self.run_id,), + ) return [row["agent_id"] for row in cursor.fetchall()] def get_sharers(self) -> list[str]: """Get IDs of agents who will share.""" cursor = self.conn.cursor() cursor.execute( - "SELECT agent_id FROM agent_states WHERE aware = 1 AND will_share = 1" + "SELECT agent_id FROM agent_states WHERE run_id = ? AND aware = 1 AND will_share = 1", + (self.run_id,), ) return [row["agent_id"] for row in cursor.fetchall()] def get_all_agent_ids(self) -> list[str]: """Get all agent IDs in the database.""" cursor = self.conn.cursor() - cursor.execute("SELECT agent_id FROM agent_states") + cursor.execute("SELECT agent_id FROM agent_states WHERE run_id = ?", (self.run_id,)) return [row["agent_id"] for row in cursor.fetchall()] def get_network_hop_depth(self, agent_id: str) -> int | None: """Get the minimum network hop depth from a seed exposure for an agent.""" cursor = self.conn.cursor() cursor.execute( - "SELECT network_hop_depth FROM agent_states WHERE agent_id = ?", - (agent_id,), + "SELECT network_hop_depth FROM agent_states WHERE run_id = ? AND agent_id = ?", + (self.run_id, agent_id), ) row = cursor.fetchone() if not row: @@ -443,8 +499,10 @@ def get_agents_to_reason(self, timestep: int, threshold: int) -> list[str]: cursor.execute( """ SELECT agent_id FROM agent_states - WHERE aware = 1 AND last_reasoning_timestep < 0 + WHERE run_id = ? AND aware = 1 AND last_reasoning_timestep < 0 """ + , + (self.run_id,), ) never_reasoned = [row["agent_id"] for row in cursor.fetchall()] @@ -456,14 +514,18 @@ def get_agents_to_reason(self, timestep: int, threshold: int) -> list[str]: COUNT(DISTINCT e.source_agent_id) as unique_sources FROM agent_states s JOIN exposures e - ON e.agent_id = s.agent_id + ON e.run_id = s.run_id + AND e.agent_id = s.agent_id AND e.timestep > s.last_reasoning_timestep AND e.source_agent_id IS NOT NULL - WHERE s.aware = 1 + WHERE s.run_id = ? + AND s.aware = 1 AND s.last_reasoning_timestep >= 0 AND s.committed = 0 GROUP BY s.agent_id """ + , + (self.run_id,), ) multi_touch = [] @@ -490,10 +552,10 @@ def record_share( cursor.execute( """ INSERT OR REPLACE INTO shared_to - (source_agent_id, target_agent_id, timestep, position) - VALUES (?, ?, ?, ?) + (run_id, source_agent_id, target_agent_id, timestep, position) + VALUES (?, ?, ?, ?, ?) """, - (source_id, target_id, timestep, position), + (self.run_id, source_id, target_id, timestep, position), ) def get_unshared_neighbors( @@ -518,10 +580,11 @@ def get_unshared_neighbors( f""" SELECT target_agent_id, position FROM shared_to - WHERE source_agent_id = ? + WHERE run_id = ? + AND source_agent_id = ? AND target_agent_id IN ({placeholders}) """, - [source_id] + neighbor_ids, + [self.run_id, source_id] + neighbor_ids, ) already_shared = { @@ -547,8 +610,8 @@ def save_metadata(self, key: str, value: str) -> None: """ cursor = self.conn.cursor() cursor.execute( - "INSERT OR REPLACE INTO simulation_metadata (key, value) VALUES (?, ?)", - (key, value), + "INSERT OR REPLACE INTO simulation_metadata (run_id, key, value) VALUES (?, ?, ?)", + (self.run_id, key, value), ) self.conn.commit() @@ -562,7 +625,10 @@ def get_metadata(self, key: str) -> str | None: Value string or None if not found """ cursor = self.conn.cursor() - cursor.execute("SELECT value FROM simulation_metadata WHERE key = ?", (key,)) + cursor.execute( + "SELECT value FROM simulation_metadata WHERE run_id = ? AND key = ?", + (self.run_id, key), + ) row = cursor.fetchone() return row["value"] if row else None @@ -573,7 +639,10 @@ def delete_metadata(self, key: str) -> None: key: Metadata key to delete """ cursor = self.conn.cursor() - cursor.execute("DELETE FROM simulation_metadata WHERE key = ?", (key,)) + cursor.execute( + "DELETE FROM simulation_metadata WHERE run_id = ? AND key = ?", + (self.run_id, key), + ) self.conn.commit() def get_last_completed_timestep(self) -> int: @@ -583,7 +652,10 @@ def get_last_completed_timestep(self) -> int: Max timestep from timestep_summaries, or -1 if none exist. """ cursor = self.conn.cursor() - cursor.execute("SELECT MAX(timestep) as max_ts FROM timestep_summaries") + cursor.execute( + "SELECT MAX(timestep) as max_ts FROM timestep_summaries WHERE run_id = ?", + (self.run_id,), + ) row = cursor.fetchone() if row and row["max_ts"] is not None: return row["max_ts"] @@ -625,8 +697,8 @@ def get_agents_already_reasoned_this_timestep(self, timestep: int) -> set[str]: """ cursor = self.conn.cursor() cursor.execute( - "SELECT agent_id FROM agent_states WHERE last_reasoning_timestep = ?", - (timestep,), + "SELECT agent_id FROM agent_states WHERE run_id = ? AND last_reasoning_timestep = ?", + (self.run_id, timestep), ) return {row["agent_id"] for row in cursor.fetchall()} @@ -642,10 +714,19 @@ def record_exposure(self, agent_id: str, exposure: ExposureRecord) -> None: # Insert exposure record cursor.execute( """ - INSERT INTO exposures (agent_id, timestep, channel, source_agent_id, content, credibility) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO exposures ( + run_id, + agent_id, + timestep, + channel, + source_agent_id, + content, + credibility + ) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( + self.run_id, agent_id, exposure.timestep, exposure.channel, @@ -675,13 +756,15 @@ def record_exposure(self, agent_id: str, exposure: ExposureRecord) -> None: ELSE MIN(network_hop_depth, ?) END, updated_at = ? - WHERE agent_id = ? + WHERE run_id = ? + AND agent_id = ? """, ( new_hop_depth, new_hop_depth, new_hop_depth, exposure.timestep, + self.run_id, agent_id, ), ) @@ -724,7 +807,8 @@ def apply_conviction_decay( ELSE will_share END, updated_at = ? - WHERE aware = 1 + WHERE run_id = ? + AND aware = 1 AND conviction IS NOT NULL AND conviction > ? AND last_reasoning_timestep < ? @@ -738,6 +822,7 @@ def apply_conviction_decay( decay_multiplier, sharing_threshold, timestep, + self.run_id, sharing_threshold, timestep, ), @@ -783,7 +868,8 @@ def update_agent_state( raw_reasoning = ?, last_reasoning_timestep = ?, updated_at = ? - WHERE agent_id = ? + WHERE run_id = ? + AND agent_id = ? """, ( state.position, @@ -808,6 +894,7 @@ def update_agent_state( state.raw_reasoning, timestep, timestep, + self.run_id, agent_id, ), ) @@ -851,7 +938,8 @@ def batch_update_states( raw_reasoning = ?, last_reasoning_timestep = ?, updated_at = ? - WHERE agent_id = ? + WHERE run_id = ? + AND agent_id = ? """, ( state.position, @@ -876,6 +964,7 @@ def batch_update_states( state.raw_reasoning, timestep, timestep, + self.run_id, agent_id, ), ) @@ -895,10 +984,11 @@ def save_memory_entry(self, agent_id: str, entry: MemoryEntry) -> None: # Insert new entry cursor.execute( """ - INSERT INTO memory_traces (agent_id, timestep, sentiment, conviction, summary) - VALUES (?, ?, ?, ?, ?) + INSERT INTO memory_traces (run_id, agent_id, timestep, sentiment, conviction, summary) + VALUES (?, ?, ?, ?, ?, ?) """, ( + self.run_id, agent_id, entry.timestep, entry.sentiment, @@ -913,12 +1003,12 @@ def save_memory_entry(self, agent_id: str, entry: MemoryEntry) -> None: DELETE FROM memory_traces WHERE id NOT IN ( SELECT id FROM memory_traces - WHERE agent_id = ? + WHERE run_id = ? AND agent_id = ? ORDER BY timestep DESC LIMIT 3 - ) AND agent_id = ? + ) AND run_id = ? AND agent_id = ? """, - (agent_id, agent_id), + (self.run_id, agent_id, self.run_id, agent_id), ) def get_memory_traces(self, agent_id: str) -> list[MemoryEntry]: @@ -934,10 +1024,10 @@ def get_memory_traces(self, agent_id: str) -> list[MemoryEntry]: cursor.execute( """ SELECT * FROM memory_traces - WHERE agent_id = ? + WHERE run_id = ? AND agent_id = ? ORDER BY timestep ASC """, - (agent_id,), + (self.run_id, agent_id), ) return [ @@ -960,10 +1050,18 @@ def log_event(self, event: SimulationEvent) -> None: cursor.execute( """ - INSERT INTO timeline (timestep, event_type, agent_id, details_json, wall_timestamp) - VALUES (?, ?, ?, ?, ?) + INSERT INTO timeline ( + run_id, + timestep, + event_type, + agent_id, + details_json, + wall_timestamp + ) + VALUES (?, ?, ?, ?, ?, ?) """, ( + self.run_id, event.timestep, event.event_type.value, event.agent_id, @@ -976,13 +1074,19 @@ def get_exposure_rate(self) -> float: """Get fraction of population that is aware.""" cursor = self.conn.cursor() - cursor.execute("SELECT COUNT(*) as total FROM agent_states") + cursor.execute( + "SELECT COUNT(*) as total FROM agent_states WHERE run_id = ?", + (self.run_id,), + ) total = cursor.fetchone()["total"] if total == 0: return 0.0 - cursor.execute("SELECT COUNT(*) as aware FROM agent_states WHERE aware = 1") + cursor.execute( + "SELECT COUNT(*) as aware FROM agent_states WHERE run_id = ? AND aware = 1", + (self.run_id,), + ) aware = cursor.fetchone()["aware"] return aware / total @@ -995,9 +1099,11 @@ def get_position_distribution(self) -> dict[str, int]: """ SELECT COALESCE(private_position, position) as position, COUNT(*) as cnt FROM agent_states - WHERE COALESCE(private_position, position) IS NOT NULL + WHERE run_id = ? + AND COALESCE(private_position, position) IS NOT NULL GROUP BY COALESCE(private_position, position) - """ + """, + (self.run_id,), ) return {row["position"]: row["cnt"] for row in cursor.fetchall()} @@ -1010,8 +1116,10 @@ def get_average_sentiment(self) -> float | None: """ SELECT AVG(COALESCE(private_sentiment, sentiment)) as avg_sentiment FROM agent_states - WHERE COALESCE(private_sentiment, sentiment) IS NOT NULL - """ + WHERE run_id = ? + AND COALESCE(private_sentiment, sentiment) IS NOT NULL + """, + (self.run_id,), ) row = cursor.fetchone() @@ -1025,8 +1133,10 @@ def get_average_conviction(self) -> float | None: """ SELECT AVG(COALESCE(private_conviction, conviction)) as avg_conviction FROM agent_states - WHERE COALESCE(private_conviction, conviction) IS NOT NULL - """ + WHERE run_id = ? + AND COALESCE(private_conviction, conviction) IS NOT NULL + """, + (self.run_id,), ) row = cursor.fetchone() @@ -1040,8 +1150,10 @@ def get_sentiment_variance(self) -> float | None: """ SELECT AVG(COALESCE(private_sentiment, sentiment)) as mean_s, COUNT(*) as cnt FROM agent_states - WHERE COALESCE(private_sentiment, sentiment) IS NOT NULL - """ + WHERE run_id = ? + AND COALESCE(private_sentiment, sentiment) IS NOT NULL + """, + (self.run_id,), ) row = cursor.fetchone() @@ -1056,9 +1168,10 @@ def get_sentiment_variance(self) -> float | None: * (COALESCE(private_sentiment, sentiment) - ?) ) as variance FROM agent_states - WHERE COALESCE(private_sentiment, sentiment) IS NOT NULL + WHERE run_id = ? + AND COALESCE(private_sentiment, sentiment) IS NOT NULL """, - (mean, mean), + (mean, mean, self.run_id), ) var_row = cursor.fetchone() return var_row["variance"] if var_row else None @@ -1074,12 +1187,13 @@ def save_timestep_summary(self, summary: TimestepSummary) -> None: cursor.execute( """ INSERT OR REPLACE INTO timestep_summaries - (timestep, new_exposures, agents_reasoned, shares_occurred, + (run_id, timestep, new_exposures, agents_reasoned, shares_occurred, state_changes, exposure_rate, position_distribution_json, average_sentiment, average_conviction, sentiment_variance) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( + self.run_id, summary.timestep, summary.new_exposures, summary.agents_reasoned, @@ -1099,8 +1213,11 @@ def get_timestep_summaries(self) -> list[TimestepSummary]: cursor.execute( """ - SELECT * FROM timestep_summaries ORDER BY timestep - """ + SELECT * FROM timestep_summaries + WHERE run_id = ? + ORDER BY timestep + """, + (self.run_id,), ) summaries = [] @@ -1137,7 +1254,7 @@ def export_final_states(self) -> list[dict[str, Any]]: """ cursor = self.conn.cursor() - cursor.execute("SELECT * FROM agent_states") + cursor.execute("SELECT * FROM agent_states WHERE run_id = ?", (self.run_id,)) agent_rows = cursor.fetchall() states = [] @@ -1145,8 +1262,10 @@ def export_final_states(self) -> list[dict[str, Any]]: """ SELECT agent_id, COUNT(*) as cnt FROM exposures + WHERE run_id = ? GROUP BY agent_id - """ + """, + (self.run_id,), ) exposure_counts = {row["agent_id"]: row["cnt"] for row in cursor.fetchall()} @@ -1227,7 +1346,10 @@ def export_timeline(self) -> list[dict[str, Any]]: """ cursor = self.conn.cursor() - cursor.execute("SELECT * FROM timeline ORDER BY timestep, id") + cursor.execute( + "SELECT * FROM timeline WHERE run_id = ? ORDER BY timestep, id", + (self.run_id,), + ) events = [] for row in cursor.fetchall(): @@ -1253,7 +1375,10 @@ def export_timeline(self) -> list[dict[str, Any]]: def get_population_count(self) -> int: """Get total number of agents.""" cursor = self.conn.cursor() - cursor.execute("SELECT COUNT(*) as cnt FROM agent_states") + cursor.execute( + "SELECT COUNT(*) as cnt FROM agent_states WHERE run_id = ?", + (self.run_id,), + ) return cursor.fetchone()["cnt"] def close(self) -> None: diff --git a/extropy/storage/study_db.py b/extropy/storage/study_db.py index 1005b98..8e4731b 100644 --- a/extropy/storage/study_db.py +++ b/extropy/storage/study_db.py @@ -29,7 +29,7 @@ class StudyDB: def __init__(self, path: Path | str): self.path = Path(path) self.path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.path)) + self.conn = sqlite3.connect(str(self.path), check_same_thread=False) self.conn.row_factory = sqlite3.Row self._set_pragmas() self.init_schema() @@ -153,7 +153,8 @@ def init_schema(self) -> None: ); CREATE TABLE IF NOT EXISTS agent_states ( - agent_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, aware INTEGER DEFAULT 0, exposure_count INTEGER DEFAULT 0, last_reasoning_timestep INTEGER DEFAULT -1, @@ -174,10 +175,12 @@ def init_schema(self) -> None: raw_reasoning TEXT, committed INTEGER DEFAULT 0, network_hop_depth INTEGER, - updated_at INTEGER DEFAULT 0 + updated_at INTEGER DEFAULT 0, + PRIMARY KEY (run_id, agent_id) ); CREATE TABLE IF NOT EXISTS exposures ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, @@ -185,20 +188,22 @@ def init_schema(self) -> None: source_agent_id TEXT, content TEXT, credibility REAL, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ); CREATE TABLE IF NOT EXISTS memory_traces ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, sentiment REAL, conviction REAL, summary TEXT, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ); CREATE TABLE IF NOT EXISTS timeline ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, timestep INTEGER, event_type TEXT, @@ -208,7 +213,8 @@ def init_schema(self) -> None: ); CREATE TABLE IF NOT EXISTS timestep_summaries ( - timestep INTEGER PRIMARY KEY, + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, new_exposures INTEGER, agents_reasoned INTEGER, shares_occurred INTEGER, @@ -217,20 +223,25 @@ def init_schema(self) -> None: position_distribution_json TEXT, average_sentiment REAL, average_conviction REAL, - sentiment_variance REAL + sentiment_variance REAL, + PRIMARY KEY (run_id, timestep) ); CREATE TABLE IF NOT EXISTS shared_to ( + run_id TEXT NOT NULL, source_agent_id TEXT, target_agent_id TEXT, timestep INTEGER, position TEXT, - PRIMARY KEY (source_agent_id, target_agent_id) + PRIMARY KEY (run_id, source_agent_id, target_agent_id) ); CREATE TABLE IF NOT EXISTS simulation_metadata ( - key TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + key TEXT NOT NULL, value TEXT + , + PRIMARY KEY (run_id, key) ); CREATE TABLE IF NOT EXISTS run_metadata ( @@ -283,12 +294,12 @@ def init_schema(self) -> None: CREATE INDEX IF NOT EXISTS idx_net_sim_chunks_status ON network_similarity_chunks(job_id, status); CREATE INDEX IF NOT EXISTS idx_sim_ckpt ON simulation_checkpoints(run_id, timestep, chunk_index); CREATE INDEX IF NOT EXISTS idx_chat_session_agent ON chat_sessions(run_id, agent_id); - CREATE INDEX IF NOT EXISTS idx_agent_states_aware ON agent_states(aware); - CREATE INDEX IF NOT EXISTS idx_agent_states_will_share ON agent_states(will_share); - CREATE INDEX IF NOT EXISTS idx_agent_states_last_reasoning ON agent_states(last_reasoning_timestep); - CREATE INDEX IF NOT EXISTS idx_exposures_agent_timestep ON exposures(agent_id, timestep); - CREATE INDEX IF NOT EXISTS idx_timeline_timestep ON timeline(timestep); - CREATE INDEX IF NOT EXISTS idx_shared_to_source ON shared_to(source_agent_id); + CREATE INDEX IF NOT EXISTS idx_agent_states_aware ON agent_states(run_id, aware); + CREATE INDEX IF NOT EXISTS idx_agent_states_will_share ON agent_states(run_id, will_share); + CREATE INDEX IF NOT EXISTS idx_agent_states_last_reasoning ON agent_states(run_id, last_reasoning_timestep); + CREATE INDEX IF NOT EXISTS idx_exposures_agent_timestep ON exposures(run_id, agent_id, timestep); + CREATE INDEX IF NOT EXISTS idx_timeline_timestep ON timeline(run_id, timestep); + CREATE INDEX IF NOT EXISTS idx_shared_to_source ON shared_to(run_id, source_agent_id); """ ) self.conn.commit() diff --git a/tests/test_engine.py b/tests/test_engine.py index 3b427cc..6171f5d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -855,6 +855,57 @@ def test_agents_already_reasoned( already_4 = engine.state_manager.get_agents_already_reasoned_this_timestep(4) assert "a0" not in already_4 + def test_chunk_checkpoints_written_with_writer_pipeline( + self, + minimal_scenario, + simple_agents, + simple_network, + minimal_pop_spec, + tmp_path, + ): + """Writer pipeline should persist per-chunk checkpoints and last checkpoint marker.""" + config = SimulationRunConfig( + scenario_path="test.yaml", + output_dir=str(tmp_path / "output"), + chunk_size=1, + ) + engine = SimulationEngine( + scenario=minimal_scenario, + population_spec=minimal_pop_spec, + agents=simple_agents, + network=simple_network, + config=config, + chunk_size=1, + checkpoint_every_chunks=2, + writer_queue_size=2, + db_write_batch_size=2, + ) + + for aid in ["a0", "a1", "a2"]: + exposure = ExposureRecord( + timestep=0, channel="broadcast", content="Test", credibility=0.9 + ) + engine.state_manager.record_exposure(aid, exposure) + + def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + response = _make_reasoning_response() + results = [] + for ctx in contexts: + if on_agent_done: + on_agent_done(ctx.agent_id, response) + results.append((ctx.agent_id, response)) + return results, BatchTokenUsage() + + with patch( + "extropy.simulation.engine.batch_reason_agents", side_effect=fake_batch + ): + reasoned, _, _ = engine._reason_agents(0) + + assert reasoned == 3 + completed = engine.study_db.get_completed_simulation_chunks(engine.run_id, 0) + assert completed == {0, 1, 2} + assert engine.study_db.get_run_metadata(engine.run_id, "last_checkpoint") == "0:2" + class TestResumeLogic: """Test engine resume/checkpoint logic.""" diff --git a/tests/test_simulation.py b/tests/test_simulation.py index a5866d3..aecec95 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -402,6 +402,34 @@ def test_get_population_count(self, temp_db, agents): count = manager.get_population_count() assert count == 3 + def test_run_scope_isolation(self, temp_db, agents): + """Different run_id views should not leak state into each other.""" + exposure = ExposureRecord( + timestep=0, + channel="email", + content="Scoped exposure", + credibility=0.9, + ) + + with StateManager(temp_db, agents=agents, run_id="run_a") as run_a: + run_a.record_exposure("agent_000", exposure) + assert run_a.get_exposure_rate() == pytest.approx(1 / 3, abs=0.01) + assert run_a.get_checkpoint_timestep() is None + run_a.mark_timestep_started(2) + assert run_a.get_checkpoint_timestep() == 2 + + with StateManager(temp_db, agents=agents, run_id="run_b") as run_b: + assert run_b.get_exposure_rate() == 0.0 + assert run_b.get_agent_state("agent_000").aware is False + assert run_b.get_checkpoint_timestep() is None + run_b.record_exposure("agent_001", exposure) + assert run_b.get_exposure_rate() == pytest.approx(1 / 3, abs=0.01) + + with StateManager(temp_db, agents=agents, run_id="run_a") as run_a_again: + assert run_a_again.get_agent_state("agent_000").aware is True + assert run_a_again.get_agent_state("agent_001").aware is False + assert run_a_again.get_checkpoint_timestep() == 2 + class TestPersonaGeneration: """Tests for persona generation.""" From d07f6be03af2255060b5854779ef1f536e2297d3 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 00:47:38 -0500 Subject: [PATCH 04/15] feat(cli): make inspect/results/report/export/chat run-aware --- extropy/cli/commands/chat.py | 56 +++++++++++++---- extropy/cli/commands/export.py | 21 ++++++- extropy/cli/commands/inspect.py | 86 ++++++++++++++++++++++--- extropy/cli/commands/report.py | 38 +++++++++-- extropy/cli/commands/results.py | 98 ++++++++++++++++++++--------- tests/test_cli.py | 108 ++++++++++++++++++++++++++++++++ 6 files changed, 351 insertions(+), 56 deletions(-) diff --git a/extropy/cli/commands/chat.py b/extropy/cli/commands/chat.py index 6ea4cdf..1844d34 100644 --- a/extropy/cli/commands/chat.py +++ b/extropy/cli/commands/chat.py @@ -24,10 +24,24 @@ def _load_agent_chat_context( timeline_n: int = 10, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: cur = conn.cursor() + cur.execute( + "SELECT population_id FROM simulation_runs WHERE run_id = ? LIMIT 1", + (run_id,), + ) + run_row = cur.fetchone() + if not run_row: + return {"run_id": run_id, "agent_id": agent_id, "error": "run_id not found"}, [] + population_id = str(run_row["population_id"]) cur.execute( - "SELECT attrs_json FROM agents WHERE agent_id = ? ORDER BY rowid DESC LIMIT 1", - (agent_id,), + """ + SELECT attrs_json + FROM agents + WHERE population_id = ? AND agent_id = ? + ORDER BY rowid DESC + LIMIT 1 + """, + (population_id, agent_id), ) attrs_row = cur.fetchone() attrs = {} @@ -37,7 +51,10 @@ def _load_agent_chat_context( except json.JSONDecodeError: attrs = {} - cur.execute("SELECT * FROM agent_states WHERE agent_id = ? LIMIT 1", (agent_id,)) + cur.execute( + "SELECT * FROM agent_states WHERE run_id = ? AND agent_id = ? LIMIT 1", + (run_id, agent_id), + ) state_row = cur.fetchone() state = dict(state_row) if state_row else {} @@ -45,16 +62,17 @@ def _load_agent_chat_context( """ SELECT timestep, event_type, details_json FROM timeline - WHERE agent_id = ? + WHERE run_id = ? AND agent_id = ? ORDER BY id DESC LIMIT ? """, - (agent_id, timeline_n), + (run_id, agent_id, timeline_n), ) timeline_rows = [dict(r) for r in cur.fetchall()] context = { "run_id": run_id, + "population_id": population_id, "agent_id": agent_id, "attributes": attrs, "state": state, @@ -62,9 +80,9 @@ def _load_agent_chat_context( } citations = [ - {"table": "agents", "agent_id": agent_id}, - {"table": "agent_states", "agent_id": agent_id}, - {"table": "timeline", "agent_id": agent_id, "limit": timeline_n}, + {"table": "agents", "population_id": population_id, "agent_id": agent_id}, + {"table": "agent_states", "run_id": run_id, "agent_id": agent_id}, + {"table": "timeline", "run_id": run_id, "agent_id": agent_id, "limit": timeline_n}, ] return context, citations @@ -147,6 +165,15 @@ def chat_interactive( console.print(f"[red]✗[/red] Study DB not found: {study_db}") raise typer.Exit(1) + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + cur.execute("SELECT 1 FROM simulation_runs WHERE run_id = ? LIMIT 1", (run_id,)) + if not cur.fetchone(): + conn.close() + console.print(f"[red]✗[/red] run_id not found: {run_id}") + raise typer.Exit(1) + with open_study_db(study_db) as db: sid = session_id or db.create_chat_session( run_id=run_id, @@ -155,9 +182,6 @@ def chat_interactive( meta={"entrypoint": "repl"}, ) - conn = sqlite3.connect(str(study_db)) - conn.row_factory = sqlite3.Row - console.print(f"[bold]Chat session[/bold] {sid}") _print_repl_help() @@ -235,6 +259,16 @@ def chat_ask( raise typer.Exit(1) started = time.time() + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute("SELECT 1 FROM simulation_runs WHERE run_id = ? LIMIT 1", (run_id,)) + if not cur.fetchone(): + console.print(f"[red]✗[/red] run_id not found: {run_id}") + raise typer.Exit(1) + finally: + conn.close() with open_study_db(study_db) as db: sid = session_id or db.create_chat_session( diff --git a/extropy/cli/commands/export.py b/extropy/cli/commands/export.py index 337ddfd..dbef866 100644 --- a/extropy/cli/commands/export.py +++ b/extropy/cli/commands/export.py @@ -78,13 +78,32 @@ def export_edges( @export_app.command("states") def export_states( study_db: Path = typer.Option(..., "--study-db"), + run_id: str | None = typer.Option(None, "--run-id"), output: Path = typer.Option(..., "--to"), ): conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row try: cur = conn.cursor() - cur.execute("SELECT * FROM agent_states ORDER BY agent_id") + if run_id: + cur.execute( + "SELECT run_id FROM simulation_runs WHERE run_id = ?", + (run_id,), + ) + else: + cur.execute( + "SELECT run_id FROM simulation_runs ORDER BY started_at DESC LIMIT 1" + ) + run_row = cur.fetchone() + if not run_row: + console.print("[yellow]No simulation runs found.[/yellow]") + raise typer.Exit(1) + resolved_run_id = str(run_row["run_id"]) + + cur.execute( + "SELECT * FROM agent_states WHERE run_id = ? ORDER BY agent_id", + (resolved_run_id,), + ) rows = [dict(row) for row in cur.fetchall()] finally: conn.close() diff --git a/extropy/cli/commands/inspect.py b/extropy/cli/commands/inspect.py index 3f83101..dcb954a 100644 --- a/extropy/cli/commands/inspect.py +++ b/extropy/cli/commands/inspect.py @@ -15,11 +15,35 @@ app.add_typer(inspect_app, name="inspect") +def _resolve_run(conn: sqlite3.Connection, run_id: str | None) -> sqlite3.Row | None: + cur = conn.cursor() + if run_id: + cur.execute( + """ + SELECT run_id, population_id, network_id, status, started_at, completed_at + FROM simulation_runs + WHERE run_id = ? + """, + (run_id,), + ) + else: + cur.execute( + """ + SELECT run_id, population_id, network_id, status, started_at, completed_at + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ + ) + return cur.fetchone() + + @inspect_app.command("summary") def inspect_summary( study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), population_id: str = typer.Option("default", "--population-id"), network_id: str = typer.Option("default", "--network-id"), + run_id: str | None = typer.Option(None, "--run-id"), ): with open_study_db(study_db) as db: agent_count = db.get_agent_count(population_id) @@ -28,13 +52,33 @@ def inspect_summary( conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row try: + run_row = _resolve_run(conn, run_id) + resolved_run_id = str(run_row["run_id"]) if run_row else None + if run_row: + population_id = str(run_row["population_id"]) + network_id = str(run_row["network_id"]) + cur = conn.cursor() - cur.execute("SELECT COUNT(*) AS cnt FROM agent_states") - sim_agents = int(cur.fetchone()["cnt"]) - cur.execute("SELECT COUNT(*) AS cnt FROM timestep_summaries") - timesteps = int(cur.fetchone()["cnt"]) - cur.execute("SELECT COUNT(*) AS cnt FROM timeline") - events = int(cur.fetchone()["cnt"]) + if resolved_run_id: + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ?", + (resolved_run_id,), + ) + sim_agents = int(cur.fetchone()["cnt"]) + cur.execute( + "SELECT COUNT(*) AS cnt FROM timestep_summaries WHERE run_id = ?", + (resolved_run_id,), + ) + timesteps = int(cur.fetchone()["cnt"]) + cur.execute( + "SELECT COUNT(*) AS cnt FROM timeline WHERE run_id = ?", + (resolved_run_id,), + ) + events = int(cur.fetchone()["cnt"]) + else: + sim_agents = 0 + timesteps = 0 + events = 0 finally: conn.close() @@ -42,6 +86,8 @@ def inspect_summary( console.print(f"study_db: {study_db}") console.print(f"population_id={population_id} agents={agent_count}") console.print(f"network_id={network_id} edges={edge_count}") + if resolved_run_id: + console.print(f"run_id={resolved_run_id}") console.print(f"simulation.agent_states={sim_agents}") console.print(f"simulation.timesteps={timesteps}") console.print(f"simulation.events={events}") @@ -51,21 +97,41 @@ def inspect_summary( def inspect_agent( study_db: Path = typer.Option(..., "--study-db"), agent_id: str = typer.Option(..., "--agent-id"), + run_id: str | None = typer.Option(None, "--run-id"), ): conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row try: + run_row = _resolve_run(conn, run_id) + if not run_row: + console.print("[yellow]No simulation runs found.[/yellow]") + return + resolved_run_id = str(run_row["run_id"]) + population_id = str(run_row["population_id"]) + cur = conn.cursor() - cur.execute("SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,)) + cur.execute( + "SELECT attrs_json FROM agents WHERE population_id = ? AND agent_id = ? LIMIT 1", + (population_id, agent_id), + ) attrs_row = cur.fetchone() attrs = json.loads(attrs_row["attrs_json"]) if attrs_row else {} - cur.execute("SELECT * FROM agent_states WHERE agent_id = ? LIMIT 1", (agent_id,)) + cur.execute( + "SELECT * FROM agent_states WHERE run_id = ? AND agent_id = ? LIMIT 1", + (resolved_run_id, agent_id), + ) state = cur.fetchone() cur.execute( - "SELECT timestep, event_type, details_json FROM timeline WHERE agent_id = ? ORDER BY id DESC LIMIT 10", - (agent_id,), + """ + SELECT timestep, event_type, details_json + FROM timeline + WHERE run_id = ? AND agent_id = ? + ORDER BY id DESC + LIMIT 10 + """, + (resolved_run_id, agent_id), ) events = cur.fetchall() finally: diff --git a/extropy/cli/commands/report.py b/extropy/cli/commands/report.py index 6b2657f..91221d8 100644 --- a/extropy/cli/commands/report.py +++ b/extropy/cli/commands/report.py @@ -17,29 +17,59 @@ @report_app.command("run") def report_run( study_db: Path = typer.Option(..., "--study-db"), + run_id: str | None = typer.Option(None, "--run-id"), output: Path = typer.Option(..., "--output", "-o"), ): conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row try: cur = conn.cursor() - cur.execute("SELECT COUNT(*) AS cnt FROM agent_states") + if run_id: + cur.execute( + "SELECT run_id, population_id FROM simulation_runs WHERE run_id = ?", + (run_id,), + ) + else: + cur.execute( + """ + SELECT run_id, population_id + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ + ) + run_row = cur.fetchone() + if not run_row: + console.print("[yellow]No simulation runs found.[/yellow]") + raise typer.Exit(1) + resolved_run_id = str(run_row["run_id"]) + + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ?", + (resolved_run_id,), + ) total = int(cur.fetchone()["cnt"]) - cur.execute("SELECT COUNT(*) AS cnt FROM agent_states WHERE aware = 1") + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ? AND aware = 1", + (resolved_run_id,), + ) aware = int(cur.fetchone()["cnt"]) cur.execute( """ SELECT COALESCE(private_position, position) AS position, COUNT(*) AS cnt FROM agent_states - WHERE COALESCE(private_position, position) IS NOT NULL + WHERE run_id = ? + AND COALESCE(private_position, position) IS NOT NULL GROUP BY COALESCE(private_position, position) - """ + """, + (resolved_run_id,), ) positions = {row["position"]: int(row["cnt"]) for row in cur.fetchall()} finally: conn.close() payload = { + "run_id": resolved_run_id, "agent_count": total, "aware_count": aware, "aware_rate": (aware / total) if total else 0.0, diff --git a/extropy/cli/commands/results.py b/extropy/cli/commands/results.py index b6ece60..372767d 100644 --- a/extropy/cli/commands/results.py +++ b/extropy/cli/commands/results.py @@ -14,9 +14,7 @@ @app.command("results") def results_command( study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), - run_id: str | None = typer.Option( - None, "--run-id", help="Run ID (reserved for multi-run support)" - ), + run_id: str | None = typer.Option(None, "--run-id", help="Simulation run id"), segment: str | None = typer.Option( None, "--segment", "-s", help="Attribute to segment by" ), @@ -34,53 +32,73 @@ def results_command( conn.row_factory = sqlite3.Row try: + cur = conn.cursor() if run_id: - cur = conn.cursor() cur.execute( - "SELECT status, started_at, completed_at, stopped_reason FROM simulation_runs WHERE run_id = ?", + """ + SELECT run_id, status, started_at, completed_at, stopped_reason, population_id + FROM simulation_runs + WHERE run_id = ? + """, (run_id,), ) - run_row = cur.fetchone() - if not run_row: - console.print(f"[red]✗[/red] run_id not found: {run_id}") - raise typer.Exit(1) - console.print( - f"[dim]run_id={run_id} status={run_row['status']} " - f"started_at={run_row['started_at']} completed_at={run_row['completed_at'] or '-'}[/dim]" + else: + cur.execute( + """ + SELECT run_id, status, started_at, completed_at, stopped_reason, population_id + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ ) + run_row = cur.fetchone() + if not run_row: + console.print("[yellow]No simulation runs found in study DB.[/yellow]") + raise typer.Exit(0) + resolved_run_id = str(run_row["run_id"]) + population_id = str(run_row["population_id"]) + console.print( + f"[dim]run_id={resolved_run_id} status={run_row['status']} " + f"started_at={run_row['started_at']} completed_at={run_row['completed_at'] or '-'}[/dim]" + ) if agent: - _display_agent(conn, agent) + _display_agent(conn, resolved_run_id, population_id, agent) return if segment: - _display_segment(conn, segment) + _display_segment(conn, resolved_run_id, population_id, segment) return if timeline: - _display_timeline(conn) + _display_timeline(conn, resolved_run_id) return - _display_summary(conn) + _display_summary(conn, resolved_run_id) finally: conn.close() -def _display_summary(conn: sqlite3.Connection) -> None: +def _display_summary(conn: sqlite3.Connection, run_id: str) -> None: cur = conn.cursor() - cur.execute("SELECT COUNT(*) AS cnt FROM agent_states") + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ?", (run_id,)) total = int(cur.fetchone()["cnt"]) if total == 0: console.print("[yellow]No simulation state found in study DB.[/yellow]") return - cur.execute("SELECT COUNT(*) AS cnt FROM agent_states WHERE aware = 1") + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ? AND aware = 1", + (run_id,), + ) aware = int(cur.fetchone()["cnt"]) cur.execute( """ SELECT COALESCE(private_position, position) AS position, COUNT(*) AS cnt FROM agent_states - WHERE COALESCE(private_position, position) IS NOT NULL + WHERE run_id = ? + AND COALESCE(private_position, position) IS NOT NULL GROUP BY COALESCE(private_position, position) ORDER BY cnt DESC - """ + """, + (run_id,), ) rows = cur.fetchall() @@ -94,14 +112,16 @@ def _display_summary(conn: sqlite3.Connection) -> None: console.print(f" - {row['position']}: {row['cnt']} ({pct:.1%})") -def _display_timeline(conn: sqlite3.Connection) -> None: +def _display_timeline(conn: sqlite3.Connection, run_id: str) -> None: cur = conn.cursor() cur.execute( """ SELECT timestep, new_exposures, agents_reasoned, shares_occurred, exposure_rate FROM timestep_summaries + WHERE run_id = ? ORDER BY timestep - """ + """, + (run_id,), ) rows = cur.fetchall() if not rows: @@ -118,9 +138,17 @@ def _display_timeline(conn: sqlite3.Connection) -> None: ) -def _display_segment(conn: sqlite3.Connection, attribute: str) -> None: +def _display_segment( + conn: sqlite3.Connection, + run_id: str, + population_id: str, + attribute: str, +) -> None: cur = conn.cursor() - cur.execute("SELECT agent_id, attrs_json FROM agents") + cur.execute( + "SELECT agent_id, attrs_json FROM agents WHERE population_id = ?", + (population_id,), + ) attr_by_agent: dict[str, str] = {} for row in cur.fetchall(): try: @@ -137,7 +165,9 @@ def _display_segment(conn: sqlite3.Connection, attribute: str) -> None: """ SELECT agent_id, aware, COALESCE(private_position, position) AS position FROM agent_states - """ + WHERE run_id = ? + """, + (run_id,), ) groups: dict[str, dict[str, int]] = {} for row in cur.fetchall(): @@ -158,22 +188,30 @@ def _display_segment(conn: sqlite3.Connection, attribute: str) -> None: console.print(f" - {key}: {total} agents, aware={aware} ({pct:.1%})") -def _display_agent(conn: sqlite3.Connection, agent_id: str) -> None: +def _display_agent( + conn: sqlite3.Connection, + run_id: str, + population_id: str, + agent_id: str, +) -> None: cur = conn.cursor() cur.execute( """ SELECT * FROM agent_states - WHERE agent_id = ? + WHERE run_id = ? AND agent_id = ? """, - (agent_id,), + (run_id, agent_id), ) row = cur.fetchone() if not row: console.print(f"[yellow]Agent not found in simulation state: {agent_id}[/yellow]") return - cur.execute("SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,)) + cur.execute( + "SELECT attrs_json FROM agents WHERE population_id = ? AND agent_id = ? LIMIT 1", + (population_id, agent_id), + ) attrs_row = cur.fetchone() attrs = {} if attrs_row: diff --git a/tests/test_cli.py b/tests/test_cli.py index 3c624af..87cc13d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,6 @@ """CLI smoke tests using typer's CliRunner.""" +import json +import sqlite3 from pathlib import Path from typer.testing import CliRunner @@ -128,3 +130,109 @@ def test_network_resume_requires_checkpoint(self): ) assert result.exit_code == 1 assert "Study DB not found" in result.output + + +def _seed_run_scoped_state(study_db: Path) -> None: + agents = [ + {"_id": "a0", "team": "alpha"}, + {"_id": "a1", "team": "beta"}, + ] + with open_study_db(study_db) as db: + db.save_sample_result(population_id="default", agents=agents, meta={"source": "test"}) + db.create_simulation_run( + run_id="run_old", + scenario_name="s", + population_id="default", + network_id="default", + config={}, + seed=1, + status="completed", + ) + db.create_simulation_run( + run_id="run_new", + scenario_name="s", + population_id="default", + network_id="default", + config={}, + seed=2, + status="running", + ) + + conn = sqlite3.connect(str(study_db)) + cur = conn.cursor() + cur.execute( + """ + INSERT INTO agent_states (run_id, agent_id, aware, position, private_position, updated_at) + VALUES ('run_old', 'a0', 1, 'old_pos', 'old_pos', 0) + """ + ) + cur.execute( + """ + INSERT INTO agent_states (run_id, agent_id, aware, position, private_position, updated_at) + VALUES ('run_new', 'a0', 1, 'new_pos', 'new_pos', 0) + """ + ) + cur.execute( + """ + INSERT INTO timestep_summaries ( + run_id, timestep, new_exposures, agents_reasoned, shares_occurred, + state_changes, exposure_rate, position_distribution_json + ) + VALUES ('run_new', 0, 1, 1, 0, 1, 0.5, '{}') + """ + ) + conn.commit() + conn.close() + + +class TestRunScopedCliReads: + def test_results_defaults_to_latest_run(self, tmp_path): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + + result = runner.invoke(app, ["results", "--study-db", str(study_db)]) + assert result.exit_code == 0 + assert "run_id=run_new" in result.output + assert "new_pos" in result.output + assert "old_pos" not in result.output + + def test_export_states_defaults_to_latest_run(self, tmp_path): + study_db = tmp_path / "study.db" + out = tmp_path / "states.jsonl" + _seed_run_scoped_state(study_db) + + result = runner.invoke( + app, + ["export", "states", "--study-db", str(study_db), "--to", str(out)], + ) + assert result.exit_code == 0 + rows = [json.loads(line) for line in out.read_text(encoding="utf-8").splitlines()] + assert len(rows) == 1 + assert rows[0]["run_id"] == "run_new" + assert rows[0]["private_position"] == "new_pos" + + def test_chat_ask_reads_state_for_requested_run(self, tmp_path): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + + result = runner.invoke( + app, + [ + "chat", + "ask", + "--study-db", + str(study_db), + "--run-id", + "run_old", + "--agent-id", + "a0", + "--prompt", + "what is my stance", + "--json", + ], + ) + assert result.exit_code == 0 + payload = json.loads(result.stdout.strip()) + assert payload["session_id"] + assert "old_pos" in payload["assistant_text"] + assert "new_pos" not in payload["assistant_text"] From db3d36400190c29a7354f880a974657998236da9 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:00:34 -0500 Subject: [PATCH 05/15] feat(config): redesign config to 2-tier fast/strong with provider/model strings Replace 8 model/provider fields with 2 tiers (fast/strong) using "provider/model" format strings. Add parse_model_string(), v1 config auto-migration, legacy env var backward compat, and CustomProviderConfig for third-party endpoints. Remove unused DefaultsConfig (db_path, population_size). Co-Authored-By: Claude Opus 4.6 --- extropy/config.py | 524 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 415 insertions(+), 109 deletions(-) diff --git a/extropy/config.py b/extropy/config.py index 55b7cb0..056ff4b 100644 --- a/extropy/config.py +++ b/extropy/config.py @@ -1,12 +1,14 @@ """Configuration management for Extropy. -Two-zone config system: -- pipeline: provider + models for phases 1-2 (spec, extend, sample, network, persona, scenario) -- simulation: provider + model for phase 3 (agent reasoning) +Two-tier config system: +- models: fast/strong model strings for pipeline phases 1-2 +- simulation: fast/strong model strings for phase 3 (agent reasoning) + +Model strings use "provider/model" format (e.g., "openai/gpt-5-mini"). Config resolution order (highest priority first): 1. Programmatic (ExtropyConfig constructed in code) -2. Environment variables (PIPELINE_PROVIDER, SIMULATION_MODEL, etc.) +2. Environment variables (MODELS_FAST, MODELS_STRONG, etc.) 3. Config file (~/.config/extropy/config.json, managed by `extropy config`) 4. Hardcoded defaults @@ -16,6 +18,7 @@ import json import logging import os +import warnings from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Any @@ -33,37 +36,113 @@ # ============================================================================= -# Two-zone config dataclasses +# Model string parsing +# ============================================================================= + + +def parse_model_string(model_string: str) -> tuple[str, str]: + """Parse a "provider/model" string into (provider, model) tuple. + + Examples: + "openai/gpt-5-mini" → ("openai", "gpt-5-mini") + "anthropic/claude-sonnet-4.5" → ("anthropic", "claude-sonnet-4.5") + "openrouter/anthropic/claude-sonnet-4.5" → ("openrouter", "anthropic/claude-sonnet-4.5") + + Raises: + ValueError: If the string doesn't contain a '/' separator. + """ + if "/" not in model_string: + raise ValueError( + f"Invalid model string: {model_string!r}. " + f"Expected format: 'provider/model' (e.g., 'openai/gpt-5-mini')" + ) + provider, _, model = model_string.partition("/") + if not provider or not model: + raise ValueError( + f"Invalid model string: {model_string!r}. " + f"Both provider and model must be non-empty." + ) + return provider, model + + +# ============================================================================= +# New two-tier config dataclasses +# ============================================================================= + + +@dataclass +class ModelsConfig: + """Pipeline model configuration (phases 1-2). + + Uses "provider/model" format strings. + - fast: used for simple_call (cheap, fast tasks) + - strong: used for reasoning_call, agentic_research (complex tasks) + """ + + fast: str = "openai/gpt-5-mini" + strong: str = "openai/gpt-5" + + +@dataclass +class SimulationConfig: + """Simulation model + tuning configuration (phase 3). + + Uses "provider/model" format strings. + - fast: used for Pass 2 (classification/routine) + - strong: used for Pass 1 (pivotal/role-play reasoning) + """ + + fast: str = "" # empty = same as models.fast + strong: str = "" # empty = same as models.strong + max_concurrent: int = 50 + rate_tier: int | None = None + rpm_override: int | None = None + tpm_override: int | None = None + + +@dataclass +class CustomProviderConfig: + """Configuration for a custom OpenAI-compatible provider endpoint.""" + + base_url: str = "" + api_key_env: str = "" + + + + +# ============================================================================= +# Legacy config dataclasses (kept for migration) # ============================================================================= @dataclass class PipelineConfig: - """Config for phases 1-2: spec, extend, sample, network, persona, scenario.""" + """DEPRECATED: Config for phases 1-2. Use ModelsConfig instead.""" provider: str = "openai" - model_simple: str = "" # empty = provider default - model_reasoning: str = "" # empty = provider default - model_research: str = "" # empty = provider default + model_simple: str = "" + model_reasoning: str = "" + model_research: str = "" @dataclass class SimZoneConfig: - """Config for phase 3: agent reasoning during simulation.""" + """DEPRECATED: Config for phase 3. Use SimulationConfig instead.""" provider: str = "openai" - model: str = "" # empty = provider default - pivotal_model: str = "" # model for pivotal reasoning (default: same as model) - routine_model: str = ( - "" # cheap model for classification (default: provider cheap tier) - ) + model: str = "" + pivotal_model: str = "" + routine_model: str = "" max_concurrent: int = 50 - rate_tier: int | None = None # rate limit tier (1-4, None = Tier 1) - rpm_override: int | None = None # override RPM limit - tpm_override: int | None = None # override TPM limit - api_format: str = ( - "" # empty = auto (responses for openai, chat_completions for azure) - ) + rate_tier: int | None = None + rpm_override: int | None = None + tpm_override: int | None = None + api_format: str = "" + + +# ============================================================================= +# Main config class +# ============================================================================= @dataclass @@ -75,8 +154,7 @@ class ExtropyConfig: Examples: # Package use — no files needed config = ExtropyConfig( - pipeline=PipelineConfig(provider="claude"), - simulation=SimZoneConfig(provider="openai", model="gpt-5-mini"), + models=ModelsConfig(fast="openai/gpt-5-mini", strong="anthropic/claude-sonnet-4.5"), ) # CLI use — loads from ~/.config/extropy/config.json @@ -84,21 +162,19 @@ class ExtropyConfig: # Override just simulation config = ExtropyConfig.load() - config.simulation.model = "gpt-5-nano" + config.simulation.strong = "openrouter/anthropic/claude-sonnet-4.5" """ - pipeline: PipelineConfig = field(default_factory=PipelineConfig) - simulation: SimZoneConfig = field(default_factory=SimZoneConfig) - - # Non-zone settings - db_path: str = "./storage/extropy.db" - default_population_size: int = 1000 + models: ModelsConfig = field(default_factory=ModelsConfig) + simulation: SimulationConfig = field(default_factory=SimulationConfig) + providers: dict[str, CustomProviderConfig] = field(default_factory=dict) @classmethod def load(cls) -> "ExtropyConfig": """Load config from file + env vars. Priority: env var values > config.json values > defaults. + Auto-migrates v1 config format if detected. """ config = cls() @@ -107,31 +183,35 @@ def load(cls) -> "ExtropyConfig": try: with open(CONFIG_FILE) as f: data = json.load(f) + + # Auto-migrate v1 config + if _is_v1_config(data): + warnings.warn( + "Detected legacy config format. Migrating to v2. " + "Run `extropy config show` to verify, then `extropy config set` to update.", + DeprecationWarning, + stacklevel=2, + ) + data = _migrate_v1_to_v2(data) + _apply_dict(config, data) except (json.JSONDecodeError, OSError) as exc: logger.warning("Failed to load config from %s: %s", CONFIG_FILE, exc) - # Layer 2: Env var overrides - if provider := os.environ.get("LLM_PROVIDER"): - # Legacy: single provider applied to both zones - config.pipeline.provider = provider - config.simulation.provider = provider - if val := os.environ.get("PIPELINE_PROVIDER"): - config.pipeline.provider = val - if val := os.environ.get("SIMULATION_PROVIDER"): - config.simulation.provider = val - if val := os.environ.get("MODEL_SIMPLE"): - config.pipeline.model_simple = val - if val := os.environ.get("MODEL_REASONING"): - config.pipeline.model_reasoning = val - if val := os.environ.get("MODEL_RESEARCH"): - config.pipeline.model_research = val - if val := os.environ.get("SIMULATION_MODEL"): - config.simulation.model = val - if val := os.environ.get("SIMULATION_PIVOTAL_MODEL"): - config.simulation.pivotal_model = val - if val := os.environ.get("SIMULATION_ROUTINE_MODEL"): - config.simulation.routine_model = val + # Layer 2: Env var overrides (new format) + if val := os.environ.get("MODELS_FAST"): + config.models.fast = val + if val := os.environ.get("MODELS_STRONG"): + config.models.strong = val + if val := os.environ.get("SIMULATION_FAST"): + config.simulation.fast = val + if val := os.environ.get("SIMULATION_STRONG"): + config.simulation.strong = val + if val := os.environ.get("SIMULATION_MAX_CONCURRENT"): + try: + config.simulation.max_concurrent = int(val) + except ValueError: + logger.warning("Invalid SIMULATION_MAX_CONCURRENT=%r, ignoring", val) if val := os.environ.get("SIMULATION_RATE_TIER"): try: config.simulation.rate_tier = int(val) @@ -147,38 +227,56 @@ def load(cls) -> "ExtropyConfig": config.simulation.tpm_override = int(val) except ValueError: logger.warning("Invalid SIMULATION_TPM_OVERRIDE=%r, ignoring", val) - if val := os.environ.get("SIMULATION_API_FORMAT"): - config.simulation.api_format = val - if val := os.environ.get("DB_PATH"): - config.db_path = val - if val := os.environ.get("DEFAULT_POPULATION_SIZE"): - try: - config.default_population_size = int(val) - except ValueError: - logger.warning("Invalid DEFAULT_POPULATION_SIZE=%r, ignoring", val) + # Layer 3: Legacy env var overrides (emit deprecation warnings) + _apply_legacy_env_vars(config) return config def save(self) -> None: """Save config to ~/.config/extropy/config.json.""" CONFIG_DIR.mkdir(parents=True, exist_ok=True) - data = asdict(self) - # Don't persist non-zone settings that are better as env vars - data.pop("db_path", None) - data.pop("default_population_size", None) + data: dict[str, Any] = { + "models": asdict(self.models), + "simulation": asdict(self.simulation), + } + if self.providers: + data["providers"] = { + name: asdict(cfg) for name, cfg in self.providers.items() + } with open(CONFIG_FILE, "w") as f: json.dump(data, f, indent=2) def to_dict(self) -> dict[str, Any]: """Convert to dict for display.""" - return asdict(self) + result = { + "models": asdict(self.models), + "simulation": asdict(self.simulation), + } + if self.providers: + result["providers"] = { + name: asdict(cfg) for name, cfg in self.providers.items() + } + return result - @property - def db_path_resolved(self) -> Path: - """Resolve database path.""" - path = Path(self.db_path) - path.parent.mkdir(parents=True, exist_ok=True) - return path + # ── Convenience resolution methods ── + + def resolve_pipeline_fast(self) -> str: + """Resolve the fast model string for pipeline use.""" + return self.models.fast + + def resolve_pipeline_strong(self) -> str: + """Resolve the strong model string for pipeline use.""" + return self.models.strong + + def resolve_sim_strong(self) -> str: + """Resolve the strong model string for simulation.""" + return self.simulation.strong or self.models.strong + + def resolve_sim_fast(self) -> str: + """Resolve the fast model string for simulation.""" + return self.simulation.fast or self.models.fast + + # ── Backward compat properties ── @property def cache_dir(self) -> Path: @@ -188,24 +286,210 @@ def cache_dir(self) -> Path: return path +# ============================================================================= +# Config dict application +# ============================================================================= + + def _apply_dict(config: ExtropyConfig, data: dict) -> None: - """Apply a dict of values onto an ExtropyConfig.""" - if "pipeline" in data and isinstance(data["pipeline"], dict): - for k, v in data["pipeline"].items(): - if hasattr(config.pipeline, k): - setattr(config.pipeline, k, v) + """Apply a dict of values onto an ExtropyConfig (v2 format).""" + if "models" in data and isinstance(data["models"], dict): + for k, v in data["models"].items(): + if hasattr(config.models, k): + setattr(config.models, k, v) if "simulation" in data and isinstance(data["simulation"], dict): for k, v in data["simulation"].items(): if hasattr(config.simulation, k): setattr(config.simulation, k, v) - if "db_path" in data: - config.db_path = data["db_path"] - if "default_population_size" in data: - config.default_population_size = int(data["default_population_size"]) + if "providers" in data and isinstance(data["providers"], dict): + for name, provider_data in data["providers"].items(): + if isinstance(provider_data, dict): + config.providers[name] = CustomProviderConfig( + base_url=provider_data.get("base_url", ""), + api_key_env=provider_data.get("api_key_env", ""), + ) + + +# ============================================================================= +# V1 → V2 migration +# ============================================================================= + +# Provider name mapping for migration +_PROVIDER_CANONICAL = { + "openai": "openai", + "claude": "anthropic", + "anthropic": "anthropic", + "azure_openai": "azure", +} + +# Default model names per old provider +_V1_PROVIDER_DEFAULTS = { + "openai": {"fast": "gpt-5-mini", "strong": "gpt-5"}, + "claude": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure_openai": {"fast": "gpt-5-mini", "strong": "gpt-5"}, +} + + +def _is_v1_config(data: dict) -> bool: + """Detect if config data is in v1 format (has 'pipeline' key).""" + return "pipeline" in data and "models" not in data + + +def _migrate_v1_to_v2(data: dict) -> dict: + """Convert v1 config format to v2. + + v1 format: + {"pipeline": {"provider": "openai", "model_simple": "...", ...}, + "simulation": {"provider": "openai", "model": "...", ...}} + + v2 format: + {"models": {"fast": "openai/gpt-5-mini", "strong": "openai/gpt-5"}, + "simulation": {"fast": "...", "strong": "...", ...}} + """ + result: dict[str, Any] = {} + + # Migrate pipeline → models + pipeline = data.get("pipeline", {}) + old_provider = pipeline.get("provider", "openai") + canonical = _PROVIDER_CANONICAL.get(old_provider, old_provider) + defaults = _V1_PROVIDER_DEFAULTS.get(old_provider, _V1_PROVIDER_DEFAULTS["openai"]) + + fast_model = pipeline.get("model_simple") or defaults["fast"] + strong_model = pipeline.get("model_reasoning") or defaults["strong"] + + result["models"] = { + "fast": f"{canonical}/{fast_model}", + "strong": f"{canonical}/{strong_model}", + } + + # Migrate simulation + sim = data.get("simulation", {}) + sim_provider = sim.get("provider", "openai") + sim_canonical = _PROVIDER_CANONICAL.get(sim_provider, sim_provider) + sim_defaults = _V1_PROVIDER_DEFAULTS.get( + sim_provider, _V1_PROVIDER_DEFAULTS["openai"] + ) + + sim_result: dict[str, Any] = {} + + # Map model/pivotal_model → strong, routine_model → fast + pivotal = sim.get("pivotal_model") or sim.get("model") or "" + routine = sim.get("routine_model") or "" + + if pivotal: + sim_result["strong"] = f"{sim_canonical}/{pivotal}" + if routine: + sim_result["fast"] = f"{sim_canonical}/{routine}" + + for k in ("max_concurrent", "rate_tier", "rpm_override", "tpm_override"): + if k in sim and sim[k] is not None: + sim_result[k] = sim[k] + + result["simulation"] = sim_result + + return result + + +# ============================================================================= +# Legacy env var handling +# ============================================================================= + +_LEGACY_ENV_WARNED: set[str] = set() + + +def _warn_legacy_env(name: str, replacement: str) -> None: + """Emit a one-time deprecation warning for a legacy env var.""" + if name not in _LEGACY_ENV_WARNED: + _LEGACY_ENV_WARNED.add(name) + warnings.warn( + f"Environment variable {name} is deprecated. Use {replacement} instead.", + DeprecationWarning, + stacklevel=4, + ) + + +def _apply_legacy_env_vars(config: ExtropyConfig) -> None: + """Apply legacy env vars with deprecation warnings.""" + # LLM_PROVIDER → both zones + if val := os.environ.get("LLM_PROVIDER"): + _warn_legacy_env("LLM_PROVIDER", "MODELS_FAST / MODELS_STRONG") + canonical = _PROVIDER_CANONICAL.get(val, val) + defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) + # Only override if no new-format env vars set + if not os.environ.get("MODELS_FAST"): + config.models.fast = f"{canonical}/{defaults['fast']}" + if not os.environ.get("MODELS_STRONG"): + config.models.strong = f"{canonical}/{defaults['strong']}" + + if val := os.environ.get("PIPELINE_PROVIDER"): + _warn_legacy_env("PIPELINE_PROVIDER", "MODELS_FAST / MODELS_STRONG") + canonical = _PROVIDER_CANONICAL.get(val, val) + defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) + if not os.environ.get("MODELS_FAST"): + config.models.fast = f"{canonical}/{defaults['fast']}" + if not os.environ.get("MODELS_STRONG"): + config.models.strong = f"{canonical}/{defaults['strong']}" + + if val := os.environ.get("SIMULATION_PROVIDER"): + _warn_legacy_env("SIMULATION_PROVIDER", "SIMULATION_FAST / SIMULATION_STRONG") + canonical = _PROVIDER_CANONICAL.get(val, val) + defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) + if not os.environ.get("SIMULATION_FAST"): + config.simulation.fast = f"{canonical}/{defaults['fast']}" + if not os.environ.get("SIMULATION_STRONG"): + config.simulation.strong = f"{canonical}/{defaults['strong']}" + + if val := os.environ.get("MODEL_SIMPLE"): + _warn_legacy_env("MODEL_SIMPLE", "MODELS_FAST") + if not os.environ.get("MODELS_FAST"): + provider, _ = parse_model_string(config.models.fast) + config.models.fast = f"{provider}/{val}" + + if val := os.environ.get("MODEL_REASONING"): + _warn_legacy_env("MODEL_REASONING", "MODELS_STRONG") + if not os.environ.get("MODELS_STRONG"): + provider, _ = parse_model_string(config.models.strong) + config.models.strong = f"{provider}/{val}" + + if val := os.environ.get("SIMULATION_MODEL"): + _warn_legacy_env("SIMULATION_MODEL", "SIMULATION_STRONG") + if not os.environ.get("SIMULATION_STRONG"): + # Resolve provider from sim strong or models strong + base = config.simulation.strong or config.models.strong + provider, _ = parse_model_string(base) + config.simulation.strong = f"{provider}/{val}" + + if val := os.environ.get("SIMULATION_PIVOTAL_MODEL"): + _warn_legacy_env("SIMULATION_PIVOTAL_MODEL", "SIMULATION_STRONG") + if not os.environ.get("SIMULATION_STRONG"): + base = config.simulation.strong or config.models.strong + provider, _ = parse_model_string(base) + config.simulation.strong = f"{provider}/{val}" + + if val := os.environ.get("SIMULATION_ROUTINE_MODEL"): + _warn_legacy_env("SIMULATION_ROUTINE_MODEL", "SIMULATION_FAST") + if not os.environ.get("SIMULATION_FAST"): + base = config.simulation.fast or config.models.fast + provider, _ = parse_model_string(base) + config.simulation.fast = f"{provider}/{val}" + + # SIMULATION_API_FORMAT — no direct replacement, just warn + if os.environ.get("SIMULATION_API_FORMAT"): + _warn_legacy_env( + "SIMULATION_API_FORMAT", + "provider-based routing (api_format is now automatic)", + ) # ============================================================================= -# API key resolution (env vars + .env file) +# API key resolution # ============================================================================= _dotenv_loaded = False @@ -219,52 +503,74 @@ def _ensure_dotenv() -> None: try: from dotenv import find_dotenv, load_dotenv - # Resolve from current working directory first so CLI commands run - # from study repos consistently pick up that repo's `.env`. dotenv_path = find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path=dotenv_path, override=False) else: - # Fallback for environments where no discoverable .env exists. load_dotenv(override=False) except ImportError: - pass # python-dotenv not installed, skip + pass except Exception: - # Keep config loading resilient even if dotenv discovery has runtime issues. pass -def get_api_key(provider: str) -> str: - """Get API key for a provider from environment variables or .env file. +def get_api_key_for_provider( + provider_name: str, + custom_providers: dict[str, CustomProviderConfig] | None = None, +) -> str: + """Get API key for a provider. - Supports: - - openai: OPENAI_API_KEY - - claude: ANTHROPIC_API_KEY - - azure_openai: AZURE_OPENAI_API_KEY + Resolution order: + 1. Custom provider api_key_env override + 2. Convention: {PROVIDER_UPPER}_API_KEY - Returns empty string if not found (providers will raise on missing keys). + Special cases: + - "anthropic" → ANTHROPIC_API_KEY + - "azure" → AZURE_OPENAI_API_KEY + + Returns empty string if not found. """ _ensure_dotenv() - if provider == "openai": - return os.environ.get("OPENAI_API_KEY", "") - elif provider == "claude": - return os.environ.get("ANTHROPIC_API_KEY", "") - elif provider == "azure_openai": - return os.environ.get("AZURE_OPENAI_API_KEY", "") - return "" + # Check custom provider override first + if custom_providers and provider_name in custom_providers: + custom = custom_providers[provider_name] + if custom.api_key_env: + return os.environ.get(custom.api_key_env, "") + + # Convention: {PROVIDER}_API_KEY + # Special cases for backward compat + key_map = { + "azure": "AZURE_OPENAI_API_KEY", + "azure_openai": "AZURE_OPENAI_API_KEY", + } + env_var = key_map.get( + provider_name, f"{provider_name.upper()}_API_KEY" + ) + return os.environ.get(env_var, "") -def get_azure_config(provider: str) -> dict[str, str]: - """Get Azure-specific configuration from environment variables. - Args: - provider: 'azure_openai' +def get_api_key(provider: str) -> str: + """DEPRECATED: Get API key for a provider. Use get_api_key_for_provider instead. + + Kept for backward compatibility. + """ + # Map old provider names + mapping = { + "claude": "anthropic", + "azure_openai": "azure", + } + canonical = mapping.get(provider, provider) + return get_api_key_for_provider(canonical) + + +def get_azure_config(provider: str) -> dict[str, str]: + """DEPRECATED: Get Azure-specific configuration. - Returns: - Dict of Azure config values (endpoint, api_version, deployment). + Azure is now handled as an OpenAI-compatible provider. """ _ensure_dotenv() - if provider == "azure_openai": + if provider in ("azure_openai", "azure"): return { "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT", ""), "api_version": os.environ.get( @@ -298,8 +604,8 @@ def configure(config: ExtropyConfig) -> None: """Set the global ExtropyConfig programmatically. Use this when extropy is used as a package: - from extropy.config import configure, ExtropyConfig, PipelineConfig - configure(ExtropyConfig(pipeline=PipelineConfig(provider="claude"))) + from extropy.config import configure, ExtropyConfig, ModelsConfig + configure(ExtropyConfig(models=ModelsConfig(fast="openai/gpt-5-mini"))) """ global _config _config = config From 093bc93c7969f544d3f45876274ec617af4d466a Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:00:41 -0500 Subject: [PATCH 06/15] =?UTF-8?q?feat(providers):=20add=20provider=20regis?= =?UTF-8?q?try,=20OpenAICompatProvider,=20rename=20claude=E2=86=92anthropi?= =?UTF-8?q?c?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add OpenAICompatProvider for third-party endpoints (OpenRouter, DeepSeek, Together, Groq) using Chat Completions API with json_schema response format. Build lazy-import provider registry in __init__.py. Rename ClaudeProvider→AnthropicProvider (keep alias). Update base.py to 2-tier abstract properties (default_fast_model, default_strong_model). Co-Authored-By: Claude Opus 4.6 --- extropy/core/providers/__init__.py | 246 ++++++++++++---- extropy/core/providers/anthropic.py | 369 +++++++++++++++++++++++ extropy/core/providers/base.py | 26 +- extropy/core/providers/claude.py | 370 +----------------------- extropy/core/providers/openai.py | 8 +- extropy/core/providers/openai_compat.py | 346 ++++++++++++++++++++++ 6 files changed, 932 insertions(+), 433 deletions(-) create mode 100644 extropy/core/providers/anthropic.py create mode 100644 extropy/core/providers/openai_compat.py diff --git a/extropy/core/providers/__init__.py b/extropy/core/providers/__init__.py index 195f949..2dfa595 100644 --- a/extropy/core/providers/__init__.py +++ b/extropy/core/providers/__init__.py @@ -1,97 +1,235 @@ -"""LLM Provider factory. +"""LLM Provider registry and factory. -Provides two-zone provider routing: -- Pipeline provider: used for phases 1-2 (spec, extend, persona, scenario) -- Simulation provider: used for phase 3 (agent reasoning) +Provides: +- BUILTIN_PROVIDERS: Registry of known provider names → factory info +- get_provider(): Create a provider instance from a provider name +- get_pipeline_provider() / get_simulation_provider(): Zone-based provider access The simulation provider is cached so its async client can be reused across batch calls and closed cleanly before the event loop shuts down. """ -from .base import LLMProvider -from ...config import get_config, get_api_key, get_azure_config - +import os -# Cached simulation provider — reused across batch calls so the async -# client isn't re-created per request, and can be closed cleanly. -_simulation_provider: LLMProvider | None = None +from .base import LLMProvider +from ...config import ( + get_config, + get_api_key_for_provider, + parse_model_string, + CustomProviderConfig, +) + + +# ============================================================================= +# Provider Registry +# ============================================================================= + +# Each entry: (module, class_name, default_kwargs) +# Lazy-imported to avoid loading all SDKs at startup. +_BUILTIN_REGISTRY: dict[str, dict] = { + "openai": { + "module": ".openai", + "class": "OpenAIProvider", + }, + "anthropic": { + "module": ".anthropic", + "class": "AnthropicProvider", + }, + "openrouter": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://openrouter.ai/api/v1", + "supports_search": True, + "provider_label": "openrouter", + "default_fast": "openai/gpt-5-mini", + "default_strong": "openai/gpt-5", + }, + }, + "azure": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "", # resolved from env + "supports_search": False, + "provider_label": "azure", + "default_fast": "gpt-5-mini", + "default_strong": "gpt-5", + }, + }, + "deepseek": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.deepseek.com/v1", + "supports_search": False, + "provider_label": "deepseek", + "default_fast": "deepseek-chat", + "default_strong": "deepseek-reasoner", + }, + }, + "together": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.together.xyz/v1", + "supports_search": False, + "provider_label": "together", + "default_fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "default_strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + }, + "groq": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.groq.com/openai/v1", + "supports_search": False, + "provider_label": "groq", + "default_fast": "llama-3.3-70b-versatile", + "default_strong": "llama-3.3-70b-versatile", + }, + }, +} + + +def get_provider( + provider_name: str, + custom_providers: dict[str, CustomProviderConfig] | None = None, +) -> LLMProvider: + """Create a provider instance by name. + + Checks custom providers first, then built-in registry. + + Args: + provider_name: Provider name (e.g., "openai", "anthropic", "openrouter") + custom_providers: Optional custom provider configs from ExtropyConfig + + Returns: + LLMProvider instance + + Raises: + ValueError: If provider is unknown + """ + api_key = get_api_key_for_provider(provider_name, custom_providers) + # Check custom providers first + if custom_providers and provider_name in custom_providers: + from .openai_compat import OpenAICompatProvider -def _create_provider(provider_name: str) -> LLMProvider: - """Create a provider instance by name.""" - api_key = get_api_key(provider_name) - - if provider_name == "openai": - from .openai import OpenAIProvider + custom = custom_providers[provider_name] + return OpenAICompatProvider( + api_key=api_key, + base_url=custom.base_url, + supports_search=False, + provider_label=provider_name, + ) - return OpenAIProvider(api_key=api_key) - elif provider_name == "claude": - from .claude import ClaudeProvider + # Check built-in registry + if provider_name not in _BUILTIN_REGISTRY: + available = sorted(set(list(_BUILTIN_REGISTRY.keys()) + list((custom_providers or {}).keys()))) + raise ValueError( + f"Unknown LLM provider: {provider_name!r}. " + f"Available: {', '.join(available)}" + ) - return ClaudeProvider(api_key=api_key) - elif provider_name == "azure_openai": - from .openai import OpenAIProvider + entry = _BUILTIN_REGISTRY[provider_name] - azure_cfg = get_azure_config(provider_name) - if not azure_cfg.get("azure_endpoint"): + # Special case: Azure needs endpoint from env + if provider_name == "azure": + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") + if not endpoint: raise ValueError( "AZURE_OPENAI_ENDPOINT not found. Set it as an environment variable.\n" " export AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/" ) - # Resolve api_format: config value > auto-default (chat_completions for Azure) + entry = dict(entry) + entry["kwargs"] = dict(entry.get("kwargs", {})) + entry["kwargs"]["base_url"] = endpoint + + # Lazy import + import importlib + + module = importlib.import_module(entry["module"], package=__package__) + cls = getattr(module, entry["class"]) + + kwargs = dict(entry.get("kwargs", {})) + kwargs["api_key"] = api_key + + return cls(**kwargs) + + +# ============================================================================= +# Zone-based provider access (backward compat) +# ============================================================================= + +# Cached providers — reused across calls for connection reuse +_cached_providers: dict[str, LLMProvider] = {} + + +def _get_or_create_provider(provider_name: str, cache_key: str = "") -> LLMProvider: + """Get or create a cached provider instance.""" + key = cache_key or provider_name + if key not in _cached_providers: config = get_config() - api_format = config.simulation.api_format or "chat_completions" - return OpenAIProvider( - api_key=api_key, - azure_endpoint=azure_cfg["azure_endpoint"], - api_version=azure_cfg.get("api_version", "2025-03-01-preview"), - azure_deployment=azure_cfg.get("azure_deployment", ""), - api_format=api_format, - ) - else: - raise ValueError( - f"Unknown LLM provider: {provider_name}. " - f"Valid options: 'openai', 'claude', 'azure_openai'" - ) + _cached_providers[key] = get_provider(provider_name, config.providers) + return _cached_providers[key] def get_pipeline_provider() -> LLMProvider: - """Get the provider for pipeline phases (spec, extend, persona, scenario).""" + """Get the provider for pipeline phases (spec, extend, persona, scenario). + + Uses the provider from models.fast (pipeline calls use both fast and strong, + but the provider is determined by the fast model string). + """ config = get_config() - return _create_provider(config.pipeline.provider) + provider, _ = parse_model_string(config.models.fast) + return _get_or_create_provider(provider, f"pipeline:{provider}") def get_simulation_provider() -> LLMProvider: """Get the cached provider for simulation phase (agent reasoning). - Caches the provider so the underlying async HTTP client is reused - across all calls in a batch, avoiding orphaned connections. + Uses the provider from the resolved simulation strong model. """ - global _simulation_provider config = get_config() - provider_name = config.simulation.provider - - if _simulation_provider is None: - _simulation_provider = _create_provider(provider_name) - - return _simulation_provider + strong_model = config.resolve_sim_strong() + provider, _ = parse_model_string(strong_model) + return _get_or_create_provider(provider, f"simulation:{provider}") async def close_simulation_provider() -> None: - """Close the cached simulation provider's async client. + """Close cached providers' async clients. Call this before the event loop shuts down to cleanly release HTTP connections and avoid 'Event loop is closed' errors. """ - global _simulation_provider - if _simulation_provider is not None: - await _simulation_provider.close_async() - _simulation_provider = None + for key, provider in list(_cached_providers.items()): + await provider.close_async() + _cached_providers.clear() + + +def reset_provider_cache() -> None: + """Reset the provider cache (for testing).""" + _cached_providers.clear() + + +# Legacy factory (kept for backward compat in tests) +def _create_provider(provider_name: str) -> LLMProvider: + """DEPRECATED: Use get_provider() instead.""" + # Map old names + name_map = {"claude": "anthropic", "azure_openai": "azure"} + canonical = name_map.get(provider_name, provider_name) + config = get_config() + return get_provider(canonical, config.providers) __all__ = [ "LLMProvider", + "get_provider", "get_pipeline_provider", "get_simulation_provider", "close_simulation_provider", + "reset_provider_cache", + "parse_model_string", ] diff --git a/extropy/core/providers/anthropic.py b/extropy/core/providers/anthropic.py new file mode 100644 index 0000000..d0dd737 --- /dev/null +++ b/extropy/core/providers/anthropic.py @@ -0,0 +1,369 @@ +"""Anthropic (Claude) LLM Provider implementation. + +Uses the tool use pattern for reliable structured output: +instead of asking Claude to output JSON in text, we define a tool +with the response schema. Claude "calls" the tool, returning structured +data guaranteed to match the schema. +""" + +import logging +import random +import time + +import anthropic + +from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback +from .logging import log_request_response, extract_error_summary + +_TRANSIENT_ANTHROPIC_ERRORS = ( + anthropic.APIConnectionError, + anthropic.InternalServerError, + anthropic.RateLimitError, +) +_MAX_API_RETRIES = 3 + + +logger = logging.getLogger(__name__) + + +def _clean_schema_for_tool(schema: dict) -> dict: + """Clean a JSON schema for use as a tool input_schema. + + Removes fields that aren't valid in tool input schemas + (like 'additionalProperties' in nested objects that Claude + doesn't support in tool definitions). + """ + cleaned = {} + for key, value in schema.items(): + if key == "additionalProperties": + continue + if isinstance(value, dict): + cleaned[key] = _clean_schema_for_tool(value) + elif isinstance(value, list): + cleaned[key] = [ + _clean_schema_for_tool(item) if isinstance(item, dict) else item + for item in value + ] + else: + cleaned[key] = value + return cleaned + + +def _make_structured_tool(schema_name: str, response_schema: dict) -> dict: + """Create a tool definition that forces structured output.""" + return { + "name": schema_name, + "description": ( + "Return your response as structured data. " + "You MUST call this tool with your complete response." + ), + "input_schema": _clean_schema_for_tool(response_schema), + } + + +def _extract_tool_input(response) -> dict | None: + """Extract tool_use input from a Claude response.""" + for block in response.content: + if block.type == "tool_use": + return block.input + return None + + +class AnthropicProvider(LLMProvider): + """Anthropic (Claude) LLM provider. + + Uses the tool use pattern for structured output — Claude "calls" a tool + with the response data, guaranteeing valid JSON matching the schema. + """ + + provider_name = "anthropic" + + def __init__(self, api_key: str = "") -> None: + if not api_key: + raise ValueError( + "Anthropic API key not found. Set it via:\n" + " export ANTHROPIC_API_KEY=sk-ant-...\n" + "Get your key from: https://console.anthropic.com/settings/keys" + ) + super().__init__(api_key) + + def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry a sync API call on transient errors with exponential backoff.""" + for attempt in range(max_retries + 1): + try: + return fn() + except _TRANSIENT_ANTHROPIC_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + logger.warning( + f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " + f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" + ) + time.sleep(wait) + + async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry an async API call on transient errors with exponential backoff.""" + import asyncio + + for attempt in range(max_retries + 1): + try: + return await fn() + except _TRANSIENT_ANTHROPIC_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + logger.warning( + f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " + f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" + ) + await asyncio.sleep(wait) + + @property + def default_fast_model(self) -> str: + return "claude-haiku-4-5-20251001" + + @property + def default_strong_model(self) -> str: + return "claude-sonnet-4-5-20250929" + + def _get_client(self) -> anthropic.Anthropic: + return anthropic.Anthropic(api_key=self._api_key) + + def _get_async_client(self) -> anthropic.AsyncAnthropic: + if self._cached_async_client is None: + self._cached_async_client = anthropic.AsyncAnthropic(api_key=self._api_key) + return self._cached_async_client + + def simple_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ) -> dict: + model = model or self.default_simple_model + client = self._get_client() + tool = _make_structured_tool(schema_name, response_schema) + + # Acquire rate limit capacity before making the call + self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) + + logger.info( + f"[Claude] simple_call starting - model={model}, schema={schema_name}" + ) + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=max_tokens or 4096, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": prompt}], + ) + ) + + structured_data = _extract_tool_input(response) + + if log: + log_request_response( + function_name="simple_call", + request={"model": model, "prompt_length": len(prompt)}, + response=response, + provider="claude", + ) + + return structured_data or {} + + async def simple_call_async( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + max_tokens: int | None = None, + ) -> tuple[dict, TokenUsage]: + model = model or self.default_simple_model + client = self._get_async_client() + tool = _make_structured_tool(schema_name, response_schema) + + response = await self._with_retry_async( + lambda: client.messages.create( + model=model, + max_tokens=max_tokens or 4096, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": prompt}], + ) + ) + + # Extract token usage + usage = TokenUsage() + if hasattr(response, "usage") and response.usage is not None: + usage = TokenUsage( + input_tokens=getattr(response.usage, "input_tokens", 0) or 0, + output_tokens=getattr(response.usage, "output_tokens", 0) or 0, + ) + + return _extract_tool_input(response) or {}, usage + + def reasoning_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> dict: + """Claude reasoning call with tool-based structured output.""" + model = model or self.default_reasoning_model + client = self._get_client() + tool = _make_structured_tool(schema_name, response_schema) + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + def _call(ep: str) -> dict: + # Acquire rate limit capacity before each API call + self._acquire_rate_limit(ep, model, max_output=16384) + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=16384, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": ep}], + ) + ) + structured_data = _extract_tool_input(response) + if log: + log_request_response( + function_name="reasoning_call", + request={"model": model, "prompt_length": len(ep)}, + response=response, + provider="claude", + ) + return structured_data or {} + + return self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + def agentic_research( + self, + prompt: str, + response_schema: dict, + schema_name: str = "research_data", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> tuple[dict, list[str]]: + """Claude agentic research with web search + tool-based structured output. + + Uses web_search tool for research and a structured output tool for the response. + Claude first searches, then calls the output tool with results. + """ + model = model or self.default_research_model + client = self._get_client() + output_tool = _make_structured_tool(schema_name, response_schema) + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + all_sources: list[str] = [] + + def _call(ep: str) -> dict: + research_prompt = ( + f"{ep}\n\n" + f"After researching, call the '{schema_name}' tool with your structured findings." + ) + + # Acquire rate limit capacity before each API call + self._acquire_rate_limit(research_prompt, model, max_output=16384) + + logger.info(f"[Claude] agentic_research - model={model}") + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=16384, + tools=[ + { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5, + }, + output_tool, + ], + messages=[{"role": "user", "content": research_prompt}], + ) + ) + + structured_data = None + sources: list[str] = [] + + for block in response.content: + if block.type == "web_search_tool_result": + if hasattr(block, "content") and block.content: + for res in block.content: + if hasattr(res, "url"): + sources.append(res.url) + + if block.type == "tool_use" and block.name == schema_name: + structured_data = block.input + + if block.type == "text": + if hasattr(block, "citations") and block.citations: + for citation in block.citations: + if hasattr(citation, "url"): + sources.append(citation.url) + + all_sources.extend(sources) + logger.info(f"[Claude] Web search completed, found {len(sources)} sources") + + if log: + log_request_response( + function_name="agentic_research", + request={"model": model, "prompt_length": len(research_prompt)}, + response=response, + provider="claude", + sources=list(set(sources)), + ) + + return structured_data or {} + + result = self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + return result, list(set(all_sources)) + + +# Backward compat alias +ClaudeProvider = AnthropicProvider diff --git a/extropy/core/providers/base.py b/extropy/core/providers/base.py index f33fbcf..fc82f53 100644 --- a/extropy/core/providers/base.py +++ b/extropy/core/providers/base.py @@ -1,5 +1,6 @@ """Abstract base class for LLM providers.""" +import logging from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Callable, TYPE_CHECKING @@ -7,6 +8,8 @@ if TYPE_CHECKING: from ..rate_limiter import RateLimiter +logger = logging.getLogger(__name__) + @dataclass class TokenUsage: @@ -34,6 +37,8 @@ class LLMProvider(ABC): All providers must implement these methods with the same signatures to ensure drop-in compatibility. + Automatically records token usage into CostTracker after each call. + Args: api_key: API key or access token for the provider. """ @@ -101,21 +106,28 @@ async def close_async(self) -> None: @property @abstractmethod - def default_simple_model(self) -> str: - """Default model for simple_call (fast, cheap).""" + def default_fast_model(self) -> str: + """Default model for fast/cheap calls (simple_call, Pass 2).""" ... @property @abstractmethod - def default_reasoning_model(self) -> str: - """Default model for reasoning_call (balanced).""" + def default_strong_model(self) -> str: + """Default model for strong/reasoning calls (reasoning_call, agentic_research, Pass 1).""" ... + # Backward-compat aliases (read-only) + @property + def default_simple_model(self) -> str: + return self.default_fast_model + + @property + def default_reasoning_model(self) -> str: + return self.default_strong_model + @property - @abstractmethod def default_research_model(self) -> str: - """Default model for agentic_research (with web search).""" - ... + return self.default_strong_model @abstractmethod def simple_call( diff --git a/extropy/core/providers/claude.py b/extropy/core/providers/claude.py index c06691a..15aa201 100644 --- a/extropy/core/providers/claude.py +++ b/extropy/core/providers/claude.py @@ -1,370 +1,8 @@ -"""Claude (Anthropic) LLM Provider implementation. +"""DEPRECATED: Use extropy.core.providers.anthropic instead. -Uses the tool use pattern for reliable structured output: -instead of asking Claude to output JSON in text, we define a tool -with the response schema. Claude "calls" the tool, returning structured -data guaranteed to match the schema. +This module re-exports AnthropicProvider as ClaudeProvider for backward compatibility. """ -import logging -import random -import time +from .anthropic import AnthropicProvider, ClaudeProvider # noqa: F401 -import anthropic - -from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback -from .logging import log_request_response, extract_error_summary - -_TRANSIENT_ANTHROPIC_ERRORS = ( - anthropic.APIConnectionError, - anthropic.InternalServerError, - anthropic.RateLimitError, -) -_MAX_API_RETRIES = 3 - - -logger = logging.getLogger(__name__) - - -def _clean_schema_for_tool(schema: dict) -> dict: - """Clean a JSON schema for use as a tool input_schema. - - Removes fields that aren't valid in tool input schemas - (like 'additionalProperties' in nested objects that Claude - doesn't support in tool definitions). - """ - cleaned = {} - for key, value in schema.items(): - if key == "additionalProperties": - continue - if isinstance(value, dict): - cleaned[key] = _clean_schema_for_tool(value) - elif isinstance(value, list): - cleaned[key] = [ - _clean_schema_for_tool(item) if isinstance(item, dict) else item - for item in value - ] - else: - cleaned[key] = value - return cleaned - - -def _make_structured_tool(schema_name: str, response_schema: dict) -> dict: - """Create a tool definition that forces structured output.""" - return { - "name": schema_name, - "description": ( - "Return your response as structured data. " - "You MUST call this tool with your complete response." - ), - "input_schema": _clean_schema_for_tool(response_schema), - } - - -def _extract_tool_input(response) -> dict | None: - """Extract tool_use input from a Claude response.""" - for block in response.content: - if block.type == "tool_use": - return block.input - return None - - -class ClaudeProvider(LLMProvider): - """Claude (Anthropic) LLM provider. - - Uses the tool use pattern for structured output — Claude "calls" a tool - with the response data, guaranteeing valid JSON matching the schema. - - """ - - provider_name = "anthropic" - - def __init__(self, api_key: str = "") -> None: - if not api_key: - raise ValueError( - "Anthropic API key not found. Set it via:\n" - " export ANTHROPIC_API_KEY=sk-ant-...\n" - "Get your key from: https://console.anthropic.com/settings/keys" - ) - super().__init__(api_key) - - def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): - """Retry a sync API call on transient errors with exponential backoff.""" - for attempt in range(max_retries + 1): - try: - return fn() - except _TRANSIENT_ANTHROPIC_ERRORS as e: - if attempt == max_retries: - raise - wait = (2**attempt) + random.random() - logger.warning( - f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " - f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" - ) - time.sleep(wait) - - async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): - """Retry an async API call on transient errors with exponential backoff.""" - import asyncio - - for attempt in range(max_retries + 1): - try: - return await fn() - except _TRANSIENT_ANTHROPIC_ERRORS as e: - if attempt == max_retries: - raise - wait = (2**attempt) + random.random() - logger.warning( - f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " - f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" - ) - await asyncio.sleep(wait) - - @property - def default_simple_model(self) -> str: - return "claude-haiku-4-5-20251001" - - @property - def default_reasoning_model(self) -> str: - return "claude-sonnet-4-5-20250929" - - @property - def default_research_model(self) -> str: - return "claude-sonnet-4-5-20250929" - - def _get_client(self) -> anthropic.Anthropic: - return anthropic.Anthropic(api_key=self._api_key) - - def _get_async_client(self) -> anthropic.AsyncAnthropic: - if self._cached_async_client is None: - self._cached_async_client = anthropic.AsyncAnthropic(api_key=self._api_key) - return self._cached_async_client - - def simple_call( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - log: bool = True, - max_tokens: int | None = None, - ) -> dict: - model = model or self.default_simple_model - client = self._get_client() - tool = _make_structured_tool(schema_name, response_schema) - - # Acquire rate limit capacity before making the call - self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) - - logger.info( - f"[Claude] simple_call starting - model={model}, schema={schema_name}" - ) - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=max_tokens or 4096, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": prompt}], - ) - ) - - structured_data = _extract_tool_input(response) - - if log: - log_request_response( - function_name="simple_call", - request={"model": model, "prompt_length": len(prompt)}, - response=response, - provider="claude", - ) - - return structured_data or {} - - async def simple_call_async( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - max_tokens: int | None = None, - ) -> tuple[dict, TokenUsage]: - model = model or self.default_simple_model - client = self._get_async_client() - tool = _make_structured_tool(schema_name, response_schema) - - response = await self._with_retry_async( - lambda: client.messages.create( - model=model, - max_tokens=max_tokens or 4096, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": prompt}], - ) - ) - - # Extract token usage - usage = TokenUsage() - if hasattr(response, "usage") and response.usage is not None: - usage = TokenUsage( - input_tokens=getattr(response.usage, "input_tokens", 0) or 0, - output_tokens=getattr(response.usage, "output_tokens", 0) or 0, - ) - - return _extract_tool_input(response) or {}, usage - - def reasoning_call( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - reasoning_effort: str = "low", - log: bool = True, - previous_errors: str | None = None, - validator: ValidatorCallback | None = None, - max_retries: int = 2, - on_retry: RetryCallback | None = None, - ) -> dict: - """Claude reasoning call with tool-based structured output.""" - model = model or self.default_reasoning_model - client = self._get_client() - tool = _make_structured_tool(schema_name, response_schema) - - effective_prompt = prompt - if previous_errors: - effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" - - def _call(ep: str) -> dict: - # Acquire rate limit capacity before each API call - self._acquire_rate_limit(ep, model, max_output=16384) - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=16384, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": ep}], - ) - ) - structured_data = _extract_tool_input(response) - if log: - log_request_response( - function_name="reasoning_call", - request={"model": model, "prompt_length": len(ep)}, - response=response, - provider="claude", - ) - return structured_data or {} - - return self._retry_with_validation( - call_fn=_call, - prompt=prompt, - validator=validator, - max_retries=max_retries, - on_retry=on_retry, - extract_error_summary_fn=extract_error_summary, - initial_prompt=effective_prompt if previous_errors else None, - ) - - def agentic_research( - self, - prompt: str, - response_schema: dict, - schema_name: str = "research_data", - model: str | None = None, - reasoning_effort: str = "low", - log: bool = True, - previous_errors: str | None = None, - validator: ValidatorCallback | None = None, - max_retries: int = 2, - on_retry: RetryCallback | None = None, - ) -> tuple[dict, list[str]]: - """Claude agentic research with web search + tool-based structured output. - - Uses web_search tool for research and a structured output tool for the response. - Claude first searches, then calls the output tool with results. - """ - model = model or self.default_research_model - client = self._get_client() - output_tool = _make_structured_tool(schema_name, response_schema) - - effective_prompt = prompt - if previous_errors: - effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" - - all_sources: list[str] = [] - - def _call(ep: str) -> dict: - research_prompt = ( - f"{ep}\n\n" - f"After researching, call the '{schema_name}' tool with your structured findings." - ) - - # Acquire rate limit capacity before each API call - self._acquire_rate_limit(research_prompt, model, max_output=16384) - - logger.info(f"[Claude] agentic_research - model={model}") - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=16384, - tools=[ - { - "type": "web_search_20250305", - "name": "web_search", - "max_uses": 5, - }, - output_tool, - ], - messages=[{"role": "user", "content": research_prompt}], - ) - ) - - structured_data = None - sources: list[str] = [] - - for block in response.content: - if block.type == "web_search_tool_result": - if hasattr(block, "content") and block.content: - for res in block.content: - if hasattr(res, "url"): - sources.append(res.url) - - if block.type == "tool_use" and block.name == schema_name: - structured_data = block.input - - if block.type == "text": - if hasattr(block, "citations") and block.citations: - for citation in block.citations: - if hasattr(citation, "url"): - sources.append(citation.url) - - all_sources.extend(sources) - logger.info(f"[Claude] Web search completed, found {len(sources)} sources") - - if log: - log_request_response( - function_name="agentic_research", - request={"model": model, "prompt_length": len(research_prompt)}, - response=response, - provider="claude", - sources=list(set(sources)), - ) - - return structured_data or {} - - result = self._retry_with_validation( - call_fn=_call, - prompt=prompt, - validator=validator, - max_retries=max_retries, - on_retry=on_retry, - extract_error_summary_fn=extract_error_summary, - initial_prompt=effective_prompt if previous_errors else None, - ) - - return result, list(set(all_sources)) +__all__ = ["ClaudeProvider", "AnthropicProvider"] diff --git a/extropy/core/providers/openai.py b/extropy/core/providers/openai.py index 871ad18..2d31307 100644 --- a/extropy/core/providers/openai.py +++ b/extropy/core/providers/openai.py @@ -193,15 +193,11 @@ async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): await asyncio.sleep(wait) @property - def default_simple_model(self) -> str: + def default_fast_model(self) -> str: return "gpt-5-mini" @property - def default_reasoning_model(self) -> str: - return "gpt-5" - - @property - def default_research_model(self) -> str: + def default_strong_model(self) -> str: return "gpt-5" def _get_client(self) -> OpenAI: diff --git a/extropy/core/providers/openai_compat.py b/extropy/core/providers/openai_compat.py new file mode 100644 index 0000000..04ed82f --- /dev/null +++ b/extropy/core/providers/openai_compat.py @@ -0,0 +1,346 @@ +"""OpenAI-compatible LLM Provider for third-party endpoints. + +Supports any provider that implements the OpenAI Chat Completions API: +- OpenRouter, DeepSeek, Together, Groq, Azure OpenAI, etc. + +Uses `openai.OpenAI(base_url=...)` with Chat Completions API for all calls. +Supports `json_schema` response format for structured output. +For agentic_research, appends `:online` to model name if provider supports search, +and parses `url_citation` annotations for sources. +""" + +import json +import logging +import random +import time + +import openai +from openai import OpenAI, AsyncOpenAI + +from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback +from .logging import log_request_response, extract_error_summary + +_TRANSIENT_ERRORS = ( + openai.APIConnectionError, + openai.InternalServerError, + openai.RateLimitError, +) +_MAX_API_RETRIES = 3 + +logger = logging.getLogger(__name__) + + +class OpenAICompatProvider(LLMProvider): + """OpenAI-compatible provider for third-party endpoints. + + Uses the Chat Completions API with json_schema response format. + """ + + def __init__( + self, + api_key: str = "", + *, + base_url: str = "", + supports_search: bool = False, + provider_label: str = "openai_compat", + default_fast: str = "gpt-5-mini", + default_strong: str = "gpt-5", + ) -> None: + if not api_key: + raise ValueError( + f"API key not found for {provider_label}. " + f"Set it as an environment variable." + ) + super().__init__(api_key) + self._base_url = base_url + self._supports_search = supports_search + self.provider_name = provider_label + self._default_fast = default_fast + self._default_strong = default_strong + + @property + def default_fast_model(self) -> str: + return self._default_fast + + @property + def default_strong_model(self) -> str: + return self._default_strong + + def _get_client(self) -> OpenAI: + kwargs: dict = {"api_key": self._api_key} + if self._base_url: + kwargs["base_url"] = self._base_url + return OpenAI(**kwargs) + + def _get_async_client(self) -> AsyncOpenAI: + if self._cached_async_client is None: + kwargs: dict = {"api_key": self._api_key} + if self._base_url: + kwargs["base_url"] = self._base_url + self._cached_async_client = AsyncOpenAI(**kwargs) + return self._cached_async_client + + def _build_params( + self, + model: str, + prompt: str, + schema: dict, + schema_name: str, + max_tokens: int | None, + ) -> dict: + """Build Chat Completions API request parameters.""" + params: dict = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": schema_name, + "strict": True, + "schema": schema, + }, + }, + } + if max_tokens is not None: + params["max_tokens"] = max_tokens + return params + + @staticmethod + def _extract_text(response) -> str | None: + """Extract text from Chat Completions response.""" + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + if content: + return content + return None + + @staticmethod + def _extract_sources(response) -> list[str]: + """Extract citation URLs from response annotations.""" + sources: list[str] = [] + if not response.choices: + return sources + message = response.choices[0].message + if hasattr(message, "annotations") and message.annotations: + for annotation in message.annotations: + if hasattr(annotation, "type") and annotation.type == "url_citation": + if hasattr(annotation, "url"): + sources.append(annotation.url) + return sources + + def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry on transient errors with exponential backoff.""" + for attempt in range(max_retries + 1): + try: + return fn() + except _TRANSIENT_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + lbl = self.provider_name + att = f"{attempt + 1}/{max_retries + 1}" + logger.warning( + f"[{lbl}] Transient error ({att}): " + f"{type(e).__name__}: {e}. " + f"Retrying in {wait:.1f}s" + ) + time.sleep(wait) + + async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): + """Async retry on transient errors.""" + import asyncio + + for attempt in range(max_retries + 1): + try: + return await fn() + except _TRANSIENT_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + lbl = self.provider_name + att = f"{attempt + 1}/{max_retries + 1}" + logger.warning( + f"[{lbl}] Transient error ({att}): " + f"{type(e).__name__}: {e}. " + f"Retrying in {wait:.1f}s" + ) + await asyncio.sleep(wait) + + def simple_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ) -> dict: + model = model or self.default_fast_model + client = self._get_client() + + self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) + + params = self._build_params( + model, prompt, response_schema, schema_name, max_tokens, + ) + lbl = self.provider_name + logger.info(f"[{lbl}] simple_call model={model} schema={schema_name}") + + api_start = time.time() + response = self._with_retry( + lambda: client.chat.completions.create(**params) + ) + api_elapsed = time.time() - api_start + logger.info(f"[{self.provider_name}] API response in {api_elapsed:.2f}s") + + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + + if log: + log_request_response( + function_name="simple_call", + request=params, + response=response, + provider=self.provider_name, + ) + + return structured_data or {} + + async def simple_call_async( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + max_tokens: int | None = None, + ) -> tuple[dict, TokenUsage]: + model = model or self.default_fast_model + client = self._get_async_client() + + params = self._build_params( + model, prompt, response_schema, schema_name, max_tokens, + ) + + response = await self._with_retry_async( + lambda: client.chat.completions.create(**params) + ) + + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + + usage = TokenUsage() + if hasattr(response, "usage") and response.usage is not None: + usage = TokenUsage( + input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, + output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, + ) + + return structured_data or {}, usage + + def reasoning_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> dict: + model = model or self.default_strong_model + client = self._get_client() + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + def _call(ep: str) -> dict: + self._acquire_rate_limit(ep, model, max_output=16384) + params = self._build_params(model, ep, response_schema, schema_name, None) + response = self._with_retry( + lambda: client.chat.completions.create(**params) + ) + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + if log: + log_request_response( + function_name="reasoning_call", + request=params, + response=response, + provider=self.provider_name, + ) + return structured_data or {} + + return self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + def agentic_research( + self, + prompt: str, + response_schema: dict, + schema_name: str = "research_data", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> tuple[dict, list[str]]: + model = model or self.default_strong_model + client = self._get_client() + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + # For providers that support search, append :online suffix + search_model = f"{model}:online" if self._supports_search else model + + all_sources: list[str] = [] + + def _call(ep: str) -> dict: + self._acquire_rate_limit(ep, model, max_output=16384) + params = self._build_params( + search_model, ep, response_schema, schema_name, None + ) + response = self._with_retry( + lambda: client.chat.completions.create(**params) + ) + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + sources = self._extract_sources(response) + all_sources.extend(sources) + + if log: + log_request_response( + function_name="agentic_research", + request=params, + response=response, + provider=self.provider_name, + sources=list(set(sources)), + ) + + return structured_data or {} + + result = self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + return result, list(set(all_sources)) From ac4c47062623e023e46945b49d90a2d96f56cf49 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:00:48 -0500 Subject: [PATCH 07/15] feat(routing): wire 2-tier config through LLM routing, engine, and estimator Update llm.py for fast/strong routing via provider/model strings. Add pricing and rate limit entries for new providers (OpenRouter, DeepSeek, Together, Groq). Update DualRateLimiter for mixed-provider support. Rename engine.py and estimator.py params from model/pivotal_model/ routine_model to strong/fast. Add backward compat properties on SimulationRunConfig. Co-Authored-By: Claude Opus 4.6 --- extropy/core/llm.py | 91 ++++++------ extropy/core/models/simulation.py | 25 ++-- extropy/core/pricing.py | 54 +++++-- extropy/core/rate_limiter.py | 70 ++++++--- extropy/core/rate_limits.py | 27 +++- extropy/simulation/engine.py | 230 ++++++++++++++++++++---------- extropy/simulation/estimator.py | 21 +-- extropy/simulation/reasoning.py | 6 +- extropy/simulation/state.py | 59 +++++--- 9 files changed, 385 insertions(+), 198 deletions(-) diff --git a/extropy/core/llm.py b/extropy/core/llm.py index ecfa0a5..dfbc3db 100644 --- a/extropy/core/llm.py +++ b/extropy/core/llm.py @@ -1,20 +1,19 @@ """LLM clients for Extropy - Facade Layer. -This module provides a unified interface to LLM providers with two-zone routing: -- Pipeline (sync calls): simple_call, reasoning_call, agentic_research - → Uses the pipeline provider (configured for phases 1-2) -- Simulation (async calls): simple_call_async - → Uses the simulation provider (configured for phase 3) +This module provides a unified interface to LLM providers with two-tier routing: +- fast: simple_call → uses models.fast (cheap, fast tasks) +- strong: reasoning_call, agentic_research → uses models.strong (complex tasks) +- simulation: simple_call_async → uses simulation.strong/fast -Configure via `extropy config` CLI or programmatically via extropy.config.configure(). +Model strings use "provider/model" format. The provider is extracted to route +to the correct backend; the model name is passed through. -Each function supports retry with error feedback via the `previous_errors` parameter. -When validation fails, pass the error message back to let the LLM self-correct. +Configure via `extropy config` CLI or programmatically via extropy.config.configure(). """ -from .providers import get_pipeline_provider, get_simulation_provider +from .providers import get_provider from .providers.base import TokenUsage, ValidatorCallback, RetryCallback -from ..config import get_config +from ..config import get_config, parse_model_string __all__ = [ @@ -28,25 +27,14 @@ ] -def _get_pipeline_model_override(tier: str) -> str | None: - """Get pipeline model override from config if configured.""" +def _resolve_provider_and_model( + model_string: str, +) -> tuple: + """Resolve a "provider/model" string to (provider_instance, model_name).""" config = get_config() - pipeline = config.pipeline - if tier == "simple" and pipeline.model_simple: - return pipeline.model_simple - elif tier == "reasoning" and pipeline.model_reasoning: - return pipeline.model_reasoning - elif tier == "research" and pipeline.model_research: - return pipeline.model_research - return None - - -def _get_simulation_model_override() -> str | None: - """Get simulation model override from config if configured.""" - config = get_config() - if config.simulation.model: - return config.simulation.model - return None + provider_name, model_name = parse_model_string(model_string) + provider = get_provider(provider_name, config.providers) + return provider, model_name def simple_call( @@ -59,20 +47,21 @@ def simple_call( ) -> dict: """Simple LLM call with structured output, no reasoning, no web search. - Routed through the PIPELINE provider. + Uses the FAST tier (config.models.fast). Use for fast, cheap tasks: - Context sufficiency checks - Simple classification - Validation """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("simple") + config = get_config() + model_string = model or config.resolve_pipeline_fast() + provider, model_name = _resolve_provider_and_model(model_string) return provider.simple_call( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, log=log, max_tokens=max_tokens, ) @@ -87,18 +76,21 @@ async def simple_call_async( ) -> tuple[dict, TokenUsage]: """Async version of simple_call for concurrent API requests. - Routed through the SIMULATION provider. - Used for batch agent reasoning during simulation. + Model is passed explicitly from simulation caller (provider/model format). Returns (structured_data, token_usage) tuple. """ - provider = get_simulation_provider() - effective_model = model or _get_simulation_model_override() + if model: + provider, model_name = _resolve_provider_and_model(model) + else: + config = get_config() + model_string = config.resolve_sim_strong() + provider, model_name = _resolve_provider_and_model(model_string) return await provider.simple_call_async( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, max_tokens=max_tokens, ) @@ -117,20 +109,21 @@ def reasoning_call( ) -> dict: """LLM call with reasoning and structured output, but NO web search. - Routed through the PIPELINE provider. + Uses the STRONG tier (config.models.strong). Use for tasks that require reasoning but not external data: - Attribute selection/categorization - Schema design - Logical analysis """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("reasoning") + config = get_config() + model_string = model or config.resolve_pipeline_strong() + provider, model_name = _resolve_provider_and_model(model_string) return provider.reasoning_call( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, reasoning_effort=reasoning_effort, log=log, previous_errors=previous_errors, @@ -154,21 +147,17 @@ def agentic_research( ) -> tuple[dict, list[str]]: """Perform agentic research with web search and structured output. - Routed through the PIPELINE provider. - - The model will: - 1. Decide what to search for - 2. Search the web (possibly multiple times) - 3. Reason about the results - 4. Return structured data matching the schema + Uses the STRONG tier (config.models.strong). + Web search is a provider capability, not a tier distinction. """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("research") + config = get_config() + model_string = model or config.resolve_pipeline_strong() + provider, model_name = _resolve_provider_and_model(model_string) return provider.agentic_research( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, reasoning_effort=reasoning_effort, log=log, previous_errors=previous_errors, diff --git a/extropy/core/models/simulation.py b/extropy/core/models/simulation.py index 5532686..88d5617 100644 --- a/extropy/core/models/simulation.py +++ b/extropy/core/models/simulation.py @@ -338,17 +338,13 @@ class SimulationRunConfig(BaseModel): scenario_path: str = Field(description="Path to scenario YAML") output_dir: str = Field(description="Directory for results output") - model: str = Field( + strong: str = Field( default="", - description="LLM model for agent reasoning (empty = use config default)", + description="Strong model for Pass 1 role-play reasoning (provider/model format, empty = config default)", ) - pivotal_model: str = Field( + fast: str = Field( default="", - description="Model for pivotal reasoning (default: same as model)", - ) - routine_model: str = Field( - default="", - description="Cheap model for routine reasoning + classification (default: provider cheap tier)", + description="Fast model for Pass 2 classification (provider/model format, empty = config default)", ) reasoning_effort: str = Field(default="low", description="Reasoning effort level") multi_touch_threshold: int = Field( @@ -362,6 +358,19 @@ class SimulationRunConfig(BaseModel): default=50, description="Agents per reasoning chunk for checkpointing" ) + # Backward compat aliases + @property + def model(self) -> str: + return self.strong + + @property + def pivotal_model(self) -> str: + return self.strong + + @property + def routine_model(self) -> str: + return self.fast + # ============================================================================= # Timestep Summary diff --git a/extropy/core/pricing.py b/extropy/core/pricing.py index 616a25d..d2de21a 100644 --- a/extropy/core/pricing.py +++ b/extropy/core/pricing.py @@ -40,21 +40,49 @@ class ModelPricing: ), "claude-haiku-4.5": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), "claude-haiku-4": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + # DeepSeek (direct API) + "deepseek-chat": ModelPricing(input_per_mtok=0.14, output_per_mtok=0.28), + "deepseek-reasoner": ModelPricing(input_per_mtok=0.55, output_per_mtok=2.19), } -# Provider default models (matches provider classes, no API key needed) +# Provider default models — 2-tier (fast/strong) PROVIDER_DEFAULTS: dict[str, dict[str, str]] = { "openai": { - "simple": "gpt-5-mini", - "reasoning": "gpt-5", + "fast": "gpt-5-mini", + "strong": "gpt-5", }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "openrouter": { + "fast": "openai/gpt-5-mini", + "strong": "openai/gpt-5", + }, + "deepseek": { + "fast": "deepseek-chat", + "strong": "deepseek-reasoner", + }, + "together": { + "fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + "groq": { + "fast": "llama-3.3-70b-versatile", + "strong": "llama-3.3-70b-versatile", + }, + # Legacy aliases "claude": { - "simple": "claude-haiku-4-5-20251001", - "reasoning": "claude-sonnet-4-5-20250929", + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", }, "azure_openai": { - "simple": "gpt-5-mini", - "reasoning": "gpt-5", + "fast": "gpt-5-mini", + "strong": "gpt-5", }, } @@ -64,15 +92,19 @@ def get_pricing(model: str) -> ModelPricing | None: return MODEL_PRICING.get(model) -def resolve_default_model(provider: str, tier: str = "reasoning") -> str: +def resolve_default_model(provider: str, tier: str = "strong") -> str: """Resolve default model name for a provider without instantiating it. Args: - provider: Provider name ('openai' or 'claude') - tier: 'simple' or 'reasoning' + provider: Provider name ('openai', 'anthropic', etc.) + tier: 'fast' or 'strong' (also accepts legacy 'simple'/'reasoning') Returns: Model name string """ + # Map legacy tier names + tier_map = {"simple": "fast", "reasoning": "strong"} + tier = tier_map.get(tier, tier) + defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["openai"]) - return defaults.get(tier, defaults["reasoning"]) + return defaults.get(tier, defaults["strong"]) diff --git a/extropy/core/rate_limiter.py b/extropy/core/rate_limiter.py index 256cdb1..7aba1b9 100644 --- a/extropy/core/rate_limiter.py +++ b/extropy/core/rate_limiter.py @@ -486,10 +486,11 @@ def stats(self) -> dict: class DualRateLimiter: - """Manages separate rate limiters for pivotal (Pass 1) and routine (Pass 2) models. + """Manages separate rate limiters for strong (Pass 1) and fast (Pass 2) models. - When pivotal and routine models are the same, uses a single shared limiter. + When strong and fast models are the same, uses a single shared limiter. When they differ, uses independent limiters since API limits are per-model. + Supports mixed providers (e.g., strong=anthropic, fast=openai). """ def __init__( @@ -499,51 +500,82 @@ def __init__( ): self.pivotal = pivotal self.routine = routine + # Aliases for new naming convention + self.strong = pivotal + self.fast = routine @classmethod def create( cls, - provider: str, + provider: str = "", pivotal_model: str = "", routine_model: str = "", tier: int | None = None, rpm_override: int | None = None, tpm_override: int | None = None, + *, + strong_model_string: str = "", + fast_model_string: str = "", ) -> "DualRateLimiter": """Create dual rate limiter for two-pass reasoning. - If both models are the same (or routine is empty), a single - shared limiter is used for both passes. + Accepts either: + - Legacy: provider + pivotal_model + routine_model (single provider) + - New: strong_model_string + fast_model_string (provider/model format, mixed providers) Args: - provider: Provider name - pivotal_model: Model for Pass 1 (role-play reasoning) - routine_model: Model for Pass 2 (classification) + provider: Provider name (legacy, used if model strings not provided) + pivotal_model: Model for Pass 1 (legacy) + routine_model: Model for Pass 2 (legacy) tier: Rate limit tier (1-4) - rpm_override: Override RPM (applies to pivotal limiter) - tpm_override: Override TPM (applies to pivotal limiter) + rpm_override: Override RPM + tpm_override: Override TPM + strong_model_string: "provider/model" for strong/pivotal (new) + fast_model_string: "provider/model" for fast/routine (new) Returns: DualRateLimiter instance """ + # Resolve strong limiter + if strong_model_string and "/" in strong_model_string: + from ..config import parse_model_string + + strong_provider, strong_model = parse_model_string(strong_model_string) + else: + strong_provider = provider + strong_model = pivotal_model + pivotal_limiter = RateLimiter.for_provider( - provider=provider, - model=pivotal_model, + provider=strong_provider, + model=strong_model, tier=tier, rpm_override=rpm_override, tpm_override=tpm_override, ) - # If routine model is the same as pivotal (or not specified), share the limiter - effective_routine = routine_model or pivotal_model - if effective_routine == pivotal_model or not effective_routine: + # Resolve fast limiter + if fast_model_string and "/" in fast_model_string: + from ..config import parse_model_string + + fast_provider, fast_model = parse_model_string(fast_model_string) + else: + fast_provider = provider + fast_model = routine_model + + # If same provider+model, share the limiter + effective_fast_model = fast_model or strong_model + if ( + fast_provider == strong_provider + and effective_fast_model == strong_model + ): + return cls(pivotal=pivotal_limiter, routine=pivotal_limiter) + + if not effective_fast_model and not fast_provider: return cls(pivotal=pivotal_limiter, routine=pivotal_limiter) - # Different models — create separate limiter for routine - # Overrides apply to both (on Azure, limits are per-resource not per-model) routine_limiter = RateLimiter.for_provider( - provider=provider, - model=effective_routine, + provider=fast_provider or strong_provider, + model=effective_fast_model, tier=tier, rpm_override=rpm_override, tpm_override=tpm_override, diff --git a/extropy/core/rate_limits.py b/extropy/core/rate_limits.py index d081e60..ec53025 100644 --- a/extropy/core/rate_limits.py +++ b/extropy/core/rate_limits.py @@ -85,11 +85,32 @@ }, } -# Map "claude" provider name to anthropic profiles +# Provider aliases — map alternate names to canonical profiles RATE_LIMIT_PROFILES["claude"] = RATE_LIMIT_PROFILES["anthropic"] - -# Azure OpenAI uses the same rate limit profiles as standard OpenAI RATE_LIMIT_PROFILES["azure_openai"] = RATE_LIMIT_PROFILES["openai"] +RATE_LIMIT_PROFILES["azure"] = RATE_LIMIT_PROFILES["openai"] + +# Third-party providers — conservative defaults +# These providers typically have per-key limits; adjust via rate_tier/rpm_override. +_THIRD_PARTY_DEFAULT = { + "default": { + 1: {"rpm": 60, "tpm": 100_000}, + 2: {"rpm": 200, "tpm": 500_000}, + 3: {"rpm": 500, "tpm": 1_000_000}, + 4: {"rpm": 1_000, "tpm": 2_000_000}, + }, +} +RATE_LIMIT_PROFILES["openrouter"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["deepseek"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["together"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["groq"] = { + "default": { + 1: {"rpm": 30, "tpm": 15_000}, + 2: {"rpm": 60, "tpm": 50_000}, + 3: {"rpm": 200, "tpm": 100_000}, + 4: {"rpm": 500, "tpm": 500_000}, + }, +} def get_limits( diff --git a/extropy/simulation/engine.py b/extropy/simulation/engine.py index 5237dc6..2bbd3b0 100644 --- a/extropy/simulation/engine.py +++ b/extropy/simulation/engine.py @@ -14,8 +14,10 @@ import json import logging +import queue import random import sqlite3 +import threading import time import uuid from datetime import datetime @@ -149,6 +151,8 @@ def __init__( run_id: str | None = None, checkpoint_every_chunks: int = 1, retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, ): """Initialize simulation engine. @@ -173,6 +177,8 @@ def __init__( self.run_id = run_id or f"run_{uuid.uuid4().hex[:12]}" self.checkpoint_every_chunks = max(1, checkpoint_every_chunks) self.retention_lite = retention_lite + self.writer_queue_size = max(1, writer_queue_size) + self.db_write_batch_size = max(1, db_write_batch_size) # Build agent map for quick lookup self.agent_map = {a.get("_id", str(i)): a for i, a in enumerate(agents)} @@ -202,6 +208,7 @@ def __init__( self.state_manager = StateManager( state_db_file, agents, + run_id=self.run_id, ) self.study_db = open_study_db(state_db_file) @@ -574,16 +581,87 @@ def _on_agent_done(agent_id: str, result: Any) -> None: context = self._build_reasoning_context(agent_id, old_state) contexts.append(context) - # Split into chunks - total_reasoned = 0 - total_changes = 0 - total_shares = 0 - completed_chunks = self.study_db.get_completed_simulation_chunks( self.run_id, timestep ) + totals = {"reasoned": 0, "changes": 0, "shares": 0} + + work_queue: queue.Queue[tuple[int, list[tuple[str, Any]], bool] | object] = ( + queue.Queue(maxsize=self.writer_queue_size) + ) + sentinel = object() + writer_error: list[Exception] = [] + + def _writer_loop() -> None: + chunks_since_checkpoint = 0 + pending_chunks: list[tuple[int, list[tuple[str, Any]], bool]] = [] + + def _flush_pending() -> None: + nonlocal chunks_since_checkpoint + if not pending_chunks: + return + with self.state_manager.transaction(): + for chunk_index, chunk_results, _is_last_chunk in pending_chunks: + reasoned, changes, shares = self._process_reasoning_chunk( + timestep, chunk_results, old_states + ) + totals["reasoned"] += reasoned + totals["changes"] += changes + totals["shares"] += shares + + for chunk_index, _chunk_results, is_last_chunk in pending_chunks: + self.study_db.save_simulation_checkpoint( + run_id=self.run_id, + timestep=timestep, + chunk_index=chunk_index, + status="done", + ) + + chunks_since_checkpoint += 1 + if ( + chunks_since_checkpoint >= self.checkpoint_every_chunks + or is_last_chunk + ): + self.study_db.set_run_metadata( + self.run_id, + "last_checkpoint", + f"{timestep}:{chunk_index}", + ) + chunks_since_checkpoint = 0 + + pending_chunks.clear() + + try: + while True: + item = work_queue.get() + try: + if item is sentinel: + _flush_pending() + break + chunk_index, chunk_results, is_last_chunk = item + if chunk_index in completed_chunks: + continue + pending_chunks.append((chunk_index, chunk_results, is_last_chunk)) + if ( + len(pending_chunks) >= self.db_write_batch_size + or is_last_chunk + ): + _flush_pending() + finally: + work_queue.task_done() + except Exception as e: # pragma: no cover - surfaced to caller + writer_error.append(e) + + writer_thread = threading.Thread( + target=_writer_loop, + name=f"sim-writer-{self.run_id}-{timestep}", + daemon=True, + ) + writer_thread.start() for chunk_start in range(0, len(contexts), self.chunk_size): + if writer_error: + break chunk_index = chunk_start // self.chunk_size if chunk_index in completed_chunks: logger.info( @@ -603,7 +681,6 @@ def _on_agent_done(agent_id: str, result: Any) -> None: reasoning_elapsed = time.time() - reasoning_start self.total_reasoning_calls += len(chunk_results) - # Accumulate token usage self.pivotal_input_tokens += chunk_usage.pivotal_input_tokens self.pivotal_output_tokens += chunk_usage.pivotal_output_tokens self.routine_input_tokens += chunk_usage.routine_input_tokens @@ -616,27 +693,26 @@ def _on_agent_done(agent_id: str, result: Any) -> None: else f"[TIMESTEP {timestep}] Chunk empty" ) - # Process and commit this chunk - with self.state_manager.transaction(): - reasoned, changes, shares = self._process_reasoning_chunk( - timestep, chunk_results, old_states - ) - if ( - ((chunk_index + 1) % self.checkpoint_every_chunks == 0) - or (chunk_start + self.chunk_size >= len(contexts)) - ): - self.study_db.save_simulation_checkpoint( - run_id=self.run_id, - timestep=timestep, - chunk_index=chunk_index, - status="done", - ) - - total_reasoned += reasoned - total_changes += changes - total_shares += shares - - return total_reasoned, total_changes, total_shares + is_last_chunk = chunk_start + self.chunk_size >= len(contexts) + work_queue.put((chunk_index, chunk_results, is_last_chunk)) + + work_queue.put(sentinel) + while work_queue.unfinished_tasks > 0: + if writer_error: + while True: + try: + work_queue.get_nowait() + work_queue.task_done() + except queue.Empty: + break + break + time.sleep(0.01) + work_queue.join() + writer_thread.join(timeout=1) + if writer_error: + raise writer_error[0] + + return totals["reasoned"], totals["changes"], totals["shares"] def _process_reasoning_chunk( self, @@ -1112,7 +1188,7 @@ def _finalize( final_exposure_rate=self.state_manager.get_exposure_rate(), outcome_distributions=outcome_dists, runtime_seconds=runtime, - model_used=self.config.model, + model_used=self.config.strong, completed_at=datetime.now(), ) @@ -1137,19 +1213,14 @@ def _compute_cost(self) -> dict[str, Any]: # Resolve effective model names for pricing config = get_config() - provider = config.simulation.provider - pivotal_model = ( - self.config.pivotal_model - or self.config.model - or config.simulation.pivotal_model - or config.simulation.model - or resolve_default_model(provider, "reasoning") - ) - routine_model = ( - self.config.routine_model - or config.simulation.routine_model - or resolve_default_model(provider, "simple") - ) + from ..config import parse_model_string + + strong_model_str = self.config.strong or config.resolve_sim_strong() + fast_model_str = self.config.fast or config.resolve_sim_fast() + + # Strip provider prefix for pricing lookup (pricing is keyed by bare model name) + _, pivotal_model = parse_model_string(strong_model_str) + _, routine_model = parse_model_string(fast_model_str) cost["pivotal_model"] = pivotal_model cost["routine_model"] = routine_model @@ -1238,9 +1309,8 @@ def _export_results(self) -> None: "scenario_name": self.scenario.meta.name, "scenario_path": self.config.scenario_path, "population_size": len(self.agents), - "model": self.config.model, - "pivotal_model": self.config.pivotal_model, - "routine_model": self.config.routine_model, + "strong_model": self.config.strong, + "fast_model": self.config.fast, "seed": self.seed, "multi_touch_threshold": self.config.multi_touch_threshold, "completed_at": datetime.now().isoformat(), @@ -1260,9 +1330,8 @@ def run_simulation( scenario_path: str | Path, output_dir: str | Path, study_db_path: str | Path | None = None, - model: str = "", - pivotal_model: str = "", - routine_model: str = "", + strong: str = "", + fast: str = "", multi_touch_threshold: int = 3, random_seed: int | None = None, on_progress: TimestepProgressCallback | None = None, @@ -1276,6 +1345,8 @@ def run_simulation( resume: bool = False, checkpoint_every_chunks: int = 1, retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, ) -> SimulationSummary: """Run a simulation from a scenario file. @@ -1284,9 +1355,8 @@ def run_simulation( Args: scenario_path: Path to scenario YAML file output_dir: Directory for results output - model: LLM model for agent reasoning - pivotal_model: Model for pivotal reasoning (default: same as model) - routine_model: Cheap model for routine + classification + strong: Strong model for Pass 1 reasoning (provider/model format) + fast: Fast model for Pass 2 classification (provider/model format) multi_touch_threshold: Re-reason after N new exposures random_seed: Random seed for reproducibility on_progress: Progress callback(timestep, max, status) @@ -1300,6 +1370,8 @@ def run_simulation( resume: Resume a prior run from DB checkpoints checkpoint_every_chunks: Mark simulation checkpoint every N chunks retention_lite: Reduce payload volume by dropping full raw reasoning text + writer_queue_size: Max buffered reasoning chunks before writer backpressure + db_write_batch_size: Number of chunks applied per DB writer transaction Returns: SimulationSummary with results @@ -1309,21 +1381,26 @@ def run_simulation( if resume and not run_id: raise ValueError("--resume requires --run-id") - def _reset_runtime_tables(path: Path) -> None: + def _reset_runtime_tables(path: Path, run_key: str) -> None: conn = sqlite3.connect(str(path)) try: cur = conn.cursor() - cur.executescript( - """ - DELETE FROM agent_states; - DELETE FROM exposures; - DELETE FROM memory_traces; - DELETE FROM timeline; - DELETE FROM timestep_summaries; - DELETE FROM shared_to; - DELETE FROM simulation_metadata; - """ - ) + statements = [ + "DELETE FROM agent_states WHERE run_id = ?", + "DELETE FROM exposures WHERE run_id = ?", + "DELETE FROM memory_traces WHERE run_id = ?", + "DELETE FROM timeline WHERE run_id = ?", + "DELETE FROM timestep_summaries WHERE run_id = ?", + "DELETE FROM shared_to WHERE run_id = ?", + "DELETE FROM simulation_metadata WHERE run_id = ?", + ] + for sql in statements: + try: + cur.execute(sql, (run_key,)) + except sqlite3.OperationalError: + # Legacy tables without run_id columns. + table = sql.split()[2] + cur.execute(f"DELETE FROM {table}") conn.commit() except sqlite3.OperationalError: # First run on this DB may not have simulation tables yet. @@ -1379,13 +1456,14 @@ def _reset_runtime_tables(path: Path) -> None: config={ "scenario_path": str(scenario_path), "output_dir": str(output_dir), - "model": model, - "pivotal_model": pivotal_model, - "routine_model": routine_model, + "strong": strong, + "fast": fast, "multi_touch_threshold": multi_touch_threshold, "chunk_size": chunk_size, "checkpoint_every_chunks": checkpoint_every_chunks, "retention_lite": retention_lite, + "writer_queue_size": writer_queue_size, + "db_write_batch_size": db_write_batch_size, "resume": resume, }, seed=random_seed, @@ -1395,7 +1473,7 @@ def _reset_runtime_tables(path: Path) -> None: db.set_run_metadata(resolved_run_id, "study_db", str(study_db_resolved)) if not resume: - _reset_runtime_tables(study_db_resolved) + _reset_runtime_tables(study_db_resolved, resolved_run_id) # Load persona config if provided persona_config = None @@ -1415,26 +1493,22 @@ def _reset_runtime_tables(path: Path) -> None: config = SimulationRunConfig( scenario_path=str(scenario_path), output_dir=str(output_dir), - model=model, - pivotal_model=pivotal_model, - routine_model=routine_model, + strong=strong, + fast=fast, multi_touch_threshold=multi_touch_threshold, random_seed=random_seed, ) - # Create dual rate limiter (separate limiters for pivotal and routine models) + # Resolve effective model strings for rate limiting from ..config import get_config entropy_config = get_config() - provider = entropy_config.simulation.provider - effective_model = model or entropy_config.simulation.model or "" - effective_pivotal = pivotal_model or effective_model - effective_routine = routine_model or entropy_config.simulation.routine_model or "" + effective_strong = strong or entropy_config.resolve_sim_strong() + effective_fast = fast or entropy_config.resolve_sim_fast() rate_limiter = DualRateLimiter.create( - provider=provider, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong_model_string=effective_strong, + fast_model_string=effective_fast, tier=rate_tier, rpm_override=rpm_override, tpm_override=tpm_override, @@ -1454,6 +1528,8 @@ def _reset_runtime_tables(path: Path) -> None: run_id=resolved_run_id, checkpoint_every_chunks=checkpoint_every_chunks, retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, ) if on_progress: diff --git a/extropy/simulation/estimator.py b/extropy/simulation/estimator.py index 376245a..2830447 100644 --- a/extropy/simulation/estimator.py +++ b/extropy/simulation/estimator.py @@ -138,9 +138,8 @@ def estimate_simulation_cost( population_spec: PopulationSpec, agents: list[dict[str, Any]], network: dict[str, Any], - provider: str = "openai", - pivotal_model: str = "", - routine_model: str = "", + strong_model: str = "", + fast_model: str = "", multi_touch_threshold: int = 3, ) -> CostEstimate: """Estimate the cost of running a simulation. @@ -153,9 +152,8 @@ def estimate_simulation_cost( population_spec: Population specification agents: List of agent dictionaries network: Network data dict - provider: LLM provider name - pivotal_model: Model for Pass 1 (empty = provider default) - routine_model: Model for Pass 2 (empty = provider cheap tier) + strong_model: Model for Pass 1 (provider/model format, empty = config default) + fast_model: Model for Pass 2 (provider/model format, empty = config default) multi_touch_threshold: Re-reasoning threshold Returns: @@ -167,9 +165,14 @@ def estimate_simulation_cost( share_prob = scenario.spread.share_probability will_share_rate = 0.55 # accounts for conviction-gated sharing - # Resolve models - eff_pivotal = pivotal_model or resolve_default_model(provider, "reasoning") - eff_routine = routine_model or resolve_default_model(provider, "simple") + # Resolve models — strip provider prefix for pricing lookup + from ..config import get_config, parse_model_string + + config = get_config() + eff_strong_str = strong_model or config.resolve_sim_strong() + eff_fast_str = fast_model or config.resolve_sim_fast() + _, eff_pivotal = parse_model_string(eff_strong_str) + _, eff_routine = parse_model_string(eff_fast_str) # Pre-compute seed exposure schedule: timestep -> expected new seed exposures seed_schedule: dict[int, float] = {} diff --git a/extropy/simulation/reasoning.py b/extropy/simulation/reasoning.py index 7c51151..5d1d8e6 100644 --- a/extropy/simulation/reasoning.py +++ b/extropy/simulation/reasoning.py @@ -455,8 +455,8 @@ async def _reason_agent_two_pass_async( position_outcome = _get_primary_position_outcome(scenario) # Determine models - main_model = config.model or None # None = provider default - classify_model = config.routine_model or None # None = provider default (cheap) + main_model = config.strong or None # None = provider default + classify_model = config.fast or None # None = provider default (cheap) # === Pass 1: Role-play === pass1_usage = TokenUsage() @@ -687,7 +687,7 @@ def reason_agent( if pass2_schema: pass2_prompt = build_pass2_prompt(reasoning, scenario) - classify_model = config.routine_model or None + classify_model = config.fast or None for attempt in range(config.max_retries): try: diff --git a/extropy/simulation/state.py b/extropy/simulation/state.py index 94a4350..c583471 100644 --- a/extropy/simulation/state.py +++ b/extropy/simulation/state.py @@ -27,16 +27,23 @@ class StateManager: for frequently accessed data. """ - def __init__(self, db_path: Path | str, agents: list[dict[str, Any]] | None = None): + def __init__( + self, + db_path: Path | str, + agents: list[dict[str, Any]] | None = None, + run_id: str = "default", + ): """Initialize state manager with database path. Args: db_path: Path to SQLite database file agents: Optional list of agents to initialize + run_id: Run scope for all state reads/writes """ self.db_path = Path(db_path) + self.run_id = run_id self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path)) + self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) self.conn.row_factory = sqlite3.Row self.conn.execute("PRAGMA foreign_keys = ON") @@ -54,7 +61,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS agent_states ( - agent_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, aware INTEGER DEFAULT 0, exposure_count INTEGER DEFAULT 0, last_reasoning_timestep INTEGER DEFAULT -1, @@ -73,7 +81,8 @@ def _create_schema(self) -> None: private_conviction REAL, private_outcomes_json TEXT, raw_reasoning TEXT, - updated_at INTEGER DEFAULT 0 + updated_at INTEGER DEFAULT 0, + PRIMARY KEY (run_id, agent_id) ) """ ) @@ -82,6 +91,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS exposures ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, @@ -89,7 +99,7 @@ def _create_schema(self) -> None: source_agent_id TEXT, content TEXT, credibility REAL, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -98,13 +108,14 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS memory_traces ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, sentiment REAL, conviction REAL, summary TEXT, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -113,6 +124,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timeline ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, timestep INTEGER, event_type TEXT, @@ -127,7 +139,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timestep_summaries ( - timestep INTEGER PRIMARY KEY, + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, new_exposures INTEGER, agents_reasoned INTEGER, shares_occurred INTEGER, @@ -136,7 +149,8 @@ def _create_schema(self) -> None: position_distribution_json TEXT, average_sentiment REAL, average_conviction REAL, - sentiment_variance REAL + sentiment_variance REAL, + PRIMARY KEY (run_id, timestep) ) """ ) @@ -145,37 +159,37 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_agent - ON exposures(agent_id) + ON exposures(run_id, agent_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_timestep - ON exposures(timestep) + ON exposures(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timeline_timestep - ON timeline(timestep) + ON timeline(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_aware - ON agent_states(aware) + ON agent_states(run_id, aware) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_will_share - ON agent_states(will_share) + ON agent_states(run_id, will_share) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_memory_traces_agent - ON memory_traces(agent_id) + ON memory_traces(run_id, agent_id) """ ) @@ -183,18 +197,19 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS shared_to ( + run_id TEXT NOT NULL, source_agent_id TEXT, target_agent_id TEXT, timestep INTEGER, position TEXT, - PRIMARY KEY (source_agent_id, target_agent_id) + PRIMARY KEY (run_id, source_agent_id, target_agent_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_shared_to_source - ON shared_to(source_agent_id) + ON shared_to(run_id, source_agent_id) """ ) @@ -202,8 +217,11 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS simulation_metadata ( - key TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + key TEXT NOT NULL, value TEXT + , + PRIMARY KEY (run_id, key) ) """ ) @@ -215,6 +233,13 @@ def _upgrade_schema(self) -> None: cursor = self.conn.cursor() migrations = [ + ("agent_states", "run_id", "TEXT DEFAULT 'default'"), + ("exposures", "run_id", "TEXT DEFAULT 'default'"), + ("memory_traces", "run_id", "TEXT DEFAULT 'default'"), + ("timeline", "run_id", "TEXT DEFAULT 'default'"), + ("timestep_summaries", "run_id", "TEXT DEFAULT 'default'"), + ("shared_to", "run_id", "TEXT DEFAULT 'default'"), + ("simulation_metadata", "run_id", "TEXT DEFAULT 'default'"), ("agent_states", "conviction", "REAL"), ("agent_states", "public_statement", "TEXT"), ("timestep_summaries", "average_conviction", "REAL"), From 36d00320737ed7515cfbaf91d08caa585033a04f Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:00:54 -0500 Subject: [PATCH 08/15] feat(cli): update config/simulate/estimate commands for fast/strong flags Replace --model/--pivotal-model/--routine-model with --strong/--fast. Rewrite config show/set for new keys (models.fast, models.strong, simulation.fast, simulation.strong, providers.*). Update tests for new abstract method names and parameter changes. Co-Authored-By: Claude Opus 4.6 --- extropy/cli/commands/config_cmd.py | 146 +++++++++++++++-------------- extropy/cli/commands/estimate.py | 38 +++----- extropy/cli/commands/simulate.py | 44 +++------ tests/test_cli.py | 2 +- tests/test_estimator.py | 18 ++-- tests/test_providers.py | 36 +++---- 6 files changed, 127 insertions(+), 157 deletions(-) diff --git a/extropy/cli/commands/config_cmd.py b/extropy/cli/commands/config_cmd.py index c2945e7..9c80170 100644 --- a/extropy/cli/commands/config_cmd.py +++ b/extropy/cli/commands/config_cmd.py @@ -7,28 +7,25 @@ get_config, reset_config, CONFIG_FILE, - get_api_key, - get_azure_config, + get_api_key_for_provider, ) VALID_KEYS = { - "pipeline.provider", - "pipeline.model_simple", - "pipeline.model_reasoning", - "pipeline.model_research", - "simulation.provider", - "simulation.model", - "simulation.pivotal_model", - "simulation.routine_model", + "models.fast", + "models.strong", + "simulation.fast", + "simulation.strong", "simulation.max_concurrent", "simulation.rate_tier", "simulation.rpm_override", "simulation.tpm_override", - "simulation.api_format", } -INT_FIELDS = {"max_concurrent", "rate_tier", "rpm_override", "tpm_override"} +INT_FIELDS = { + "max_concurrent", "rate_tier", "rpm_override", + "tpm_override", +} @app.command("config") @@ -39,7 +36,7 @@ def config_command( ), key: str | None = typer.Argument( None, - help="Config key (e.g. pipeline.provider, simulation.model)", + help="Config key (e.g. models.fast, simulation.strong)", ), value: str | None = typer.Argument( None, @@ -50,9 +47,9 @@ def config_command( Examples: extropy config show - extropy config set pipeline.provider claude - extropy config set simulation.provider openai - extropy config set simulation.model gpt-5-mini + extropy config set models.fast openai/gpt-5-mini + extropy config set models.strong anthropic/claude-sonnet-4.5 + extropy config set simulation.strong openrouter/anthropic/claude-sonnet-4.5 extropy config reset """ if action == "show": @@ -82,35 +79,22 @@ def _show_config(): console.print("[bold]Extropy Configuration[/bold]") console.print("─" * 40) - # Pipeline zone + # Models (pipeline) console.print() - console.print("[bold cyan]Pipeline[/bold cyan] (spec, extend, persona, scenario)") - console.print(f" provider = {config.pipeline.provider}") - console.print( - f" model_simple = {config.pipeline.model_simple or '[dim](provider default)[/dim]'}" - ) console.print( - f" model_reasoning = {config.pipeline.model_reasoning or '[dim](provider default)[/dim]'}" - ) - console.print( - f" model_research = {config.pipeline.model_research or '[dim](provider default)[/dim]'}" + "[bold cyan]Models[/bold cyan] " + "(pipeline: spec, extend, persona, scenario)" ) + console.print(f" fast = {config.models.fast}") + console.print(f" strong = {config.models.strong}") - # Simulation zone + # Simulation console.print() console.print("[bold cyan]Simulation[/bold cyan] (agent reasoning)") - console.print(f" provider = {config.simulation.provider}") - console.print( - f" model = {config.simulation.model or '[dim](provider default)[/dim]'}" - ) - console.print( - f" pivotal_model = {config.simulation.pivotal_model or '[dim](same as model)[/dim]'}" - ) - console.print( - f" routine_model = {config.simulation.routine_model or '[dim](provider default)[/dim]'}" - ) - console.print( - f" api_format = {config.simulation.api_format or '[dim](auto)[/dim]'}" + strong_val = config.simulation.strong or "[dim](= models.strong)[/dim]" + fast_val = config.simulation.fast or "[dim](= models.fast)[/dim]" + console.print(f" strong = {strong_val}") + console.print(f" fast = {fast_val}" ) console.print(f" max_concurrent = {config.simulation.max_concurrent}") console.print( @@ -121,25 +105,24 @@ def _show_config(): if config.simulation.tpm_override: console.print(f" tpm_override = {config.simulation.tpm_override}") + # Custom providers + if config.providers: + console.print() + console.print("[bold cyan]Custom Providers[/bold cyan]") + for name, provider_cfg in config.providers.items(): + console.print(f" {name}:") + console.print(f" base_url = {provider_cfg.base_url}") + if provider_cfg.api_key_env: + console.print(f" api_key_env = {provider_cfg.api_key_env}") + # API keys status console.print() console.print("[bold cyan]API Keys[/bold cyan] (from env vars)") _show_key_status("openai", "OPENAI_API_KEY") - _show_key_status("claude", "ANTHROPIC_API_KEY") - _show_key_status("azure_openai", "AZURE_OPENAI_API_KEY") - - # Azure-specific config (show when Azure provider is in use) - active_providers = {config.pipeline.provider, config.simulation.provider} - if "azure_openai" in active_providers: - azure_cfg = get_azure_config("azure_openai") - console.print() - console.print("[bold cyan]Azure OpenAI[/bold cyan]") - console.print( - f" endpoint = {azure_cfg['azure_endpoint'] or '[dim]not set[/dim]'}" - ) - console.print(f" api_version = {azure_cfg['api_version']}") - if azure_cfg["azure_deployment"]: - console.print(f" deployment = {azure_cfg['azure_deployment']}") + _show_key_status("anthropic", "ANTHROPIC_API_KEY") + _show_key_status("azure", "AZURE_OPENAI_API_KEY") + _show_key_status("openrouter", "OPENROUTER_API_KEY") + _show_key_status("deepseek", "DEEPSEEK_API_KEY") # Config file console.print() @@ -152,7 +135,7 @@ def _show_config(): def _show_key_status(provider: str, env_var_label: str): """Show whether an API key is configured.""" - key = get_api_key(provider) + key = get_api_key_for_provider(provider) if key: masked = key[:8] + "..." + key[-4:] if len(key) > 16 else "***" console.print(f" {env_var_label}: [green]{masked}[/green]") @@ -162,35 +145,54 @@ def _show_key_status(provider: str, env_var_label: str): def _set_config(key: str, value: str): """Set a config value and save.""" - if key not in VALID_KEYS: + # Allow dynamic provider keys like providers.mycompany.base_url + is_provider_key = key.startswith("providers.") + if key not in VALID_KEYS and not is_provider_key: console.print(f"[red]Unknown key:[/red] {key}") console.print() console.print("Available keys:") for k in sorted(VALID_KEYS): console.print(f" {k}") + console.print(" providers..base_url") + console.print(" providers..api_key_env") raise typer.Exit(1) # Load current config (or defaults if no file) config = get_config() - zone, field = key.split(".", 1) - if zone == "pipeline": - target = config.pipeline - elif zone == "simulation": - target = config.simulation - else: - console.print(f"[red]Unknown zone:[/red] {zone}") - raise typer.Exit(1) - - # Type coercion - if field in INT_FIELDS: - try: - setattr(target, field, int(value)) - except ValueError: - console.print(f"[red]Invalid integer value:[/red] {value}") + if is_provider_key: + parts = key.split(".", 2) + if len(parts) != 3 or parts[2] not in ("base_url", "api_key_env"): + console.print( + f"[red]Invalid provider key:[/red] {key}\n" + "Expected: providers..base_url or providers..api_key_env" + ) raise typer.Exit(1) + provider_name = parts[1] + field = parts[2] + from ...config import CustomProviderConfig + if provider_name not in config.providers: + config.providers[provider_name] = CustomProviderConfig() + setattr(config.providers[provider_name], field, value) else: - setattr(target, field, value) + zone, field_name = key.split(".", 1) + if zone == "models": + target = config.models + elif zone == "simulation": + target = config.simulation + else: + console.print(f"[red]Unknown zone:[/red] {zone}") + raise typer.Exit(1) + + # Type coercion + if field_name in INT_FIELDS: + try: + setattr(target, field_name, int(value)) + except ValueError: + console.print(f"[red]Invalid integer value:[/red] {value}") + raise typer.Exit(1) + else: + setattr(target, field_name, value) config.save() reset_config() # Clear cached singleton so next get_config() reloads diff --git a/extropy/cli/commands/estimate.py b/extropy/cli/commands/estimate.py index 90f26a0..d175484 100644 --- a/extropy/cli/commands/estimate.py +++ b/extropy/cli/commands/estimate.py @@ -11,21 +11,16 @@ def estimate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), - model: str = typer.Option( + strong: str = typer.Option( "", - "--model", + "--strong", "-m", - help="LLM model for agent reasoning (empty = use config default)", + help="Strong model for Pass 1 (provider/model format)", ), - pivotal_model: str = typer.Option( + fast: str = typer.Option( "", - "--pivotal-model", - help="Model for pivotal/first-pass reasoning (default: same as --model)", - ), - routine_model: str = typer.Option( - "", - "--routine-model", - help="Cheap model for classification pass (default: provider cheap tier)", + "--fast", + help="Fast model for Pass 2 (provider/model format)", ), threshold: int = typer.Option( 3, "--threshold", "-t", help="Multi-touch threshold for re-reasoning" @@ -42,8 +37,9 @@ def estimate_command( Example: extropy estimate scenario.yaml --study-db study.db - extropy estimate scenario.yaml --study-db study.db --model gpt-5-mini - extropy estimate scenario.yaml --study-db study.db --pivotal-model gpt-5 --routine-model gpt-5-mini -v + extropy estimate scenario.yaml --study-db study.db --strong openai/gpt-5 + extropy estimate scenario.yaml --study-db study.db \\ + --strong openai/gpt-5 --fast openai/gpt-5-mini -v """ from ...config import get_config from ...core.models import ScenarioSpec, PopulationSpec @@ -90,11 +86,8 @@ def estimate_command( # Resolve config config = get_config() - provider = config.simulation.provider - - eff_model = model or config.simulation.model - eff_pivotal = pivotal_model or config.simulation.pivotal_model or eff_model - eff_routine = routine_model or config.simulation.routine_model + effective_strong = strong or config.resolve_sim_strong() + effective_fast = fast or config.resolve_sim_fast() # Run estimation est = estimate_simulation_cost( @@ -102,9 +95,8 @@ def estimate_command( population_spec=population_spec, agents=agents, network=network, - provider=provider, - pivotal_model=eff_pivotal, - routine_model=eff_routine, + strong_model=effective_strong, + fast_model=effective_fast, multi_touch_threshold=threshold, ) @@ -129,10 +121,10 @@ def estimate_command( # Models section console.print("[bold]Models[/bold]") _print_model_line( - console, "Pass 1 (pivotal)", est.pivotal_model, est.pivotal_pricing + console, "Pass 1 (strong)", est.pivotal_model, est.pivotal_pricing ) _print_model_line( - console, "Pass 2 (routine)", est.routine_model, est.routine_pricing + console, "Pass 2 (fast)", est.routine_model, est.routine_pricing ) console.print() diff --git a/extropy/cli/commands/simulate.py b/extropy/cli/commands/simulate.py index 7d6168e..9746ec7 100644 --- a/extropy/cli/commands/simulate.py +++ b/extropy/cli/commands/simulate.py @@ -102,21 +102,16 @@ def simulate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), output: Path = typer.Option(..., "--output", "-o", help="Output results directory"), study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), - model: str = typer.Option( + strong: str = typer.Option( "", - "--model", + "--strong", "-m", - help="LLM model for agent reasoning (empty = use config default)", + help="Strong model for Pass 1 (provider/model format)", ), - pivotal_model: str = typer.Option( + fast: str = typer.Option( "", - "--pivotal-model", - help="Model for pivotal/first-pass reasoning (default: same as --model)", - ), - routine_model: str = typer.Option( - "", - "--routine-model", - help="Cheap model for classification pass (default: provider cheap tier)", + "--fast", + help="Fast model for Pass 2 (provider/model format)", ), threshold: int = typer.Option( 3, "--threshold", "-t", help="Multi-touch threshold for re-reasoning" @@ -240,29 +235,18 @@ def simulate_command( config = get_config() # Resolve models from CLI args > config > defaults - effective_model = model or config.simulation.model - effective_pivotal = pivotal_model or config.simulation.pivotal_model - effective_routine = routine_model or config.simulation.routine_model + effective_strong = strong or config.resolve_sim_strong() + effective_fast = fast or config.resolve_sim_fast() effective_tier = rate_tier or config.simulation.rate_tier effective_rpm = rpm_override or config.simulation.rpm_override effective_tpm = tpm_override or config.simulation.tpm_override - display_model = effective_model or f"({config.simulation.provider} default)" - display_provider = config.simulation.provider - console.print(f"Simulating: [bold]{scenario_file}[/bold]") console.print(f"Output: {output}") console.print(f"Study DB: {study_db}") console.print( - f"Provider: {display_provider} | Model: {display_model} | Threshold: {threshold}" + f"Strong: {effective_strong} | Fast: {effective_fast} | Threshold: {threshold}" ) - if effective_pivotal or effective_routine: - parts = [] - if effective_pivotal: - parts.append(f"Pivotal: {effective_pivotal}") - if effective_routine: - parts.append(f"Routine: {effective_routine}") - console.print(" | ".join(parts)) if effective_tier: console.print(f"Rate tier: {effective_tier}") if effective_rpm or effective_tpm: @@ -319,9 +303,8 @@ def on_progress(timestep: int, max_timesteps: int, status: str): scenario_path=scenario_file, output_dir=output, study_db_path=study_db, - model=effective_model, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong=effective_strong, + fast=effective_fast, multi_touch_threshold=threshold, random_seed=seed, on_progress=on_progress, @@ -353,9 +336,8 @@ def do_simulation(): scenario_path=scenario_file, output_dir=output, study_db_path=study_db, - model=effective_model, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong=effective_strong, + fast=effective_fast, multi_touch_threshold=threshold, random_seed=seed, on_progress=on_progress if not quiet else None, diff --git a/tests/test_cli.py b/tests/test_cli.py index 3c624af..15f33da 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -17,7 +17,7 @@ class TestConfigCommand: def test_config_show(self): result = runner.invoke(app, ["config", "show"]) assert result.exit_code == 0 - assert "Pipeline" in result.output + assert "Models" in result.output assert "Simulation" in result.output def test_config_set_invalid_key(self): diff --git a/tests/test_estimator.py b/tests/test_estimator.py index 9bfe443..a40a053 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -323,7 +323,8 @@ def test_basic_estimate( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", + strong_model="openai/gpt-5", + fast_model="openai/gpt-5-mini", ) assert est.population_size == 10 @@ -370,7 +371,8 @@ def test_model_resolution_openai( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", + strong_model="openai/gpt-5", + fast_model="openai/gpt-5-mini", ) assert est.pivotal_model == "gpt-5" assert est.routine_model == "gpt-5-mini" @@ -383,7 +385,8 @@ def test_model_resolution_claude( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="claude", + strong_model="anthropic/claude-sonnet-4-5-20250929", + fast_model="anthropic/claude-haiku-4-5-20251001", ) assert est.pivotal_model == "claude-sonnet-4-5-20250929" assert est.routine_model == "claude-haiku-4-5-20251001" @@ -396,9 +399,8 @@ def test_explicit_model_override( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", - pivotal_model="gpt-5-mini", - routine_model="gpt-5-mini", + strong_model="openai/gpt-5-mini", + fast_model="openai/gpt-5-mini", ) assert est.pivotal_model == "gpt-5-mini" assert est.routine_model == "gpt-5-mini" @@ -411,8 +413,8 @@ def test_unknown_model_pricing_none( population_spec=small_pop_spec, agents=small_agents, network=small_network, - pivotal_model="unknown-model-x", - routine_model="unknown-model-y", + strong_model="openai/unknown-model-x", + fast_model="openai/unknown-model-y", ) assert est.pivotal_pricing is None assert est.routine_pricing is None diff --git a/tests/test_providers.py b/tests/test_providers.py index 291d846..752eb77 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -380,9 +380,8 @@ def test_no_validator_returns_immediately(self): """With no validator, first result is returned.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -421,9 +420,8 @@ def test_initial_prompt_used_on_first_call(self): """When initial_prompt is provided, it should be used for the first call.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -463,9 +461,8 @@ def test_validation_retries_use_base_prompt_not_initial(self): """Validation retries should use prompt, not initial_prompt.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -528,9 +525,8 @@ def test_validator_succeeds_on_first_attempt_with_initial_prompt(self): """When validator passes on first try with initial_prompt, no retries occur.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -574,9 +570,8 @@ def test_on_retry_callback_invoked_correctly(self): """Test that on_retry callback is invoked with correct parameters.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -640,9 +635,8 @@ def test_no_initial_prompt_defaults_to_prompt(self): """When initial_prompt is None, prompt is used for first call.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -813,12 +807,10 @@ class TestProviderFactoryAzure: ) def test_create_azure_openai_provider(self): from extropy.core.providers import _create_provider + from extropy.core.providers.openai_compat import OpenAICompatProvider provider = _create_provider("azure_openai") - assert isinstance(provider, OpenAIProvider) - assert provider._is_azure is True - assert provider._azure_endpoint == "https://my-resource.openai.azure.com" - assert provider._azure_deployment == "my-deployment" + assert isinstance(provider, OpenAICompatProvider) @patch.dict( "os.environ", From ce7b2b6b9da57e59cc6a3ac46059390daae5c638 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:07:07 -0500 Subject: [PATCH 09/15] feat(cost): add cost tracking package, wire providers, CLI --cost flag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create extropy/core/cost/ package (pricing.py, tracker.py, ledger.py) with three-tier pricing resolution (OpenRouter API → cache → fallback), session-scoped CostTracker singleton, and SQLite cost ledger - All models use Pydantic BaseModel (ModelPricing, CallRecord, ModelUsage, CostEntry, TokenUsage) - Wire _record_usage() hook into LLMProvider base class - Extract + record token usage from all OpenAI and Anthropic provider methods (simple_call, reasoning_call, agentic_research, simple_call_async) - Add --cost CLI flag and defaults.show_cost config option for cost footer - Add atexit hook to print cost summary and persist to ledger - Includes ruff format/lint fixes across codebase Co-authored-by: Cursor --- extropy/cli/app.py | 49 ++- extropy/cli/commands/chat.py | 32 +- extropy/cli/commands/config_cmd.py | 160 +++---- extropy/cli/commands/estimate.py | 40 +- extropy/cli/commands/inspect.py | 12 +- extropy/cli/commands/migrate.py | 8 +- extropy/cli/commands/network.py | 4 +- extropy/cli/commands/query.py | 4 +- extropy/cli/commands/results.py | 4 +- extropy/cli/commands/simulate.py | 44 +- extropy/config.py | 555 +++++++++++++++++++----- extropy/core/cost/__init__.py | 39 ++ extropy/core/cost/ledger.py | 278 ++++++++++++ extropy/core/cost/pricing.py | 369 ++++++++++++++++ extropy/core/cost/tracker.py | 288 ++++++++++++ extropy/core/llm.py | 91 ++-- extropy/core/models/scenario.py | 8 +- extropy/core/models/simulation.py | 25 +- extropy/core/pricing.py | 54 ++- extropy/core/providers/__init__.py | 248 ++++++++--- extropy/core/providers/anthropic.py | 388 +++++++++++++++++ extropy/core/providers/base.py | 48 +- extropy/core/providers/claude.py | 370 +--------------- extropy/core/providers/openai.py | 51 ++- extropy/core/rate_limiter.py | 67 ++- extropy/core/rate_limits.py | 27 +- extropy/population/network/generator.py | 26 +- extropy/simulation/engine.py | 238 ++++++---- extropy/simulation/estimator.py | 23 +- extropy/simulation/reasoning.py | 10 +- extropy/simulation/state.py | 59 ++- extropy/utils/resource_governor.py | 4 +- tests/test_cli.py | 7 +- tests/test_compiler.py | 5 +- tests/test_estimator.py | 18 +- tests/test_providers.py | 36 +- tests/test_scenario_validator.py | 5 +- 37 files changed, 2750 insertions(+), 944 deletions(-) create mode 100644 extropy/core/cost/__init__.py create mode 100644 extropy/core/cost/ledger.py create mode 100644 extropy/core/cost/pricing.py create mode 100644 extropy/core/cost/tracker.py create mode 100644 extropy/core/providers/anthropic.py diff --git a/extropy/cli/app.py b/extropy/cli/app.py index 561f0c4..327b535 100644 --- a/extropy/cli/app.py +++ b/extropy/cli/app.py @@ -1,5 +1,6 @@ """Core CLI app definition and global state.""" +import atexit from typing import Annotated import typer @@ -15,6 +16,7 @@ # Global state for JSON mode (set by callback) _json_mode = False +_show_cost = False def get_json_mode() -> bool: @@ -30,6 +32,28 @@ def _version_callback(value: bool) -> None: raise typer.Exit() +def _print_cost_footer() -> None: + """Print cost summary footer at CLI exit (if enabled and there are records).""" + try: + from ..core.cost.tracker import CostTracker + from ..core.cost.ledger import record_session + + tracker = CostTracker.get() + if not tracker.has_records: + return + + # Persist to ledger + summary = tracker.summary() + record_session(summary) + + # Print footer + line = tracker.summary_line() + if line: + console.print(f"\n[dim]Cost: {line}[/dim]") + except Exception: + pass # Never let cost display crash the CLI + + @app.callback() def main_callback( json_output: Annotated[ @@ -49,14 +73,37 @@ def main_callback( is_eager=True, ), ] = False, + cost: Annotated[ + bool, + typer.Option( + "--cost", + help="Show cost summary after command completes", + is_eager=True, + ), + ] = False, ): """Extropy: Population simulation engine for agent-based modeling. Use --json for machine-readable output suitable for scripting and AI tools. + Use --cost to show token usage and cost summary after each command. """ - global _json_mode + global _json_mode, _show_cost _json_mode = json_output + # Determine if cost footer should be shown: --cost flag or config setting + show = cost + if not show: + try: + from ..config import get_config + + show = get_config().defaults.show_cost + except Exception: + pass + + _show_cost = show + if _show_cost: + atexit.register(_print_cost_footer) + # Import commands to register them with the app from .commands import ( # noqa: E402, F401 diff --git a/extropy/cli/commands/chat.py b/extropy/cli/commands/chat.py index 6ea4cdf..eee7867 100644 --- a/extropy/cli/commands/chat.py +++ b/extropy/cli/commands/chat.py @@ -100,7 +100,9 @@ def _summarize_context(context: dict[str, Any], prompt: str) -> str: top_attrs = [(k, v) for k, v in attrs.items() if not str(k).startswith("_")] top_attrs = sorted(top_attrs)[:8] if top_attrs: - lines.append("- Key attributes: " + ", ".join(f"{k}={v}" for k, v in top_attrs)) + lines.append( + "- Key attributes: " + ", ".join(f"{k}={v}" for k, v in top_attrs) + ) if timeline: lines.append("- Recent timeline events:") @@ -184,19 +186,25 @@ def chat_interactive( n = int(parts[1]) if len(parts) > 1 else 10 except ValueError: n = 10 - context, _ = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=max(1, n)) + context, _ = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=max(1, n) + ) for item in context.get("timeline", []): console.print( f"t={item.get('timestep')} {item.get('event_type')} {item.get('details_json') or '{}'}" ) continue if prompt == "/context": - context, _ = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=10) + context, _ = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=10 + ) console.print_json(data=context) continue started = time.time() - context, citations = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=12) + context, citations = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=12 + ) answer = _summarize_context(context, prompt) latency_ms = int((time.time() - started) * 1000) @@ -207,7 +215,11 @@ def chat_interactive( "assistant", answer, citations={"sources": citations}, - token_usage={"input_tokens": 0, "output_tokens": 0, "latency_ms": latency_ms}, + token_usage={ + "input_tokens": 0, + "output_tokens": 0, + "latency_ms": latency_ms, + }, ) console.print(answer) @@ -247,7 +259,9 @@ def chat_ask( conn = sqlite3.connect(str(study_db)) conn.row_factory = sqlite3.Row try: - context, citations = _load_agent_chat_context(conn, run_id, agent_id, timeline_n=12) + context, citations = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=12 + ) answer = _summarize_context(context, prompt) finally: conn.close() @@ -261,7 +275,11 @@ def chat_ask( "assistant", answer, citations={"sources": citations}, - token_usage={"input_tokens": 0, "output_tokens": 0, "latency_ms": latency_ms}, + token_usage={ + "input_tokens": 0, + "output_tokens": 0, + "latency_ms": latency_ms, + }, ) payload = { diff --git a/extropy/cli/commands/config_cmd.py b/extropy/cli/commands/config_cmd.py index c2945e7..eeb8be9 100644 --- a/extropy/cli/commands/config_cmd.py +++ b/extropy/cli/commands/config_cmd.py @@ -7,28 +7,30 @@ get_config, reset_config, CONFIG_FILE, - get_api_key, - get_azure_config, + get_api_key_for_provider, ) VALID_KEYS = { - "pipeline.provider", - "pipeline.model_simple", - "pipeline.model_reasoning", - "pipeline.model_research", - "simulation.provider", - "simulation.model", - "simulation.pivotal_model", - "simulation.routine_model", + "models.fast", + "models.strong", + "simulation.fast", + "simulation.strong", "simulation.max_concurrent", "simulation.rate_tier", "simulation.rpm_override", "simulation.tpm_override", - "simulation.api_format", + "defaults.population_size", + "defaults.db_path", } -INT_FIELDS = {"max_concurrent", "rate_tier", "rpm_override", "tpm_override"} +INT_FIELDS = { + "max_concurrent", + "rate_tier", + "rpm_override", + "tpm_override", + "population_size", +} @app.command("config") @@ -39,7 +41,7 @@ def config_command( ), key: str | None = typer.Argument( None, - help="Config key (e.g. pipeline.provider, simulation.model)", + help="Config key (e.g. models.fast, simulation.strong)", ), value: str | None = typer.Argument( None, @@ -50,9 +52,9 @@ def config_command( Examples: extropy config show - extropy config set pipeline.provider claude - extropy config set simulation.provider openai - extropy config set simulation.model gpt-5-mini + extropy config set models.fast openai/gpt-5-mini + extropy config set models.strong anthropic/claude-sonnet-4.5 + extropy config set simulation.strong openrouter/anthropic/claude-sonnet-4.5 extropy config reset """ if action == "show": @@ -82,36 +84,21 @@ def _show_config(): console.print("[bold]Extropy Configuration[/bold]") console.print("─" * 40) - # Pipeline zone + # Models (pipeline) console.print() - console.print("[bold cyan]Pipeline[/bold cyan] (spec, extend, persona, scenario)") - console.print(f" provider = {config.pipeline.provider}") - console.print( - f" model_simple = {config.pipeline.model_simple or '[dim](provider default)[/dim]'}" - ) console.print( - f" model_reasoning = {config.pipeline.model_reasoning or '[dim](provider default)[/dim]'}" - ) - console.print( - f" model_research = {config.pipeline.model_research or '[dim](provider default)[/dim]'}" + "[bold cyan]Models[/bold cyan] (pipeline: spec, extend, persona, scenario)" ) + console.print(f" fast = {config.models.fast}") + console.print(f" strong = {config.models.strong}") - # Simulation zone + # Simulation console.print() console.print("[bold cyan]Simulation[/bold cyan] (agent reasoning)") - console.print(f" provider = {config.simulation.provider}") - console.print( - f" model = {config.simulation.model or '[dim](provider default)[/dim]'}" - ) - console.print( - f" pivotal_model = {config.simulation.pivotal_model or '[dim](same as model)[/dim]'}" - ) - console.print( - f" routine_model = {config.simulation.routine_model or '[dim](provider default)[/dim]'}" - ) - console.print( - f" api_format = {config.simulation.api_format or '[dim](auto)[/dim]'}" - ) + strong_val = config.simulation.strong or "[dim](= models.strong)[/dim]" + fast_val = config.simulation.fast or "[dim](= models.fast)[/dim]" + console.print(f" strong = {strong_val}") + console.print(f" fast = {fast_val}") console.print(f" max_concurrent = {config.simulation.max_concurrent}") console.print( f" rate_tier = {config.simulation.rate_tier or '[dim](tier 1)[/dim]'}" @@ -121,25 +108,30 @@ def _show_config(): if config.simulation.tpm_override: console.print(f" tpm_override = {config.simulation.tpm_override}") + # Custom providers + if config.providers: + console.print() + console.print("[bold cyan]Custom Providers[/bold cyan]") + for name, provider_cfg in config.providers.items(): + console.print(f" {name}:") + console.print(f" base_url = {provider_cfg.base_url}") + if provider_cfg.api_key_env: + console.print(f" api_key_env = {provider_cfg.api_key_env}") + + # Defaults + console.print() + console.print("[bold cyan]Defaults[/bold cyan]") + console.print(f" population_size = {config.defaults.population_size}") + console.print(f" db_path = {config.defaults.db_path}") + # API keys status console.print() console.print("[bold cyan]API Keys[/bold cyan] (from env vars)") _show_key_status("openai", "OPENAI_API_KEY") - _show_key_status("claude", "ANTHROPIC_API_KEY") - _show_key_status("azure_openai", "AZURE_OPENAI_API_KEY") - - # Azure-specific config (show when Azure provider is in use) - active_providers = {config.pipeline.provider, config.simulation.provider} - if "azure_openai" in active_providers: - azure_cfg = get_azure_config("azure_openai") - console.print() - console.print("[bold cyan]Azure OpenAI[/bold cyan]") - console.print( - f" endpoint = {azure_cfg['azure_endpoint'] or '[dim]not set[/dim]'}" - ) - console.print(f" api_version = {azure_cfg['api_version']}") - if azure_cfg["azure_deployment"]: - console.print(f" deployment = {azure_cfg['azure_deployment']}") + _show_key_status("anthropic", "ANTHROPIC_API_KEY") + _show_key_status("azure", "AZURE_OPENAI_API_KEY") + _show_key_status("openrouter", "OPENROUTER_API_KEY") + _show_key_status("deepseek", "DEEPSEEK_API_KEY") # Config file console.print() @@ -152,7 +144,7 @@ def _show_config(): def _show_key_status(provider: str, env_var_label: str): """Show whether an API key is configured.""" - key = get_api_key(provider) + key = get_api_key_for_provider(provider) if key: masked = key[:8] + "..." + key[-4:] if len(key) > 16 else "***" console.print(f" {env_var_label}: [green]{masked}[/green]") @@ -162,35 +154,57 @@ def _show_key_status(provider: str, env_var_label: str): def _set_config(key: str, value: str): """Set a config value and save.""" - if key not in VALID_KEYS: + # Allow dynamic provider keys like providers.mycompany.base_url + is_provider_key = key.startswith("providers.") + if key not in VALID_KEYS and not is_provider_key: console.print(f"[red]Unknown key:[/red] {key}") console.print() console.print("Available keys:") for k in sorted(VALID_KEYS): console.print(f" {k}") + console.print(" providers..base_url") + console.print(" providers..api_key_env") raise typer.Exit(1) # Load current config (or defaults if no file) config = get_config() - zone, field = key.split(".", 1) - if zone == "pipeline": - target = config.pipeline - elif zone == "simulation": - target = config.simulation - else: - console.print(f"[red]Unknown zone:[/red] {zone}") - raise typer.Exit(1) - - # Type coercion - if field in INT_FIELDS: - try: - setattr(target, field, int(value)) - except ValueError: - console.print(f"[red]Invalid integer value:[/red] {value}") + if is_provider_key: + parts = key.split(".", 2) + if len(parts) != 3 or parts[2] not in ("base_url", "api_key_env"): + console.print( + f"[red]Invalid provider key:[/red] {key}\n" + "Expected: providers..base_url or providers..api_key_env" + ) raise typer.Exit(1) + provider_name = parts[1] + field = parts[2] + from ...config import CustomProviderConfig + + if provider_name not in config.providers: + config.providers[provider_name] = CustomProviderConfig() + setattr(config.providers[provider_name], field, value) else: - setattr(target, field, value) + zone, field_name = key.split(".", 1) + if zone == "models": + target = config.models + elif zone == "simulation": + target = config.simulation + elif zone == "defaults": + target = config.defaults + else: + console.print(f"[red]Unknown zone:[/red] {zone}") + raise typer.Exit(1) + + # Type coercion + if field_name in INT_FIELDS: + try: + setattr(target, field_name, int(value)) + except ValueError: + console.print(f"[red]Invalid integer value:[/red] {value}") + raise typer.Exit(1) + else: + setattr(target, field_name, value) config.save() reset_config() # Clear cached singleton so next get_config() reloads diff --git a/extropy/cli/commands/estimate.py b/extropy/cli/commands/estimate.py index 90f26a0..71f5597 100644 --- a/extropy/cli/commands/estimate.py +++ b/extropy/cli/commands/estimate.py @@ -11,21 +11,16 @@ def estimate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), - model: str = typer.Option( + strong: str = typer.Option( "", - "--model", + "--strong", "-m", - help="LLM model for agent reasoning (empty = use config default)", + help="Strong model for Pass 1 (provider/model format)", ), - pivotal_model: str = typer.Option( + fast: str = typer.Option( "", - "--pivotal-model", - help="Model for pivotal/first-pass reasoning (default: same as --model)", - ), - routine_model: str = typer.Option( - "", - "--routine-model", - help="Cheap model for classification pass (default: provider cheap tier)", + "--fast", + help="Fast model for Pass 2 (provider/model format)", ), threshold: int = typer.Option( 3, "--threshold", "-t", help="Multi-touch threshold for re-reasoning" @@ -42,8 +37,9 @@ def estimate_command( Example: extropy estimate scenario.yaml --study-db study.db - extropy estimate scenario.yaml --study-db study.db --model gpt-5-mini - extropy estimate scenario.yaml --study-db study.db --pivotal-model gpt-5 --routine-model gpt-5-mini -v + extropy estimate scenario.yaml --study-db study.db --strong openai/gpt-5 + extropy estimate scenario.yaml --study-db study.db \\ + --strong openai/gpt-5 --fast openai/gpt-5-mini -v """ from ...config import get_config from ...core.models import ScenarioSpec, PopulationSpec @@ -90,11 +86,8 @@ def estimate_command( # Resolve config config = get_config() - provider = config.simulation.provider - - eff_model = model or config.simulation.model - eff_pivotal = pivotal_model or config.simulation.pivotal_model or eff_model - eff_routine = routine_model or config.simulation.routine_model + effective_strong = strong or config.resolve_sim_strong() + effective_fast = fast or config.resolve_sim_fast() # Run estimation est = estimate_simulation_cost( @@ -102,9 +95,8 @@ def estimate_command( population_spec=population_spec, agents=agents, network=network, - provider=provider, - pivotal_model=eff_pivotal, - routine_model=eff_routine, + strong_model=effective_strong, + fast_model=effective_fast, multi_touch_threshold=threshold, ) @@ -129,11 +121,9 @@ def estimate_command( # Models section console.print("[bold]Models[/bold]") _print_model_line( - console, "Pass 1 (pivotal)", est.pivotal_model, est.pivotal_pricing - ) - _print_model_line( - console, "Pass 2 (routine)", est.routine_model, est.routine_pricing + console, "Pass 1 (strong)", est.pivotal_model, est.pivotal_pricing ) + _print_model_line(console, "Pass 2 (fast)", est.routine_model, est.routine_pricing) console.print() # Calls table diff --git a/extropy/cli/commands/inspect.py b/extropy/cli/commands/inspect.py index 3f83101..71600b2 100644 --- a/extropy/cli/commands/inspect.py +++ b/extropy/cli/commands/inspect.py @@ -56,11 +56,15 @@ def inspect_agent( conn.row_factory = sqlite3.Row try: cur = conn.cursor() - cur.execute("SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,)) + cur.execute( + "SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,) + ) attrs_row = cur.fetchone() attrs = json.loads(attrs_row["attrs_json"]) if attrs_row else {} - cur.execute("SELECT * FROM agent_states WHERE agent_id = ? LIMIT 1", (agent_id,)) + cur.execute( + "SELECT * FROM agent_states WHERE agent_id = ? LIMIT 1", (agent_id,) + ) state = cur.fetchone() cur.execute( @@ -81,7 +85,9 @@ def inspect_agent( if state: console.print("[bold]State[/bold]") - console.print(f" aware={bool(state['aware'])} will_share={bool(state['will_share'])}") + console.print( + f" aware={bool(state['aware'])} will_share={bool(state['will_share'])}" + ) console.print( f" position={state['private_position'] or state['position']} " f"sentiment={state['private_sentiment'] if state['private_sentiment'] is not None else state['sentiment']}" diff --git a/extropy/cli/commands/migrate.py b/extropy/cli/commands/migrate.py index e26bb91..e0d3643 100644 --- a/extropy/cli/commands/migrate.py +++ b/extropy/cli/commands/migrate.py @@ -40,13 +40,17 @@ def migrate_legacy_artifacts( ): """Ingest legacy `agents.json`/`network.json` into `study.db`.""" if agents_file is None and network_file is None: - console.print("[red]✗[/red] Provide at least one of --agents-file or --network-file") + console.print( + "[red]✗[/red] Provide at least one of --agents-file or --network-file" + ) raise typer.Exit(1) with open_study_db(study_db) as db: if population_spec is not None: if not population_spec.exists(): - console.print(f"[red]✗[/red] population spec not found: {population_spec}") + console.print( + f"[red]✗[/red] population spec not found: {population_spec}" + ) raise typer.Exit(1) db.save_population_spec( population_id=population_id, diff --git a/extropy/cli/commands/network.py b/extropy/cli/commands/network.py index d1b8eaa..32cd401 100644 --- a/extropy/cli/commands/network.py +++ b/extropy/cli/commands/network.py @@ -14,9 +14,7 @@ @app.command("network") def network_command( - study_db: Path = typer.Option( - ..., "--study-db", help="Canonical study DB file" - ), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), population_id: str = typer.Option( "default", "--population-id", help="Population ID in study DB" ), diff --git a/extropy/cli/commands/query.py b/extropy/cli/commands/query.py index f15210c..9bf158c 100644 --- a/extropy/cli/commands/query.py +++ b/extropy/cli/commands/query.py @@ -39,7 +39,9 @@ def query_sql( req = ReadOnlySQLRequest(sql=sql, limit=limit) normalized = req.sql.strip().lower() if not normalized.startswith(_ALLOWED_PREFIXES): - console.print("[red]✗[/red] Only read-only SELECT/WITH/EXPLAIN queries are allowed") + console.print( + "[red]✗[/red] Only read-only SELECT/WITH/EXPLAIN queries are allowed" + ) raise typer.Exit(1) padded = f" {normalized} " if ";" in req.sql.strip().rstrip(";"): diff --git a/extropy/cli/commands/results.py b/extropy/cli/commands/results.py index b6ece60..c555479 100644 --- a/extropy/cli/commands/results.py +++ b/extropy/cli/commands/results.py @@ -170,7 +170,9 @@ def _display_agent(conn: sqlite3.Connection, agent_id: str) -> None: ) row = cur.fetchone() if not row: - console.print(f"[yellow]Agent not found in simulation state: {agent_id}[/yellow]") + console.print( + f"[yellow]Agent not found in simulation state: {agent_id}[/yellow]" + ) return cur.execute("SELECT attrs_json FROM agents WHERE agent_id = ? LIMIT 1", (agent_id,)) diff --git a/extropy/cli/commands/simulate.py b/extropy/cli/commands/simulate.py index 7d6168e..9746ec7 100644 --- a/extropy/cli/commands/simulate.py +++ b/extropy/cli/commands/simulate.py @@ -102,21 +102,16 @@ def simulate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), output: Path = typer.Option(..., "--output", "-o", help="Output results directory"), study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), - model: str = typer.Option( + strong: str = typer.Option( "", - "--model", + "--strong", "-m", - help="LLM model for agent reasoning (empty = use config default)", + help="Strong model for Pass 1 (provider/model format)", ), - pivotal_model: str = typer.Option( + fast: str = typer.Option( "", - "--pivotal-model", - help="Model for pivotal/first-pass reasoning (default: same as --model)", - ), - routine_model: str = typer.Option( - "", - "--routine-model", - help="Cheap model for classification pass (default: provider cheap tier)", + "--fast", + help="Fast model for Pass 2 (provider/model format)", ), threshold: int = typer.Option( 3, "--threshold", "-t", help="Multi-touch threshold for re-reasoning" @@ -240,29 +235,18 @@ def simulate_command( config = get_config() # Resolve models from CLI args > config > defaults - effective_model = model or config.simulation.model - effective_pivotal = pivotal_model or config.simulation.pivotal_model - effective_routine = routine_model or config.simulation.routine_model + effective_strong = strong or config.resolve_sim_strong() + effective_fast = fast or config.resolve_sim_fast() effective_tier = rate_tier or config.simulation.rate_tier effective_rpm = rpm_override or config.simulation.rpm_override effective_tpm = tpm_override or config.simulation.tpm_override - display_model = effective_model or f"({config.simulation.provider} default)" - display_provider = config.simulation.provider - console.print(f"Simulating: [bold]{scenario_file}[/bold]") console.print(f"Output: {output}") console.print(f"Study DB: {study_db}") console.print( - f"Provider: {display_provider} | Model: {display_model} | Threshold: {threshold}" + f"Strong: {effective_strong} | Fast: {effective_fast} | Threshold: {threshold}" ) - if effective_pivotal or effective_routine: - parts = [] - if effective_pivotal: - parts.append(f"Pivotal: {effective_pivotal}") - if effective_routine: - parts.append(f"Routine: {effective_routine}") - console.print(" | ".join(parts)) if effective_tier: console.print(f"Rate tier: {effective_tier}") if effective_rpm or effective_tpm: @@ -319,9 +303,8 @@ def on_progress(timestep: int, max_timesteps: int, status: str): scenario_path=scenario_file, output_dir=output, study_db_path=study_db, - model=effective_model, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong=effective_strong, + fast=effective_fast, multi_touch_threshold=threshold, random_seed=seed, on_progress=on_progress, @@ -353,9 +336,8 @@ def do_simulation(): scenario_path=scenario_file, output_dir=output, study_db_path=study_db, - model=effective_model, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong=effective_strong, + fast=effective_fast, multi_touch_threshold=threshold, random_seed=seed, on_progress=on_progress if not quiet else None, diff --git a/extropy/config.py b/extropy/config.py index 55b7cb0..e157132 100644 --- a/extropy/config.py +++ b/extropy/config.py @@ -1,12 +1,14 @@ """Configuration management for Extropy. -Two-zone config system: -- pipeline: provider + models for phases 1-2 (spec, extend, sample, network, persona, scenario) -- simulation: provider + model for phase 3 (agent reasoning) +Two-tier config system: +- models: fast/strong model strings for pipeline phases 1-2 +- simulation: fast/strong model strings for phase 3 (agent reasoning) + +Model strings use "provider/model" format (e.g., "openai/gpt-5-mini"). Config resolution order (highest priority first): 1. Programmatic (ExtropyConfig constructed in code) -2. Environment variables (PIPELINE_PROVIDER, SIMULATION_MODEL, etc.) +2. Environment variables (MODELS_FAST, MODELS_STRONG, etc.) 3. Config file (~/.config/extropy/config.json, managed by `extropy config`) 4. Hardcoded defaults @@ -16,6 +18,7 @@ import json import logging import os +import warnings from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Any @@ -33,37 +36,120 @@ # ============================================================================= -# Two-zone config dataclasses +# Model string parsing +# ============================================================================= + + +def parse_model_string(model_string: str) -> tuple[str, str]: + """Parse a "provider/model" string into (provider, model) tuple. + + Examples: + "openai/gpt-5-mini" → ("openai", "gpt-5-mini") + "anthropic/claude-sonnet-4.5" → ("anthropic", "claude-sonnet-4.5") + "openrouter/anthropic/claude-sonnet-4.5" → ("openrouter", "anthropic/claude-sonnet-4.5") + + Raises: + ValueError: If the string doesn't contain a '/' separator. + """ + if "/" not in model_string: + raise ValueError( + f"Invalid model string: {model_string!r}. " + f"Expected format: 'provider/model' (e.g., 'openai/gpt-5-mini')" + ) + provider, _, model = model_string.partition("/") + if not provider or not model: + raise ValueError( + f"Invalid model string: {model_string!r}. " + f"Both provider and model must be non-empty." + ) + return provider, model + + +# ============================================================================= +# New two-tier config dataclasses +# ============================================================================= + + +@dataclass +class ModelsConfig: + """Pipeline model configuration (phases 1-2). + + Uses "provider/model" format strings. + - fast: used for simple_call (cheap, fast tasks) + - strong: used for reasoning_call, agentic_research (complex tasks) + """ + + fast: str = "openai/gpt-5-mini" + strong: str = "openai/gpt-5" + + +@dataclass +class SimulationConfig: + """Simulation model + tuning configuration (phase 3). + + Uses "provider/model" format strings. + - fast: used for Pass 2 (classification/routine) + - strong: used for Pass 1 (pivotal/role-play reasoning) + """ + + fast: str = "" # empty = same as models.fast + strong: str = "" # empty = same as models.strong + max_concurrent: int = 50 + rate_tier: int | None = None + rpm_override: int | None = None + tpm_override: int | None = None + + +@dataclass +class CustomProviderConfig: + """Configuration for a custom OpenAI-compatible provider endpoint.""" + + base_url: str = "" + api_key_env: str = "" + + +@dataclass +class DefaultsConfig: + """Non-zone default settings.""" + + population_size: int = 1000 + db_path: str = "./storage/extropy.db" + show_cost: bool = False # Show cost footer after every CLI command + + +# ============================================================================= +# Legacy config dataclasses (kept for migration) # ============================================================================= @dataclass class PipelineConfig: - """Config for phases 1-2: spec, extend, sample, network, persona, scenario.""" + """DEPRECATED: Config for phases 1-2. Use ModelsConfig instead.""" provider: str = "openai" - model_simple: str = "" # empty = provider default - model_reasoning: str = "" # empty = provider default - model_research: str = "" # empty = provider default + model_simple: str = "" + model_reasoning: str = "" + model_research: str = "" @dataclass class SimZoneConfig: - """Config for phase 3: agent reasoning during simulation.""" + """DEPRECATED: Config for phase 3. Use SimulationConfig instead.""" provider: str = "openai" - model: str = "" # empty = provider default - pivotal_model: str = "" # model for pivotal reasoning (default: same as model) - routine_model: str = ( - "" # cheap model for classification (default: provider cheap tier) - ) + model: str = "" + pivotal_model: str = "" + routine_model: str = "" max_concurrent: int = 50 - rate_tier: int | None = None # rate limit tier (1-4, None = Tier 1) - rpm_override: int | None = None # override RPM limit - tpm_override: int | None = None # override TPM limit - api_format: str = ( - "" # empty = auto (responses for openai, chat_completions for azure) - ) + rate_tier: int | None = None + rpm_override: int | None = None + tpm_override: int | None = None + api_format: str = "" + + +# ============================================================================= +# Main config class +# ============================================================================= @dataclass @@ -75,8 +161,7 @@ class ExtropyConfig: Examples: # Package use — no files needed config = ExtropyConfig( - pipeline=PipelineConfig(provider="claude"), - simulation=SimZoneConfig(provider="openai", model="gpt-5-mini"), + models=ModelsConfig(fast="openai/gpt-5-mini", strong="anthropic/claude-sonnet-4.5"), ) # CLI use — loads from ~/.config/extropy/config.json @@ -84,21 +169,20 @@ class ExtropyConfig: # Override just simulation config = ExtropyConfig.load() - config.simulation.model = "gpt-5-nano" + config.simulation.strong = "openrouter/anthropic/claude-sonnet-4.5" """ - pipeline: PipelineConfig = field(default_factory=PipelineConfig) - simulation: SimZoneConfig = field(default_factory=SimZoneConfig) - - # Non-zone settings - db_path: str = "./storage/extropy.db" - default_population_size: int = 1000 + models: ModelsConfig = field(default_factory=ModelsConfig) + simulation: SimulationConfig = field(default_factory=SimulationConfig) + providers: dict[str, CustomProviderConfig] = field(default_factory=dict) + defaults: DefaultsConfig = field(default_factory=DefaultsConfig) @classmethod def load(cls) -> "ExtropyConfig": """Load config from file + env vars. Priority: env var values > config.json values > defaults. + Auto-migrates v1 config format if detected. """ config = cls() @@ -107,31 +191,35 @@ def load(cls) -> "ExtropyConfig": try: with open(CONFIG_FILE) as f: data = json.load(f) + + # Auto-migrate v1 config + if _is_v1_config(data): + warnings.warn( + "Detected legacy config format. Migrating to v2. " + "Run `extropy config show` to verify, then `extropy config set` to update.", + DeprecationWarning, + stacklevel=2, + ) + data = _migrate_v1_to_v2(data) + _apply_dict(config, data) except (json.JSONDecodeError, OSError) as exc: logger.warning("Failed to load config from %s: %s", CONFIG_FILE, exc) - # Layer 2: Env var overrides - if provider := os.environ.get("LLM_PROVIDER"): - # Legacy: single provider applied to both zones - config.pipeline.provider = provider - config.simulation.provider = provider - if val := os.environ.get("PIPELINE_PROVIDER"): - config.pipeline.provider = val - if val := os.environ.get("SIMULATION_PROVIDER"): - config.simulation.provider = val - if val := os.environ.get("MODEL_SIMPLE"): - config.pipeline.model_simple = val - if val := os.environ.get("MODEL_REASONING"): - config.pipeline.model_reasoning = val - if val := os.environ.get("MODEL_RESEARCH"): - config.pipeline.model_research = val - if val := os.environ.get("SIMULATION_MODEL"): - config.simulation.model = val - if val := os.environ.get("SIMULATION_PIVOTAL_MODEL"): - config.simulation.pivotal_model = val - if val := os.environ.get("SIMULATION_ROUTINE_MODEL"): - config.simulation.routine_model = val + # Layer 2: Env var overrides (new format) + if val := os.environ.get("MODELS_FAST"): + config.models.fast = val + if val := os.environ.get("MODELS_STRONG"): + config.models.strong = val + if val := os.environ.get("SIMULATION_FAST"): + config.simulation.fast = val + if val := os.environ.get("SIMULATION_STRONG"): + config.simulation.strong = val + if val := os.environ.get("SIMULATION_MAX_CONCURRENT"): + try: + config.simulation.max_concurrent = int(val) + except ValueError: + logger.warning("Invalid SIMULATION_MAX_CONCURRENT=%r, ignoring", val) if val := os.environ.get("SIMULATION_RATE_TIER"): try: config.simulation.rate_tier = int(val) @@ -147,36 +235,88 @@ def load(cls) -> "ExtropyConfig": config.simulation.tpm_override = int(val) except ValueError: logger.warning("Invalid SIMULATION_TPM_OVERRIDE=%r, ignoring", val) - if val := os.environ.get("SIMULATION_API_FORMAT"): - config.simulation.api_format = val if val := os.environ.get("DB_PATH"): - config.db_path = val + config.defaults.db_path = val if val := os.environ.get("DEFAULT_POPULATION_SIZE"): try: - config.default_population_size = int(val) + config.defaults.population_size = int(val) except ValueError: logger.warning("Invalid DEFAULT_POPULATION_SIZE=%r, ignoring", val) + # Layer 3: Legacy env var overrides (emit deprecation warnings) + _apply_legacy_env_vars(config) + return config def save(self) -> None: """Save config to ~/.config/extropy/config.json.""" CONFIG_DIR.mkdir(parents=True, exist_ok=True) - data = asdict(self) - # Don't persist non-zone settings that are better as env vars - data.pop("db_path", None) - data.pop("default_population_size", None) + data: dict[str, Any] = { + "models": asdict(self.models), + "simulation": asdict(self.simulation), + } + if self.providers: + data["providers"] = { + name: asdict(cfg) for name, cfg in self.providers.items() + } + if self.defaults != DefaultsConfig(): + data["defaults"] = asdict(self.defaults) with open(CONFIG_FILE, "w") as f: json.dump(data, f, indent=2) def to_dict(self) -> dict[str, Any]: """Convert to dict for display.""" - return asdict(self) + result = { + "models": asdict(self.models), + "simulation": asdict(self.simulation), + "defaults": asdict(self.defaults), + } + if self.providers: + result["providers"] = { + name: asdict(cfg) for name, cfg in self.providers.items() + } + return result + + # ── Convenience resolution methods ── + + def resolve_pipeline_fast(self) -> str: + """Resolve the fast model string for pipeline use.""" + return self.models.fast + + def resolve_pipeline_strong(self) -> str: + """Resolve the strong model string for pipeline use.""" + return self.models.strong + + def resolve_sim_strong(self) -> str: + """Resolve the strong model string for simulation.""" + return self.simulation.strong or self.models.strong + + def resolve_sim_fast(self) -> str: + """Resolve the fast model string for simulation.""" + return self.simulation.fast or self.models.fast + + # ── Backward compat properties ── + + @property + def db_path(self) -> str: + return self.defaults.db_path + + @db_path.setter + def db_path(self, value: str) -> None: + self.defaults.db_path = value + + @property + def default_population_size(self) -> int: + return self.defaults.population_size + + @default_population_size.setter + def default_population_size(self, value: int) -> None: + self.defaults.population_size = value @property def db_path_resolved(self) -> Path: """Resolve database path.""" - path = Path(self.db_path) + path = Path(self.defaults.db_path) path.parent.mkdir(parents=True, exist_ok=True) return path @@ -188,24 +328,225 @@ def cache_dir(self) -> Path: return path +# ============================================================================= +# Config dict application +# ============================================================================= + + def _apply_dict(config: ExtropyConfig, data: dict) -> None: - """Apply a dict of values onto an ExtropyConfig.""" - if "pipeline" in data and isinstance(data["pipeline"], dict): - for k, v in data["pipeline"].items(): - if hasattr(config.pipeline, k): - setattr(config.pipeline, k, v) + """Apply a dict of values onto an ExtropyConfig (v2 format).""" + if "models" in data and isinstance(data["models"], dict): + for k, v in data["models"].items(): + if hasattr(config.models, k): + setattr(config.models, k, v) if "simulation" in data and isinstance(data["simulation"], dict): for k, v in data["simulation"].items(): if hasattr(config.simulation, k): setattr(config.simulation, k, v) + if "providers" in data and isinstance(data["providers"], dict): + for name, provider_data in data["providers"].items(): + if isinstance(provider_data, dict): + config.providers[name] = CustomProviderConfig( + base_url=provider_data.get("base_url", ""), + api_key_env=provider_data.get("api_key_env", ""), + ) + if "defaults" in data and isinstance(data["defaults"], dict): + for k, v in data["defaults"].items(): + if hasattr(config.defaults, k): + if k == "population_size": + v = int(v) + setattr(config.defaults, k, v) + # Backward compat: top-level db_path / default_population_size + if "db_path" in data: + config.defaults.db_path = data["db_path"] + if "default_population_size" in data: + config.defaults.population_size = int(data["default_population_size"]) + + +# ============================================================================= +# V1 → V2 migration +# ============================================================================= + +# Provider name mapping for migration +_PROVIDER_CANONICAL = { + "openai": "openai", + "claude": "anthropic", + "anthropic": "anthropic", + "azure_openai": "azure", +} + +# Default model names per old provider +_V1_PROVIDER_DEFAULTS = { + "openai": {"fast": "gpt-5-mini", "strong": "gpt-5"}, + "claude": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure_openai": {"fast": "gpt-5-mini", "strong": "gpt-5"}, +} + + +def _is_v1_config(data: dict) -> bool: + """Detect if config data is in v1 format (has 'pipeline' key).""" + return "pipeline" in data and "models" not in data + + +def _migrate_v1_to_v2(data: dict) -> dict: + """Convert v1 config format to v2. + + v1 format: + {"pipeline": {"provider": "openai", "model_simple": "...", ...}, + "simulation": {"provider": "openai", "model": "...", ...}} + + v2 format: + {"models": {"fast": "openai/gpt-5-mini", "strong": "openai/gpt-5"}, + "simulation": {"fast": "...", "strong": "...", ...}} + """ + result: dict[str, Any] = {} + + # Migrate pipeline → models + pipeline = data.get("pipeline", {}) + old_provider = pipeline.get("provider", "openai") + canonical = _PROVIDER_CANONICAL.get(old_provider, old_provider) + defaults = _V1_PROVIDER_DEFAULTS.get(old_provider, _V1_PROVIDER_DEFAULTS["openai"]) + + fast_model = pipeline.get("model_simple") or defaults["fast"] + strong_model = pipeline.get("model_reasoning") or defaults["strong"] + + result["models"] = { + "fast": f"{canonical}/{fast_model}", + "strong": f"{canonical}/{strong_model}", + } + + # Migrate simulation + sim = data.get("simulation", {}) + sim_provider = sim.get("provider", "openai") + sim_canonical = _PROVIDER_CANONICAL.get(sim_provider, sim_provider) + sim_result: dict[str, Any] = {} + + # Map model/pivotal_model → strong, routine_model → fast + pivotal = sim.get("pivotal_model") or sim.get("model") or "" + routine = sim.get("routine_model") or "" + + if pivotal: + sim_result["strong"] = f"{sim_canonical}/{pivotal}" + if routine: + sim_result["fast"] = f"{sim_canonical}/{routine}" + + for k in ("max_concurrent", "rate_tier", "rpm_override", "tpm_override"): + if k in sim and sim[k] is not None: + sim_result[k] = sim[k] + + result["simulation"] = sim_result + + # Carry forward non-zone settings if "db_path" in data: - config.db_path = data["db_path"] + result.setdefault("defaults", {})["db_path"] = data["db_path"] if "default_population_size" in data: - config.default_population_size = int(data["default_population_size"]) + result.setdefault("defaults", {})["population_size"] = data[ + "default_population_size" + ] + + return result # ============================================================================= -# API key resolution (env vars + .env file) +# Legacy env var handling +# ============================================================================= + +_LEGACY_ENV_WARNED: set[str] = set() + + +def _warn_legacy_env(name: str, replacement: str) -> None: + """Emit a one-time deprecation warning for a legacy env var.""" + if name not in _LEGACY_ENV_WARNED: + _LEGACY_ENV_WARNED.add(name) + warnings.warn( + f"Environment variable {name} is deprecated. Use {replacement} instead.", + DeprecationWarning, + stacklevel=4, + ) + + +def _apply_legacy_env_vars(config: ExtropyConfig) -> None: + """Apply legacy env vars with deprecation warnings.""" + # LLM_PROVIDER → both zones + if val := os.environ.get("LLM_PROVIDER"): + _warn_legacy_env("LLM_PROVIDER", "MODELS_FAST / MODELS_STRONG") + canonical = _PROVIDER_CANONICAL.get(val, val) + defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) + # Only override if no new-format env vars set + if not os.environ.get("MODELS_FAST"): + config.models.fast = f"{canonical}/{defaults['fast']}" + if not os.environ.get("MODELS_STRONG"): + config.models.strong = f"{canonical}/{defaults['strong']}" + + if val := os.environ.get("PIPELINE_PROVIDER"): + _warn_legacy_env("PIPELINE_PROVIDER", "MODELS_FAST / MODELS_STRONG") + canonical = _PROVIDER_CANONICAL.get(val, val) + defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) + if not os.environ.get("MODELS_FAST"): + config.models.fast = f"{canonical}/{defaults['fast']}" + if not os.environ.get("MODELS_STRONG"): + config.models.strong = f"{canonical}/{defaults['strong']}" + + if val := os.environ.get("SIMULATION_PROVIDER"): + _warn_legacy_env("SIMULATION_PROVIDER", "SIMULATION_FAST / SIMULATION_STRONG") + canonical = _PROVIDER_CANONICAL.get(val, val) + defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) + if not os.environ.get("SIMULATION_FAST"): + config.simulation.fast = f"{canonical}/{defaults['fast']}" + if not os.environ.get("SIMULATION_STRONG"): + config.simulation.strong = f"{canonical}/{defaults['strong']}" + + if val := os.environ.get("MODEL_SIMPLE"): + _warn_legacy_env("MODEL_SIMPLE", "MODELS_FAST") + if not os.environ.get("MODELS_FAST"): + provider, _ = parse_model_string(config.models.fast) + config.models.fast = f"{provider}/{val}" + + if val := os.environ.get("MODEL_REASONING"): + _warn_legacy_env("MODEL_REASONING", "MODELS_STRONG") + if not os.environ.get("MODELS_STRONG"): + provider, _ = parse_model_string(config.models.strong) + config.models.strong = f"{provider}/{val}" + + if val := os.environ.get("SIMULATION_MODEL"): + _warn_legacy_env("SIMULATION_MODEL", "SIMULATION_STRONG") + if not os.environ.get("SIMULATION_STRONG"): + # Resolve provider from sim strong or models strong + base = config.simulation.strong or config.models.strong + provider, _ = parse_model_string(base) + config.simulation.strong = f"{provider}/{val}" + + if val := os.environ.get("SIMULATION_PIVOTAL_MODEL"): + _warn_legacy_env("SIMULATION_PIVOTAL_MODEL", "SIMULATION_STRONG") + if not os.environ.get("SIMULATION_STRONG"): + base = config.simulation.strong or config.models.strong + provider, _ = parse_model_string(base) + config.simulation.strong = f"{provider}/{val}" + + if val := os.environ.get("SIMULATION_ROUTINE_MODEL"): + _warn_legacy_env("SIMULATION_ROUTINE_MODEL", "SIMULATION_FAST") + if not os.environ.get("SIMULATION_FAST"): + base = config.simulation.fast or config.models.fast + provider, _ = parse_model_string(base) + config.simulation.fast = f"{provider}/{val}" + + # SIMULATION_API_FORMAT — no direct replacement, just warn + if os.environ.get("SIMULATION_API_FORMAT"): + _warn_legacy_env( + "SIMULATION_API_FORMAT", + "provider-based routing (api_format is now automatic)", + ) + + +# ============================================================================= +# API key resolution # ============================================================================= _dotenv_loaded = False @@ -219,52 +560,72 @@ def _ensure_dotenv() -> None: try: from dotenv import find_dotenv, load_dotenv - # Resolve from current working directory first so CLI commands run - # from study repos consistently pick up that repo's `.env`. dotenv_path = find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path=dotenv_path, override=False) else: - # Fallback for environments where no discoverable .env exists. load_dotenv(override=False) except ImportError: - pass # python-dotenv not installed, skip + pass except Exception: - # Keep config loading resilient even if dotenv discovery has runtime issues. pass -def get_api_key(provider: str) -> str: - """Get API key for a provider from environment variables or .env file. +def get_api_key_for_provider( + provider_name: str, + custom_providers: dict[str, CustomProviderConfig] | None = None, +) -> str: + """Get API key for a provider. - Supports: - - openai: OPENAI_API_KEY - - claude: ANTHROPIC_API_KEY - - azure_openai: AZURE_OPENAI_API_KEY + Resolution order: + 1. Custom provider api_key_env override + 2. Convention: {PROVIDER_UPPER}_API_KEY - Returns empty string if not found (providers will raise on missing keys). + Special cases: + - "anthropic" → ANTHROPIC_API_KEY + - "azure" → AZURE_OPENAI_API_KEY + + Returns empty string if not found. """ _ensure_dotenv() - if provider == "openai": - return os.environ.get("OPENAI_API_KEY", "") - elif provider == "claude": - return os.environ.get("ANTHROPIC_API_KEY", "") - elif provider == "azure_openai": - return os.environ.get("AZURE_OPENAI_API_KEY", "") - return "" + # Check custom provider override first + if custom_providers and provider_name in custom_providers: + custom = custom_providers[provider_name] + if custom.api_key_env: + return os.environ.get(custom.api_key_env, "") + + # Convention: {PROVIDER}_API_KEY + # Special cases for backward compat + key_map = { + "azure": "AZURE_OPENAI_API_KEY", + "azure_openai": "AZURE_OPENAI_API_KEY", + } + env_var = key_map.get(provider_name, f"{provider_name.upper()}_API_KEY") + return os.environ.get(env_var, "") -def get_azure_config(provider: str) -> dict[str, str]: - """Get Azure-specific configuration from environment variables. - Args: - provider: 'azure_openai' +def get_api_key(provider: str) -> str: + """DEPRECATED: Get API key for a provider. Use get_api_key_for_provider instead. + + Kept for backward compatibility. + """ + # Map old provider names + mapping = { + "claude": "anthropic", + "azure_openai": "azure", + } + canonical = mapping.get(provider, provider) + return get_api_key_for_provider(canonical) + + +def get_azure_config(provider: str) -> dict[str, str]: + """DEPRECATED: Get Azure-specific configuration. - Returns: - Dict of Azure config values (endpoint, api_version, deployment). + Azure is now handled as an OpenAI-compatible provider. """ _ensure_dotenv() - if provider == "azure_openai": + if provider in ("azure_openai", "azure"): return { "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT", ""), "api_version": os.environ.get( @@ -298,8 +659,8 @@ def configure(config: ExtropyConfig) -> None: """Set the global ExtropyConfig programmatically. Use this when extropy is used as a package: - from extropy.config import configure, ExtropyConfig, PipelineConfig - configure(ExtropyConfig(pipeline=PipelineConfig(provider="claude"))) + from extropy.config import configure, ExtropyConfig, ModelsConfig + configure(ExtropyConfig(models=ModelsConfig(fast="openai/gpt-5-mini"))) """ global _config _config = config diff --git a/extropy/core/cost/__init__.py b/extropy/core/cost/__init__.py new file mode 100644 index 0000000..1377de8 --- /dev/null +++ b/extropy/core/cost/__init__.py @@ -0,0 +1,39 @@ +"""Cost tracking, pricing resolution, and persistent ledger. + +This package provides: +- CostTracker: Session-scoped accumulator (auto-records from providers) +- Pricing: Three-tier model pricing resolution (OpenRouter → cache → fallback) +- Ledger: Persistent cost history (~/.config/extropy/cost_ledger.db) +""" + +from .tracker import CostTracker, CallRecord, ModelUsage +from .pricing import ( + ModelPricing, + get_pricing, + resolve_default_model, + refresh_pricing, + get_cache_info, + FALLBACK_PRICING, + PROVIDER_DEFAULTS, +) +from .ledger import CostEntry, record_session, query_entries, query_totals + +__all__ = [ + # Tracker + "CostTracker", + "CallRecord", + "ModelUsage", + # Pricing + "ModelPricing", + "get_pricing", + "resolve_default_model", + "refresh_pricing", + "get_cache_info", + "FALLBACK_PRICING", + "PROVIDER_DEFAULTS", + # Ledger + "CostEntry", + "record_session", + "query_entries", + "query_totals", +] diff --git a/extropy/core/cost/ledger.py b/extropy/core/cost/ledger.py new file mode 100644 index 0000000..c6d7e08 --- /dev/null +++ b/extropy/core/cost/ledger.py @@ -0,0 +1,278 @@ +"""Persistent cost ledger. + +Appends session cost summaries to a local SQLite database. +Provides query methods for the `extropy cost` command. +""" + +import json +import logging +import sqlite3 +import time +from datetime import datetime +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +_LEDGER_DIR = Path.home() / ".config" / "extropy" +_LEDGER_FILE = _LEDGER_DIR / "cost_ledger.db" + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS cost_entries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp REAL NOT NULL, + date TEXT NOT NULL, + command TEXT NOT NULL, + scenario TEXT NOT NULL DEFAULT '', + total_calls INTEGER NOT NULL DEFAULT 0, + total_input_tokens INTEGER NOT NULL DEFAULT 0, + total_output_tokens INTEGER NOT NULL DEFAULT 0, + total_cost REAL, + models_json TEXT NOT NULL DEFAULT '{}', + elapsed_seconds REAL +); + +CREATE INDEX IF NOT EXISTS idx_cost_entries_date ON cost_entries(date); +CREATE INDEX IF NOT EXISTS idx_cost_entries_command ON cost_entries(command); +""" + + +class CostEntry(BaseModel): + """A single cost ledger entry.""" + + timestamp: float + date: str + command: str + scenario: str + total_calls: int + total_input_tokens: int + total_output_tokens: int + total_cost: float | None + models: dict[str, Any] + elapsed_seconds: float | None + + +def _get_connection() -> sqlite3.Connection: + """Get a connection to the ledger database, creating it if needed.""" + _LEDGER_DIR.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(_LEDGER_FILE)) + conn.row_factory = sqlite3.Row + conn.executescript(_SCHEMA) + return conn + + +def record_session(summary: dict[str, Any]) -> None: + """Append a session cost summary to the ledger. + + Args: + summary: Dict from CostTracker.summary() + """ + if summary.get("total_calls", 0) == 0: + return + + try: + conn = _get_connection() + try: + now = time.time() + conn.execute( + """ + INSERT INTO cost_entries + (timestamp, date, command, scenario, total_calls, + total_input_tokens, total_output_tokens, total_cost, + models_json, elapsed_seconds) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + now, + datetime.fromtimestamp(now).strftime("%Y-%m-%d"), + summary.get("command", ""), + summary.get("scenario", ""), + summary.get("total_calls", 0), + summary.get("total_input_tokens", 0), + summary.get("total_output_tokens", 0), + summary.get("total_cost"), + json.dumps(summary.get("by_model", {})), + summary.get("elapsed_seconds"), + ), + ) + conn.commit() + finally: + conn.close() + except (sqlite3.Error, OSError) as e: + logger.debug(f"Failed to record cost to ledger: {e}") + + +def query_entries( + days: int | None = 7, + command: str | None = None, + limit: int = 100, +) -> list[CostEntry]: + """Query cost ledger entries. + + Args: + days: Number of days to look back (None = all time) + command: Filter by command name (None = all commands) + limit: Max entries to return + + Returns: + List of CostEntry, newest first. + """ + try: + conn = _get_connection() + except (sqlite3.Error, OSError): + return [] + + try: + clauses = [] + params: list[Any] = [] + + if days is not None: + cutoff = time.time() - (days * 86400) + clauses.append("timestamp >= ?") + params.append(cutoff) + + if command: + clauses.append("command = ?") + params.append(command) + + where = f"WHERE {' AND '.join(clauses)}" if clauses else "" + params.append(limit) + + rows = conn.execute( + f""" + SELECT * FROM cost_entries + {where} + ORDER BY timestamp DESC + LIMIT ? + """, + params, + ).fetchall() + + entries = [] + for row in rows: + try: + models = json.loads(row["models_json"]) + except (json.JSONDecodeError, TypeError): + models = {} + + entries.append( + CostEntry( + timestamp=row["timestamp"], + date=row["date"], + command=row["command"], + scenario=row["scenario"], + total_calls=row["total_calls"], + total_input_tokens=row["total_input_tokens"], + total_output_tokens=row["total_output_tokens"], + total_cost=row["total_cost"], + models=models, + elapsed_seconds=row["elapsed_seconds"], + ) + ) + + return entries + finally: + conn.close() + + +def query_totals( + days: int | None = 7, + group_by: str | None = None, +) -> dict[str, Any]: + """Query aggregated cost totals. + + Args: + days: Number of days to look back (None = all time) + group_by: Group results by "command", "date", or "model" (None = totals only) + + Returns: + Dict with total and optional grouped breakdowns. + """ + try: + conn = _get_connection() + except (sqlite3.Error, OSError): + return {"total_cost": None, "total_calls": 0} + + try: + where = "" + params: list[Any] = [] + if days is not None: + cutoff = time.time() - (days * 86400) + where = "WHERE timestamp >= ?" + params.append(cutoff) + + # Overall totals + row = conn.execute( + f""" + SELECT + COUNT(*) as sessions, + SUM(total_calls) as calls, + SUM(total_input_tokens) as input_tokens, + SUM(total_output_tokens) as output_tokens, + SUM(total_cost) as cost + FROM cost_entries + {where} + """, + params, + ).fetchone() + + result: dict[str, Any] = { + "sessions": row["sessions"] or 0, + "total_calls": row["calls"] or 0, + "total_input_tokens": row["input_tokens"] or 0, + "total_output_tokens": row["output_tokens"] or 0, + "total_cost": round(row["cost"], 4) if row["cost"] else None, + } + + # Grouped breakdown + if group_by == "command": + rows = conn.execute( + f""" + SELECT command, + COUNT(*) as sessions, + SUM(total_calls) as calls, + SUM(total_cost) as cost + FROM cost_entries + {where} + GROUP BY command + ORDER BY cost DESC + """, + params, + ).fetchall() + result["by_command"] = { + r["command"]: { + "sessions": r["sessions"], + "calls": r["calls"] or 0, + "cost": round(r["cost"], 4) if r["cost"] else None, + } + for r in rows + } + + elif group_by == "date": + rows = conn.execute( + f""" + SELECT date, + COUNT(*) as sessions, + SUM(total_calls) as calls, + SUM(total_cost) as cost + FROM cost_entries + {where} + GROUP BY date + ORDER BY date DESC + """, + params, + ).fetchall() + result["by_date"] = { + r["date"]: { + "sessions": r["sessions"], + "calls": r["calls"] or 0, + "cost": round(r["cost"], 4) if r["cost"] else None, + } + for r in rows + } + + return result + finally: + conn.close() diff --git a/extropy/core/cost/pricing.py b/extropy/core/cost/pricing.py new file mode 100644 index 0000000..a87085e --- /dev/null +++ b/extropy/core/cost/pricing.py @@ -0,0 +1,369 @@ +"""Model pricing resolution for cost estimation and tracking. + +Three-tier pricing resolution: +1. OpenRouter API (free, no auth, covers 200+ models) → cached locally +2. Local cache file (~/.config/extropy/pricing_cache.json) with 24h TTL +3. Hardcoded fallback table for offline/known models + +Provides per-model input/output pricing (USD per million tokens) +and provider default model resolution without needing API keys. +""" + +import json +import logging +import time +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +# Cache location and TTL +_CACHE_DIR = Path.home() / ".config" / "extropy" +_CACHE_FILE = _CACHE_DIR / "pricing_cache.json" +_CACHE_TTL_SECONDS = 24 * 60 * 60 # 24 hours + + +class ModelPricing(BaseModel, frozen=True): + """Pricing for a single model (USD per million tokens).""" + + input_per_mtok: float + output_per_mtok: float + + +# ── Hardcoded fallback (Tier 3) ────────────────────────────────────────────── + +# Known model pricing (USD per million tokens) +# Sources: OpenAI and Anthropic pricing pages as of 2025 +FALLBACK_PRICING: dict[str, ModelPricing] = { + # OpenAI + "gpt-5": ModelPricing(input_per_mtok=2.50, output_per_mtok=10.00), + "gpt-5-mini": ModelPricing(input_per_mtok=0.30, output_per_mtok=1.50), + "gpt-5-nano": ModelPricing(input_per_mtok=0.10, output_per_mtok=0.40), + "gpt-5.2": ModelPricing(input_per_mtok=2.50, output_per_mtok=10.00), + # Azure-hosted models + "DeepSeek-V3.2": ModelPricing(input_per_mtok=0.80, output_per_mtok=2.00), + "Kimi-K2.5": ModelPricing(input_per_mtok=1.00, output_per_mtok=4.00), + # Claude + "claude-sonnet-4-5-20250929": ModelPricing( + input_per_mtok=3.00, output_per_mtok=15.00 + ), + "claude-sonnet-4-5-20250514": ModelPricing( + input_per_mtok=3.00, output_per_mtok=15.00 + ), + "claude-sonnet-4.5": ModelPricing(input_per_mtok=3.00, output_per_mtok=15.00), + "claude-sonnet-4": ModelPricing(input_per_mtok=3.00, output_per_mtok=15.00), + "claude-haiku-4-5-20251001": ModelPricing( + input_per_mtok=0.80, output_per_mtok=4.00 + ), + "claude-haiku-4.5": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + "claude-haiku-4": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + # DeepSeek (direct API) + "deepseek-chat": ModelPricing(input_per_mtok=0.14, output_per_mtok=0.28), + "deepseek-reasoner": ModelPricing(input_per_mtok=0.55, output_per_mtok=2.19), +} + +# Provider default models — 2-tier (fast/strong) +PROVIDER_DEFAULTS: dict[str, dict[str, str]] = { + "openai": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "openrouter": { + "fast": "openai/gpt-5-mini", + "strong": "openai/gpt-5", + }, + "deepseek": { + "fast": "deepseek-chat", + "strong": "deepseek-reasoner", + }, + "together": { + "fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + "groq": { + "fast": "llama-3.3-70b-versatile", + "strong": "llama-3.3-70b-versatile", + }, + # Legacy aliases + "claude": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure_openai": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, +} + + +# ── In-memory cache ────────────────────────────────────────────────────────── + +_memory_cache: dict[str, ModelPricing] = {} +_memory_cache_loaded: bool = False + + +# ── Tier 1: OpenRouter API ─────────────────────────────────────────────────── + + +def _fetch_openrouter_pricing() -> dict[str, ModelPricing] | None: + """Fetch pricing from OpenRouter API (no auth required). + + Returns: + Dict of model_id → ModelPricing, or None if fetch failed. + """ + try: + import urllib.request + import urllib.error + + url = "https://openrouter.ai/api/v1/models" + req = urllib.request.Request(url, headers={"User-Agent": "extropy"}) + with urllib.request.urlopen(req, timeout=10) as resp: + data = json.loads(resp.read().decode()) + + result: dict[str, ModelPricing] = {} + for model in data.get("data", []): + model_id = model.get("id", "") + pricing = model.get("pricing", {}) + + # OpenRouter returns pricing as string USD per token (not per MTok) + prompt_price = pricing.get("prompt") + completion_price = pricing.get("completion") + + if prompt_price is not None and completion_price is not None: + try: + input_per_tok = float(prompt_price) + output_per_tok = float(completion_price) + except (ValueError, TypeError): + continue + + # Skip free/zero-cost models + if input_per_tok == 0 and output_per_tok == 0: + continue + + # Convert per-token to per-million-tokens + result[model_id] = ModelPricing( + input_per_mtok=input_per_tok * 1_000_000, + output_per_mtok=output_per_tok * 1_000_000, + ) + + if result: + logger.debug(f"Fetched pricing for {len(result)} models from OpenRouter") + return result + + except Exception as e: + logger.debug(f"Failed to fetch OpenRouter pricing: {e}") + + return None + + +# ── Tier 2: Local cache file ───────────────────────────────────────────────── + + +def _load_cache() -> dict[str, ModelPricing] | None: + """Load pricing from local cache file if it exists and is fresh. + + Returns: + Dict of model_id → ModelPricing, or None if cache is stale/missing. + """ + if not _CACHE_FILE.exists(): + return None + + try: + with open(_CACHE_FILE) as f: + data = json.load(f) + + # Check TTL + cached_at = data.get("cached_at", 0) + if time.time() - cached_at > _CACHE_TTL_SECONDS: + logger.debug("Pricing cache expired") + return None + + result: dict[str, ModelPricing] = {} + for model_id, pricing in data.get("models", {}).items(): + result[model_id] = ModelPricing( + input_per_mtok=pricing["input_per_mtok"], + output_per_mtok=pricing["output_per_mtok"], + ) + + logger.debug(f"Loaded {len(result)} models from pricing cache") + return result + + except (json.JSONDecodeError, KeyError, OSError) as e: + logger.debug(f"Failed to load pricing cache: {e}") + return None + + +def _save_cache(pricing: dict[str, ModelPricing]) -> None: + """Save pricing to local cache file.""" + try: + _CACHE_DIR.mkdir(parents=True, exist_ok=True) + data = { + "cached_at": time.time(), + "models": {model_id: p.model_dump() for model_id, p in pricing.items()}, + } + with open(_CACHE_FILE, "w") as f: + json.dump(data, f, indent=2) + logger.debug(f"Saved pricing cache with {len(pricing)} models") + except OSError as e: + logger.debug(f"Failed to save pricing cache: {e}") + + +# ── Resolution logic ───────────────────────────────────────────────────────── + + +def _ensure_cache_loaded() -> None: + """Lazily load the pricing cache into memory (once per process).""" + global _memory_cache, _memory_cache_loaded + if _memory_cache_loaded: + return + + # Try local cache first (fast, no network) + cached = _load_cache() + if cached: + _memory_cache = cached + _memory_cache_loaded = True + return + + # Try OpenRouter API + fetched = _fetch_openrouter_pricing() + if fetched: + _memory_cache = fetched + _save_cache(fetched) + _memory_cache_loaded = True + return + + # No dynamic pricing available — will fall through to hardcoded + _memory_cache_loaded = True + + +def _normalize_model_id(model: str) -> list[str]: + """Generate candidate lookup keys for a model name. + + Handles the mapping between bare model names (used by providers) + and OpenRouter-style IDs (provider/model). + + Args: + model: Model name (e.g., "gpt-5-mini" or "openai/gpt-5-mini") + + Returns: + List of candidate keys to try, in priority order. + """ + candidates = [model] + + # If it's already a provider/model format, also try the bare model name + if "/" in model: + bare = model.rsplit("/", 1)[-1] + candidates.append(bare) + else: + # Try common provider prefixes for bare model names + if model.startswith("gpt-"): + candidates.append(f"openai/{model}") + elif model.startswith("claude-"): + candidates.append(f"anthropic/{model}") + elif model.startswith("deepseek-"): + candidates.append(f"deepseek/{model}") + elif model.startswith("llama-") or model.startswith("meta-llama/"): + candidates.append(f"meta-llama/{model}") + + return candidates + + +def get_pricing(model: str) -> ModelPricing | None: + """Get pricing for a model using three-tier resolution. + + Resolution order: + 1. OpenRouter API cache (refreshed every 24h) + 2. Local cache file + 3. Hardcoded fallback table + + Args: + model: Model name (bare like "gpt-5-mini" or qualified like "openai/gpt-5-mini") + + Returns: + ModelPricing or None if no pricing found. + """ + _ensure_cache_loaded() + + candidates = _normalize_model_id(model) + + # Try dynamic cache first + for candidate in candidates: + if candidate in _memory_cache: + return _memory_cache[candidate] + + # Fall back to hardcoded + for candidate in candidates: + if candidate in FALLBACK_PRICING: + return FALLBACK_PRICING[candidate] + + return None + + +def resolve_default_model(provider: str, tier: str = "strong") -> str: + """Resolve default model name for a provider without instantiating it. + + Args: + provider: Provider name ('openai', 'anthropic', etc.) + tier: 'fast' or 'strong' (also accepts legacy 'simple'/'reasoning') + + Returns: + Model name string + """ + # Map legacy tier names + tier_map = {"simple": "fast", "reasoning": "strong"} + tier = tier_map.get(tier, tier) + + defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["openai"]) + return defaults.get(tier, defaults["strong"]) + + +def refresh_pricing() -> bool: + """Force-refresh pricing from OpenRouter API. + + Returns: + True if refresh succeeded. + """ + global _memory_cache, _memory_cache_loaded + + fetched = _fetch_openrouter_pricing() + if fetched: + _memory_cache = fetched + _memory_cache_loaded = True + _save_cache(fetched) + return True + return False + + +def get_cache_info() -> dict[str, Any]: + """Get info about the pricing cache state (for diagnostics).""" + info: dict[str, Any] = { + "cache_file": str(_CACHE_FILE), + "cache_exists": _CACHE_FILE.exists(), + "memory_loaded": _memory_cache_loaded, + "memory_models": len(_memory_cache), + "fallback_models": len(FALLBACK_PRICING), + } + + if _CACHE_FILE.exists(): + try: + with open(_CACHE_FILE) as f: + data = json.load(f) + cached_at = data.get("cached_at", 0) + age_hours = (time.time() - cached_at) / 3600 + info["cache_age_hours"] = round(age_hours, 1) + info["cache_fresh"] = age_hours < (_CACHE_TTL_SECONDS / 3600) + info["cached_models"] = len(data.get("models", {})) + except (json.JSONDecodeError, OSError): + info["cache_corrupt"] = True + + return info diff --git a/extropy/core/cost/tracker.py b/extropy/core/cost/tracker.py new file mode 100644 index 0000000..58375e4 --- /dev/null +++ b/extropy/core/cost/tracker.py @@ -0,0 +1,288 @@ +"""Session-scoped cost accumulator. + +Automatically records token usage from every LLM provider call within +a CLI session. Providers push usage via CostTracker.record(); the CLI +reads the totals at exit via CostTracker.summary(). + +Thread-safe — simulation calls record() from async workers concurrently. +""" + +import logging +import threading +import time +from typing import Any + +from pydantic import BaseModel + +from ..providers.base import TokenUsage +from .pricing import get_pricing + +logger = logging.getLogger(__name__) + + +class CallRecord(BaseModel): + """A single LLM API call's token usage.""" + + model: str + input_tokens: int + output_tokens: int + timestamp: float + call_type: str = "" # "simple", "reasoning", "agentic_research", "async" + + +class ModelUsage(BaseModel): + """Accumulated usage for a single model.""" + + calls: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + +class CostTracker: + """Session-scoped cost accumulator. + + Singleton per process. Providers auto-record into this after each call. + The CLI reads summary/cost at session end. + + Thread-safe: uses a lock for mutation since simulation workers + call record() concurrently. + + Note: This is not a Pydantic model because it manages mutable state + with thread locks and singleton lifecycle — patterns that don't fit + Pydantic's immutable-validation model. + """ + + _instance: "CostTracker | None" = None + _lock_cls = threading.Lock() # Class-level lock for singleton creation + + def __init__(self) -> None: + self._records: list[CallRecord] = [] + self._by_model: dict[str, ModelUsage] = {} + self._lock = threading.Lock() + self._started_at = time.time() + self._command: str = "" # Set by CLI (e.g., "spec", "simulate") + self._scenario: str = "" # Set by CLI for ledger tagging + + @classmethod + def get(cls) -> "CostTracker": + """Get or create the singleton instance.""" + if cls._instance is None: + with cls._lock_cls: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + """Reset the singleton (for testing or new session).""" + with cls._lock_cls: + cls._instance = None + + def set_context(self, command: str = "", scenario: str = "") -> None: + """Set session context for ledger tagging. + + Args: + command: CLI command name (e.g., "spec", "simulate") + scenario: Scenario/population name for identification + """ + self._command = command + self._scenario = scenario + + def record( + self, + model: str, + usage: TokenUsage, + call_type: str = "", + ) -> None: + """Record token usage from a single LLM API call. + + Called automatically by provider base class after each call. + + Args: + model: Model name used for the call + usage: Token usage from the API response + call_type: Type of call ("simple", "reasoning", etc.) + """ + if usage.input_tokens == 0 and usage.output_tokens == 0: + return + + record = CallRecord( + model=model, + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + timestamp=time.time(), + call_type=call_type, + ) + + with self._lock: + self._records.append(record) + + if model not in self._by_model: + self._by_model[model] = ModelUsage() + + mu = self._by_model[model] + mu.calls += 1 + mu.input_tokens += usage.input_tokens + mu.output_tokens += usage.output_tokens + + @property + def total_calls(self) -> int: + """Total number of LLM calls recorded.""" + with self._lock: + return sum(mu.calls for mu in self._by_model.values()) + + @property + def total_input_tokens(self) -> int: + """Total input tokens across all models.""" + with self._lock: + return sum(mu.input_tokens for mu in self._by_model.values()) + + @property + def total_output_tokens(self) -> int: + """Total output tokens across all models.""" + with self._lock: + return sum(mu.output_tokens for mu in self._by_model.values()) + + def total_cost(self) -> float | None: + """Compute total USD cost from recorded usage. + + Returns: + Total cost in USD, or None if no pricing available for any model. + """ + with self._lock: + total = 0.0 + has_any_pricing = False + + for model, mu in self._by_model.items(): + pricing = get_pricing(model) + if pricing: + has_any_pricing = True + total += ( + mu.input_tokens * pricing.input_per_mtok + + mu.output_tokens * pricing.output_per_mtok + ) / 1_000_000 + + return total if has_any_pricing else None + + def cost_by_model(self) -> dict[str, dict[str, Any]]: + """Get cost breakdown by model. + + Returns: + Dict of model → {calls, input_tokens, output_tokens, cost} + """ + with self._lock: + result: dict[str, dict[str, Any]] = {} + for model, mu in self._by_model.items(): + pricing = get_pricing(model) + cost = None + if pricing: + cost = ( + mu.input_tokens * pricing.input_per_mtok + + mu.output_tokens * pricing.output_per_mtok + ) / 1_000_000 + + result[model] = { + "calls": mu.calls, + "input_tokens": mu.input_tokens, + "output_tokens": mu.output_tokens, + "cost": cost, + } + return result + + def summary(self) -> dict[str, Any]: + """Full session summary for export/display. + + Returns: + Dict with total and per-model breakdowns. + """ + with self._lock: + by_model = {} + total_cost = 0.0 + has_pricing = False + + for model, mu in self._by_model.items(): + pricing = get_pricing(model) + model_cost = None + if pricing: + has_pricing = True + model_cost = ( + mu.input_tokens * pricing.input_per_mtok + + mu.output_tokens * pricing.output_per_mtok + ) / 1_000_000 + total_cost += model_cost + + by_model[model] = { + "calls": mu.calls, + "input_tokens": mu.input_tokens, + "output_tokens": mu.output_tokens, + "cost": round(model_cost, 4) if model_cost is not None else None, + } + + total_in = sum(mu.input_tokens for mu in self._by_model.values()) + total_out = sum(mu.output_tokens for mu in self._by_model.values()) + + return { + "command": self._command, + "scenario": self._scenario, + "total_calls": sum(mu.calls for mu in self._by_model.values()), + "total_input_tokens": total_in, + "total_output_tokens": total_out, + "total_cost": round(total_cost, 4) if has_pricing else None, + "by_model": by_model, + "elapsed_seconds": round(time.time() - self._started_at, 1), + } + + def summary_line(self) -> str | None: + """One-line cost summary for CLI footer. + + Returns: + Formatted string like "$0.38 · openai/gpt-5 · 8 calls · 87k in / 12k out", + or None if no calls were recorded. + """ + with self._lock: + total_calls = sum(mu.calls for mu in self._by_model.values()) + if total_calls == 0: + return None + + total_in = sum(mu.input_tokens for mu in self._by_model.values()) + total_out = sum(mu.output_tokens for mu in self._by_model.values()) + models = list(self._by_model.keys()) + + cost = self.total_cost() + + parts = [] + + # Cost + if cost is not None: + parts.append(f"${cost:.2f}") + else: + parts.append("cost unknown") + + # Model(s) + if len(models) == 1: + parts.append(models[0]) + elif len(models) > 1: + parts.append(f"{len(models)} models") + + # Call count + parts.append(f"{total_calls} call{'s' if total_calls != 1 else ''}") + + # Token counts + parts.append(f"{_format_tokens(total_in)} in / {_format_tokens(total_out)} out") + + return " · ".join(parts) + + @property + def has_records(self) -> bool: + """Whether any calls have been recorded.""" + with self._lock: + return len(self._records) > 0 + + +def _format_tokens(n: int) -> str: + """Format token count for display (e.g., 87k, 1.5M).""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + elif n >= 1_000: + return f"{n / 1_000:.0f}k" + return str(n) diff --git a/extropy/core/llm.py b/extropy/core/llm.py index ecfa0a5..dfbc3db 100644 --- a/extropy/core/llm.py +++ b/extropy/core/llm.py @@ -1,20 +1,19 @@ """LLM clients for Extropy - Facade Layer. -This module provides a unified interface to LLM providers with two-zone routing: -- Pipeline (sync calls): simple_call, reasoning_call, agentic_research - → Uses the pipeline provider (configured for phases 1-2) -- Simulation (async calls): simple_call_async - → Uses the simulation provider (configured for phase 3) +This module provides a unified interface to LLM providers with two-tier routing: +- fast: simple_call → uses models.fast (cheap, fast tasks) +- strong: reasoning_call, agentic_research → uses models.strong (complex tasks) +- simulation: simple_call_async → uses simulation.strong/fast -Configure via `extropy config` CLI or programmatically via extropy.config.configure(). +Model strings use "provider/model" format. The provider is extracted to route +to the correct backend; the model name is passed through. -Each function supports retry with error feedback via the `previous_errors` parameter. -When validation fails, pass the error message back to let the LLM self-correct. +Configure via `extropy config` CLI or programmatically via extropy.config.configure(). """ -from .providers import get_pipeline_provider, get_simulation_provider +from .providers import get_provider from .providers.base import TokenUsage, ValidatorCallback, RetryCallback -from ..config import get_config +from ..config import get_config, parse_model_string __all__ = [ @@ -28,25 +27,14 @@ ] -def _get_pipeline_model_override(tier: str) -> str | None: - """Get pipeline model override from config if configured.""" +def _resolve_provider_and_model( + model_string: str, +) -> tuple: + """Resolve a "provider/model" string to (provider_instance, model_name).""" config = get_config() - pipeline = config.pipeline - if tier == "simple" and pipeline.model_simple: - return pipeline.model_simple - elif tier == "reasoning" and pipeline.model_reasoning: - return pipeline.model_reasoning - elif tier == "research" and pipeline.model_research: - return pipeline.model_research - return None - - -def _get_simulation_model_override() -> str | None: - """Get simulation model override from config if configured.""" - config = get_config() - if config.simulation.model: - return config.simulation.model - return None + provider_name, model_name = parse_model_string(model_string) + provider = get_provider(provider_name, config.providers) + return provider, model_name def simple_call( @@ -59,20 +47,21 @@ def simple_call( ) -> dict: """Simple LLM call with structured output, no reasoning, no web search. - Routed through the PIPELINE provider. + Uses the FAST tier (config.models.fast). Use for fast, cheap tasks: - Context sufficiency checks - Simple classification - Validation """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("simple") + config = get_config() + model_string = model or config.resolve_pipeline_fast() + provider, model_name = _resolve_provider_and_model(model_string) return provider.simple_call( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, log=log, max_tokens=max_tokens, ) @@ -87,18 +76,21 @@ async def simple_call_async( ) -> tuple[dict, TokenUsage]: """Async version of simple_call for concurrent API requests. - Routed through the SIMULATION provider. - Used for batch agent reasoning during simulation. + Model is passed explicitly from simulation caller (provider/model format). Returns (structured_data, token_usage) tuple. """ - provider = get_simulation_provider() - effective_model = model or _get_simulation_model_override() + if model: + provider, model_name = _resolve_provider_and_model(model) + else: + config = get_config() + model_string = config.resolve_sim_strong() + provider, model_name = _resolve_provider_and_model(model_string) return await provider.simple_call_async( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, max_tokens=max_tokens, ) @@ -117,20 +109,21 @@ def reasoning_call( ) -> dict: """LLM call with reasoning and structured output, but NO web search. - Routed through the PIPELINE provider. + Uses the STRONG tier (config.models.strong). Use for tasks that require reasoning but not external data: - Attribute selection/categorization - Schema design - Logical analysis """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("reasoning") + config = get_config() + model_string = model or config.resolve_pipeline_strong() + provider, model_name = _resolve_provider_and_model(model_string) return provider.reasoning_call( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, reasoning_effort=reasoning_effort, log=log, previous_errors=previous_errors, @@ -154,21 +147,17 @@ def agentic_research( ) -> tuple[dict, list[str]]: """Perform agentic research with web search and structured output. - Routed through the PIPELINE provider. - - The model will: - 1. Decide what to search for - 2. Search the web (possibly multiple times) - 3. Reason about the results - 4. Return structured data matching the schema + Uses the STRONG tier (config.models.strong). + Web search is a provider capability, not a tier distinction. """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("research") + config = get_config() + model_string = model or config.resolve_pipeline_strong() + provider, model_name = _resolve_provider_and_model(model_string) return provider.agentic_research( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, reasoning_effort=reasoning_effort, log=log, previous_errors=previous_errors, diff --git a/extropy/core/models/scenario.py b/extropy/core/models/scenario.py index 5cabe1f..ab592df 100644 --- a/extropy/core/models/scenario.py +++ b/extropy/core/models/scenario.py @@ -269,7 +269,9 @@ class ScenarioMeta(BaseModel): description: str = Field(description="Full scenario description") population_spec: str = Field(description="Path to population YAML") study_db: str = Field(description="Path to canonical study DB") - population_id: str = Field(default="default", description="Population ID in study DB") + population_id: str = Field( + default="default", description="Population ID in study DB" + ) network_id: str = Field(default="default", description="Network ID in study DB") created_at: datetime = Field(default_factory=datetime.now) @@ -310,9 +312,7 @@ def from_yaml(cls, path: Path | str) -> "ScenarioSpec": raise ValueError("Scenario YAML must parse to an object") meta = data.get("meta", {}) - if isinstance(meta, dict) and ( - "agents_file" in meta or "network_file" in meta - ): + if isinstance(meta, dict) and ("agents_file" in meta or "network_file" in meta): raise ValueError( "Legacy scenario schema detected (meta.agents_file/meta.network_file). " "Migrate with: extropy migrate scenario --input " diff --git a/extropy/core/models/simulation.py b/extropy/core/models/simulation.py index 5532686..88d5617 100644 --- a/extropy/core/models/simulation.py +++ b/extropy/core/models/simulation.py @@ -338,17 +338,13 @@ class SimulationRunConfig(BaseModel): scenario_path: str = Field(description="Path to scenario YAML") output_dir: str = Field(description="Directory for results output") - model: str = Field( + strong: str = Field( default="", - description="LLM model for agent reasoning (empty = use config default)", + description="Strong model for Pass 1 role-play reasoning (provider/model format, empty = config default)", ) - pivotal_model: str = Field( + fast: str = Field( default="", - description="Model for pivotal reasoning (default: same as model)", - ) - routine_model: str = Field( - default="", - description="Cheap model for routine reasoning + classification (default: provider cheap tier)", + description="Fast model for Pass 2 classification (provider/model format, empty = config default)", ) reasoning_effort: str = Field(default="low", description="Reasoning effort level") multi_touch_threshold: int = Field( @@ -362,6 +358,19 @@ class SimulationRunConfig(BaseModel): default=50, description="Agents per reasoning chunk for checkpointing" ) + # Backward compat aliases + @property + def model(self) -> str: + return self.strong + + @property + def pivotal_model(self) -> str: + return self.strong + + @property + def routine_model(self) -> str: + return self.fast + # ============================================================================= # Timestep Summary diff --git a/extropy/core/pricing.py b/extropy/core/pricing.py index 616a25d..d2de21a 100644 --- a/extropy/core/pricing.py +++ b/extropy/core/pricing.py @@ -40,21 +40,49 @@ class ModelPricing: ), "claude-haiku-4.5": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), "claude-haiku-4": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + # DeepSeek (direct API) + "deepseek-chat": ModelPricing(input_per_mtok=0.14, output_per_mtok=0.28), + "deepseek-reasoner": ModelPricing(input_per_mtok=0.55, output_per_mtok=2.19), } -# Provider default models (matches provider classes, no API key needed) +# Provider default models — 2-tier (fast/strong) PROVIDER_DEFAULTS: dict[str, dict[str, str]] = { "openai": { - "simple": "gpt-5-mini", - "reasoning": "gpt-5", + "fast": "gpt-5-mini", + "strong": "gpt-5", }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "openrouter": { + "fast": "openai/gpt-5-mini", + "strong": "openai/gpt-5", + }, + "deepseek": { + "fast": "deepseek-chat", + "strong": "deepseek-reasoner", + }, + "together": { + "fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + "groq": { + "fast": "llama-3.3-70b-versatile", + "strong": "llama-3.3-70b-versatile", + }, + # Legacy aliases "claude": { - "simple": "claude-haiku-4-5-20251001", - "reasoning": "claude-sonnet-4-5-20250929", + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", }, "azure_openai": { - "simple": "gpt-5-mini", - "reasoning": "gpt-5", + "fast": "gpt-5-mini", + "strong": "gpt-5", }, } @@ -64,15 +92,19 @@ def get_pricing(model: str) -> ModelPricing | None: return MODEL_PRICING.get(model) -def resolve_default_model(provider: str, tier: str = "reasoning") -> str: +def resolve_default_model(provider: str, tier: str = "strong") -> str: """Resolve default model name for a provider without instantiating it. Args: - provider: Provider name ('openai' or 'claude') - tier: 'simple' or 'reasoning' + provider: Provider name ('openai', 'anthropic', etc.) + tier: 'fast' or 'strong' (also accepts legacy 'simple'/'reasoning') Returns: Model name string """ + # Map legacy tier names + tier_map = {"simple": "fast", "reasoning": "strong"} + tier = tier_map.get(tier, tier) + defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["openai"]) - return defaults.get(tier, defaults["reasoning"]) + return defaults.get(tier, defaults["strong"]) diff --git a/extropy/core/providers/__init__.py b/extropy/core/providers/__init__.py index 195f949..1fb5c39 100644 --- a/extropy/core/providers/__init__.py +++ b/extropy/core/providers/__init__.py @@ -1,97 +1,237 @@ -"""LLM Provider factory. +"""LLM Provider registry and factory. -Provides two-zone provider routing: -- Pipeline provider: used for phases 1-2 (spec, extend, persona, scenario) -- Simulation provider: used for phase 3 (agent reasoning) +Provides: +- BUILTIN_PROVIDERS: Registry of known provider names → factory info +- get_provider(): Create a provider instance from a provider name +- get_pipeline_provider() / get_simulation_provider(): Zone-based provider access The simulation provider is cached so its async client can be reused across batch calls and closed cleanly before the event loop shuts down. """ -from .base import LLMProvider -from ...config import get_config, get_api_key, get_azure_config - +import os -# Cached simulation provider — reused across batch calls so the async -# client isn't re-created per request, and can be closed cleanly. -_simulation_provider: LLMProvider | None = None +from .base import LLMProvider +from ...config import ( + get_config, + get_api_key_for_provider, + parse_model_string, + CustomProviderConfig, +) + + +# ============================================================================= +# Provider Registry +# ============================================================================= + +# Each entry: (module, class_name, default_kwargs) +# Lazy-imported to avoid loading all SDKs at startup. +_BUILTIN_REGISTRY: dict[str, dict] = { + "openai": { + "module": ".openai", + "class": "OpenAIProvider", + }, + "anthropic": { + "module": ".anthropic", + "class": "AnthropicProvider", + }, + "openrouter": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://openrouter.ai/api/v1", + "supports_search": True, + "provider_label": "openrouter", + "default_fast": "openai/gpt-5-mini", + "default_strong": "openai/gpt-5", + }, + }, + "azure": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "", # resolved from env + "supports_search": False, + "provider_label": "azure", + "default_fast": "gpt-5-mini", + "default_strong": "gpt-5", + }, + }, + "deepseek": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.deepseek.com/v1", + "supports_search": False, + "provider_label": "deepseek", + "default_fast": "deepseek-chat", + "default_strong": "deepseek-reasoner", + }, + }, + "together": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.together.xyz/v1", + "supports_search": False, + "provider_label": "together", + "default_fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "default_strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + }, + "groq": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.groq.com/openai/v1", + "supports_search": False, + "provider_label": "groq", + "default_fast": "llama-3.3-70b-versatile", + "default_strong": "llama-3.3-70b-versatile", + }, + }, +} + + +def get_provider( + provider_name: str, + custom_providers: dict[str, CustomProviderConfig] | None = None, +) -> LLMProvider: + """Create a provider instance by name. + + Checks custom providers first, then built-in registry. + + Args: + provider_name: Provider name (e.g., "openai", "anthropic", "openrouter") + custom_providers: Optional custom provider configs from ExtropyConfig + + Returns: + LLMProvider instance + + Raises: + ValueError: If provider is unknown + """ + api_key = get_api_key_for_provider(provider_name, custom_providers) + # Check custom providers first + if custom_providers and provider_name in custom_providers: + from .openai_compat import OpenAICompatProvider -def _create_provider(provider_name: str) -> LLMProvider: - """Create a provider instance by name.""" - api_key = get_api_key(provider_name) - - if provider_name == "openai": - from .openai import OpenAIProvider + custom = custom_providers[provider_name] + return OpenAICompatProvider( + api_key=api_key, + base_url=custom.base_url, + supports_search=False, + provider_label=provider_name, + ) - return OpenAIProvider(api_key=api_key) - elif provider_name == "claude": - from .claude import ClaudeProvider + # Check built-in registry + if provider_name not in _BUILTIN_REGISTRY: + available = sorted( + set(list(_BUILTIN_REGISTRY.keys()) + list((custom_providers or {}).keys())) + ) + raise ValueError( + f"Unknown LLM provider: {provider_name!r}. " + f"Available: {', '.join(available)}" + ) - return ClaudeProvider(api_key=api_key) - elif provider_name == "azure_openai": - from .openai import OpenAIProvider + entry = _BUILTIN_REGISTRY[provider_name] - azure_cfg = get_azure_config(provider_name) - if not azure_cfg.get("azure_endpoint"): + # Special case: Azure needs endpoint from env + if provider_name == "azure": + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") + if not endpoint: raise ValueError( "AZURE_OPENAI_ENDPOINT not found. Set it as an environment variable.\n" " export AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/" ) - # Resolve api_format: config value > auto-default (chat_completions for Azure) + entry = dict(entry) + entry["kwargs"] = dict(entry.get("kwargs", {})) + entry["kwargs"]["base_url"] = endpoint + + # Lazy import + import importlib + + module = importlib.import_module(entry["module"], package=__package__) + cls = getattr(module, entry["class"]) + + kwargs = dict(entry.get("kwargs", {})) + kwargs["api_key"] = api_key + + return cls(**kwargs) + + +# ============================================================================= +# Zone-based provider access (backward compat) +# ============================================================================= + +# Cached providers — reused across calls for connection reuse +_cached_providers: dict[str, LLMProvider] = {} + + +def _get_or_create_provider(provider_name: str, cache_key: str = "") -> LLMProvider: + """Get or create a cached provider instance.""" + key = cache_key or provider_name + if key not in _cached_providers: config = get_config() - api_format = config.simulation.api_format or "chat_completions" - return OpenAIProvider( - api_key=api_key, - azure_endpoint=azure_cfg["azure_endpoint"], - api_version=azure_cfg.get("api_version", "2025-03-01-preview"), - azure_deployment=azure_cfg.get("azure_deployment", ""), - api_format=api_format, - ) - else: - raise ValueError( - f"Unknown LLM provider: {provider_name}. " - f"Valid options: 'openai', 'claude', 'azure_openai'" - ) + _cached_providers[key] = get_provider(provider_name, config.providers) + return _cached_providers[key] def get_pipeline_provider() -> LLMProvider: - """Get the provider for pipeline phases (spec, extend, persona, scenario).""" + """Get the provider for pipeline phases (spec, extend, persona, scenario). + + Uses the provider from models.fast (pipeline calls use both fast and strong, + but the provider is determined by the fast model string). + """ config = get_config() - return _create_provider(config.pipeline.provider) + provider, _ = parse_model_string(config.models.fast) + return _get_or_create_provider(provider, f"pipeline:{provider}") def get_simulation_provider() -> LLMProvider: """Get the cached provider for simulation phase (agent reasoning). - Caches the provider so the underlying async HTTP client is reused - across all calls in a batch, avoiding orphaned connections. + Uses the provider from the resolved simulation strong model. """ - global _simulation_provider config = get_config() - provider_name = config.simulation.provider - - if _simulation_provider is None: - _simulation_provider = _create_provider(provider_name) - - return _simulation_provider + strong_model = config.resolve_sim_strong() + provider, _ = parse_model_string(strong_model) + return _get_or_create_provider(provider, f"simulation:{provider}") async def close_simulation_provider() -> None: - """Close the cached simulation provider's async client. + """Close cached providers' async clients. Call this before the event loop shuts down to cleanly release HTTP connections and avoid 'Event loop is closed' errors. """ - global _simulation_provider - if _simulation_provider is not None: - await _simulation_provider.close_async() - _simulation_provider = None + for key, provider in list(_cached_providers.items()): + await provider.close_async() + _cached_providers.clear() + + +def reset_provider_cache() -> None: + """Reset the provider cache (for testing).""" + _cached_providers.clear() + + +# Legacy factory (kept for backward compat in tests) +def _create_provider(provider_name: str) -> LLMProvider: + """DEPRECATED: Use get_provider() instead.""" + # Map old names + name_map = {"claude": "anthropic", "azure_openai": "azure"} + canonical = name_map.get(provider_name, provider_name) + config = get_config() + return get_provider(canonical, config.providers) __all__ = [ "LLMProvider", + "get_provider", "get_pipeline_provider", "get_simulation_provider", "close_simulation_provider", + "reset_provider_cache", + "parse_model_string", ] diff --git a/extropy/core/providers/anthropic.py b/extropy/core/providers/anthropic.py new file mode 100644 index 0000000..d68ee01 --- /dev/null +++ b/extropy/core/providers/anthropic.py @@ -0,0 +1,388 @@ +"""Anthropic (Claude) LLM Provider implementation. + +Uses the tool use pattern for reliable structured output: +instead of asking Claude to output JSON in text, we define a tool +with the response schema. Claude "calls" the tool, returning structured +data guaranteed to match the schema. +""" + +import logging +import random +import time + +import anthropic + +from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback +from .logging import log_request_response, extract_error_summary + +_TRANSIENT_ANTHROPIC_ERRORS = ( + anthropic.APIConnectionError, + anthropic.InternalServerError, + anthropic.RateLimitError, +) +_MAX_API_RETRIES = 3 + + +logger = logging.getLogger(__name__) + + +def _clean_schema_for_tool(schema: dict) -> dict: + """Clean a JSON schema for use as a tool input_schema. + + Removes fields that aren't valid in tool input schemas + (like 'additionalProperties' in nested objects that Claude + doesn't support in tool definitions). + """ + cleaned = {} + for key, value in schema.items(): + if key == "additionalProperties": + continue + if isinstance(value, dict): + cleaned[key] = _clean_schema_for_tool(value) + elif isinstance(value, list): + cleaned[key] = [ + _clean_schema_for_tool(item) if isinstance(item, dict) else item + for item in value + ] + else: + cleaned[key] = value + return cleaned + + +def _make_structured_tool(schema_name: str, response_schema: dict) -> dict: + """Create a tool definition that forces structured output.""" + return { + "name": schema_name, + "description": ( + "Return your response as structured data. " + "You MUST call this tool with your complete response." + ), + "input_schema": _clean_schema_for_tool(response_schema), + } + + +def _extract_tool_input(response) -> dict | None: + """Extract tool_use input from a Claude response.""" + for block in response.content: + if block.type == "tool_use": + return block.input + return None + + +def _extract_usage(response) -> TokenUsage: + """Extract token usage from an Anthropic API response.""" + if not hasattr(response, "usage") or response.usage is None: + return TokenUsage() + return TokenUsage( + input_tokens=getattr(response.usage, "input_tokens", 0) or 0, + output_tokens=getattr(response.usage, "output_tokens", 0) or 0, + ) + + +class AnthropicProvider(LLMProvider): + """Anthropic (Claude) LLM provider. + + Uses the tool use pattern for structured output — Claude "calls" a tool + with the response data, guaranteeing valid JSON matching the schema. + """ + + provider_name = "anthropic" + + def __init__(self, api_key: str = "") -> None: + if not api_key: + raise ValueError( + "Anthropic API key not found. Set it via:\n" + " export ANTHROPIC_API_KEY=sk-ant-...\n" + "Get your key from: https://console.anthropic.com/settings/keys" + ) + super().__init__(api_key) + + def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry a sync API call on transient errors with exponential backoff.""" + for attempt in range(max_retries + 1): + try: + return fn() + except _TRANSIENT_ANTHROPIC_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + logger.warning( + f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " + f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" + ) + time.sleep(wait) + + async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry an async API call on transient errors with exponential backoff.""" + import asyncio + + for attempt in range(max_retries + 1): + try: + return await fn() + except _TRANSIENT_ANTHROPIC_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + logger.warning( + f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " + f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" + ) + await asyncio.sleep(wait) + + @property + def default_fast_model(self) -> str: + return "claude-haiku-4-5-20251001" + + @property + def default_strong_model(self) -> str: + return "claude-sonnet-4-5-20250929" + + def _get_client(self) -> anthropic.Anthropic: + return anthropic.Anthropic(api_key=self._api_key) + + def _get_async_client(self) -> anthropic.AsyncAnthropic: + if self._cached_async_client is None: + self._cached_async_client = anthropic.AsyncAnthropic(api_key=self._api_key) + return self._cached_async_client + + def simple_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ) -> dict: + model = model or self.default_simple_model + client = self._get_client() + tool = _make_structured_tool(schema_name, response_schema) + + # Acquire rate limit capacity before making the call + self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) + + logger.info( + f"[Claude] simple_call starting - model={model}, schema={schema_name}" + ) + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=max_tokens or 4096, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": prompt}], + ) + ) + + structured_data = _extract_tool_input(response) + + # Record token usage + usage = _extract_usage(response) + self._record_usage(model, usage, call_type="simple") + + if log: + log_request_response( + function_name="simple_call", + request={"model": model, "prompt_length": len(prompt)}, + response=response, + provider="claude", + ) + + return structured_data or {} + + async def simple_call_async( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + max_tokens: int | None = None, + ) -> tuple[dict, TokenUsage]: + model = model or self.default_simple_model + client = self._get_async_client() + tool = _make_structured_tool(schema_name, response_schema) + + response = await self._with_retry_async( + lambda: client.messages.create( + model=model, + max_tokens=max_tokens or 4096, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": prompt}], + ) + ) + + # Extract and record token usage + usage = _extract_usage(response) + self._record_usage(model, usage, call_type="async") + + return _extract_tool_input(response) or {}, usage + + def reasoning_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> dict: + """Claude reasoning call with tool-based structured output.""" + model = model or self.default_reasoning_model + client = self._get_client() + tool = _make_structured_tool(schema_name, response_schema) + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + def _call(ep: str) -> dict: + # Acquire rate limit capacity before each API call + self._acquire_rate_limit(ep, model, max_output=16384) + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=16384, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": ep}], + ) + ) + structured_data = _extract_tool_input(response) + + # Record token usage + ru = _extract_usage(response) + self._record_usage(model, ru, call_type="reasoning") + + if log: + log_request_response( + function_name="reasoning_call", + request={"model": model, "prompt_length": len(ep)}, + response=response, + provider="claude", + ) + return structured_data or {} + + return self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + def agentic_research( + self, + prompt: str, + response_schema: dict, + schema_name: str = "research_data", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> tuple[dict, list[str]]: + """Claude agentic research with web search + tool-based structured output. + + Uses web_search tool for research and a structured output tool for the response. + Claude first searches, then calls the output tool with results. + """ + model = model or self.default_research_model + client = self._get_client() + output_tool = _make_structured_tool(schema_name, response_schema) + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + all_sources: list[str] = [] + + def _call(ep: str) -> dict: + research_prompt = ( + f"{ep}\n\n" + f"After researching, call the '{schema_name}' tool with your structured findings." + ) + + # Acquire rate limit capacity before each API call + self._acquire_rate_limit(research_prompt, model, max_output=16384) + + logger.info(f"[Claude] agentic_research - model={model}") + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=16384, + tools=[ + { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5, + }, + output_tool, + ], + messages=[{"role": "user", "content": research_prompt}], + ) + ) + + structured_data = None + sources: list[str] = [] + + for block in response.content: + if block.type == "web_search_tool_result": + if hasattr(block, "content") and block.content: + for res in block.content: + if hasattr(res, "url"): + sources.append(res.url) + + if block.type == "tool_use" and block.name == schema_name: + structured_data = block.input + + if block.type == "text": + if hasattr(block, "citations") and block.citations: + for citation in block.citations: + if hasattr(citation, "url"): + sources.append(citation.url) + + all_sources.extend(sources) + logger.info(f"[Claude] Web search completed, found {len(sources)} sources") + + # Record token usage + ru = _extract_usage(response) + self._record_usage(model, ru, call_type="agentic_research") + + if log: + log_request_response( + function_name="agentic_research", + request={"model": model, "prompt_length": len(research_prompt)}, + response=response, + provider="claude", + sources=list(set(sources)), + ) + + return structured_data or {} + + result = self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + return result, list(set(all_sources)) + + +# Backward compat alias +ClaudeProvider = AnthropicProvider diff --git a/extropy/core/providers/base.py b/extropy/core/providers/base.py index f33fbcf..d796ca7 100644 --- a/extropy/core/providers/base.py +++ b/extropy/core/providers/base.py @@ -1,15 +1,18 @@ """Abstract base class for LLM providers.""" +import logging from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Callable, TYPE_CHECKING +from pydantic import BaseModel + if TYPE_CHECKING: from ..rate_limiter import RateLimiter +logger = logging.getLogger(__name__) + -@dataclass -class TokenUsage: +class TokenUsage(BaseModel): """Token usage from a single LLM API call.""" input_tokens: int = 0 @@ -34,6 +37,8 @@ class LLMProvider(ABC): All providers must implement these methods with the same signatures to ensure drop-in compatibility. + Automatically records token usage into CostTracker after each call. + Args: api_key: API key or access token for the provider. """ @@ -89,6 +94,22 @@ def _acquire_rate_limit( estimated_output_tokens=max_output, ) + def _record_usage(self, model: str, usage: TokenUsage, call_type: str = "") -> None: + """Record token usage into the session CostTracker. + + Called after each API call. Safe to call even if no CostTracker + is active (e.g., in tests or library use without CLI). + """ + if usage.input_tokens == 0 and usage.output_tokens == 0: + return + try: + from ..cost.tracker import CostTracker + + CostTracker.get().record(model=model, usage=usage, call_type=call_type) + except Exception: + # Never let cost tracking break actual LLM calls + pass + async def close_async(self) -> None: """Close the cached async client to release connections cleanly. @@ -101,21 +122,28 @@ async def close_async(self) -> None: @property @abstractmethod - def default_simple_model(self) -> str: - """Default model for simple_call (fast, cheap).""" + def default_fast_model(self) -> str: + """Default model for fast/cheap calls (simple_call, Pass 2).""" ... @property @abstractmethod - def default_reasoning_model(self) -> str: - """Default model for reasoning_call (balanced).""" + def default_strong_model(self) -> str: + """Default model for strong/reasoning calls (reasoning_call, agentic_research, Pass 1).""" ... + # Backward-compat aliases (read-only) + @property + def default_simple_model(self) -> str: + return self.default_fast_model + + @property + def default_reasoning_model(self) -> str: + return self.default_strong_model + @property - @abstractmethod def default_research_model(self) -> str: - """Default model for agentic_research (with web search).""" - ... + return self.default_strong_model @abstractmethod def simple_call( diff --git a/extropy/core/providers/claude.py b/extropy/core/providers/claude.py index c06691a..15aa201 100644 --- a/extropy/core/providers/claude.py +++ b/extropy/core/providers/claude.py @@ -1,370 +1,8 @@ -"""Claude (Anthropic) LLM Provider implementation. +"""DEPRECATED: Use extropy.core.providers.anthropic instead. -Uses the tool use pattern for reliable structured output: -instead of asking Claude to output JSON in text, we define a tool -with the response schema. Claude "calls" the tool, returning structured -data guaranteed to match the schema. +This module re-exports AnthropicProvider as ClaudeProvider for backward compatibility. """ -import logging -import random -import time +from .anthropic import AnthropicProvider, ClaudeProvider # noqa: F401 -import anthropic - -from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback -from .logging import log_request_response, extract_error_summary - -_TRANSIENT_ANTHROPIC_ERRORS = ( - anthropic.APIConnectionError, - anthropic.InternalServerError, - anthropic.RateLimitError, -) -_MAX_API_RETRIES = 3 - - -logger = logging.getLogger(__name__) - - -def _clean_schema_for_tool(schema: dict) -> dict: - """Clean a JSON schema for use as a tool input_schema. - - Removes fields that aren't valid in tool input schemas - (like 'additionalProperties' in nested objects that Claude - doesn't support in tool definitions). - """ - cleaned = {} - for key, value in schema.items(): - if key == "additionalProperties": - continue - if isinstance(value, dict): - cleaned[key] = _clean_schema_for_tool(value) - elif isinstance(value, list): - cleaned[key] = [ - _clean_schema_for_tool(item) if isinstance(item, dict) else item - for item in value - ] - else: - cleaned[key] = value - return cleaned - - -def _make_structured_tool(schema_name: str, response_schema: dict) -> dict: - """Create a tool definition that forces structured output.""" - return { - "name": schema_name, - "description": ( - "Return your response as structured data. " - "You MUST call this tool with your complete response." - ), - "input_schema": _clean_schema_for_tool(response_schema), - } - - -def _extract_tool_input(response) -> dict | None: - """Extract tool_use input from a Claude response.""" - for block in response.content: - if block.type == "tool_use": - return block.input - return None - - -class ClaudeProvider(LLMProvider): - """Claude (Anthropic) LLM provider. - - Uses the tool use pattern for structured output — Claude "calls" a tool - with the response data, guaranteeing valid JSON matching the schema. - - """ - - provider_name = "anthropic" - - def __init__(self, api_key: str = "") -> None: - if not api_key: - raise ValueError( - "Anthropic API key not found. Set it via:\n" - " export ANTHROPIC_API_KEY=sk-ant-...\n" - "Get your key from: https://console.anthropic.com/settings/keys" - ) - super().__init__(api_key) - - def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): - """Retry a sync API call on transient errors with exponential backoff.""" - for attempt in range(max_retries + 1): - try: - return fn() - except _TRANSIENT_ANTHROPIC_ERRORS as e: - if attempt == max_retries: - raise - wait = (2**attempt) + random.random() - logger.warning( - f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " - f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" - ) - time.sleep(wait) - - async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): - """Retry an async API call on transient errors with exponential backoff.""" - import asyncio - - for attempt in range(max_retries + 1): - try: - return await fn() - except _TRANSIENT_ANTHROPIC_ERRORS as e: - if attempt == max_retries: - raise - wait = (2**attempt) + random.random() - logger.warning( - f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " - f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" - ) - await asyncio.sleep(wait) - - @property - def default_simple_model(self) -> str: - return "claude-haiku-4-5-20251001" - - @property - def default_reasoning_model(self) -> str: - return "claude-sonnet-4-5-20250929" - - @property - def default_research_model(self) -> str: - return "claude-sonnet-4-5-20250929" - - def _get_client(self) -> anthropic.Anthropic: - return anthropic.Anthropic(api_key=self._api_key) - - def _get_async_client(self) -> anthropic.AsyncAnthropic: - if self._cached_async_client is None: - self._cached_async_client = anthropic.AsyncAnthropic(api_key=self._api_key) - return self._cached_async_client - - def simple_call( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - log: bool = True, - max_tokens: int | None = None, - ) -> dict: - model = model or self.default_simple_model - client = self._get_client() - tool = _make_structured_tool(schema_name, response_schema) - - # Acquire rate limit capacity before making the call - self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) - - logger.info( - f"[Claude] simple_call starting - model={model}, schema={schema_name}" - ) - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=max_tokens or 4096, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": prompt}], - ) - ) - - structured_data = _extract_tool_input(response) - - if log: - log_request_response( - function_name="simple_call", - request={"model": model, "prompt_length": len(prompt)}, - response=response, - provider="claude", - ) - - return structured_data or {} - - async def simple_call_async( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - max_tokens: int | None = None, - ) -> tuple[dict, TokenUsage]: - model = model or self.default_simple_model - client = self._get_async_client() - tool = _make_structured_tool(schema_name, response_schema) - - response = await self._with_retry_async( - lambda: client.messages.create( - model=model, - max_tokens=max_tokens or 4096, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": prompt}], - ) - ) - - # Extract token usage - usage = TokenUsage() - if hasattr(response, "usage") and response.usage is not None: - usage = TokenUsage( - input_tokens=getattr(response.usage, "input_tokens", 0) or 0, - output_tokens=getattr(response.usage, "output_tokens", 0) or 0, - ) - - return _extract_tool_input(response) or {}, usage - - def reasoning_call( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - reasoning_effort: str = "low", - log: bool = True, - previous_errors: str | None = None, - validator: ValidatorCallback | None = None, - max_retries: int = 2, - on_retry: RetryCallback | None = None, - ) -> dict: - """Claude reasoning call with tool-based structured output.""" - model = model or self.default_reasoning_model - client = self._get_client() - tool = _make_structured_tool(schema_name, response_schema) - - effective_prompt = prompt - if previous_errors: - effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" - - def _call(ep: str) -> dict: - # Acquire rate limit capacity before each API call - self._acquire_rate_limit(ep, model, max_output=16384) - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=16384, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": ep}], - ) - ) - structured_data = _extract_tool_input(response) - if log: - log_request_response( - function_name="reasoning_call", - request={"model": model, "prompt_length": len(ep)}, - response=response, - provider="claude", - ) - return structured_data or {} - - return self._retry_with_validation( - call_fn=_call, - prompt=prompt, - validator=validator, - max_retries=max_retries, - on_retry=on_retry, - extract_error_summary_fn=extract_error_summary, - initial_prompt=effective_prompt if previous_errors else None, - ) - - def agentic_research( - self, - prompt: str, - response_schema: dict, - schema_name: str = "research_data", - model: str | None = None, - reasoning_effort: str = "low", - log: bool = True, - previous_errors: str | None = None, - validator: ValidatorCallback | None = None, - max_retries: int = 2, - on_retry: RetryCallback | None = None, - ) -> tuple[dict, list[str]]: - """Claude agentic research with web search + tool-based structured output. - - Uses web_search tool for research and a structured output tool for the response. - Claude first searches, then calls the output tool with results. - """ - model = model or self.default_research_model - client = self._get_client() - output_tool = _make_structured_tool(schema_name, response_schema) - - effective_prompt = prompt - if previous_errors: - effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" - - all_sources: list[str] = [] - - def _call(ep: str) -> dict: - research_prompt = ( - f"{ep}\n\n" - f"After researching, call the '{schema_name}' tool with your structured findings." - ) - - # Acquire rate limit capacity before each API call - self._acquire_rate_limit(research_prompt, model, max_output=16384) - - logger.info(f"[Claude] agentic_research - model={model}") - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=16384, - tools=[ - { - "type": "web_search_20250305", - "name": "web_search", - "max_uses": 5, - }, - output_tool, - ], - messages=[{"role": "user", "content": research_prompt}], - ) - ) - - structured_data = None - sources: list[str] = [] - - for block in response.content: - if block.type == "web_search_tool_result": - if hasattr(block, "content") and block.content: - for res in block.content: - if hasattr(res, "url"): - sources.append(res.url) - - if block.type == "tool_use" and block.name == schema_name: - structured_data = block.input - - if block.type == "text": - if hasattr(block, "citations") and block.citations: - for citation in block.citations: - if hasattr(citation, "url"): - sources.append(citation.url) - - all_sources.extend(sources) - logger.info(f"[Claude] Web search completed, found {len(sources)} sources") - - if log: - log_request_response( - function_name="agentic_research", - request={"model": model, "prompt_length": len(research_prompt)}, - response=response, - provider="claude", - sources=list(set(sources)), - ) - - return structured_data or {} - - result = self._retry_with_validation( - call_fn=_call, - prompt=prompt, - validator=validator, - max_retries=max_retries, - on_retry=on_retry, - extract_error_summary_fn=extract_error_summary, - initial_prompt=effective_prompt if previous_errors else None, - ) - - return result, list(set(all_sources)) +__all__ = ["ClaudeProvider", "AnthropicProvider"] diff --git a/extropy/core/providers/openai.py b/extropy/core/providers/openai.py index 871ad18..45d5599 100644 --- a/extropy/core/providers/openai.py +++ b/extropy/core/providers/openai.py @@ -110,6 +110,20 @@ def _extract_chat_completions_text(response) -> str | None: return content return None + def _extract_usage(self, response, use_chat: bool = False) -> TokenUsage: + """Extract token usage from an OpenAI API response.""" + if not hasattr(response, "usage") or response.usage is None: + return TokenUsage() + if use_chat: + return TokenUsage( + input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, + output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, + ) + return TokenUsage( + input_tokens=getattr(response.usage, "input_tokens", 0) or 0, + output_tokens=getattr(response.usage, "output_tokens", 0) or 0, + ) + def _build_responses_params( self, model: str, @@ -193,15 +207,11 @@ async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): await asyncio.sleep(wait) @property - def default_simple_model(self) -> str: + def default_fast_model(self) -> str: return "gpt-5-mini" @property - def default_reasoning_model(self) -> str: - return "gpt-5" - - @property - def default_research_model(self) -> str: + def default_strong_model(self) -> str: return "gpt-5" def _get_client(self) -> OpenAI: @@ -278,6 +288,10 @@ def simple_call( raw_text = self._extract_output_text(response) structured_data = json.loads(raw_text) if raw_text else None + # Extract and record token usage + usage = self._extract_usage(response, use_chat=use_chat) + self._record_usage(model, usage, call_type="simple") + if log: log_request_response( function_name="simple_call", @@ -326,19 +340,9 @@ async def simple_call_async( raw_text = self._extract_output_text(response) structured_data = json.loads(raw_text) if raw_text else None - # Extract token usage - usage = TokenUsage() - if hasattr(response, "usage") and response.usage is not None: - if use_chat: - usage = TokenUsage( - input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, - output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, - ) - else: - usage = TokenUsage( - input_tokens=getattr(response.usage, "input_tokens", 0) or 0, - output_tokens=getattr(response.usage, "output_tokens", 0) or 0, - ) + # Extract and record token usage + usage = self._extract_usage(response, use_chat=use_chat) + self._record_usage(model, usage, call_type="async") return structured_data or {}, usage @@ -384,6 +388,11 @@ def _call(ep: str) -> dict: ) raw_text = self._extract_output_text(response) structured_data = json.loads(raw_text) if raw_text else None + + # Record token usage + usage = self._extract_usage(response) + self._record_usage(model, usage, call_type="reasoning") + if log: log_request_response( function_name="reasoning_call", @@ -480,6 +489,10 @@ def _call(ep: str) -> dict: all_sources.extend(sources) + # Record token usage + usage = self._extract_usage(response) + self._record_usage(model, usage, call_type="agentic_research") + if log: log_request_response( function_name="agentic_research", diff --git a/extropy/core/rate_limiter.py b/extropy/core/rate_limiter.py index 256cdb1..8ec6063 100644 --- a/extropy/core/rate_limiter.py +++ b/extropy/core/rate_limiter.py @@ -486,10 +486,11 @@ def stats(self) -> dict: class DualRateLimiter: - """Manages separate rate limiters for pivotal (Pass 1) and routine (Pass 2) models. + """Manages separate rate limiters for strong (Pass 1) and fast (Pass 2) models. - When pivotal and routine models are the same, uses a single shared limiter. + When strong and fast models are the same, uses a single shared limiter. When they differ, uses independent limiters since API limits are per-model. + Supports mixed providers (e.g., strong=anthropic, fast=openai). """ def __init__( @@ -499,51 +500,79 @@ def __init__( ): self.pivotal = pivotal self.routine = routine + # Aliases for new naming convention + self.strong = pivotal + self.fast = routine @classmethod def create( cls, - provider: str, + provider: str = "", pivotal_model: str = "", routine_model: str = "", tier: int | None = None, rpm_override: int | None = None, tpm_override: int | None = None, + *, + strong_model_string: str = "", + fast_model_string: str = "", ) -> "DualRateLimiter": """Create dual rate limiter for two-pass reasoning. - If both models are the same (or routine is empty), a single - shared limiter is used for both passes. + Accepts either: + - Legacy: provider + pivotal_model + routine_model (single provider) + - New: strong_model_string + fast_model_string (provider/model format, mixed providers) Args: - provider: Provider name - pivotal_model: Model for Pass 1 (role-play reasoning) - routine_model: Model for Pass 2 (classification) + provider: Provider name (legacy, used if model strings not provided) + pivotal_model: Model for Pass 1 (legacy) + routine_model: Model for Pass 2 (legacy) tier: Rate limit tier (1-4) - rpm_override: Override RPM (applies to pivotal limiter) - tpm_override: Override TPM (applies to pivotal limiter) + rpm_override: Override RPM + tpm_override: Override TPM + strong_model_string: "provider/model" for strong/pivotal (new) + fast_model_string: "provider/model" for fast/routine (new) Returns: DualRateLimiter instance """ + # Resolve strong limiter + if strong_model_string and "/" in strong_model_string: + from ..config import parse_model_string + + strong_provider, strong_model = parse_model_string(strong_model_string) + else: + strong_provider = provider + strong_model = pivotal_model + pivotal_limiter = RateLimiter.for_provider( - provider=provider, - model=pivotal_model, + provider=strong_provider, + model=strong_model, tier=tier, rpm_override=rpm_override, tpm_override=tpm_override, ) - # If routine model is the same as pivotal (or not specified), share the limiter - effective_routine = routine_model or pivotal_model - if effective_routine == pivotal_model or not effective_routine: + # Resolve fast limiter + if fast_model_string and "/" in fast_model_string: + from ..config import parse_model_string + + fast_provider, fast_model = parse_model_string(fast_model_string) + else: + fast_provider = provider + fast_model = routine_model + + # If same provider+model, share the limiter + effective_fast_model = fast_model or strong_model + if fast_provider == strong_provider and effective_fast_model == strong_model: + return cls(pivotal=pivotal_limiter, routine=pivotal_limiter) + + if not effective_fast_model and not fast_provider: return cls(pivotal=pivotal_limiter, routine=pivotal_limiter) - # Different models — create separate limiter for routine - # Overrides apply to both (on Azure, limits are per-resource not per-model) routine_limiter = RateLimiter.for_provider( - provider=provider, - model=effective_routine, + provider=fast_provider or strong_provider, + model=effective_fast_model, tier=tier, rpm_override=rpm_override, tpm_override=tpm_override, diff --git a/extropy/core/rate_limits.py b/extropy/core/rate_limits.py index d081e60..ec53025 100644 --- a/extropy/core/rate_limits.py +++ b/extropy/core/rate_limits.py @@ -85,11 +85,32 @@ }, } -# Map "claude" provider name to anthropic profiles +# Provider aliases — map alternate names to canonical profiles RATE_LIMIT_PROFILES["claude"] = RATE_LIMIT_PROFILES["anthropic"] - -# Azure OpenAI uses the same rate limit profiles as standard OpenAI RATE_LIMIT_PROFILES["azure_openai"] = RATE_LIMIT_PROFILES["openai"] +RATE_LIMIT_PROFILES["azure"] = RATE_LIMIT_PROFILES["openai"] + +# Third-party providers — conservative defaults +# These providers typically have per-key limits; adjust via rate_tier/rpm_override. +_THIRD_PARTY_DEFAULT = { + "default": { + 1: {"rpm": 60, "tpm": 100_000}, + 2: {"rpm": 200, "tpm": 500_000}, + 3: {"rpm": 500, "tpm": 1_000_000}, + 4: {"rpm": 1_000, "tpm": 2_000_000}, + }, +} +RATE_LIMIT_PROFILES["openrouter"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["deepseek"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["together"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["groq"] = { + "default": { + 1: {"rpm": 30, "tpm": 15_000}, + 2: {"rpm": 60, "tpm": 50_000}, + 3: {"rpm": 200, "tpm": 100_000}, + 4: {"rpm": 500, "tpm": 500_000}, + }, +} def get_limits( diff --git a/extropy/population/network/generator.py b/extropy/population/network/generator.py index 4e412df..b96876c 100644 --- a/extropy/population/network/generator.py +++ b/extropy/population/network/generator.py @@ -78,7 +78,10 @@ def _build_blocked_candidate_map( continue blocks[attr].setdefault(val, []).append(idx) - target_pool = max(config.min_candidate_pool, int(config.avg_degree * config.candidate_pool_multiplier)) + target_pool = max( + config.min_candidate_pool, + int(config.avg_degree * config.candidate_pool_multiplier), + ) target_pool = max(1, min(n - 1, target_pool)) candidate_map: list[list[int]] = [[] for _ in range(n)] @@ -258,7 +261,9 @@ def _init_similarity_worker( _SIM_WORKER_CANDIDATE_MAP = candidate_map -def _compute_similarity_chunk(task: tuple[int, int]) -> tuple[int, list[tuple[int, int, float]]]: +def _compute_similarity_chunk( + task: tuple[int, int], +) -> tuple[int, list[tuple[int, int, float]]]: """Compute similarities for a chunk of row indices in a worker process.""" start, end = task if _SIM_WORKER_AGENTS is None: @@ -333,7 +338,8 @@ def _compute_similarities_parallel( ), ) as ex: futures = { - ex.submit(_compute_similarity_chunk, task): task for task in pending_tasks + ex.submit(_compute_similarity_chunk, task): task + for task in pending_tasks } pending_results: dict[int, list[tuple[int, int, float]]] = {} sorted_starts = [start for start, _ in tasks] @@ -361,10 +367,7 @@ def _compute_similarities_parallel( completed_row_count += current_end - current_start completed_rows = max(completed_rows, current_end) - if ( - checkpoint_path is not None - and checkpoint_signature is not None - ): + if checkpoint_path is not None and checkpoint_signature is not None: if _is_db_checkpoint(checkpoint_path) and checkpoint_job_id: with open_study_db(checkpoint_path) as db: db.save_similarity_chunk_rows( @@ -451,7 +454,10 @@ def _compute_similarities_serial( if j <= i: continue sim = compute_similarity( - agents[i], agents[j], config.attribute_weights, config.ordinal_levels + agents[i], + agents[j], + config.attribute_weights, + config.ordinal_levels, ) if sim >= threshold: similarities[(i, j)] = sim @@ -1237,7 +1243,9 @@ def generate_network( if config.candidate_mode == "blocked": if on_progress: on_progress("Preparing candidate blocks", 0, n) - candidate_map, blocking_attrs = _build_blocked_candidate_map(agents, config, seed) + candidate_map, blocking_attrs = _build_blocked_candidate_map( + agents, config, seed + ) if on_progress: on_progress("Preparing candidate blocks", n, n) if candidate_map is None: diff --git a/extropy/simulation/engine.py b/extropy/simulation/engine.py index 5237dc6..7f0f73c 100644 --- a/extropy/simulation/engine.py +++ b/extropy/simulation/engine.py @@ -14,8 +14,10 @@ import json import logging +import queue import random import sqlite3 +import threading import time import uuid from datetime import datetime @@ -149,6 +151,8 @@ def __init__( run_id: str | None = None, checkpoint_every_chunks: int = 1, retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, ): """Initialize simulation engine. @@ -173,6 +177,8 @@ def __init__( self.run_id = run_id or f"run_{uuid.uuid4().hex[:12]}" self.checkpoint_every_chunks = max(1, checkpoint_every_chunks) self.retention_lite = retention_lite + self.writer_queue_size = max(1, writer_queue_size) + self.db_write_batch_size = max(1, db_write_batch_size) # Build agent map for quick lookup self.agent_map = {a.get("_id", str(i)): a for i, a in enumerate(agents)} @@ -198,10 +204,13 @@ def __init__( self.output_dir.mkdir(parents=True, exist_ok=True) # Initialize state manager - state_db_file = Path(state_db_path) if state_db_path else self.output_dir / "study.db" + state_db_file = ( + Path(state_db_path) if state_db_path else self.output_dir / "study.db" + ) self.state_manager = StateManager( state_db_file, agents, + run_id=self.run_id, ) self.study_db = open_study_db(state_db_file) @@ -574,16 +583,89 @@ def _on_agent_done(agent_id: str, result: Any) -> None: context = self._build_reasoning_context(agent_id, old_state) contexts.append(context) - # Split into chunks - total_reasoned = 0 - total_changes = 0 - total_shares = 0 - completed_chunks = self.study_db.get_completed_simulation_chunks( self.run_id, timestep ) + totals = {"reasoned": 0, "changes": 0, "shares": 0} + + work_queue: queue.Queue[tuple[int, list[tuple[str, Any]], bool] | object] = ( + queue.Queue(maxsize=self.writer_queue_size) + ) + sentinel = object() + writer_error: list[Exception] = [] + + def _writer_loop() -> None: + chunks_since_checkpoint = 0 + pending_chunks: list[tuple[int, list[tuple[str, Any]], bool]] = [] + + def _flush_pending() -> None: + nonlocal chunks_since_checkpoint + if not pending_chunks: + return + with self.state_manager.transaction(): + for chunk_index, chunk_results, _is_last_chunk in pending_chunks: + reasoned, changes, shares = self._process_reasoning_chunk( + timestep, chunk_results, old_states + ) + totals["reasoned"] += reasoned + totals["changes"] += changes + totals["shares"] += shares + + for chunk_index, _chunk_results, is_last_chunk in pending_chunks: + self.study_db.save_simulation_checkpoint( + run_id=self.run_id, + timestep=timestep, + chunk_index=chunk_index, + status="done", + ) + + chunks_since_checkpoint += 1 + if ( + chunks_since_checkpoint >= self.checkpoint_every_chunks + or is_last_chunk + ): + self.study_db.set_run_metadata( + self.run_id, + "last_checkpoint", + f"{timestep}:{chunk_index}", + ) + chunks_since_checkpoint = 0 + + pending_chunks.clear() + + try: + while True: + item = work_queue.get() + try: + if item is sentinel: + _flush_pending() + break + chunk_index, chunk_results, is_last_chunk = item + if chunk_index in completed_chunks: + continue + pending_chunks.append( + (chunk_index, chunk_results, is_last_chunk) + ) + if ( + len(pending_chunks) >= self.db_write_batch_size + or is_last_chunk + ): + _flush_pending() + finally: + work_queue.task_done() + except Exception as e: # pragma: no cover - surfaced to caller + writer_error.append(e) + + writer_thread = threading.Thread( + target=_writer_loop, + name=f"sim-writer-{self.run_id}-{timestep}", + daemon=True, + ) + writer_thread.start() for chunk_start in range(0, len(contexts), self.chunk_size): + if writer_error: + break chunk_index = chunk_start // self.chunk_size if chunk_index in completed_chunks: logger.info( @@ -603,7 +685,6 @@ def _on_agent_done(agent_id: str, result: Any) -> None: reasoning_elapsed = time.time() - reasoning_start self.total_reasoning_calls += len(chunk_results) - # Accumulate token usage self.pivotal_input_tokens += chunk_usage.pivotal_input_tokens self.pivotal_output_tokens += chunk_usage.pivotal_output_tokens self.routine_input_tokens += chunk_usage.routine_input_tokens @@ -616,27 +697,26 @@ def _on_agent_done(agent_id: str, result: Any) -> None: else f"[TIMESTEP {timestep}] Chunk empty" ) - # Process and commit this chunk - with self.state_manager.transaction(): - reasoned, changes, shares = self._process_reasoning_chunk( - timestep, chunk_results, old_states - ) - if ( - ((chunk_index + 1) % self.checkpoint_every_chunks == 0) - or (chunk_start + self.chunk_size >= len(contexts)) - ): - self.study_db.save_simulation_checkpoint( - run_id=self.run_id, - timestep=timestep, - chunk_index=chunk_index, - status="done", - ) - - total_reasoned += reasoned - total_changes += changes - total_shares += shares - - return total_reasoned, total_changes, total_shares + is_last_chunk = chunk_start + self.chunk_size >= len(contexts) + work_queue.put((chunk_index, chunk_results, is_last_chunk)) + + work_queue.put(sentinel) + while work_queue.unfinished_tasks > 0: + if writer_error: + while True: + try: + work_queue.get_nowait() + work_queue.task_done() + except queue.Empty: + break + break + time.sleep(0.01) + work_queue.join() + writer_thread.join(timeout=1) + if writer_error: + raise writer_error[0] + + return totals["reasoned"], totals["changes"], totals["shares"] def _process_reasoning_chunk( self, @@ -1112,7 +1192,7 @@ def _finalize( final_exposure_rate=self.state_manager.get_exposure_rate(), outcome_distributions=outcome_dists, runtime_seconds=runtime, - model_used=self.config.model, + model_used=self.config.strong, completed_at=datetime.now(), ) @@ -1122,7 +1202,7 @@ def _compute_cost(self) -> dict[str, Any]: Returns: Cost dictionary with token counts and estimated USD. """ - from ..core.pricing import get_pricing, resolve_default_model + from ..core.pricing import get_pricing from ..config import get_config cost: dict[str, Any] = { @@ -1137,19 +1217,14 @@ def _compute_cost(self) -> dict[str, Any]: # Resolve effective model names for pricing config = get_config() - provider = config.simulation.provider - pivotal_model = ( - self.config.pivotal_model - or self.config.model - or config.simulation.pivotal_model - or config.simulation.model - or resolve_default_model(provider, "reasoning") - ) - routine_model = ( - self.config.routine_model - or config.simulation.routine_model - or resolve_default_model(provider, "simple") - ) + from ..config import parse_model_string + + strong_model_str = self.config.strong or config.resolve_sim_strong() + fast_model_str = self.config.fast or config.resolve_sim_fast() + + # Strip provider prefix for pricing lookup (pricing is keyed by bare model name) + _, pivotal_model = parse_model_string(strong_model_str) + _, routine_model = parse_model_string(fast_model_str) cost["pivotal_model"] = pivotal_model cost["routine_model"] = routine_model @@ -1238,9 +1313,8 @@ def _export_results(self) -> None: "scenario_name": self.scenario.meta.name, "scenario_path": self.config.scenario_path, "population_size": len(self.agents), - "model": self.config.model, - "pivotal_model": self.config.pivotal_model, - "routine_model": self.config.routine_model, + "strong_model": self.config.strong, + "fast_model": self.config.fast, "seed": self.seed, "multi_touch_threshold": self.config.multi_touch_threshold, "completed_at": datetime.now().isoformat(), @@ -1260,9 +1334,8 @@ def run_simulation( scenario_path: str | Path, output_dir: str | Path, study_db_path: str | Path | None = None, - model: str = "", - pivotal_model: str = "", - routine_model: str = "", + strong: str = "", + fast: str = "", multi_touch_threshold: int = 3, random_seed: int | None = None, on_progress: TimestepProgressCallback | None = None, @@ -1276,6 +1349,8 @@ def run_simulation( resume: bool = False, checkpoint_every_chunks: int = 1, retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, ) -> SimulationSummary: """Run a simulation from a scenario file. @@ -1284,9 +1359,8 @@ def run_simulation( Args: scenario_path: Path to scenario YAML file output_dir: Directory for results output - model: LLM model for agent reasoning - pivotal_model: Model for pivotal reasoning (default: same as model) - routine_model: Cheap model for routine + classification + strong: Strong model for Pass 1 reasoning (provider/model format) + fast: Fast model for Pass 2 classification (provider/model format) multi_touch_threshold: Re-reason after N new exposures random_seed: Random seed for reproducibility on_progress: Progress callback(timestep, max, status) @@ -1300,6 +1374,8 @@ def run_simulation( resume: Resume a prior run from DB checkpoints checkpoint_every_chunks: Mark simulation checkpoint every N chunks retention_lite: Reduce payload volume by dropping full raw reasoning text + writer_queue_size: Max buffered reasoning chunks before writer backpressure + db_write_batch_size: Number of chunks applied per DB writer transaction Returns: SimulationSummary with results @@ -1309,21 +1385,26 @@ def run_simulation( if resume and not run_id: raise ValueError("--resume requires --run-id") - def _reset_runtime_tables(path: Path) -> None: + def _reset_runtime_tables(path: Path, run_key: str) -> None: conn = sqlite3.connect(str(path)) try: cur = conn.cursor() - cur.executescript( - """ - DELETE FROM agent_states; - DELETE FROM exposures; - DELETE FROM memory_traces; - DELETE FROM timeline; - DELETE FROM timestep_summaries; - DELETE FROM shared_to; - DELETE FROM simulation_metadata; - """ - ) + statements = [ + "DELETE FROM agent_states WHERE run_id = ?", + "DELETE FROM exposures WHERE run_id = ?", + "DELETE FROM memory_traces WHERE run_id = ?", + "DELETE FROM timeline WHERE run_id = ?", + "DELETE FROM timestep_summaries WHERE run_id = ?", + "DELETE FROM shared_to WHERE run_id = ?", + "DELETE FROM simulation_metadata WHERE run_id = ?", + ] + for sql in statements: + try: + cur.execute(sql, (run_key,)) + except sqlite3.OperationalError: + # Legacy tables without run_id columns. + table = sql.split()[2] + cur.execute(f"DELETE FROM {table}") conn.commit() except sqlite3.OperationalError: # First run on this DB may not have simulation tables yet. @@ -1379,13 +1460,14 @@ def _reset_runtime_tables(path: Path) -> None: config={ "scenario_path": str(scenario_path), "output_dir": str(output_dir), - "model": model, - "pivotal_model": pivotal_model, - "routine_model": routine_model, + "strong": strong, + "fast": fast, "multi_touch_threshold": multi_touch_threshold, "chunk_size": chunk_size, "checkpoint_every_chunks": checkpoint_every_chunks, "retention_lite": retention_lite, + "writer_queue_size": writer_queue_size, + "db_write_batch_size": db_write_batch_size, "resume": resume, }, seed=random_seed, @@ -1395,7 +1477,7 @@ def _reset_runtime_tables(path: Path) -> None: db.set_run_metadata(resolved_run_id, "study_db", str(study_db_resolved)) if not resume: - _reset_runtime_tables(study_db_resolved) + _reset_runtime_tables(study_db_resolved, resolved_run_id) # Load persona config if provided persona_config = None @@ -1415,26 +1497,22 @@ def _reset_runtime_tables(path: Path) -> None: config = SimulationRunConfig( scenario_path=str(scenario_path), output_dir=str(output_dir), - model=model, - pivotal_model=pivotal_model, - routine_model=routine_model, + strong=strong, + fast=fast, multi_touch_threshold=multi_touch_threshold, random_seed=random_seed, ) - # Create dual rate limiter (separate limiters for pivotal and routine models) + # Resolve effective model strings for rate limiting from ..config import get_config entropy_config = get_config() - provider = entropy_config.simulation.provider - effective_model = model or entropy_config.simulation.model or "" - effective_pivotal = pivotal_model or effective_model - effective_routine = routine_model or entropy_config.simulation.routine_model or "" + effective_strong = strong or entropy_config.resolve_sim_strong() + effective_fast = fast or entropy_config.resolve_sim_fast() rate_limiter = DualRateLimiter.create( - provider=provider, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong_model_string=effective_strong, + fast_model_string=effective_fast, tier=rate_tier, rpm_override=rpm_override, tpm_override=tpm_override, @@ -1454,6 +1532,8 @@ def _reset_runtime_tables(path: Path) -> None: run_id=resolved_run_id, checkpoint_every_chunks=checkpoint_every_chunks, retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, ) if on_progress: diff --git a/extropy/simulation/estimator.py b/extropy/simulation/estimator.py index 376245a..5a4a103 100644 --- a/extropy/simulation/estimator.py +++ b/extropy/simulation/estimator.py @@ -9,7 +9,7 @@ from typing import Any from ..core.models import ScenarioSpec, PopulationSpec -from ..core.pricing import ModelPricing, get_pricing, resolve_default_model +from ..core.pricing import ModelPricing, get_pricing from ..utils.eval_safe import eval_condition, ConditionError @@ -138,9 +138,8 @@ def estimate_simulation_cost( population_spec: PopulationSpec, agents: list[dict[str, Any]], network: dict[str, Any], - provider: str = "openai", - pivotal_model: str = "", - routine_model: str = "", + strong_model: str = "", + fast_model: str = "", multi_touch_threshold: int = 3, ) -> CostEstimate: """Estimate the cost of running a simulation. @@ -153,9 +152,8 @@ def estimate_simulation_cost( population_spec: Population specification agents: List of agent dictionaries network: Network data dict - provider: LLM provider name - pivotal_model: Model for Pass 1 (empty = provider default) - routine_model: Model for Pass 2 (empty = provider cheap tier) + strong_model: Model for Pass 1 (provider/model format, empty = config default) + fast_model: Model for Pass 2 (provider/model format, empty = config default) multi_touch_threshold: Re-reasoning threshold Returns: @@ -167,9 +165,14 @@ def estimate_simulation_cost( share_prob = scenario.spread.share_probability will_share_rate = 0.55 # accounts for conviction-gated sharing - # Resolve models - eff_pivotal = pivotal_model or resolve_default_model(provider, "reasoning") - eff_routine = routine_model or resolve_default_model(provider, "simple") + # Resolve models — strip provider prefix for pricing lookup + from ..config import get_config, parse_model_string + + config = get_config() + eff_strong_str = strong_model or config.resolve_sim_strong() + eff_fast_str = fast_model or config.resolve_sim_fast() + _, eff_pivotal = parse_model_string(eff_strong_str) + _, eff_routine = parse_model_string(eff_fast_str) # Pre-compute seed exposure schedule: timestep -> expected new seed exposures seed_schedule: dict[int, float] = {} diff --git a/extropy/simulation/reasoning.py b/extropy/simulation/reasoning.py index 7c51151..4f87339 100644 --- a/extropy/simulation/reasoning.py +++ b/extropy/simulation/reasoning.py @@ -455,8 +455,8 @@ async def _reason_agent_two_pass_async( position_outcome = _get_primary_position_outcome(scenario) # Determine models - main_model = config.model or None # None = provider default - classify_model = config.routine_model or None # None = provider default (cheap) + main_model = config.strong or None # None = provider default + classify_model = config.fast or None # None = provider default (cheap) # === Pass 1: Role-play === pass1_usage = TokenUsage() @@ -687,7 +687,7 @@ def reason_agent( if pass2_schema: pass2_prompt = build_pass2_prompt(reasoning, scenario) - classify_model = config.routine_model or None + classify_model = config.fast or None for attempt in range(config.max_retries): try: @@ -794,7 +794,9 @@ async def reason_with_pacing( ctx: ReasoningContext, ) -> tuple[int, str, ReasoningResponse | None, float]: start = time.time() - result = await _reason_agent_two_pass_async(ctx, scenario, config, rate_limiter) + result = await _reason_agent_two_pass_async( + ctx, scenario, config, rate_limiter + ) elapsed = time.time() - start completed[0] += 1 diff --git a/extropy/simulation/state.py b/extropy/simulation/state.py index 94a4350..c583471 100644 --- a/extropy/simulation/state.py +++ b/extropy/simulation/state.py @@ -27,16 +27,23 @@ class StateManager: for frequently accessed data. """ - def __init__(self, db_path: Path | str, agents: list[dict[str, Any]] | None = None): + def __init__( + self, + db_path: Path | str, + agents: list[dict[str, Any]] | None = None, + run_id: str = "default", + ): """Initialize state manager with database path. Args: db_path: Path to SQLite database file agents: Optional list of agents to initialize + run_id: Run scope for all state reads/writes """ self.db_path = Path(db_path) + self.run_id = run_id self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path)) + self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) self.conn.row_factory = sqlite3.Row self.conn.execute("PRAGMA foreign_keys = ON") @@ -54,7 +61,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS agent_states ( - agent_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, aware INTEGER DEFAULT 0, exposure_count INTEGER DEFAULT 0, last_reasoning_timestep INTEGER DEFAULT -1, @@ -73,7 +81,8 @@ def _create_schema(self) -> None: private_conviction REAL, private_outcomes_json TEXT, raw_reasoning TEXT, - updated_at INTEGER DEFAULT 0 + updated_at INTEGER DEFAULT 0, + PRIMARY KEY (run_id, agent_id) ) """ ) @@ -82,6 +91,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS exposures ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, @@ -89,7 +99,7 @@ def _create_schema(self) -> None: source_agent_id TEXT, content TEXT, credibility REAL, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -98,13 +108,14 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS memory_traces ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, sentiment REAL, conviction REAL, summary TEXT, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -113,6 +124,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timeline ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, timestep INTEGER, event_type TEXT, @@ -127,7 +139,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timestep_summaries ( - timestep INTEGER PRIMARY KEY, + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, new_exposures INTEGER, agents_reasoned INTEGER, shares_occurred INTEGER, @@ -136,7 +149,8 @@ def _create_schema(self) -> None: position_distribution_json TEXT, average_sentiment REAL, average_conviction REAL, - sentiment_variance REAL + sentiment_variance REAL, + PRIMARY KEY (run_id, timestep) ) """ ) @@ -145,37 +159,37 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_agent - ON exposures(agent_id) + ON exposures(run_id, agent_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_timestep - ON exposures(timestep) + ON exposures(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timeline_timestep - ON timeline(timestep) + ON timeline(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_aware - ON agent_states(aware) + ON agent_states(run_id, aware) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_will_share - ON agent_states(will_share) + ON agent_states(run_id, will_share) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_memory_traces_agent - ON memory_traces(agent_id) + ON memory_traces(run_id, agent_id) """ ) @@ -183,18 +197,19 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS shared_to ( + run_id TEXT NOT NULL, source_agent_id TEXT, target_agent_id TEXT, timestep INTEGER, position TEXT, - PRIMARY KEY (source_agent_id, target_agent_id) + PRIMARY KEY (run_id, source_agent_id, target_agent_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_shared_to_source - ON shared_to(source_agent_id) + ON shared_to(run_id, source_agent_id) """ ) @@ -202,8 +217,11 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS simulation_metadata ( - key TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + key TEXT NOT NULL, value TEXT + , + PRIMARY KEY (run_id, key) ) """ ) @@ -215,6 +233,13 @@ def _upgrade_schema(self) -> None: cursor = self.conn.cursor() migrations = [ + ("agent_states", "run_id", "TEXT DEFAULT 'default'"), + ("exposures", "run_id", "TEXT DEFAULT 'default'"), + ("memory_traces", "run_id", "TEXT DEFAULT 'default'"), + ("timeline", "run_id", "TEXT DEFAULT 'default'"), + ("timestep_summaries", "run_id", "TEXT DEFAULT 'default'"), + ("shared_to", "run_id", "TEXT DEFAULT 'default'"), + ("simulation_metadata", "run_id", "TEXT DEFAULT 'default'"), ("agent_states", "conviction", "REAL"), ("agent_states", "public_statement", "TEXT"), ("timestep_summaries", "average_conviction", "REAL"), diff --git a/extropy/utils/resource_governor.py b/extropy/utils/resource_governor.py index 33b81a2..5287cb4 100644 --- a/extropy/utils/resource_governor.py +++ b/extropy/utils/resource_governor.py @@ -71,7 +71,9 @@ def recommend_workers( return requested_workers snap = self.snapshot() - cpu_cap = max(1, snap.cpu_count - 1) if self.safe_auto_workers else snap.cpu_count + cpu_cap = ( + max(1, snap.cpu_count - 1) if self.safe_auto_workers else snap.cpu_count + ) mem_cap = max(1, int(snap.memory_budget_gb / max(0.1, memory_per_worker_gb))) if self.safe_auto_workers: diff --git a/tests/test_cli.py b/tests/test_cli.py index 3c624af..0a3eb91 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ """CLI smoke tests using typer's CliRunner.""" + from pathlib import Path from typer.testing import CliRunner @@ -17,7 +18,7 @@ class TestConfigCommand: def test_config_show(self): result = runner.invoke(app, ["config", "show"]) assert result.exit_code == 0 - assert "Pipeline" in result.output + assert "Models" in result.output assert "Simulation" in result.output def test_config_set_invalid_key(self): @@ -78,7 +79,9 @@ def test_network_command_supports_fast_mode_and_checkpoint(self, tmp_path): {"_id": "a3", "role": "y", "team": "beta"}, ] with open_study_db(study_db) as db: - db.save_sample_result(population_id="default", agents=agents, meta={"source": "test"}) + db.save_sample_result( + population_id="default", agents=agents, meta={"source": "test"} + ) NetworkConfig(seed=42, avg_degree=2.0).to_yaml(config_path) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 32c0728..eec5597 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -85,7 +85,10 @@ def mock_files(self, minimal_population_spec, tmp_path): pop_path = tmp_path / "population.yaml" minimal_population_spec.to_yaml(pop_path) - agents = [{"_id": f"agent_{i:03d}", "age": 30 + i, "gender": "male"} for i in range(10)] + agents = [ + {"_id": f"agent_{i:03d}", "age": 30 + i, "gender": "male"} + for i in range(10) + ] edges = [ { "source": f"agent_{i:03d}", diff --git a/tests/test_estimator.py b/tests/test_estimator.py index 9bfe443..a40a053 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -323,7 +323,8 @@ def test_basic_estimate( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", + strong_model="openai/gpt-5", + fast_model="openai/gpt-5-mini", ) assert est.population_size == 10 @@ -370,7 +371,8 @@ def test_model_resolution_openai( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", + strong_model="openai/gpt-5", + fast_model="openai/gpt-5-mini", ) assert est.pivotal_model == "gpt-5" assert est.routine_model == "gpt-5-mini" @@ -383,7 +385,8 @@ def test_model_resolution_claude( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="claude", + strong_model="anthropic/claude-sonnet-4-5-20250929", + fast_model="anthropic/claude-haiku-4-5-20251001", ) assert est.pivotal_model == "claude-sonnet-4-5-20250929" assert est.routine_model == "claude-haiku-4-5-20251001" @@ -396,9 +399,8 @@ def test_explicit_model_override( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", - pivotal_model="gpt-5-mini", - routine_model="gpt-5-mini", + strong_model="openai/gpt-5-mini", + fast_model="openai/gpt-5-mini", ) assert est.pivotal_model == "gpt-5-mini" assert est.routine_model == "gpt-5-mini" @@ -411,8 +413,8 @@ def test_unknown_model_pricing_none( population_spec=small_pop_spec, agents=small_agents, network=small_network, - pivotal_model="unknown-model-x", - routine_model="unknown-model-y", + strong_model="openai/unknown-model-x", + fast_model="openai/unknown-model-y", ) assert est.pivotal_pricing is None assert est.routine_pricing is None diff --git a/tests/test_providers.py b/tests/test_providers.py index 291d846..752eb77 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -380,9 +380,8 @@ def test_no_validator_returns_immediately(self): """With no validator, first result is returned.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -421,9 +420,8 @@ def test_initial_prompt_used_on_first_call(self): """When initial_prompt is provided, it should be used for the first call.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -463,9 +461,8 @@ def test_validation_retries_use_base_prompt_not_initial(self): """Validation retries should use prompt, not initial_prompt.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -528,9 +525,8 @@ def test_validator_succeeds_on_first_attempt_with_initial_prompt(self): """When validator passes on first try with initial_prompt, no retries occur.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -574,9 +570,8 @@ def test_on_retry_callback_invoked_correctly(self): """Test that on_retry callback is invoked with correct parameters.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -640,9 +635,8 @@ def test_no_initial_prompt_defaults_to_prompt(self): """When initial_prompt is None, prompt is used for first call.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -813,12 +807,10 @@ class TestProviderFactoryAzure: ) def test_create_azure_openai_provider(self): from extropy.core.providers import _create_provider + from extropy.core.providers.openai_compat import OpenAICompatProvider provider = _create_provider("azure_openai") - assert isinstance(provider, OpenAIProvider) - assert provider._is_azure is True - assert provider._azure_endpoint == "https://my-resource.openai.azure.com" - assert provider._azure_deployment == "my-deployment" + assert isinstance(provider, OpenAICompatProvider) @patch.dict( "os.environ", diff --git a/tests/test_scenario_validator.py b/tests/test_scenario_validator.py index 5464228..7f23cde 100644 --- a/tests/test_scenario_validator.py +++ b/tests/test_scenario_validator.py @@ -1,6 +1,5 @@ """Tests for scenario validation behavior.""" -import json from pathlib import Path from extropy.core.models.scenario import ( @@ -163,7 +162,9 @@ def test_validate_scenario_allows_edge_weight_in_spread_modifier(tmp_path: Path) population_path.write_text("placeholder: true\n") with open_study_db(study_db) as db: - db.save_sample_result(population_id="default", agents=[], meta={"source": "test"}) + db.save_sample_result( + population_id="default", agents=[], meta={"source": "test"} + ) spec = _make_scenario_spec( str(population_path), From d992aef5b7ae5e8ee39240d7c83fa562840b90d9 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:08:13 -0500 Subject: [PATCH 10/15] refactor(config): remove all legacy/deprecated code from config module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop PipelineConfig, SimZoneConfig, v1→v2 migration, legacy env var handling, get_api_key(), and get_azure_config(). Pure Pydantic models with no backward-compat cruft. ~200 lines removed. Co-Authored-By: Claude Opus 4.6 --- extropy/config.py | 307 ++++------------------------------------------ 1 file changed, 26 insertions(+), 281 deletions(-) diff --git a/extropy/config.py b/extropy/config.py index 056ff4b..be66c43 100644 --- a/extropy/config.py +++ b/extropy/config.py @@ -18,11 +18,11 @@ import json import logging import os -import warnings -from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Any +from pydantic import BaseModel, ConfigDict, Field + logger = logging.getLogger(__name__) @@ -66,12 +66,11 @@ def parse_model_string(model_string: str) -> tuple[str, str]: # ============================================================================= -# New two-tier config dataclasses +# Two-tier config models # ============================================================================= -@dataclass -class ModelsConfig: +class ModelsConfig(BaseModel): """Pipeline model configuration (phases 1-2). Uses "provider/model" format strings. @@ -79,12 +78,13 @@ class ModelsConfig: - strong: used for reasoning_call, agentic_research (complex tasks) """ + model_config = ConfigDict(populate_by_name=True) + fast: str = "openai/gpt-5-mini" strong: str = "openai/gpt-5" -@dataclass -class SimulationConfig: +class SimulationConfig(BaseModel): """Simulation model + tuning configuration (phase 3). Uses "provider/model" format strings. @@ -92,6 +92,8 @@ class SimulationConfig: - strong: used for Pass 1 (pivotal/role-play reasoning) """ + model_config = ConfigDict(populate_by_name=True) + fast: str = "" # empty = same as models.fast strong: str = "" # empty = same as models.strong max_concurrent: int = 50 @@ -100,53 +102,19 @@ class SimulationConfig: tpm_override: int | None = None -@dataclass -class CustomProviderConfig: - """Configuration for a custom OpenAI-compatible provider endpoint.""" +class CustomProviderConfig(BaseModel): + """Config for a custom OpenAI-compatible provider endpoint.""" base_url: str = "" api_key_env: str = "" - - -# ============================================================================= -# Legacy config dataclasses (kept for migration) -# ============================================================================= - - -@dataclass -class PipelineConfig: - """DEPRECATED: Config for phases 1-2. Use ModelsConfig instead.""" - - provider: str = "openai" - model_simple: str = "" - model_reasoning: str = "" - model_research: str = "" - - -@dataclass -class SimZoneConfig: - """DEPRECATED: Config for phase 3. Use SimulationConfig instead.""" - - provider: str = "openai" - model: str = "" - pivotal_model: str = "" - routine_model: str = "" - max_concurrent: int = 50 - rate_tier: int | None = None - rpm_override: int | None = None - tpm_override: int | None = None - api_format: str = "" - - # ============================================================================= # Main config class # ============================================================================= -@dataclass -class ExtropyConfig: +class ExtropyConfig(BaseModel): """Top-level extropy configuration. Construct programmatically for package use, or load from config file for CLI use. @@ -165,40 +133,30 @@ class ExtropyConfig: config.simulation.strong = "openrouter/anthropic/claude-sonnet-4.5" """ - models: ModelsConfig = field(default_factory=ModelsConfig) - simulation: SimulationConfig = field(default_factory=SimulationConfig) - providers: dict[str, CustomProviderConfig] = field(default_factory=dict) + model_config = ConfigDict(populate_by_name=True) + + models: ModelsConfig = Field(default_factory=ModelsConfig) + simulation: SimulationConfig = Field(default_factory=SimulationConfig) + providers: dict[str, CustomProviderConfig] = Field(default_factory=dict) @classmethod def load(cls) -> "ExtropyConfig": """Load config from file + env vars. Priority: env var values > config.json values > defaults. - Auto-migrates v1 config format if detected. """ config = cls() - # Layer 1: Load from config file if it exists + # Load from config file if it exists if CONFIG_FILE.exists(): try: with open(CONFIG_FILE) as f: data = json.load(f) - - # Auto-migrate v1 config - if _is_v1_config(data): - warnings.warn( - "Detected legacy config format. Migrating to v2. " - "Run `extropy config show` to verify, then `extropy config set` to update.", - DeprecationWarning, - stacklevel=2, - ) - data = _migrate_v1_to_v2(data) - _apply_dict(config, data) except (json.JSONDecodeError, OSError) as exc: logger.warning("Failed to load config from %s: %s", CONFIG_FILE, exc) - # Layer 2: Env var overrides (new format) + # Env var overrides if val := os.environ.get("MODELS_FAST"): config.models.fast = val if val := os.environ.get("MODELS_STRONG"): @@ -227,8 +185,6 @@ def load(cls) -> "ExtropyConfig": config.simulation.tpm_override = int(val) except ValueError: logger.warning("Invalid SIMULATION_TPM_OVERRIDE=%r, ignoring", val) - # Layer 3: Legacy env var overrides (emit deprecation warnings) - _apply_legacy_env_vars(config) return config @@ -236,12 +192,12 @@ def save(self) -> None: """Save config to ~/.config/extropy/config.json.""" CONFIG_DIR.mkdir(parents=True, exist_ok=True) data: dict[str, Any] = { - "models": asdict(self.models), - "simulation": asdict(self.simulation), + "models": self.models.model_dump(), + "simulation": self.simulation.model_dump(), } if self.providers: data["providers"] = { - name: asdict(cfg) for name, cfg in self.providers.items() + name: cfg.model_dump() for name, cfg in self.providers.items() } with open(CONFIG_FILE, "w") as f: json.dump(data, f, indent=2) @@ -249,12 +205,12 @@ def save(self) -> None: def to_dict(self) -> dict[str, Any]: """Convert to dict for display.""" result = { - "models": asdict(self.models), - "simulation": asdict(self.simulation), + "models": self.models.model_dump(), + "simulation": self.simulation.model_dump(), } if self.providers: result["providers"] = { - name: asdict(cfg) for name, cfg in self.providers.items() + name: cfg.model_dump() for name, cfg in self.providers.items() } return result @@ -310,184 +266,6 @@ def _apply_dict(config: ExtropyConfig, data: dict) -> None: ) -# ============================================================================= -# V1 → V2 migration -# ============================================================================= - -# Provider name mapping for migration -_PROVIDER_CANONICAL = { - "openai": "openai", - "claude": "anthropic", - "anthropic": "anthropic", - "azure_openai": "azure", -} - -# Default model names per old provider -_V1_PROVIDER_DEFAULTS = { - "openai": {"fast": "gpt-5-mini", "strong": "gpt-5"}, - "claude": { - "fast": "claude-haiku-4-5-20251001", - "strong": "claude-sonnet-4-5-20250929", - }, - "anthropic": { - "fast": "claude-haiku-4-5-20251001", - "strong": "claude-sonnet-4-5-20250929", - }, - "azure_openai": {"fast": "gpt-5-mini", "strong": "gpt-5"}, -} - - -def _is_v1_config(data: dict) -> bool: - """Detect if config data is in v1 format (has 'pipeline' key).""" - return "pipeline" in data and "models" not in data - - -def _migrate_v1_to_v2(data: dict) -> dict: - """Convert v1 config format to v2. - - v1 format: - {"pipeline": {"provider": "openai", "model_simple": "...", ...}, - "simulation": {"provider": "openai", "model": "...", ...}} - - v2 format: - {"models": {"fast": "openai/gpt-5-mini", "strong": "openai/gpt-5"}, - "simulation": {"fast": "...", "strong": "...", ...}} - """ - result: dict[str, Any] = {} - - # Migrate pipeline → models - pipeline = data.get("pipeline", {}) - old_provider = pipeline.get("provider", "openai") - canonical = _PROVIDER_CANONICAL.get(old_provider, old_provider) - defaults = _V1_PROVIDER_DEFAULTS.get(old_provider, _V1_PROVIDER_DEFAULTS["openai"]) - - fast_model = pipeline.get("model_simple") or defaults["fast"] - strong_model = pipeline.get("model_reasoning") or defaults["strong"] - - result["models"] = { - "fast": f"{canonical}/{fast_model}", - "strong": f"{canonical}/{strong_model}", - } - - # Migrate simulation - sim = data.get("simulation", {}) - sim_provider = sim.get("provider", "openai") - sim_canonical = _PROVIDER_CANONICAL.get(sim_provider, sim_provider) - sim_defaults = _V1_PROVIDER_DEFAULTS.get( - sim_provider, _V1_PROVIDER_DEFAULTS["openai"] - ) - - sim_result: dict[str, Any] = {} - - # Map model/pivotal_model → strong, routine_model → fast - pivotal = sim.get("pivotal_model") or sim.get("model") or "" - routine = sim.get("routine_model") or "" - - if pivotal: - sim_result["strong"] = f"{sim_canonical}/{pivotal}" - if routine: - sim_result["fast"] = f"{sim_canonical}/{routine}" - - for k in ("max_concurrent", "rate_tier", "rpm_override", "tpm_override"): - if k in sim and sim[k] is not None: - sim_result[k] = sim[k] - - result["simulation"] = sim_result - - return result - - -# ============================================================================= -# Legacy env var handling -# ============================================================================= - -_LEGACY_ENV_WARNED: set[str] = set() - - -def _warn_legacy_env(name: str, replacement: str) -> None: - """Emit a one-time deprecation warning for a legacy env var.""" - if name not in _LEGACY_ENV_WARNED: - _LEGACY_ENV_WARNED.add(name) - warnings.warn( - f"Environment variable {name} is deprecated. Use {replacement} instead.", - DeprecationWarning, - stacklevel=4, - ) - - -def _apply_legacy_env_vars(config: ExtropyConfig) -> None: - """Apply legacy env vars with deprecation warnings.""" - # LLM_PROVIDER → both zones - if val := os.environ.get("LLM_PROVIDER"): - _warn_legacy_env("LLM_PROVIDER", "MODELS_FAST / MODELS_STRONG") - canonical = _PROVIDER_CANONICAL.get(val, val) - defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) - # Only override if no new-format env vars set - if not os.environ.get("MODELS_FAST"): - config.models.fast = f"{canonical}/{defaults['fast']}" - if not os.environ.get("MODELS_STRONG"): - config.models.strong = f"{canonical}/{defaults['strong']}" - - if val := os.environ.get("PIPELINE_PROVIDER"): - _warn_legacy_env("PIPELINE_PROVIDER", "MODELS_FAST / MODELS_STRONG") - canonical = _PROVIDER_CANONICAL.get(val, val) - defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) - if not os.environ.get("MODELS_FAST"): - config.models.fast = f"{canonical}/{defaults['fast']}" - if not os.environ.get("MODELS_STRONG"): - config.models.strong = f"{canonical}/{defaults['strong']}" - - if val := os.environ.get("SIMULATION_PROVIDER"): - _warn_legacy_env("SIMULATION_PROVIDER", "SIMULATION_FAST / SIMULATION_STRONG") - canonical = _PROVIDER_CANONICAL.get(val, val) - defaults = _V1_PROVIDER_DEFAULTS.get(val, _V1_PROVIDER_DEFAULTS["openai"]) - if not os.environ.get("SIMULATION_FAST"): - config.simulation.fast = f"{canonical}/{defaults['fast']}" - if not os.environ.get("SIMULATION_STRONG"): - config.simulation.strong = f"{canonical}/{defaults['strong']}" - - if val := os.environ.get("MODEL_SIMPLE"): - _warn_legacy_env("MODEL_SIMPLE", "MODELS_FAST") - if not os.environ.get("MODELS_FAST"): - provider, _ = parse_model_string(config.models.fast) - config.models.fast = f"{provider}/{val}" - - if val := os.environ.get("MODEL_REASONING"): - _warn_legacy_env("MODEL_REASONING", "MODELS_STRONG") - if not os.environ.get("MODELS_STRONG"): - provider, _ = parse_model_string(config.models.strong) - config.models.strong = f"{provider}/{val}" - - if val := os.environ.get("SIMULATION_MODEL"): - _warn_legacy_env("SIMULATION_MODEL", "SIMULATION_STRONG") - if not os.environ.get("SIMULATION_STRONG"): - # Resolve provider from sim strong or models strong - base = config.simulation.strong or config.models.strong - provider, _ = parse_model_string(base) - config.simulation.strong = f"{provider}/{val}" - - if val := os.environ.get("SIMULATION_PIVOTAL_MODEL"): - _warn_legacy_env("SIMULATION_PIVOTAL_MODEL", "SIMULATION_STRONG") - if not os.environ.get("SIMULATION_STRONG"): - base = config.simulation.strong or config.models.strong - provider, _ = parse_model_string(base) - config.simulation.strong = f"{provider}/{val}" - - if val := os.environ.get("SIMULATION_ROUTINE_MODEL"): - _warn_legacy_env("SIMULATION_ROUTINE_MODEL", "SIMULATION_FAST") - if not os.environ.get("SIMULATION_FAST"): - base = config.simulation.fast or config.models.fast - provider, _ = parse_model_string(base) - config.simulation.fast = f"{provider}/{val}" - - # SIMULATION_API_FORMAT — no direct replacement, just warn - if os.environ.get("SIMULATION_API_FORMAT"): - _warn_legacy_env( - "SIMULATION_API_FORMAT", - "provider-based routing (api_format is now automatic)", - ) - - # ============================================================================= # API key resolution # ============================================================================= @@ -544,43 +322,10 @@ def get_api_key_for_provider( "azure": "AZURE_OPENAI_API_KEY", "azure_openai": "AZURE_OPENAI_API_KEY", } - env_var = key_map.get( - provider_name, f"{provider_name.upper()}_API_KEY" - ) + env_var = key_map.get(provider_name, f"{provider_name.upper()}_API_KEY") return os.environ.get(env_var, "") -def get_api_key(provider: str) -> str: - """DEPRECATED: Get API key for a provider. Use get_api_key_for_provider instead. - - Kept for backward compatibility. - """ - # Map old provider names - mapping = { - "claude": "anthropic", - "azure_openai": "azure", - } - canonical = mapping.get(provider, provider) - return get_api_key_for_provider(canonical) - - -def get_azure_config(provider: str) -> dict[str, str]: - """DEPRECATED: Get Azure-specific configuration. - - Azure is now handled as an OpenAI-compatible provider. - """ - _ensure_dotenv() - if provider in ("azure_openai", "azure"): - return { - "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT", ""), - "api_version": os.environ.get( - "AZURE_OPENAI_API_VERSION", "2025-03-01-preview" - ), - "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT", ""), - } - return {} - - # ============================================================================= # Global config singleton # ============================================================================= From f87dacd1caab23843f90043627a5af7a04730061 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:13:46 -0500 Subject: [PATCH 11/15] feat(network): enforce DB-only similarity checkpoints and resume --- README.md | 23 +-- extropy/cli/commands/network.py | 21 ++- extropy/population/network/__init__.py | 3 +- extropy/population/network/generator.py | 186 +++++++----------------- extropy/scenario/__init__.py | 5 +- extropy/storage/study_db.py | 28 +--- tests/test_cli.py | 26 +++- tests/test_network.py | 69 +++++++-- 8 files changed, 165 insertions(+), 196 deletions(-) diff --git a/README.md b/README.md index 33ddbb7..cbf69ee 100644 --- a/README.md +++ b/README.md @@ -61,22 +61,22 @@ extropy config show ```bash mkdir -p austin +STUDY_DB=austin/study.db # Build a population extropy spec "500 Austin TX commuters who drive into downtown for work" -o austin/base.yaml extropy extend austin/base.yaml -s "Response to a $15/day downtown congestion tax" -o austin/population.yaml -extropy sample austin/population.yaml -o austin/agents.json --seed 42 -extropy network austin/agents.json -o austin/network.json -p austin/population.yaml --seed 42 -extropy persona austin/population.yaml --agents austin/agents.json -o austin/population.persona.yaml +extropy sample austin/population.yaml --study-db "$STUDY_DB" --seed 42 +extropy network --study-db "$STUDY_DB" -p austin/population.yaml --seed 42 --checkpoint "$STUDY_DB" # Compile and run a scenario -extropy scenario -p austin/population.yaml -a austin/agents.json -n austin/network.json -o austin/scenario.yaml -extropy estimate austin/scenario.yaml -extropy simulate austin/scenario.yaml -o austin/results/ --seed 42 +extropy scenario -p austin/population.yaml --study-db "$STUDY_DB" -o austin/scenario.yaml +extropy estimate austin/scenario.yaml --study-db "$STUDY_DB" +extropy simulate austin/scenario.yaml --study-db "$STUDY_DB" -o austin/results/ --seed 42 # View results -extropy results austin/results/ -extropy results austin/results/ --segment income +extropy results --study-db "$STUDY_DB" +extropy results --study-db "$STUDY_DB" --segment income ``` ### What Comes Out @@ -126,11 +126,14 @@ Each agent reasoned individually. A low-income commuter with no transit access r Simulation output directory (`austin/results/`) contains: - `study.db` (canonical state + checkpoint store) -- `agent_states.json` (final per-agent states) - `by_timestep.json` (time-series aggregates) -- `outcome_distributions.json` (final distributions) - `meta.json` (run metadata + token/cost summary) +For full datasets, use explicit exports from `study.db`: +- `extropy export states --study-db "$STUDY_DB" --to austin/results/states.jsonl` +- `extropy export agents --study-db "$STUDY_DB" --to austin/results/agents.jsonl` +- `extropy export edges --study-db "$STUDY_DB" --to austin/results/edges.jsonl` + The scenario YAML controls what gets tracked: ```yaml diff --git a/extropy/cli/commands/network.py b/extropy/cli/commands/network.py index d1b8eaa..1c49a49 100644 --- a/extropy/cli/commands/network.py +++ b/extropy/cli/commands/network.py @@ -88,12 +88,12 @@ def network_command( checkpoint: Path | None = typer.Option( None, "--checkpoint", - help="Path to similarity checkpoint file (.pkl) or study DB (.db)", + help="DB path for similarity checkpointing (must be the same as --study-db)", ), resume_checkpoint: bool = typer.Option( False, "--resume-checkpoint", - help="Resume similarity stage from --checkpoint file", + help="Resume similarity stage from checkpoint tables in --study-db", ), checkpoint_every: int = typer.Option( 250, @@ -149,8 +149,15 @@ def network_command( start_time = time.time() console.print() - if resume_checkpoint and checkpoint is None: - checkpoint = study_db + if ( + checkpoint is not None + and checkpoint.expanduser().resolve() != study_db.expanduser().resolve() + ): + console.print( + "[red]✗[/red] --checkpoint must point to the same canonical file as --study-db" + ) + raise typer.Exit(1) + checkpoint_db = study_db if (resume_checkpoint or checkpoint is not None) else None # Load Agents if not study_db.exists(): @@ -293,7 +300,7 @@ def network_command( console.print( f"[dim]Mode: {config.candidate_mode} | workers={config.similarity_workers} " - f"| checkpoint={'on' if checkpoint else 'off'}[/dim]" + f"| checkpoint={'on' if checkpoint_db else 'off'}[/dim]" ) if resource_mode == "auto": snap = governor.snapshot() @@ -323,7 +330,7 @@ def do_generation(): agents, config, on_progress, - checkpoint_path=checkpoint, + checkpoint_path=checkpoint_db, resume_from_checkpoint=resume_checkpoint, ) else: @@ -331,7 +338,7 @@ def do_generation(): agents, config, on_progress, - checkpoint_path=checkpoint, + checkpoint_path=checkpoint_db, resume_from_checkpoint=resume_checkpoint, ) except Exception as e: diff --git a/extropy/population/network/__init__.py b/extropy/population/network/__init__.py index 7772d0a..e6f03c6 100644 --- a/extropy/population/network/__init__.py +++ b/extropy/population/network/__init__.py @@ -7,7 +7,8 @@ Usage: from extropy.network import generate_network, NetworkConfig, NetworkResult - # Load agents from JSON + # Agents are typically loaded from study.db via CLI, then passed here. + # (load_agents_json is kept for explicit import/export workflows.) agents = load_agents_json("agents.json") # Generate network with default config (flat — no similarity structure) diff --git a/extropy/population/network/generator.py b/extropy/population/network/generator.py index 4e412df..a2a1d98 100644 --- a/extropy/population/network/generator.py +++ b/extropy/population/network/generator.py @@ -7,7 +7,6 @@ import logging import hashlib import multiprocessing as mp -import pickle import random from concurrent.futures import ProcessPoolExecutor, as_completed from datetime import datetime @@ -33,10 +32,6 @@ _SIM_WORKER_CANDIDATE_MAP: list[list[int]] | None = None -def _is_db_checkpoint(path: Path | None) -> bool: - return path is not None and path.suffix.lower() == ".db" - - def _choose_blocking_attributes(config: NetworkConfig) -> list[str]: """Choose blocking attributes for candidate pruning.""" if config.blocking_attributes: @@ -137,104 +132,33 @@ def _similarity_checkpoint_job_id(signature: dict[str, Any]) -> str: return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24] -def _save_similarity_checkpoint( - path: Path, - similarities: dict[tuple[int, int], float], - completed_rows: int, - signature: dict[str, Any], - completed_chunks: list[tuple[int, int]] | None = None, -) -> None: - """Persist sparse similarities so generation can resume after interruption.""" - payload = { - "version": 1, - "completed_rows": completed_rows, - "completed_chunks": completed_chunks or [], - "signature": signature, - "similarities": similarities, - "saved_at": datetime.now().isoformat(), - } - if _is_db_checkpoint(path): - job_id = _similarity_checkpoint_job_id(signature) - with open_study_db(path) as db: - db.init_network_similarity_job( - network_run_id=f"checkpoint:{job_id}", - signature=signature, - job_id=job_id, - ) - db.save_similarity_snapshot(job_id=job_id, payload=pickle.dumps(payload)) - return - - path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(path.suffix + ".tmp") - with open(tmp_path, "wb") as f: - pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) - tmp_path.replace(path) - - def _load_similarity_checkpoint( - path: Path, + checkpoint_db: Path, expected_signature: dict[str, Any], ) -> tuple[dict[tuple[int, int], float], int, set[int]]: """Load checkpoint and validate compatibility with current run settings.""" - if _is_db_checkpoint(path): - job_id = _similarity_checkpoint_job_id(expected_signature) - with open_study_db(path) as db: - signature = db.get_network_similarity_job_signature(job_id) - if signature is not None: - if signature != expected_signature: - raise ValueError( - "Checkpoint settings do not match current run. " - "Delete checkpoint or run with matching config." - ) - - done_chunks = db.list_completed_similarity_chunks(job_id) - done_starts = {start for start, _ in done_chunks} - similarities = db.load_similarity_pairs(job_id) - - # Resume serial fallback only from contiguous completed prefix. - contiguous_rows = 0 - for start, end in done_chunks: - if start != contiguous_rows: - break - contiguous_rows = end - - return similarities, max(0, contiguous_rows), done_starts - - payload_bytes = db.get_similarity_snapshot(job_id) - if payload_bytes is None: - raise ValueError(f"Checkpoint not found in study DB: job_id={job_id}") - payload = pickle.loads(payload_bytes) - else: - with open(path, "rb") as f: - payload = pickle.load(f) + job_id = _similarity_checkpoint_job_id(expected_signature) + with open_study_db(checkpoint_db) as db: + signature = db.get_network_similarity_job_signature(job_id) + if signature is None: + raise ValueError(f"Checkpoint not found in study DB: job_id={job_id}") + if signature != expected_signature: + raise ValueError( + "Checkpoint settings do not match current run. " + "Delete checkpoint or run with matching config." + ) - signature = payload.get("signature", {}) - if signature != expected_signature: - raise ValueError( - "Checkpoint settings do not match current run. " - "Delete checkpoint or run with matching config." - ) + done_chunks = db.list_completed_similarity_chunks(job_id) + done_starts = {start for start, _ in done_chunks} + similarities = db.load_similarity_pairs(job_id) - similarities = payload.get("similarities", {}) - completed_rows = int(payload.get("completed_rows", 0)) - completed_chunk_starts: set[int] = set() - allowed_completed_rows = max(0, completed_rows) - for item in payload.get("completed_chunks", []): - if ( - isinstance(item, (list, tuple)) - and len(item) == 2 - and isinstance(item[0], int) - and isinstance(item[1], int) - ): - # Guard against stale/inconsistent payloads where completed_chunks - # were not truncated with completed_rows. - if item[0] < allowed_completed_rows and item[1] <= allowed_completed_rows: - completed_chunk_starts.add(item[0]) - - if not isinstance(similarities, dict): - raise ValueError("Invalid checkpoint similarities payload") + contiguous_rows = 0 + for start, end in done_chunks: + if start != contiguous_rows: + break + contiguous_rows = end - return similarities, max(0, completed_rows), completed_chunk_starts + return similarities, max(0, contiguous_rows), done_starts def _init_similarity_worker( @@ -364,25 +288,14 @@ def _compute_similarities_parallel( if ( checkpoint_path is not None and checkpoint_signature is not None + and checkpoint_job_id is not None ): - if _is_db_checkpoint(checkpoint_path) and checkpoint_job_id: - with open_study_db(checkpoint_path) as db: - db.save_similarity_chunk_rows( - job_id=checkpoint_job_id, - chunk_start=current_start, - chunk_end=current_end, - rows=chunk_rows, - ) - else: - completed_chunks = [ - (s, e) for s, e in tasks if s in completed_starts - ] - _save_similarity_checkpoint( - path=checkpoint_path, - similarities=similarities, - completed_rows=min(completed_row_count, n), - signature=checkpoint_signature, - completed_chunks=completed_chunks, + with open_study_db(checkpoint_path) as db: + db.save_similarity_chunk_rows( + job_id=checkpoint_job_id, + chunk_start=current_start, + chunk_end=current_end, + rows=chunk_rows, ) if on_progress: @@ -392,12 +305,22 @@ def _compute_similarities_parallel( next_commit_idx += 1 except Exception as e: + downgraded_config = config.model_copy( + update={ + "similarity_workers": 1, + "similarity_chunk_size": max(8, config.similarity_chunk_size // 2), + } + ) logger.warning( - "Parallel similarity failed (%s). Falling back to serial mode.", e + "Parallel similarity failed (%s). Falling back to serial mode " + "(chunk_size %d -> %d).", + e, + config.similarity_chunk_size, + downgraded_config.similarity_chunk_size, ) return _compute_similarities_serial( agents=agents, - config=config, + config=downgraded_config, candidate_map=candidate_map, on_progress=on_progress, checkpoint_path=checkpoint_path, @@ -460,8 +383,15 @@ def _compute_similarities_serial( completed_starts.add(start) completed_row_count += end - start - if checkpoint_path is not None and checkpoint_signature is not None: - if _is_db_checkpoint(checkpoint_path) and checkpoint_job_id: + if ( + checkpoint_path is not None + and checkpoint_signature is not None + and checkpoint_job_id is not None + ): + if ( + completed_row_count % checkpoint_every == 0 + or chunk_idx == len(tasks) - 1 + ): with open_study_db(checkpoint_path) as db: db.save_similarity_chunk_rows( job_id=checkpoint_job_id, @@ -469,18 +399,6 @@ def _compute_similarities_serial( chunk_end=end, rows=local_rows, ) - elif ( - completed_row_count % checkpoint_every == 0 - or chunk_idx == len(tasks) - 1 - ): - completed_chunks = [(s, e) for s, e in tasks if s in completed_starts] - _save_similarity_checkpoint( - path=checkpoint_path, - similarities=similarities, - completed_rows=min(completed_row_count, n), - signature=checkpoint_signature, - completed_chunks=completed_chunks, - ) if on_progress: on_progress("Computing similarities", min(completed_row_count, n), n) @@ -1225,6 +1143,10 @@ def generate_network( n = len(agents) agent_ids = [a.get("_id", f"agent_{i}") for i, a in enumerate(agents)] checkpoint_file = Path(checkpoint_path) if checkpoint_path else None + if checkpoint_file is not None and checkpoint_file.suffix.lower() != ".db": + raise ValueError( + "Network checkpoints are DB-only now. Use --study-db (or --checkpoint )." + ) # Step 1: Compute degree factors degree_factors = [compute_degree_factor(a, config) for a in agents] @@ -1256,7 +1178,7 @@ def generate_network( blocking_attrs=blocking_attrs, ) checkpoint_job_id: str | None = None - if _is_db_checkpoint(checkpoint_file): + if checkpoint_file is not None: checkpoint_job_id = _similarity_checkpoint_job_id(checkpoint_signature) if not resume_from_checkpoint: with open_study_db(checkpoint_file) as db: @@ -1272,7 +1194,7 @@ def generate_network( completed_chunk_starts: set[int] = set() if resume_from_checkpoint and checkpoint_file is None: - raise ValueError("--resume-checkpoint requires --checkpoint path") + raise ValueError("--resume-checkpoint requires a checkpoint DB path") if resume_from_checkpoint: if checkpoint_file is None or not checkpoint_file.exists(): diff --git a/extropy/scenario/__init__.py b/extropy/scenario/__init__.py index 6dd793d..31c8893 100644 --- a/extropy/scenario/__init__.py +++ b/extropy/scenario/__init__.py @@ -16,8 +16,9 @@ >>> spec, result = create_scenario( ... "Netflix announces $3 price increase", ... "population.yaml", - ... "agents.json", - ... "network.json", + ... study_db_path="study.db", + ... population_id="default", + ... network_id="default", ... "scenario.yaml" ... ) >>> result.valid diff --git a/extropy/storage/study_db.py b/extropy/storage/study_db.py index 8e4731b..627d288 100644 --- a/extropy/storage/study_db.py +++ b/extropy/storage/study_db.py @@ -133,12 +133,6 @@ def init_schema(self) -> None: PRIMARY KEY (job_id, i, j) ) WITHOUT ROWID; - CREATE TABLE IF NOT EXISTS network_similarity_snapshots ( - job_id TEXT PRIMARY KEY, - payload BLOB NOT NULL, - updated_at TEXT NOT NULL - ); - CREATE TABLE IF NOT EXISTS simulation_runs ( run_id TEXT PRIMARY KEY, scenario_name TEXT, @@ -297,6 +291,8 @@ def init_schema(self) -> None: CREATE INDEX IF NOT EXISTS idx_agent_states_aware ON agent_states(run_id, aware); CREATE INDEX IF NOT EXISTS idx_agent_states_will_share ON agent_states(run_id, will_share); CREATE INDEX IF NOT EXISTS idx_agent_states_last_reasoning ON agent_states(run_id, last_reasoning_timestep); + CREATE INDEX IF NOT EXISTS idx_agent_states_run_awws + ON agent_states(run_id, aware, will_share, last_reasoning_timestep); CREATE INDEX IF NOT EXISTS idx_exposures_agent_timestep ON exposures(run_id, agent_id, timestep); CREATE INDEX IF NOT EXISTS idx_timeline_timestep ON timeline(run_id, timestep); CREATE INDEX IF NOT EXISTS idx_shared_to_source ON shared_to(run_id, source_agent_id); @@ -691,26 +687,6 @@ def mark_similarity_job_complete(self, job_id: str, drop_pairs: bool = False) -> cursor.execute("DELETE FROM network_similarity_pairs WHERE job_id = ?", (job_id,)) self.conn.commit() - def save_similarity_snapshot(self, job_id: str, payload: bytes) -> None: - cursor = self.conn.cursor() - cursor.execute( - """ - INSERT OR REPLACE INTO network_similarity_snapshots (job_id, payload, updated_at) - VALUES (?, ?, ?) - """, - (job_id, payload, _now_iso()), - ) - self.conn.commit() - - def get_similarity_snapshot(self, job_id: str) -> bytes | None: - cursor = self.conn.cursor() - cursor.execute( - "SELECT payload FROM network_similarity_snapshots WHERE job_id = ?", - (job_id,), - ) - row = cursor.fetchone() - return bytes(row["payload"]) if row else None - def create_simulation_run( self, run_id: str, diff --git a/tests/test_cli.py b/tests/test_cli.py index 87cc13d..9d3de05 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -71,7 +71,6 @@ def test_network_command_supports_fast_mode_and_checkpoint(self, tmp_path): study_db = tmp_path / "study.db" config_path = tmp_path / "network-config.yaml" output_path = tmp_path / "network.json" - checkpoint_path = tmp_path / "network-checkpoint.pkl" agents = [ {"_id": "a0", "role": "x", "team": "alpha"}, @@ -106,7 +105,7 @@ def test_network_command_supports_fast_mode_and_checkpoint(self, tmp_path): "--similarity-chunk-size", "8", "--checkpoint", - str(checkpoint_path), + str(study_db), "--checkpoint-every", "1", ], @@ -114,7 +113,9 @@ def test_network_command_supports_fast_mode_and_checkpoint(self, tmp_path): assert result.exit_code == 0 assert output_path.exists() - assert checkpoint_path.exists() + with open_study_db(study_db) as db: + rows = db.run_select("SELECT COUNT(*) AS cnt FROM network_similarity_jobs") + assert rows and int(rows[0]["cnt"]) >= 1 def test_network_resume_requires_checkpoint(self): result = runner.invoke( @@ -131,6 +132,25 @@ def test_network_resume_requires_checkpoint(self): assert result.exit_code == 1 assert "Study DB not found" in result.output + def test_network_checkpoint_must_match_study_db(self, tmp_path): + study_db = tmp_path / "study.db" + other_db = tmp_path / "other.db" + with open_study_db(study_db) as db: + db.save_sample_result(population_id="default", agents=[{"_id": "a0"}], meta={}) + + result = runner.invoke( + app, + [ + "network", + "--study-db", + str(study_db), + "--checkpoint", + str(other_db), + ], + ) + assert result.exit_code == 1 + assert "--checkpoint must point to the same canonical file as --study-db" in result.output + def _seed_run_scoped_state(study_db: Path) -> None: agents = [ diff --git a/tests/test_network.py b/tests/test_network.py index eb1effe..749f023 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -733,12 +733,18 @@ def test_generate_network_blocked_mode_reproducibility(self, sample_agents): def test_generate_network_resume_from_checkpoint_matches_fresh(self, sample_agents): """Resuming from a saved similarity checkpoint should match a fresh run.""" - import pickle + import sqlite3 - config = REFERENCE_NETWORK_CONFIG.model_copy(update={"seed": 42}) + config = REFERENCE_NETWORK_CONFIG.model_copy( + update={ + "seed": 42, + "similarity_chunk_size": 8, + "checkpoint_every_rows": 1, + } + ) with tempfile.TemporaryDirectory() as tmpdir: - checkpoint_path = Path(tmpdir) / "network-similarity.pkl" + checkpoint_path = Path(tmpdir) / "study.db" # Build and persist checkpoint from a full run. result_checkpointed = generate_network( @@ -748,18 +754,39 @@ def test_generate_network_resume_from_checkpoint_matches_fresh(self, sample_agen ) assert checkpoint_path.exists() - # Simulate interruption by truncating completed_rows in checkpoint metadata. - with open(checkpoint_path, "rb") as f: - payload = pickle.load(f) - completed_rows = max(1, len(sample_agents) // 2) - payload["completed_rows"] = completed_rows - payload["similarities"] = { - pair: sim - for pair, sim in payload["similarities"].items() - if pair[0] < completed_rows - } - with open(checkpoint_path, "wb") as f: - pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + # Simulate interruption by dropping the latter half of completed chunks. + conn = sqlite3.connect(str(checkpoint_path)) + cur = conn.cursor() + cur.execute( + "SELECT job_id FROM network_similarity_jobs ORDER BY created_at DESC LIMIT 1" + ) + job_id = cur.fetchone()[0] + cutoff = max(8, (len(sample_agents) // 2)) + cur.execute( + """ + SELECT MIN(chunk_start) + FROM network_similarity_chunks + WHERE job_id = ? AND chunk_start >= ? + """, + (job_id, cutoff), + ) + drop_start = cur.fetchone()[0] + if drop_start is None: + drop_start = 0 + cur.execute( + "DELETE FROM network_similarity_chunks WHERE job_id = ? AND chunk_start >= ?", + (job_id, drop_start), + ) + cur.execute( + "DELETE FROM network_similarity_pairs WHERE job_id = ? AND i >= ?", + (job_id, drop_start), + ) + cur.execute( + "UPDATE network_similarity_jobs SET status = 'running' WHERE job_id = ?", + (job_id,), + ) + conn.commit() + conn.close() resumed = generate_network( sample_agents, @@ -776,6 +803,18 @@ def test_generate_network_resume_from_checkpoint_matches_fresh(self, sample_agen assert resumed_edges == fresh_edges assert len(resumed.edges) == len(result_checkpointed.edges) + def test_generate_network_checkpoint_requires_db_path(self, sample_agents): + """Checkpoint path must be a SQLite DB path.""" + config = REFERENCE_NETWORK_CONFIG.model_copy(update={"seed": 42}) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "network-similarity.pkl" + with pytest.raises(ValueError, match="DB-only"): + generate_network( + sample_agents, + config, + checkpoint_path=checkpoint_path, + ) + class TestGenerateNetworkWithMetrics: """Tests for network generation with metrics.""" From 3bf939ab569b613e371053723fcf07eeb6a3b96b Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:13:57 -0500 Subject: [PATCH 12/15] feat(sim): add runtime memory guardrails and compact default exports --- extropy/cli/commands/simulate.py | 2 + extropy/simulation/__init__.py | 2 - extropy/simulation/engine.py | 95 ++++++++++++++++--------- extropy/simulation/reasoning.py | 5 +- extropy/utils/resource_governor.py | 20 ++++++ tests/test_engine.py | 109 +++++++++++++++++++++++++++-- tests/test_resource_governor.py | 26 +++++++ 7 files changed, 217 insertions(+), 42 deletions(-) create mode 100644 tests/test_resource_governor.py diff --git a/extropy/cli/commands/simulate.py b/extropy/cli/commands/simulate.py index bd592eb..41f475e 100644 --- a/extropy/cli/commands/simulate.py +++ b/extropy/cli/commands/simulate.py @@ -332,6 +332,7 @@ def on_progress(timestep: int, max_timesteps: int, status: str): retention_lite=retention_lite, writer_queue_size=writer_queue_size, db_write_batch_size=db_write_batch_size, + resource_governor=governor, ) simulation_error = None except Exception as e: @@ -368,6 +369,7 @@ def do_simulation(): retention_lite=retention_lite, writer_queue_size=writer_queue_size, db_write_batch_size=db_write_batch_size, + resource_governor=governor, ) except Exception as e: simulation_error = e diff --git a/extropy/simulation/__init__.py b/extropy/simulation/__init__.py index 2d6ed67..1d3ecb8 100644 --- a/extropy/simulation/__init__.py +++ b/extropy/simulation/__init__.py @@ -26,9 +26,7 @@ Output: Results directory containing: - study.db: Canonical SQLite database with simulation state/checkpoints - - agent_states.json: Final state per agent - by_timestep.json: Metrics over time - - outcome_distributions.json: Final outcome distributions - meta.json: Run configuration """ diff --git a/extropy/simulation/engine.py b/extropy/simulation/engine.py index fe85c0e..acec1b8 100644 --- a/extropy/simulation/engine.py +++ b/extropy/simulation/engine.py @@ -49,6 +49,7 @@ from .propagation import apply_seed_exposures, propagate_through_network from .stopping import evaluate_stopping_conditions from ..utils.callbacks import TimestepProgressCallback +from ..utils.resource_governor import ResourceGovernor from .aggregation import ( compute_timestep_summary, compute_final_aggregates, @@ -156,6 +157,7 @@ def __init__( retention_lite: bool = False, writer_queue_size: int = 256, db_write_batch_size: int = 100, + resource_governor: ResourceGovernor | None = None, ): """Initialize simulation engine. @@ -182,6 +184,14 @@ def __init__( self.retention_lite = retention_lite self.writer_queue_size = max(1, writer_queue_size) self.db_write_batch_size = max(1, db_write_batch_size) + self.resource_governor = resource_governor + self.reasoning_max_concurrency = 50 + if self.resource_governor is not None: + self.reasoning_max_concurrency = self.resource_governor.recommend_workers( + requested_workers=50, + memory_per_worker_gb=0.2, + ) + self._last_guardrail_timestep = -1 # Build agent map for quick lookup self.agent_map = {a.get("_id", str(i)): a for i, a in enumerate(agents)} @@ -294,6 +304,49 @@ def set_progress_state(self, progress: SimulationProgress) -> None: """ self._progress = progress + def _apply_runtime_guardrails(self, timestep: int) -> None: + """Downshift runtime knobs when process memory nears configured budget.""" + if self.resource_governor is None or self.resource_governor.resource_mode != "auto": + return + + ratio = self.resource_governor.memory_pressure_ratio() + if ratio < 0.85: + return + + factor = 0.5 if ratio >= 0.98 else 0.75 + old_concurrency = self.reasoning_max_concurrency + old_batch = self.db_write_batch_size + old_queue = self.writer_queue_size + + self.reasoning_max_concurrency = self.resource_governor.downshift_int( + self.reasoning_max_concurrency, factor=factor, minimum=1 + ) + self.db_write_batch_size = self.resource_governor.downshift_int( + self.db_write_batch_size, factor=factor, minimum=1 + ) + self.writer_queue_size = self.resource_governor.downshift_int( + self.writer_queue_size, factor=factor, minimum=4 + ) + + changed = ( + old_concurrency != self.reasoning_max_concurrency + or old_batch != self.db_write_batch_size + or old_queue != self.writer_queue_size + ) + if changed and timestep != self._last_guardrail_timestep: + self._last_guardrail_timestep = timestep + logger.warning( + "[RESOURCE] Memory pressure %.2fx budget; " + "reasoning_concurrency %d->%d, writer_batch %d->%d, writer_queue %d->%d", + ratio, + old_concurrency, + self.reasoning_max_concurrency, + old_batch, + self.db_write_batch_size, + old_queue, + self.writer_queue_size, + ) + def _report_progress(self, timestep: int, status: str) -> None: """Report progress to callback.""" if self._on_progress: @@ -528,6 +581,7 @@ def _reason_agents(self, timestep: int) -> tuple[int, int, int]: Returns: Tuple of (agents_reasoned, state_changes, shares_occurred). """ + self._apply_runtime_guardrails(timestep) agents_to_reason = self.state_manager.get_agents_to_reason( timestep, self.config.multi_touch_threshold, @@ -665,6 +719,7 @@ def _flush_pending() -> None: for chunk_start in range(0, len(contexts), self.chunk_size): if writer_error: break + self._apply_runtime_guardrails(timestep) chunk_index = chunk_start // self.chunk_size if chunk_index in completed_chunks: logger.info( @@ -678,6 +733,7 @@ def _flush_pending() -> None: chunk_contexts, self.scenario, self.config, + max_concurrency=self.reasoning_max_concurrency, rate_limiter=self.rate_limiter, on_agent_done=_on_agent_done, ) @@ -1272,7 +1328,7 @@ def _compute_cost(self) -> dict[str, Any]: return cost def _export_results(self) -> None: - """Export all results to output directory.""" + """Export compact default artifacts to output directory.""" # Export summary summaries = self.state_manager.get_timestep_summaries() timeline_agg = compute_timeline_aggregates(summaries) @@ -1280,40 +1336,6 @@ def _export_results(self) -> None: with open(self.output_dir / "by_timestep.json", "w") as f: json.dump(timeline_agg, f, indent=2) - # Export final agent states - final_states = self.state_manager.export_final_states() - - # Merge with agent attributes - agent_results = [] - for state in final_states: - agent_id = state["agent_id"] - agent = self.agent_map.get(agent_id, {}) - - agent_results.append( - { - "agent_id": agent_id, - "attributes": { - k: v for k, v in agent.items() if not k.startswith("_") - }, - "final_state": state, - "reasoning_count": ( - 1 if state["last_reasoning_timestep"] >= 0 else 0 - ), - } - ) - - with open(self.output_dir / "agent_states.json", "w") as f: - json.dump(agent_results, f, indent=2) - - # Export outcome distributions - outcome_dists = compute_outcome_distributions( - self.state_manager, - self.scenario.outcomes.suggested_outcomes, - ) - - with open(self.output_dir / "outcome_distributions.json", "w") as f: - json.dump(outcome_dists, f, indent=2) - # Export meta information meta = { "scenario_name": self.scenario.meta.name, @@ -1359,6 +1381,7 @@ def run_simulation( retention_lite: bool = False, writer_queue_size: int = 256, db_write_batch_size: int = 100, + resource_governor: ResourceGovernor | None = None, ) -> SimulationSummary: """Run a simulation from a scenario file. @@ -1385,6 +1408,7 @@ def run_simulation( retention_lite: Reduce payload volume by dropping full raw reasoning text writer_queue_size: Maximum buffered chunks waiting for DB writer db_write_batch_size: Number of chunks applied per DB writer transaction + resource_governor: Optional governor for runtime downshift guardrails Returns: SimulationSummary with results @@ -1548,6 +1572,7 @@ def _reset_runtime_tables(path: Path, run_key: str) -> None: retention_lite=retention_lite, writer_queue_size=writer_queue_size, db_write_batch_size=db_write_batch_size, + resource_governor=resource_governor, ) if on_progress: diff --git a/extropy/simulation/reasoning.py b/extropy/simulation/reasoning.py index 7c51151..3c52674 100644 --- a/extropy/simulation/reasoning.py +++ b/extropy/simulation/reasoning.py @@ -777,7 +777,10 @@ def batch_reason_agents( async def run_all(): if rate_limiter: - target_concurrency = max(1, rate_limiter.max_safe_concurrent) + target_concurrency = min( + max(1, rate_limiter.max_safe_concurrent), + max(1, max_concurrency), + ) stagger_interval = 60.0 / rate_limiter.pivotal.rpm logger.info( f"[BATCH] Concurrency cap: {target_concurrency}, " diff --git a/extropy/utils/resource_governor.py b/extropy/utils/resource_governor.py index 33b81a2..574d9ef 100644 --- a/extropy/utils/resource_governor.py +++ b/extropy/utils/resource_governor.py @@ -4,6 +4,7 @@ import os import platform +import resource import subprocess from dataclasses import dataclass @@ -61,6 +62,25 @@ def snapshot(self) -> ResourceSnapshot: memory_budget_gb=round(budget, 2), ) + @staticmethod + def _current_process_memory_gb() -> float: + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + system = platform.system().lower() + # Linux reports KB, macOS reports bytes. + if system == "darwin": + return float(usage) / (1024**3) + return float(usage) / (1024**2) + + def memory_pressure_ratio(self) -> float: + snap = self.snapshot() + current = self._current_process_memory_gb() + budget = max(0.1, snap.memory_budget_gb) + return current / budget + + @staticmethod + def downshift_int(current: int, factor: float, minimum: int = 1) -> int: + return max(minimum, int(max(1, current) * factor)) + def recommend_workers( self, requested_workers: int, diff --git a/tests/test_engine.py b/tests/test_engine.py index 6171f5d..7eeaf18 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,6 +20,7 @@ SimulationRunConfig, ) from extropy.simulation.progress import SimulationProgress +from extropy.utils.resource_governor import ResourceGovernor from extropy.core.models.scenario import ( Event, EventType, @@ -887,7 +888,14 @@ def test_chunk_checkpoints_written_with_writer_pipeline( ) engine.state_manager.record_exposure(aid, exposure) - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): response = _make_reasoning_response() results = [] for ctx in contexts: @@ -1190,7 +1198,14 @@ def test_progress_state_updated( response_a0 = _make_reasoning_response(position="adopt", conviction=0.5) response_a1 = _make_reasoning_response(position="reject", conviction=0.7) - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): results = [] for ctx in contexts: if ctx.agent_id == "a0": @@ -1253,7 +1268,14 @@ def test_on_agent_done_callback_passed( received_kwargs = {} - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): received_kwargs["on_agent_done"] = on_agent_done resp = _make_reasoning_response() return [(ctx.agent_id, resp) for ctx in contexts], BatchTokenUsage() @@ -1448,7 +1470,14 @@ def test_tokens_accumulate_across_chunks( call_count = [0] - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): call_count[0] += 1 resp = _make_reasoning_response() results = [(ctx.agent_id, resp) for ctx in contexts] @@ -1557,3 +1586,75 @@ def test_cost_unknown_model_returns_null_usd( meta = json.load(f) assert meta["cost"]["estimated_usd"] is None + + def test_export_results_keeps_compact_default_artifacts( + self, + minimal_scenario, + simple_agents, + simple_network, + minimal_pop_spec, + tmp_path, + ): + """Default export should keep compact summaries and skip large JSON dumps.""" + config = SimulationRunConfig( + scenario_path="test.yaml", + output_dir=str(tmp_path / "output"), + ) + engine = SimulationEngine( + scenario=minimal_scenario, + population_spec=minimal_pop_spec, + agents=simple_agents, + network=simple_network, + config=config, + ) + + engine._export_results() + + assert (tmp_path / "output" / "meta.json").exists() + assert (tmp_path / "output" / "by_timestep.json").exists() + assert not (tmp_path / "output" / "agent_states.json").exists() + assert not (tmp_path / "output" / "outcome_distributions.json").exists() + + def test_runtime_guardrails_downshift_under_pressure( + self, + minimal_scenario, + simple_agents, + simple_network, + minimal_pop_spec, + tmp_path, + ): + """Runtime memory pressure should downshift concurrency/write knobs.""" + + class HighPressureGovernor(ResourceGovernor): + def memory_pressure_ratio(self) -> float: + return 1.1 + + config = SimulationRunConfig( + scenario_path="test.yaml", + output_dir=str(tmp_path / "output"), + ) + governor = HighPressureGovernor(resource_mode="auto") + engine = SimulationEngine( + scenario=minimal_scenario, + population_spec=minimal_pop_spec, + agents=simple_agents, + network=simple_network, + config=config, + writer_queue_size=64, + db_write_batch_size=16, + resource_governor=governor, + ) + before = ( + engine.reasoning_max_concurrency, + engine.db_write_batch_size, + engine.writer_queue_size, + ) + engine._apply_runtime_guardrails(timestep=0) + after = ( + engine.reasoning_max_concurrency, + engine.db_write_batch_size, + engine.writer_queue_size, + ) + assert after[0] < before[0] + assert after[1] < before[1] + assert after[2] < before[2] diff --git a/tests/test_resource_governor.py b/tests/test_resource_governor.py new file mode 100644 index 0000000..fff7938 --- /dev/null +++ b/tests/test_resource_governor.py @@ -0,0 +1,26 @@ +"""Tests for resource auto-tuning and runtime guardrails.""" + +from extropy.utils.resource_governor import ResourceGovernor + + +def test_downshift_int_respects_minimum(): + assert ResourceGovernor.downshift_int(100, factor=0.5, minimum=1) == 50 + assert ResourceGovernor.downshift_int(2, factor=0.1, minimum=4) == 4 + + +def test_memory_pressure_ratio_uses_budget(monkeypatch): + governor = ResourceGovernor(resource_mode="auto", max_memory_gb=8.0) + monkeypatch.setattr( + governor, + "_detect_total_memory_gb", + lambda: 8.0, + ) + monkeypatch.setattr( + governor, + "_current_process_memory_gb", + lambda: 3.2, + ) + + # Budget is 80% of capped memory => 6.4 GB, so ratio should be 0.5. + assert governor.memory_pressure_ratio() == 0.5 + From c57449df3f292b73427c453f667523cc77ba6912 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:24:51 -0500 Subject: [PATCH 13/15] feat(persona): add study-db agent source and db-first preview tests --- extropy/cli/commands/persona.py | 50 +++++++++++++++--- tests/test_cli.py | 92 +++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 7 deletions(-) diff --git a/extropy/cli/commands/persona.py b/extropy/cli/commands/persona.py index 5ceb07e..6f868ed 100644 --- a/extropy/cli/commands/persona.py +++ b/extropy/cli/commands/persona.py @@ -4,12 +4,14 @@ import time from pathlib import Path from threading import Event, Thread +from typing import Any import typer from rich.live import Live from rich.spinner import Spinner from ...core.models import PopulationSpec +from ...storage import open_study_db from ..app import app, console from ..utils import ( format_elapsed, @@ -22,6 +24,16 @@ def persona_command( agents_file: Path = typer.Option( None, "--agents", "-a", help="Sampled agents JSON file (for population stats)" ), + study_db: Path | None = typer.Option( + None, + "--study-db", + help="Canonical study DB file (preferred; loads sampled agents by population id)", + ), + population_id: str = typer.Option( + "default", + "--population-id", + help="Population id when loading agents from --study-db", + ), output: Path = typer.Option( None, "--output", @@ -65,10 +77,11 @@ def persona_command( 3 = Generation error EXAMPLES: - extropy persona population.yaml --agents agents.json - extropy persona population.yaml -a agents.json -o persona_config.yaml - extropy persona population.yaml -a agents.json --agent 42 -y - extropy persona population.yaml -a agents.json --show # preview existing + extropy persona population.yaml --study-db study.db --population-id default + extropy persona population.yaml --study-db study.db -o persona_config.yaml + extropy persona population.yaml --study-db study.db --agent 42 -y + extropy persona population.yaml --study-db study.db --show + extropy persona population.yaml --agents agents.json # legacy input """ from ...population.persona import ( generate_persona_config, @@ -96,7 +109,11 @@ def persona_command( ) # Load Agents (optional but recommended) - agents = None + agents: list[dict[str, Any]] | None = None + if agents_file and study_db: + console.print("[red]✗[/red] Use either --agents or --study-db, not both") + raise typer.Exit(1) + if agents_file: if not agents_file.exists(): console.print(f"[red]✗[/red] Agents file not found: {agents_file}") @@ -119,9 +136,28 @@ def persona_command( raise typer.Exit(1) console.print(f"[green]✓[/green] Loaded {len(agents)} agents") + elif study_db: + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(2) + with console.status("[cyan]Loading agents from study DB...[/cyan]"): + try: + with open_study_db(study_db) as db: + agents = db.get_agents(population_id) + except Exception as e: + console.print(f"[red]✗[/red] Failed to load agents from study DB: {e}") + raise typer.Exit(1) + if not agents: + console.print( + f"[red]✗[/red] No agents found for population_id '{population_id}' in {study_db}" + ) + raise typer.Exit(1) + console.print( + f"[green]✓[/green] Loaded {len(agents)} agents from study DB population_id={population_id}" + ) else: console.print( - "[yellow]⚠[/yellow] No agents file - population stats will use defaults" + "[yellow]⚠[/yellow] No agent source provided (--study-db or --agents) - population stats will use defaults" ) # Handle --show mode: preview existing config without regenerating @@ -149,7 +185,7 @@ def persona_command( console.print() if not agents: - console.print("[red]✗[/red] Need --agents to preview personas") + console.print("[red]✗[/red] Need --study-db or --agents to preview personas") raise typer.Exit(1) if agent_index >= len(agents): diff --git a/tests/test_cli.py b/tests/test_cli.py index 9d3de05..d867eb3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ """CLI smoke tests using typer's CliRunner.""" import json import sqlite3 +from types import SimpleNamespace from pathlib import Path from typer.testing import CliRunner @@ -256,3 +257,94 @@ def test_chat_ask_reads_state_for_requested_run(self, tmp_path): assert payload["session_id"] assert "old_pos" in payload["assistant_text"] assert "new_pos" not in payload["assistant_text"] + + +class TestPersonaCommand: + def test_persona_show_loads_agents_from_study_db(self, tmp_path, monkeypatch): + import extropy.cli.commands.persona as persona_cmd + import extropy.population.persona as persona_pkg + + class DummyPopulationSpec: + @classmethod + def from_yaml(cls, _path): + return SimpleNamespace( + meta=SimpleNamespace(description="test population"), + attributes=[{"name": "age"}], + ) + + class DummyPersonaConfig: + @classmethod + def from_file(cls, _path): + return object() + + monkeypatch.setattr(persona_cmd, "PopulationSpec", DummyPopulationSpec) + monkeypatch.setattr(persona_pkg, "PersonaConfig", DummyPersonaConfig) + monkeypatch.setattr( + persona_pkg, + "preview_persona", + lambda _agent, _config, max_width=80: "I am a test persona.", + ) + + spec_file = tmp_path / "population.yaml" + spec_file.write_text("meta: {}\n", encoding="utf-8") + persona_file = spec_file.with_suffix(".persona.yaml") + persona_file.write_text("dummy: true\n", encoding="utf-8") + + study_db = tmp_path / "study.db" + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=[{"_id": "a0", "age": 30}, {"_id": "a1", "age": 41}], + meta={"source": "test"}, + ) + + result = runner.invoke( + app, + [ + "persona", + str(spec_file), + "--study-db", + str(study_db), + "--population-id", + "default", + "--show", + ], + ) + assert result.exit_code == 0 + assert "Loaded 2 agents from study DB population_id=default" in result.output + assert "Persona for Agent a0" in result.output + + def test_persona_rejects_agents_and_study_db_together(self, tmp_path, monkeypatch): + import extropy.cli.commands.persona as persona_cmd + + monkeypatch.setattr( + persona_cmd.PopulationSpec, + "from_yaml", + classmethod( + lambda cls, _path: SimpleNamespace( + meta=SimpleNamespace(description="test population"), + attributes=[{"name": "age"}], + ) + ), + ) + spec_file = tmp_path / "population.yaml" + spec_file.write_text("meta: {}\n", encoding="utf-8") + agents_file = tmp_path / "agents.json" + agents_file.write_text("[]\n", encoding="utf-8") + study_db = tmp_path / "study.db" + with open_study_db(study_db) as db: + db.save_sample_result(population_id="default", agents=[{"_id": "a0"}], meta={}) + + result = runner.invoke( + app, + [ + "persona", + str(spec_file), + "--agents", + str(agents_file), + "--study-db", + str(study_db), + ], + ) + assert result.exit_code == 1 + assert "Use either --agents or --study-db, not both" in result.output From cd8a23a778352eb94e25d9c6ac634b50a630c444 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 01:25:04 -0500 Subject: [PATCH 14/15] refactor(config): remove DefaultsConfig, move show_cost to top-level Drop population_size and db_path (unused). Move show_cost to ExtropyConfig directly. Config display only shows show_cost when enabled. Co-Authored-By: Claude Opus 4.6 --- extropy/cli/app.py | 2 +- extropy/cli/commands/config_cmd.py | 25 +++++++------------------ extropy/config.py | 25 +++++-------------------- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/extropy/cli/app.py b/extropy/cli/app.py index 327b535..75e6d87 100644 --- a/extropy/cli/app.py +++ b/extropy/cli/app.py @@ -96,7 +96,7 @@ def main_callback( try: from ..config import get_config - show = get_config().defaults.show_cost + show = get_config().show_cost except Exception: pass diff --git a/extropy/cli/commands/config_cmd.py b/extropy/cli/commands/config_cmd.py index 480d16e..00d0590 100644 --- a/extropy/cli/commands/config_cmd.py +++ b/extropy/cli/commands/config_cmd.py @@ -20,9 +20,7 @@ "simulation.rate_tier", "simulation.rpm_override", "simulation.tpm_override", - "defaults.population_size", - "defaults.db_path", - "defaults.show_cost", + "show_cost", } INT_FIELDS = { @@ -30,11 +28,6 @@ "rate_tier", "rpm_override", "tpm_override", - "population_size", -} - -BOOL_FIELDS = { - "show_cost", } @@ -123,12 +116,10 @@ def _show_config(): if provider_cfg.api_key_env: console.print(f" api_key_env = {provider_cfg.api_key_env}") - # Defaults - console.print() - console.print("[bold cyan]Defaults[/bold cyan]") - console.print(f" population_size = {config.defaults.population_size}") - console.print(f" db_path = {config.defaults.db_path}") - console.print(f" show_cost = {config.defaults.show_cost}") + # Cost tracking + if config.show_cost: + console.print() + console.print(f" show_cost = {config.show_cost}") # API keys status console.print() @@ -190,14 +181,14 @@ def _set_config(key: str, value: str): if provider_name not in config.providers: config.providers[provider_name] = CustomProviderConfig() setattr(config.providers[provider_name], field, value) + elif key == "show_cost": + config.show_cost = value.lower() in ("true", "1", "yes") else: zone, field_name = key.split(".", 1) if zone == "models": target = config.models elif zone == "simulation": target = config.simulation - elif zone == "defaults": - target = config.defaults else: console.print(f"[red]Unknown zone:[/red] {zone}") raise typer.Exit(1) @@ -209,8 +200,6 @@ def _set_config(key: str, value: str): except ValueError: console.print(f"[red]Invalid integer value:[/red] {value}") raise typer.Exit(1) - elif field_name in BOOL_FIELDS: - setattr(target, field_name, value.lower() in ("true", "1", "yes")) else: setattr(target, field_name, value) diff --git a/extropy/config.py b/extropy/config.py index de94ef4..f2e2ca7 100644 --- a/extropy/config.py +++ b/extropy/config.py @@ -109,14 +109,6 @@ class CustomProviderConfig(BaseModel): api_key_env: str = "" -class DefaultsConfig(BaseModel): - """Non-zone default settings.""" - - population_size: int = 1000 - db_path: str = "./storage/extropy.db" - show_cost: bool = False # Show cost footer after every CLI command - - # ============================================================================= # Main config class # ============================================================================= @@ -146,7 +138,7 @@ class ExtropyConfig(BaseModel): models: ModelsConfig = Field(default_factory=ModelsConfig) simulation: SimulationConfig = Field(default_factory=SimulationConfig) providers: dict[str, CustomProviderConfig] = Field(default_factory=dict) - defaults: DefaultsConfig = Field(default_factory=DefaultsConfig) + show_cost: bool = False @classmethod def load(cls) -> "ExtropyConfig": @@ -208,8 +200,8 @@ def save(self) -> None: data["providers"] = { name: cfg.model_dump() for name, cfg in self.providers.items() } - if self.defaults != DefaultsConfig(): - data["defaults"] = self.defaults.model_dump() + if self.show_cost: + data["show_cost"] = True with open(CONFIG_FILE, "w") as f: json.dump(data, f, indent=2) @@ -218,7 +210,6 @@ def to_dict(self) -> dict[str, Any]: result = { "models": self.models.model_dump(), "simulation": self.simulation.model_dump(), - "defaults": self.defaults.model_dump(), } if self.providers: result["providers"] = { @@ -276,14 +267,8 @@ def _apply_dict(config: ExtropyConfig, data: dict) -> None: base_url=provider_data.get("base_url", ""), api_key_env=provider_data.get("api_key_env", ""), ) - if "defaults" in data and isinstance(data["defaults"], dict): - for k, v in data["defaults"].items(): - if hasattr(config.defaults, k): - if k == "population_size": - v = int(v) - elif k == "show_cost": - v = bool(v) - setattr(config.defaults, k, v) + if "show_cost" in data: + config.show_cost = bool(data["show_cost"]) # ============================================================================= From be8d3de219f8566e9340000ab5d4c86f28d9f376 Mon Sep 17 00:00:00 2001 From: DeveshParagiri Date: Sun, 15 Feb 2026 02:00:56 -0500 Subject: [PATCH 15/15] chore: trigger claude code review