diff --git a/.claude/skills/extropy/references/ANALYSIS_PLAYBOOK.md b/.claude/skills/extropy/references/ANALYSIS_PLAYBOOK.md new file mode 100644 index 0000000..336e31f --- /dev/null +++ b/.claude/skills/extropy/references/ANALYSIS_PLAYBOOK.md @@ -0,0 +1,59 @@ +# Analysis Playbook + +Use this after runs complete. + +## Primary Artifacts + +- `results/meta.json` +- `results/by_timestep.json` +- `results/outcome_distributions.json` +- `results/agent_states.json` +- `results/timeline.jsonl` + +## Analysis Sequence + +1. Run-level summary +- population size +- timesteps completed +- stop reason +- total reasoning calls +- token/cost summary + +2. Dynamics +- exposure curve shape over time +- when state changes plateau +- whether run stopped by condition or quiescence + +3. Outcomes +- final distribution by outcome +- concentration/polarization patterns +- compare to baseline variants + +4. Segments +Use: +```bash +extropy results --segment +``` +Evaluate heterogeneous effects by segment. + +5. Agent-level deep dive +Use: +```bash +extropy results --agent +``` +Look for representative or anomalous trajectories. + +6. Convergence + uncertainty across runs +- Run the same scenario across multiple seeds. +- Report central tendency + spread for key outcomes (mean, min/max, std where possible). +- Flag unstable outcomes where between-seed variance is decision-relevant. +- Do not present one run as a definitive forecast. + +## Comparative Analysis Template + +When comparing runs, report: +1. What changed (single axis) +2. Delta in exposure speed +3. Delta in key outcomes +4. Delta in cost/runtime +5. Confidence assessment (needs more seeds or stable) diff --git a/.claude/skills/extropy/references/CAPABILITIES_AND_EXAMPLES.md b/.claude/skills/extropy/references/CAPABILITIES_AND_EXAMPLES.md new file mode 100644 index 0000000..e054c27 --- /dev/null +++ b/.claude/skills/extropy/references/CAPABILITIES_AND_EXAMPLES.md @@ -0,0 +1,142 @@ +# Capabilities and Examples + +Use this file to map user intent to what Extropy can model and how to execute it. + +## 1) Core Capability Classes + +1. Population synthesis +- Build statistically grounded synthetic populations from natural-language scope. +- Add scenario-specific behavioral/psychographic attributes. + +2. Social graph simulation +- Generate network structures and influence pathways. +- Model diffusion and exposure propagation. + +3. Scenario compilation +- Translate events/policies/product changes into executable exposure + outcome logic. + +4. Agent reasoning dynamics +- Simulate iterative belief updates, memory effects, and classification outcomes. + +5. Outcome analytics +- Produce timeline dynamics, final distributions, segment deltas, and agent-level traces. + +6. Experiment operations +- Support estimation, batching, sweeps, versioning, triage, and reporting. + +## 2) Decision Domains + +- Public policy and governance +- Market/pricing strategy +- Product launch and diffusion +- Crisis and reputation response +- Messaging and political strategy +- Community and urban planning +- Healthcare behavior change +- B2B and enterprise transformation + +## 3) Advanced Study Patterns + +1. Counterfactual suites +- Baseline vs alternatives under fixed population/config. + +2. Sensitivity analysis +- Sweep one axis at a time around baseline assumptions. + +3. Confidence sweeps +- Multi-seed reruns for stability/variance analysis. + +4. Segment stress tests +- Identify cohorts with fragile or highly parameter-sensitive outcomes. + +5. Mechanism-first analysis +- Explain outcomes from exposure paths and agent state traces. + +## 4) Practical Boundaries + +- Best for social-behavioral dynamics, not physics/logistics optimization. +- Multi-event cascades are better modeled as staged runs. +- Outputs are simulation-informed forecasts, not guaranteed outcomes. + +## 5) Trigger Phrases + +Use this skill when users ask things like: +- "simulate how people will respond to..." +- "what happens if we raise price by..." +- "which segments will churn/adopt/protest" +- "test these message variants before launch" +- "run scenario analysis with uncertainty" +- "why did this segment flip in simulation" + +## 6) Example Requests (Illustrative Only) + +These are examples, not defaults. + +1. Policy: congestion pricing alternatives +- Ask: estimate compliance/backlash across income and commute-access segments. +- Shape: baseline + alternatives + equity cuts. + +2. Public health messaging +- Ask: find least responsive groups and best message frame. +- Shape: same population, multiple message scenarios, compare adoption/sentiment. + +3. SaaS pricing +- Ask: estimate churn/downgrade/stay under +10/+20/+30 price shifts. +- Shape: counterfactual suite + revenue-risk tradeoff. + +4. Product launch +- Ask: predict enable/disable behavior for default-on AI feature. +- Shape: adoption outcomes + trust/privacy sensitivity sweep. + +5. Crisis response +- Ask: compare apology-only vs refund vs policy-change response. +- Shape: trust recovery and negative WOM dynamics by segment. + +6. Political messaging +- Ask: compare message resonance/backlash by ideology/economic exposure. +- Shape: frame variants + propagation differences. + +7. Community planning +- Ask: simulate support/neutral/oppose response to development proposal. +- Shape: concern taxonomy + coalition risk. + +8. Healthcare adoption +- Ask: model clinician switching under reimbursement changes. +- Shape: policy variants + adoption friction analysis. + +9. Enterprise change +- Ask: simulate compliance/disengagement/attrition intent under policy shift. +- Shape: role/commute/trust segment breakdown. + +10. Deep triage +- Ask: debug flat exposure curve or unstable seed outcomes. +- Shape: evidence-led root cause, minimal fix, rerun command. + +## 7) Quick Execution Templates + +1. Baseline + sensitivity +- 1 baseline +- 3 variants +- 3 seeds each +- 2 to 3 key segment cuts + +2. Message shootout +- 1 population +- 3 to 5 message scenarios +- fixed config + seeds +- rank by primary KPI + stability + +3. Decision brief inputs +- decision objective +- top findings +- segment impacts +- confidence/stability +- recommendation + caveats + +## 8) Capability to File Map + +- Can Extropy model this? -> this file +- How to run it? -> `OPERATIONS.md` +- How to validate/fix/escalate? -> `QUALITY_TRIAGE_ESCALATION.md` +- How to analyze outcomes? -> `ANALYSIS_PLAYBOOK.md` +- How to write decision report? -> `EXPERIMENT_REPORT_TEMPLATE.md` diff --git a/.claude/skills/extropy/references/EXPERIMENT_REPORT_TEMPLATE.md b/.claude/skills/extropy/references/EXPERIMENT_REPORT_TEMPLATE.md new file mode 100644 index 0000000..be1fe4e --- /dev/null +++ b/.claude/skills/extropy/references/EXPERIMENT_REPORT_TEMPLATE.md @@ -0,0 +1,63 @@ +# Experiment Report Template + +Use this template for every completed experiment batch. + +## 1. Decision Context + +- Study name: +- Decision to support: +- Primary KPI/outcome: +- Constraints (budget, timeline, policy limits): + +## 2. Experiment Setup + +- Population description: +- Scenario description: +- Run set ID: +- Variants included: +- Seed policy (single seed or multi-seed): +- Model/provider config: + +## 3. Headline Results + +- Baseline outcome distribution: +- Most important segment deltas: +- Exposure dynamics summary: +- Stop condition and total timesteps: + +## 4. Confidence and Stability + +- Number of seeds: +- Between-seed variance for key outcomes: +- Stable findings (low variance): +- Unstable findings (high variance): +- Confidence statement (high/medium/low): + +## 5. Why It Happened (Mechanism) + +- Dominant drivers inferred from traces: +- Key peer influence patterns: +- Conviction/memory effects observed: +- Outlier trajectories and interpretation: + +## 6. Cost and Operations + +- Total token usage (pivotal/routine): +- Estimated cost: +- Runtime and bottlenecks: +- Any retries/errors/resume events: + +## 7. Recommendations + +1. Immediate decision recommendation +2. Risk caveats +3. Next experiment(s) to run +4. What would change your recommendation + +## 8. Evidence Files + +- `results/meta.json` +- `results/outcome_distributions.json` +- `results/by_timestep.json` +- `results/agent_states.json` +- `results/timeline.jsonl` diff --git a/.claude/skills/extropy/references/OPERATIONS.md b/.claude/skills/extropy/references/OPERATIONS.md new file mode 100644 index 0000000..dae3163 --- /dev/null +++ b/.claude/skills/extropy/references/OPERATIONS.md @@ -0,0 +1,153 @@ +# Operations + +Use this file for end-to-end execution, run management, and repeatable experiment operations. + +## 1) Preconditions + +1. Ensure output root exists (`runs/`). +2. Verify provider/runtime config (`extropy config show`). +3. Verify required API keys for active provider(s). +4. Confirm objective, success metric, and scope; if missing, run clarification from `QUALITY_TRIAGE_ESCALATION.md`. +5. Select a realism gate profile and schema map before strict realism checks. + +## 2) End-to-End Build Order + +1. `spec` -> `base.yaml` +2. `extend` -> `population.yaml` +3. `sample` -> `agents.json` +4. `network` -> `network.json` +5. `persona` -> `population.persona.yaml` +6. `scenario` -> `scenario.yaml` +7. `estimate` -> cost/volume preview +8. `simulate` -> `results/` +9. `results` -> aggregate + segment views + +Minimal skeleton: + +```bash +# population build +extropy spec "" -o runs//base.yaml +extropy extend runs//base.yaml -s "" -o runs//population.yaml +extropy sample runs//population.yaml -o runs//agents.json --seed 42 +extropy network runs//agents.json -p runs//population.yaml -o runs//network.json --seed 42 +extropy persona runs//population.yaml --agents runs//agents.json -o runs//population.persona.yaml + +# scenario + sim +extropy scenario -p runs//population.yaml -a runs//agents.json -n runs//network.json -o runs//scenario.yaml +extropy estimate runs//scenario.yaml +extropy simulate runs//scenario.yaml -o runs//results --seed 42 +``` + +## 3) Standard Output Paths + +- `runs//base.yaml` +- `runs//population.yaml` +- `runs//agents.json` +- `runs//network.json` +- `runs//population.persona.yaml` +- `runs//scenario.yaml` +- `runs//results/` + +## 4) Reproducibility Rules + +- Always set `--seed` on `sample`, `network`, `simulate`. +- Log provider/model/threshold/chunk/rate overrides for every run. +- Prefer non-interactive flags when available (`--yes`). +- Never compare variants directly unless scenario/config/seed policy are comparable. + +## 5) Autopilot Loop (After Every Stage) + +1. Run command. +2. Wait for full exit. +3. Run stage quality checks from `QUALITY_TRIAGE_ESCALATION.md`. +4. If FAIL: apply smallest upstream fix; rerun only dependent downstream stages. +5. If same gate fails twice: escalate per policy. + +## 6) Long-Run Conduct + +For long-running `spec`, `extend`, `simulate`: +- Do not infer failure before process exit. +- Monitor health non-destructively. +- Judge outputs only after exit unless canceled/timeboxed by policy. + +## 7) Batch + Variant Management + +Use this structure: + +```text +runs/ + / + registry/ + runs.csv + latest.txt + specs/ + pop/ + persona/ + network-config/ + scenario/ + batches/ + / + manifest.yaml + variants/ + / + inputs/ + results/ +``` + +Canonical IDs: +- `study_slug`: lowercase kebab-case +- `batch_id`: `bYYYYMMDD-HHMM--vNN` +- `variant_id`: `vr---s` +- `scenario_rev`: `scn-vNN` +- `config_rev`: `cfg-vNN` + +Revision rules: +- bump `scn-vNN` when event/exposure/outcome logic changes +- bump `cfg-vNN` when provider/model/rate/logic defaults change + +## 8) Manifest Contract + +Each batch should declare: +- study + batch_id +- scenario/config revisions +- objective +- base spec paths +- variant IDs, seeds, and explicit overrides + +## 9) Registry Contract + +Append one row per variant run in `registry/runs.csv` with: +- timestamp, study, batch_id, variant_id +- scenario_rev, config_rev, seed +- status, results_dir, notes + +Update `registry/latest.txt` to the latest successful promoted baseline batch. + +## 10) Sweep Patterns + +Use one-axis sweeps unless user explicitly requests full factorial. + +Common axes: +- seed +- threshold +- chunk size +- model/routine model +- provider and rate settings + +Recommended sets: +1. Baseline + 2 to 3 sensitivity variants. +2. Confidence sweep with 5 to 10 seeds. + +## 11) Resume + Recovery + +If `simulation.db` exists and run interrupted: +- rerun same `extropy simulate ... -o ` command +- do not change seed/config mid-resume + +## 12) Minimum Batch Deliverables + +1. Updated run registry rows +2. One experiment report (`EXPERIMENT_REPORT_TEMPLATE.md`) +3. At least two segment analyses +4. Stability/confidence section (or explicit single-seed caveat) +5. Gate status summary (PASS/WARN/FAIL) for each variant diff --git a/.claude/skills/extropy/references/QUALITY_TRIAGE_ESCALATION.md b/.claude/skills/extropy/references/QUALITY_TRIAGE_ESCALATION.md new file mode 100644 index 0000000..40dce6c --- /dev/null +++ b/.claude/skills/extropy/references/QUALITY_TRIAGE_ESCALATION.md @@ -0,0 +1,239 @@ +# Quality, Triage, and Escalation + +Use this file to decide what to ask, what to validate, how to debug, and when to escalate. + +## 1) Clarify Before Expensive Runs + +Ask the fewest high-leverage questions needed to prevent wasted runs. + +Priority questions: +1. What concrete decision should this simulation inform? +2. Which outcome/metric is primary? +3. Who is in scope, and which segments must be broken out? +4. What scenario/change and time horizon should be simulated? +5. What constraints apply (cost/runtime/provider/deadline)? +6. What realism contract applies (hard constraints + benchmark priors)? + +Optional follow-ups: +- Baseline only vs alternatives/counterfactuals? +- Single run vs confidence sweep? +- Raw outputs only vs recommendations? +- If tradeoffs appear, optimize for realism or speed/cost? + +If user is unsure, use practical defaults and state assumptions before execution. + +## 2) Gate Profiles + +Select one before sample realism checks: +- `generic-all-ages` +- `adults-only` +- `benchmark-calibrated` +- `us-national-all-ages` (example benchmark profile) + +If ambiguous, ask once and proceed. + +## 3) Schema Map Requirement (for Realism) + +Before strict realism checks, define: +1. Required sampled fields +2. Bounded numeric fields + valid ranges +3. Mutually exclusive categorical sets +4. Hard conditional constraints (must always hold) +5. Soft conditional constraints (target rates/tolerances) +6. Optional benchmark priors + tolerance policy + +If schema map is incomplete: +- run structural checks +- mark representativeness as uncalibrated +- escalate for missing constraints when decision risk is high + +## 4) Stage Gates + +Status: +- PASS: continue +- WARN: continue with documented caveat +- FAIL: fix, rerun affected stages + +### Gate 1: `spec` (`base.yaml`) + +FAIL on any: +1. File missing/empty after command exit +2. `extropy validate ` fails +3. Explicit outcome leakage in pre-event attributes + +WARN/FAIL quality: +- critical distributions rely on low-authority sources + +### Gate 2: `extend` (`population.yaml`) + +FAIL on any: +1. File missing/empty +2. `extropy validate ` fails +3. New attributes violate pre-event intent unless explicitly intended + +WARN/FAIL quality: +- modifier stacking causes boundary instability + +### Gate 3: `sample` (`agents.json`) + +#### 3A) Structural integrity (hard) + +FAIL on any: +1. Requested count != generated count +2. Null values in required fields +3. Out-of-range bounded values +4. Hard exclusivity violations +5. Hard conditional constraint violations + +General hard rules: +- cohort/eligibility labels must be consistent +- strict impossible combinations must be zero + +#### 3B) Realism (schema + profile) + +1. Hard-rule violation count +- FAIL if > 0 + +2. Soft-constraint deviation +- WARN if over soft tolerance +- FAIL if over 2x soft tolerance + +3. Marginal prior drift (if priors provided) +- WARN > 3pp +- FAIL > 5pp + +4. Conditional prior drift (if conditional priors provided) +- WARN > 5pp +- FAIL > 10pp + +5. Distribution support/collapse +- WARN on unexpected support erosion +- FAIL on unexpected collapse of major categories + +Example mapping (all-ages household schema; adapt per study): +- adults (`age > 17`) with `employment_status == "not applicable/child"` -> FAIL if > 0 +- minors (`age <= 17`) with non-minor marital statuses -> FAIL if > 0 +- adult K-12 enrollment leakage -> WARN > 0.2%, FAIL > 1.0% + +Do not proceed to `network` unless Gate 3 is PASS or user accepts WARN explicitly. + +### Gate 4: `network` (`network.json`) + +FAIL on any: +1. Nodes/edges missing or unparseable +2. Orphan edge endpoints +3. Malformed edge fields used by scenario logic + +WARN/FAIL quality: +- graph plausibility (near-empty / nearly complete unintentionally) +- large disconnected components unless intended + +### Gate 5: `scenario` (`scenario.yaml`) + +FAIL on any: +1. File missing or invalid +2. Exposure logic non-executable/empty for intended channels +3. Outcomes not measurable/schema-consistent + +WARN/FAIL quality: +- contradictory stop conditions +- ambiguous outcome definitions + +### Gate 6: `simulate` (`results/`) + +FAIL on any: +1. Missing required artifacts (`meta.json`, `by_timestep.json`, `outcome_distributions.json`, `agent_states.json`) +2. Invalid/no stop reason and no max-timestep completion + +WARN/FAIL quality: +- degenerate dynamics (suspiciously flat or broken) +- runtime/cost outside acceptable operating envelope + +## 5) Auto-Fix Loop + +If any gate FAILs: +1. Apply the smallest upstream fix tied to the failing metric +2. Rerun only dependent downstream stages +3. Re-run the same gate + +If the same gate FAILs twice: +- stop autonomous iteration +- escalate with options and recommendation + +## 6) Triage Playbook + +### A) Command fails immediately +1. Check path existence +2. Check provider/config mismatch (`extropy config show`) +3. Check API key env vars +4. Retry with smallest reproducible command + +### B) Validation/spec issues +Run: +```bash +extropy validate +``` +Common causes: +- formula/condition reference errors +- invalid distribution params +- dependency cycles +- scenario references to unknown attributes/edge types + +### C) Exposure not spreading +Inspect: +- `seed_exposure.rules` probabilities/conditions +- spread settings (`share_probability`, modifiers, `max_hops`, `decay_per_hop`) +- network connectivity and edge typing + +Evidence files: +- `by_timestep.json` +- `timeline.jsonl` + +### D) Weird outcome dynamics +Check: +- outcome definitions and classification clarity +- threshold tuning +- conviction/flip-resistance effects + +Evidence files: +- `meta.json` +- `agent_states.json` +- `timeline.jsonl` + +### E) Cost/latency blowups +1. Compare `extropy estimate` against `meta.json` actuals +2. Reduce population/timesteps/reasoning frequency +3. Move routine/pass-2 to cheaper model +4. Tune rate and chunk settings + +Triage output format: +1. Symptom +2. Likely root cause +3. Evidence +4. Minimal fix +5. Re-run command + +## 7) Escalation Policy + +Escalate to human before further autonomous edits when: +1. Same gate fails twice +2. Fix changes core study assumptions +3. Accuracy vs speed/cost objectives conflict +4. Sensitive policy/ethics framing needs stakeholder decision +5. Required priors/constraints are unavailable but decision quality depends on them + +Escalation payload: +1. Current stage +2. Exact blocker +3. Evidence +4. Options (A/B/C) + tradeoffs +5. Recommended option + +## 8) Long-Run Waiting Rule + +Do not call failure only because: +- file size is unchanged mid-run +- stdout is quiet +- process still active + +Failure requires process error exit, user cancellation, or timeout escalation. diff --git a/README.md b/README.md index 9ceb30f..cbf69ee 100644 --- a/README.md +++ b/README.md @@ -61,22 +61,22 @@ extropy config show ```bash mkdir -p austin +STUDY_DB=austin/study.db # Build a population extropy spec "500 Austin TX commuters who drive into downtown for work" -o austin/base.yaml extropy extend austin/base.yaml -s "Response to a $15/day downtown congestion tax" -o austin/population.yaml -extropy sample austin/population.yaml -o austin/agents.json --seed 42 -extropy network austin/agents.json -o austin/network.json -p austin/population.yaml --seed 42 -extropy persona austin/population.yaml --agents austin/agents.json -o austin/population.persona.yaml +extropy sample austin/population.yaml --study-db "$STUDY_DB" --seed 42 +extropy network --study-db "$STUDY_DB" -p austin/population.yaml --seed 42 --checkpoint "$STUDY_DB" # Compile and run a scenario -extropy scenario -p austin/population.yaml -a austin/agents.json -n austin/network.json -o austin/scenario.yaml -extropy estimate austin/scenario.yaml -extropy simulate austin/scenario.yaml -o austin/results/ --seed 42 +extropy scenario -p austin/population.yaml --study-db "$STUDY_DB" -o austin/scenario.yaml +extropy estimate austin/scenario.yaml --study-db "$STUDY_DB" +extropy simulate austin/scenario.yaml --study-db "$STUDY_DB" -o austin/results/ --seed 42 # View results -extropy results austin/results/ -extropy results austin/results/ --segment income +extropy results --study-db "$STUDY_DB" +extropy results --study-db "$STUDY_DB" --segment income ``` ### What Comes Out @@ -125,13 +125,15 @@ $50-100k: drive_and_pay 40% | switch_to_transit 28% | shift_schedule 21% Each agent reasoned individually. A low-income commuter with no transit access reacts differently than a tech worker near a rail stop — not because we scripted it, but because their attributes, persona, and social context led them there. Simulation output directory (`austin/results/`) contains: -- `simulation.db` (checkpointable state store) -- `timeline.jsonl` (streaming event log) -- `agent_states.json` (final per-agent states) +- `study.db` (canonical state + checkpoint store) - `by_timestep.json` (time-series aggregates) -- `outcome_distributions.json` (final distributions) - `meta.json` (run metadata + token/cost summary) +For full datasets, use explicit exports from `study.db`: +- `extropy export states --study-db "$STUDY_DB" --to austin/results/states.jsonl` +- `extropy export agents --study-db "$STUDY_DB" --to austin/results/agents.jsonl` +- `extropy export edges --study-db "$STUDY_DB" --to austin/results/edges.jsonl` + The scenario YAML controls what gets tracked: ```yaml diff --git a/docs/commands.md b/docs/commands.md index c4b11d3..ff156e6 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -242,8 +242,9 @@ Auto-detection: if `{population_stem}.network-config.yaml` exists alongside the Save a generated config for inspection/editing with `--save-config`: ```bash -extropy network austin/agents.json -o austin/network.json \ - -p austin/population.yaml --save-config austin/network-config.yaml +extropy network --study-db austin/study.db --network-id baseline \ + -p austin/population.yaml --save-config austin/network-config.yaml \ + -o austin/network.json ``` ### How connections form @@ -260,7 +261,8 @@ The network uses a **Watts-Strogatz small-world model** with attribute-based sim Add `-v` to print network quality metrics: ```bash -extropy network austin/agents.json -o austin/network.json -p austin/population.yaml --validate +extropy network --study-db austin/study.db --network-id baseline \ + -p austin/population.yaml --validate ``` This shows clustering coefficient, average path length, modularity, and flags anything outside expected ranges for a realistic social network. @@ -269,8 +271,10 @@ This shows clustering coefficient, average path length, modularity, and flags an | | Name | Description | |---|---|---| -| **Arg** | `agents_file` | Agents JSON file | -| **Opt** | `--output` / `-o` | Output network JSON file **(required)** | +| **Opt** | `--study-db` | Canonical study DB path **(required)** | +| **Opt** | `--population-id` | Population ID in study DB (default: `default`) | +| **Opt** | `--network-id` | Network ID to write/read (default: `default`) | +| **Opt** | `--output` / `-o` | Optional network JSON export path | | **Opt** | `--population` / `-p` | Population spec YAML — generates network config via LLM | | **Opt** | `--network-config` / `-c` | Custom network config YAML file | | **Opt** | `--save-config` | Save the generated/loaded config to YAML | @@ -282,7 +286,7 @@ This shows clustering coefficient, average path length, modularity, and flags an ### Output -A JSON file (`network.json`) containing nodes (agent IDs) and weighted, typed edges. +Canonical output is `study.db` (`network_edges`, `network_runs`, `network_metrics`). Optional JSON export can be written with `--output`. --- @@ -461,8 +465,7 @@ These aren't scripted responses. They emerge from each agent's unique combinatio ### Output A results directory containing: -- `simulation.db` — SQLite database with full simulation state -- `timeline.jsonl` — Event-by-event timeline +- `study.db` — canonical SQLite state and checkpoint store - `agent_states.json` — Final state of every agent - `by_timestep.json` — Per-timestep metrics (exposure, sentiment, conviction, position distributions) - `outcome_distributions.json` — Aggregate outcome distributions @@ -473,7 +476,7 @@ A results directory containing: ## Viewing Results ```bash -extropy results austin/results/ +extropy results --study-db austin/study.db ``` Display a summary of simulation outcomes — exposure rates, outcome distributions, and convergence information. diff --git a/extropy/cli/app.py b/extropy/cli/app.py index aa320c3..75e6d87 100644 --- a/extropy/cli/app.py +++ b/extropy/cli/app.py @@ -1,5 +1,6 @@ """Core CLI app definition and global state.""" +import atexit from typing import Annotated import typer @@ -15,6 +16,7 @@ # Global state for JSON mode (set by callback) _json_mode = False +_show_cost = False def get_json_mode() -> bool: @@ -30,6 +32,28 @@ def _version_callback(value: bool) -> None: raise typer.Exit() +def _print_cost_footer() -> None: + """Print cost summary footer at CLI exit (if enabled and there are records).""" + try: + from ..core.cost.tracker import CostTracker + from ..core.cost.ledger import record_session + + tracker = CostTracker.get() + if not tracker.has_records: + return + + # Persist to ledger + summary = tracker.summary() + record_session(summary) + + # Print footer + line = tracker.summary_line() + if line: + console.print(f"\n[dim]Cost: {line}[/dim]") + except Exception: + pass # Never let cost display crash the CLI + + @app.callback() def main_callback( json_output: Annotated[ @@ -49,14 +73,37 @@ def main_callback( is_eager=True, ), ] = False, + cost: Annotated[ + bool, + typer.Option( + "--cost", + help="Show cost summary after command completes", + is_eager=True, + ), + ] = False, ): """Extropy: Population simulation engine for agent-based modeling. Use --json for machine-readable output suitable for scripting and AI tools. + Use --cost to show token usage and cost summary after each command. """ - global _json_mode + global _json_mode, _show_cost _json_mode = json_output + # Determine if cost footer should be shown: --cost flag or config setting + show = cost + if not show: + try: + from ..config import get_config + + show = get_config().show_cost + except Exception: + pass + + _show_cost = show + if _show_cost: + atexit.register(_print_cost_footer) + # Import commands to register them with the app from .commands import ( # noqa: E402, F401 @@ -71,4 +118,10 @@ def main_callback( estimate, results, config_cmd, + inspect, + query, + report, + export, + chat, + migrate, ) diff --git a/extropy/cli/commands/__init__.py b/extropy/cli/commands/__init__.py index 8c36014..df2e0cc 100644 --- a/extropy/cli/commands/__init__.py +++ b/extropy/cli/commands/__init__.py @@ -12,6 +12,12 @@ estimate, results, config_cmd, + inspect, + query, + report, + export, + chat, + migrate, ) __all__ = [ @@ -26,4 +32,10 @@ "estimate", "results", "config_cmd", + "inspect", + "query", + "report", + "export", + "chat", + "migrate", ] diff --git a/extropy/cli/commands/chat.py b/extropy/cli/commands/chat.py new file mode 100644 index 0000000..2a15796 --- /dev/null +++ b/extropy/cli/commands/chat.py @@ -0,0 +1,337 @@ +"""Agent chat commands backed by study DB history.""" + +from __future__ import annotations + +import json +import sqlite3 +import time +from pathlib import Path +from typing import Any + +import typer + +from ...storage import open_study_db +from ..app import app, console, get_json_mode + +chat_app = typer.Typer(help="Chat with simulated agents using DB-backed history") +app.add_typer(chat_app, name="chat") + + +def _load_agent_chat_context( + conn: sqlite3.Connection, + run_id: str, + agent_id: str, + timeline_n: int = 10, +) -> tuple[dict[str, Any], list[dict[str, Any]]]: + cur = conn.cursor() + cur.execute( + "SELECT population_id FROM simulation_runs WHERE run_id = ? LIMIT 1", + (run_id,), + ) + run_row = cur.fetchone() + if not run_row: + return {"run_id": run_id, "agent_id": agent_id, "error": "run_id not found"}, [] + population_id = str(run_row["population_id"]) + + cur.execute( + """ + SELECT attrs_json + FROM agents + WHERE population_id = ? AND agent_id = ? + ORDER BY rowid DESC + LIMIT 1 + """, + (population_id, agent_id), + ) + attrs_row = cur.fetchone() + attrs = {} + if attrs_row and attrs_row["attrs_json"]: + try: + attrs = json.loads(attrs_row["attrs_json"]) + except json.JSONDecodeError: + attrs = {} + + cur.execute( + "SELECT * FROM agent_states WHERE run_id = ? AND agent_id = ? LIMIT 1", + (run_id, agent_id), + ) + state_row = cur.fetchone() + state = dict(state_row) if state_row else {} + + cur.execute( + """ + SELECT timestep, event_type, details_json + FROM timeline + WHERE run_id = ? AND agent_id = ? + ORDER BY id DESC + LIMIT ? + """, + (run_id, agent_id, timeline_n), + ) + timeline_rows = [dict(r) for r in cur.fetchall()] + + context = { + "run_id": run_id, + "population_id": population_id, + "agent_id": agent_id, + "attributes": attrs, + "state": state, + "timeline": list(reversed(timeline_rows)), + } + + citations = [ + {"table": "agents", "population_id": population_id, "agent_id": agent_id}, + {"table": "agent_states", "run_id": run_id, "agent_id": agent_id}, + { + "table": "timeline", + "run_id": run_id, + "agent_id": agent_id, + "limit": timeline_n, + }, + ] + return context, citations + + +def _summarize_context(context: dict[str, Any], prompt: str) -> str: + state = context.get("state", {}) + attrs = context.get("attributes", {}) + timeline = context.get("timeline", []) + agent_id = context.get("agent_id") + + private_position = state.get("private_position") or state.get("position") + private_sentiment = state.get("private_sentiment") + if private_sentiment is None: + private_sentiment = state.get("sentiment") + private_conviction = state.get("private_conviction") + if private_conviction is None: + private_conviction = state.get("conviction") + + lines = [f"Agent `{agent_id}` context snapshot:"] + if private_position is not None: + lines.append(f"- Position: {private_position}") + if private_sentiment is not None: + lines.append(f"- Sentiment: {private_sentiment:.3f}") + if private_conviction is not None: + lines.append(f"- Conviction: {private_conviction:.3f}") + + if state.get("public_statement"): + lines.append(f"- Public statement: {state['public_statement']}") + if state.get("raw_reasoning"): + lines.append(f"- Latest raw reasoning: {state['raw_reasoning']}") + + if attrs: + top_attrs = [(k, v) for k, v in attrs.items() if not str(k).startswith("_")] + top_attrs = sorted(top_attrs)[:8] + if top_attrs: + lines.append( + "- Key attributes: " + ", ".join(f"{k}={v}" for k, v in top_attrs) + ) + + if timeline: + lines.append("- Recent timeline events:") + for item in timeline[-5:]: + details = item.get("details_json") or "{}" + lines.append( + f" - t={item.get('timestep')} {item.get('event_type')} details={details}" + ) + + lines.append(f"- Your prompt: {prompt}") + lines.append( + "This answer is grounded in persisted DB state and does not mutate simulation state." + ) + return "\n".join(lines) + + +def _print_repl_help() -> None: + console.print("[dim]Commands: /context, /timeline , /history, /exit[/dim]") + + +@chat_app.callback(invoke_without_command=True) +def chat_interactive( + ctx: typer.Context, + study_db: Path | None = typer.Option(None, "--study-db"), + run_id: str | None = typer.Option(None, "--run-id"), + agent_id: str | None = typer.Option(None, "--agent-id"), + session_id: str | None = typer.Option(None, "--session-id"), +): + """Interactive chat REPL. + + Example: + extropy chat --study-db study.db --run-id run_123 --agent-id a_42 + """ + if ctx.invoked_subcommand is not None: + return + + if not study_db or not run_id or not agent_id: + console.print( + "[red]✗[/red] interactive chat requires --study-db, --run-id, and --agent-id" + ) + raise typer.Exit(1) + + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + cur.execute("SELECT 1 FROM simulation_runs WHERE run_id = ? LIMIT 1", (run_id,)) + if not cur.fetchone(): + conn.close() + console.print(f"[red]✗[/red] run_id not found: {run_id}") + raise typer.Exit(1) + + with open_study_db(study_db) as db: + sid = session_id or db.create_chat_session( + run_id=run_id, + agent_id=agent_id, + mode="interactive", + meta={"entrypoint": "repl"}, + ) + + console.print(f"[bold]Chat session[/bold] {sid}") + _print_repl_help() + + try: + while True: + try: + prompt = input("chat> ").strip() + except EOFError: + break + + if not prompt: + continue + if prompt == "/exit": + break + if prompt == "/history": + with open_study_db(study_db) as db: + messages = db.get_chat_messages(sid) + for m in messages: + console.print(f"[{m['role']}] {m['content']}") + continue + if prompt.startswith("/timeline"): + parts = prompt.split() + try: + n = int(parts[1]) if len(parts) > 1 else 10 + except ValueError: + n = 10 + context, _ = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=max(1, n) + ) + for item in context.get("timeline", []): + console.print( + f"t={item.get('timestep')} {item.get('event_type')} {item.get('details_json') or '{}'}" + ) + continue + if prompt == "/context": + context, _ = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=10 + ) + console.print_json(data=context) + continue + + started = time.time() + context, citations = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=12 + ) + answer = _summarize_context(context, prompt) + latency_ms = int((time.time() - started) * 1000) + + with open_study_db(study_db) as db: + db.append_chat_message(sid, "user", prompt) + db.append_chat_message( + sid, + "assistant", + answer, + citations={"sources": citations}, + token_usage={ + "input_tokens": 0, + "output_tokens": 0, + "latency_ms": latency_ms, + }, + ) + + console.print(answer) + + finally: + conn.close() + + +@chat_app.command("ask") +def chat_ask( + study_db: Path = typer.Option(..., "--study-db"), + run_id: str = typer.Option(..., "--run-id"), + agent_id: str = typer.Option(..., "--agent-id"), + prompt: str = typer.Option(..., "--prompt"), + session_id: str | None = typer.Option(None, "--session-id"), + json_output: bool = typer.Option(False, "--json"), +): + """Non-interactive chat API for automation. + + Example: + extropy chat ask --study-db study.db --run-id r1 --agent-id a1 --prompt "What changed?" --json + """ + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + started = time.time() + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute("SELECT 1 FROM simulation_runs WHERE run_id = ? LIMIT 1", (run_id,)) + if not cur.fetchone(): + console.print(f"[red]✗[/red] run_id not found: {run_id}") + raise typer.Exit(1) + finally: + conn.close() + + with open_study_db(study_db) as db: + sid = session_id or db.create_chat_session( + run_id=run_id, + agent_id=agent_id, + mode="machine", + meta={"entrypoint": "ask"}, + ) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + context, citations = _load_agent_chat_context( + conn, run_id, agent_id, timeline_n=12 + ) + answer = _summarize_context(context, prompt) + finally: + conn.close() + + latency_ms = int((time.time() - started) * 1000) + + with open_study_db(study_db) as db: + user_turn = db.append_chat_message(sid, "user", prompt) + assistant_turn = db.append_chat_message( + sid, + "assistant", + answer, + citations={"sources": citations}, + token_usage={ + "input_tokens": 0, + "output_tokens": 0, + "latency_ms": latency_ms, + }, + ) + + payload = { + "session_id": sid, + "user_turn_index": user_turn, + "turn_index": assistant_turn, + "assistant_text": answer, + "citations": {"sources": citations}, + "token_usage": {"input_tokens": 0, "output_tokens": 0}, + "latency_ms": latency_ms, + } + + if json_output or get_json_mode(): + console.print_json(data=payload) + else: + console.print(answer) diff --git a/extropy/cli/commands/config_cmd.py b/extropy/cli/commands/config_cmd.py index c2945e7..00d0590 100644 --- a/extropy/cli/commands/config_cmd.py +++ b/extropy/cli/commands/config_cmd.py @@ -7,28 +7,28 @@ get_config, reset_config, CONFIG_FILE, - get_api_key, - get_azure_config, + get_api_key_for_provider, ) VALID_KEYS = { - "pipeline.provider", - "pipeline.model_simple", - "pipeline.model_reasoning", - "pipeline.model_research", - "simulation.provider", - "simulation.model", - "simulation.pivotal_model", - "simulation.routine_model", + "models.fast", + "models.strong", + "simulation.fast", + "simulation.strong", "simulation.max_concurrent", "simulation.rate_tier", "simulation.rpm_override", "simulation.tpm_override", - "simulation.api_format", + "show_cost", } -INT_FIELDS = {"max_concurrent", "rate_tier", "rpm_override", "tpm_override"} +INT_FIELDS = { + "max_concurrent", + "rate_tier", + "rpm_override", + "tpm_override", +} @app.command("config") @@ -39,7 +39,7 @@ def config_command( ), key: str | None = typer.Argument( None, - help="Config key (e.g. pipeline.provider, simulation.model)", + help="Config key (e.g. models.fast, simulation.strong)", ), value: str | None = typer.Argument( None, @@ -50,9 +50,9 @@ def config_command( Examples: extropy config show - extropy config set pipeline.provider claude - extropy config set simulation.provider openai - extropy config set simulation.model gpt-5-mini + extropy config set models.fast openai/gpt-5-mini + extropy config set models.strong anthropic/claude-sonnet-4.5 + extropy config set simulation.strong openrouter/anthropic/claude-sonnet-4.5 extropy config reset """ if action == "show": @@ -82,36 +82,21 @@ def _show_config(): console.print("[bold]Extropy Configuration[/bold]") console.print("─" * 40) - # Pipeline zone + # Models (pipeline) console.print() - console.print("[bold cyan]Pipeline[/bold cyan] (spec, extend, persona, scenario)") - console.print(f" provider = {config.pipeline.provider}") - console.print( - f" model_simple = {config.pipeline.model_simple or '[dim](provider default)[/dim]'}" - ) console.print( - f" model_reasoning = {config.pipeline.model_reasoning or '[dim](provider default)[/dim]'}" - ) - console.print( - f" model_research = {config.pipeline.model_research or '[dim](provider default)[/dim]'}" + "[bold cyan]Models[/bold cyan] (pipeline: spec, extend, persona, scenario)" ) + console.print(f" fast = {config.models.fast}") + console.print(f" strong = {config.models.strong}") - # Simulation zone + # Simulation console.print() console.print("[bold cyan]Simulation[/bold cyan] (agent reasoning)") - console.print(f" provider = {config.simulation.provider}") - console.print( - f" model = {config.simulation.model or '[dim](provider default)[/dim]'}" - ) - console.print( - f" pivotal_model = {config.simulation.pivotal_model or '[dim](same as model)[/dim]'}" - ) - console.print( - f" routine_model = {config.simulation.routine_model or '[dim](provider default)[/dim]'}" - ) - console.print( - f" api_format = {config.simulation.api_format or '[dim](auto)[/dim]'}" - ) + strong_val = config.simulation.strong or "[dim](= models.strong)[/dim]" + fast_val = config.simulation.fast or "[dim](= models.fast)[/dim]" + console.print(f" strong = {strong_val}") + console.print(f" fast = {fast_val}") console.print(f" max_concurrent = {config.simulation.max_concurrent}") console.print( f" rate_tier = {config.simulation.rate_tier or '[dim](tier 1)[/dim]'}" @@ -121,25 +106,29 @@ def _show_config(): if config.simulation.tpm_override: console.print(f" tpm_override = {config.simulation.tpm_override}") + # Custom providers + if config.providers: + console.print() + console.print("[bold cyan]Custom Providers[/bold cyan]") + for name, provider_cfg in config.providers.items(): + console.print(f" {name}:") + console.print(f" base_url = {provider_cfg.base_url}") + if provider_cfg.api_key_env: + console.print(f" api_key_env = {provider_cfg.api_key_env}") + + # Cost tracking + if config.show_cost: + console.print() + console.print(f" show_cost = {config.show_cost}") + # API keys status console.print() console.print("[bold cyan]API Keys[/bold cyan] (from env vars)") _show_key_status("openai", "OPENAI_API_KEY") - _show_key_status("claude", "ANTHROPIC_API_KEY") - _show_key_status("azure_openai", "AZURE_OPENAI_API_KEY") - - # Azure-specific config (show when Azure provider is in use) - active_providers = {config.pipeline.provider, config.simulation.provider} - if "azure_openai" in active_providers: - azure_cfg = get_azure_config("azure_openai") - console.print() - console.print("[bold cyan]Azure OpenAI[/bold cyan]") - console.print( - f" endpoint = {azure_cfg['azure_endpoint'] or '[dim]not set[/dim]'}" - ) - console.print(f" api_version = {azure_cfg['api_version']}") - if azure_cfg["azure_deployment"]: - console.print(f" deployment = {azure_cfg['azure_deployment']}") + _show_key_status("anthropic", "ANTHROPIC_API_KEY") + _show_key_status("azure", "AZURE_OPENAI_API_KEY") + _show_key_status("openrouter", "OPENROUTER_API_KEY") + _show_key_status("deepseek", "DEEPSEEK_API_KEY") # Config file console.print() @@ -152,7 +141,7 @@ def _show_config(): def _show_key_status(provider: str, env_var_label: str): """Show whether an API key is configured.""" - key = get_api_key(provider) + key = get_api_key_for_provider(provider) if key: masked = key[:8] + "..." + key[-4:] if len(key) > 16 else "***" console.print(f" {env_var_label}: [green]{masked}[/green]") @@ -162,35 +151,57 @@ def _show_key_status(provider: str, env_var_label: str): def _set_config(key: str, value: str): """Set a config value and save.""" - if key not in VALID_KEYS: + # Allow dynamic provider keys like providers.mycompany.base_url + is_provider_key = key.startswith("providers.") + if key not in VALID_KEYS and not is_provider_key: console.print(f"[red]Unknown key:[/red] {key}") console.print() console.print("Available keys:") for k in sorted(VALID_KEYS): console.print(f" {k}") + console.print(" providers..base_url") + console.print(" providers..api_key_env") raise typer.Exit(1) # Load current config (or defaults if no file) config = get_config() - zone, field = key.split(".", 1) - if zone == "pipeline": - target = config.pipeline - elif zone == "simulation": - target = config.simulation - else: - console.print(f"[red]Unknown zone:[/red] {zone}") - raise typer.Exit(1) - - # Type coercion - if field in INT_FIELDS: - try: - setattr(target, field, int(value)) - except ValueError: - console.print(f"[red]Invalid integer value:[/red] {value}") + if is_provider_key: + parts = key.split(".", 2) + if len(parts) != 3 or parts[2] not in ("base_url", "api_key_env"): + console.print( + f"[red]Invalid provider key:[/red] {key}\n" + "Expected: providers..base_url or providers..api_key_env" + ) raise typer.Exit(1) + provider_name = parts[1] + field = parts[2] + from ...config import CustomProviderConfig + + if provider_name not in config.providers: + config.providers[provider_name] = CustomProviderConfig() + setattr(config.providers[provider_name], field, value) + elif key == "show_cost": + config.show_cost = value.lower() in ("true", "1", "yes") else: - setattr(target, field, value) + zone, field_name = key.split(".", 1) + if zone == "models": + target = config.models + elif zone == "simulation": + target = config.simulation + else: + console.print(f"[red]Unknown zone:[/red] {zone}") + raise typer.Exit(1) + + # Type coercion + if field_name in INT_FIELDS: + try: + setattr(target, field_name, int(value)) + except ValueError: + console.print(f"[red]Invalid integer value:[/red] {value}") + raise typer.Exit(1) + else: + setattr(target, field_name, value) config.save() reset_config() # Clear cached singleton so next get_config() reloads diff --git a/extropy/cli/commands/estimate.py b/extropy/cli/commands/estimate.py index ec734e5..71f5597 100644 --- a/extropy/cli/commands/estimate.py +++ b/extropy/cli/commands/estimate.py @@ -1,6 +1,5 @@ """Estimate command for predicting simulation costs before running.""" -import json from pathlib import Path import typer @@ -11,21 +10,17 @@ @app.command("estimate") def estimate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), - model: str = typer.Option( + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + strong: str = typer.Option( "", - "--model", + "--strong", "-m", - help="LLM model for agent reasoning (empty = use config default)", + help="Strong model for Pass 1 (provider/model format)", ), - pivotal_model: str = typer.Option( + fast: str = typer.Option( "", - "--pivotal-model", - help="Model for pivotal/first-pass reasoning (default: same as --model)", - ), - routine_model: str = typer.Option( - "", - "--routine-model", - help="Cheap model for classification pass (default: provider cheap tier)", + "--fast", + help="Fast model for Pass 2 (provider/model format)", ), threshold: int = typer.Option( 3, "--threshold", "-t", help="Multi-touch threshold for re-reasoning" @@ -41,19 +36,23 @@ def estimate_command( model, and predicts LLM calls, tokens, and USD cost. No API keys required. Example: - extropy estimate scenario.yaml - extropy estimate scenario.yaml --model gpt-5-mini - extropy estimate scenario.yaml --pivotal-model gpt-5 --routine-model gpt-5-mini -v + extropy estimate scenario.yaml --study-db study.db + extropy estimate scenario.yaml --study-db study.db --strong openai/gpt-5 + extropy estimate scenario.yaml --study-db study.db \\ + --strong openai/gpt-5 --fast openai/gpt-5-mini -v """ from ...config import get_config from ...core.models import ScenarioSpec, PopulationSpec - from ...population.network import load_agents_json from ...simulation.estimator import estimate_simulation_cost + from ...storage import open_study_db # Validate input file if not scenario_file.exists(): console.print(f"[red]x[/red] Scenario file not found: {scenario_file}") raise typer.Exit(1) + if not study_db.exists(): + console.print(f"[red]x[/red] Study DB not found: {study_db}") + raise typer.Exit(1) # Load scenario try: @@ -71,32 +70,24 @@ def estimate_command( raise typer.Exit(1) population_spec = PopulationSpec.from_yaml(pop_path) - # Load agents - agents_path = Path(scenario.meta.agents_file) - if not agents_path.is_absolute(): - agents_path = scenario_file.parent / agents_path - if not agents_path.exists(): - console.print(f"[red]x[/red] Agents file not found: {agents_path}") + with open_study_db(study_db) as db: + agents = db.get_agents(scenario.meta.population_id) + network = db.get_network(scenario.meta.network_id) + if not agents: + console.print( + f"[red]x[/red] Population ID not found in study DB: {scenario.meta.population_id}" + ) raise typer.Exit(1) - agents = load_agents_json(agents_path) - - # Load network - network_path = Path(scenario.meta.network_file) - if not network_path.is_absolute(): - network_path = scenario_file.parent / network_path - if not network_path.exists(): - console.print(f"[red]x[/red] Network file not found: {network_path}") + if not network.get("edges"): + console.print( + f"[red]x[/red] Network ID not found in study DB: {scenario.meta.network_id}" + ) raise typer.Exit(1) - with open(network_path) as f: - network = json.load(f) # Resolve config config = get_config() - provider = config.simulation.provider - - eff_model = model or config.simulation.model - eff_pivotal = pivotal_model or config.simulation.pivotal_model or eff_model - eff_routine = routine_model or config.simulation.routine_model + effective_strong = strong or config.resolve_sim_strong() + effective_fast = fast or config.resolve_sim_fast() # Run estimation est = estimate_simulation_cost( @@ -104,9 +95,8 @@ def estimate_command( population_spec=population_spec, agents=agents, network=network, - provider=provider, - pivotal_model=eff_pivotal, - routine_model=eff_routine, + strong_model=effective_strong, + fast_model=effective_fast, multi_touch_threshold=threshold, ) @@ -131,11 +121,9 @@ def estimate_command( # Models section console.print("[bold]Models[/bold]") _print_model_line( - console, "Pass 1 (pivotal)", est.pivotal_model, est.pivotal_pricing - ) - _print_model_line( - console, "Pass 2 (routine)", est.routine_model, est.routine_pricing + console, "Pass 1 (strong)", est.pivotal_model, est.pivotal_pricing ) + _print_model_line(console, "Pass 2 (fast)", est.routine_model, est.routine_pricing) console.print() # Calls table diff --git a/extropy/cli/commands/export.py b/extropy/cli/commands/export.py new file mode 100644 index 0000000..dbef866 --- /dev/null +++ b/extropy/cli/commands/export.py @@ -0,0 +1,112 @@ +"""Explicit exports from study DB.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import typer + +from ..app import app, console + +export_app = typer.Typer(help="Export datasets from study DB") +app.add_typer(export_app, name="export") + + +def _write_jsonl(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, default=str) + "\n") + + +@export_app.command("agents") +def export_agents( + study_db: Path = typer.Option(..., "--study-db"), + population_id: str = typer.Option("default", "--population-id"), + output: Path = typer.Option(..., "--to"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + "SELECT agent_id, attrs_json FROM agents WHERE population_id = ? ORDER BY agent_id", + (population_id,), + ) + rows = [] + for row in cur.fetchall(): + try: + rows.append(json.loads(row["attrs_json"])) + except json.JSONDecodeError: + rows.append({"_id": row["agent_id"]}) + finally: + conn.close() + + _write_jsonl(output, rows) + console.print(f"[green]✓[/green] Exported {len(rows)} agents -> {output}") + + +@export_app.command("edges") +def export_edges( + study_db: Path = typer.Option(..., "--study-db"), + network_id: str = typer.Option("default", "--network-id"), + output: Path = typer.Option(..., "--to"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + """ + SELECT source_id, target_id, weight, edge_type, influence_st, influence_ts + FROM network_edges + WHERE network_id = ? + ORDER BY source_id, target_id + """, + (network_id,), + ) + rows = [dict(row) for row in cur.fetchall()] + finally: + conn.close() + + _write_jsonl(output, rows) + console.print(f"[green]✓[/green] Exported {len(rows)} edges -> {output}") + + +@export_app.command("states") +def export_states( + study_db: Path = typer.Option(..., "--study-db"), + run_id: str | None = typer.Option(None, "--run-id"), + output: Path = typer.Option(..., "--to"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + if run_id: + cur.execute( + "SELECT run_id FROM simulation_runs WHERE run_id = ?", + (run_id,), + ) + else: + cur.execute( + "SELECT run_id FROM simulation_runs ORDER BY started_at DESC LIMIT 1" + ) + run_row = cur.fetchone() + if not run_row: + console.print("[yellow]No simulation runs found.[/yellow]") + raise typer.Exit(1) + resolved_run_id = str(run_row["run_id"]) + + cur.execute( + "SELECT * FROM agent_states WHERE run_id = ? ORDER BY agent_id", + (resolved_run_id,), + ) + rows = [dict(row) for row in cur.fetchall()] + finally: + conn.close() + + _write_jsonl(output, rows) + console.print(f"[green]✓[/green] Exported {len(rows)} agent states -> {output}") diff --git a/extropy/cli/commands/inspect.py b/extropy/cli/commands/inspect.py new file mode 100644 index 0000000..3ea95be --- /dev/null +++ b/extropy/cli/commands/inspect.py @@ -0,0 +1,206 @@ +"""Inspect commands for DB-backed artifacts.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import typer + +from ...storage import open_study_db +from ..app import app, console + +inspect_app = typer.Typer(help="Inspect study DB entities") +app.add_typer(inspect_app, name="inspect") + + +def _resolve_run(conn: sqlite3.Connection, run_id: str | None) -> sqlite3.Row | None: + cur = conn.cursor() + if run_id: + cur.execute( + """ + SELECT run_id, population_id, network_id, status, started_at, completed_at + FROM simulation_runs + WHERE run_id = ? + """, + (run_id,), + ) + else: + cur.execute( + """ + SELECT run_id, population_id, network_id, status, started_at, completed_at + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ + ) + return cur.fetchone() + + +@inspect_app.command("summary") +def inspect_summary( + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + population_id: str = typer.Option("default", "--population-id"), + network_id: str = typer.Option("default", "--network-id"), + run_id: str | None = typer.Option(None, "--run-id"), +): + with open_study_db(study_db) as db: + agent_count = db.get_agent_count(population_id) + edge_count = db.get_network_edge_count(network_id) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + run_row = _resolve_run(conn, run_id) + resolved_run_id = str(run_row["run_id"]) if run_row else None + if run_row: + population_id = str(run_row["population_id"]) + network_id = str(run_row["network_id"]) + + cur = conn.cursor() + if resolved_run_id: + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ?", + (resolved_run_id,), + ) + sim_agents = int(cur.fetchone()["cnt"]) + cur.execute( + "SELECT COUNT(*) AS cnt FROM timestep_summaries WHERE run_id = ?", + (resolved_run_id,), + ) + timesteps = int(cur.fetchone()["cnt"]) + cur.execute( + "SELECT COUNT(*) AS cnt FROM timeline WHERE run_id = ?", + (resolved_run_id,), + ) + events = int(cur.fetchone()["cnt"]) + else: + sim_agents = 0 + timesteps = 0 + events = 0 + finally: + conn.close() + + console.print("[bold]Study Summary[/bold]") + console.print(f"study_db: {study_db}") + console.print(f"population_id={population_id} agents={agent_count}") + console.print(f"network_id={network_id} edges={edge_count}") + if resolved_run_id: + console.print(f"run_id={resolved_run_id}") + console.print(f"simulation.agent_states={sim_agents}") + console.print(f"simulation.timesteps={timesteps}") + console.print(f"simulation.events={events}") + + +@inspect_app.command("agent") +def inspect_agent( + study_db: Path = typer.Option(..., "--study-db"), + agent_id: str = typer.Option(..., "--agent-id"), + run_id: str | None = typer.Option(None, "--run-id"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + run_row = _resolve_run(conn, run_id) + if not run_row: + console.print("[yellow]No simulation runs found.[/yellow]") + return + resolved_run_id = str(run_row["run_id"]) + population_id = str(run_row["population_id"]) + + cur = conn.cursor() + cur.execute( + "SELECT attrs_json FROM agents WHERE population_id = ? AND agent_id = ? LIMIT 1", + (population_id, agent_id), + ) + attrs_row = cur.fetchone() + attrs = json.loads(attrs_row["attrs_json"]) if attrs_row else {} + + cur.execute( + "SELECT * FROM agent_states WHERE run_id = ? AND agent_id = ? LIMIT 1", + (resolved_run_id, agent_id), + ) + state = cur.fetchone() + + cur.execute( + """ + SELECT timestep, event_type, details_json + FROM timeline + WHERE run_id = ? AND agent_id = ? + ORDER BY id DESC + LIMIT 10 + """, + (resolved_run_id, agent_id), + ) + events = cur.fetchall() + finally: + conn.close() + + console.print(f"[bold]Agent {agent_id}[/bold]") + if attrs: + console.print("[bold]Attributes[/bold]") + for key in sorted(attrs.keys()): + if key.startswith("_"): + continue + console.print(f" - {key}: {attrs[key]}") + + if state: + console.print("[bold]State[/bold]") + console.print( + f" aware={bool(state['aware'])} will_share={bool(state['will_share'])}" + ) + console.print( + f" position={state['private_position'] or state['position']} " + f"sentiment={state['private_sentiment'] if state['private_sentiment'] is not None else state['sentiment']}" + ) + if state["raw_reasoning"]: + console.print("[bold]Raw reasoning[/bold]") + console.print(str(state["raw_reasoning"])) + + if events: + console.print("[bold]Recent events[/bold]") + for row in events: + details = row["details_json"] or "{}" + console.print(f" t={row['timestep']} {row['event_type']} {details}") + + +@inspect_app.command("network") +def inspect_network( + study_db: Path = typer.Option(..., "--study-db"), + network_id: str = typer.Option("default", "--network-id"), + top: int = typer.Option(10, "--top", min=1), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + "SELECT COUNT(*) AS cnt, AVG(weight) AS avg_w FROM network_edges WHERE network_id = ?", + (network_id,), + ) + row = cur.fetchone() + edge_count = int(row["cnt"]) if row else 0 + avg_w = float(row["avg_w"]) if row and row["avg_w"] is not None else 0.0 + + cur.execute( + """ + SELECT source_id, COUNT(*) AS degree + FROM network_edges + WHERE network_id = ? + GROUP BY source_id + ORDER BY degree DESC + LIMIT ? + """, + (network_id, top), + ) + top_rows = cur.fetchall() + finally: + conn.close() + + console.print(f"[bold]Network {network_id}[/bold]") + console.print(f"edges={edge_count} avg_weight={avg_w:.4f}") + if top_rows: + console.print("top source degrees:") + for r in top_rows: + console.print(f" - {r['source_id']}: {r['degree']}") diff --git a/extropy/cli/commands/migrate.py b/extropy/cli/commands/migrate.py new file mode 100644 index 0000000..e0d3643 --- /dev/null +++ b/extropy/cli/commands/migrate.py @@ -0,0 +1,152 @@ +"""Migration commands for DB-first runtime artifacts.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import typer +import yaml + +from ...storage import open_study_db +from ..app import app, console + +migrate_app = typer.Typer(help="Migrate legacy artifacts to DB-first schema") +app.add_typer(migrate_app, name="migrate") + + +def _load_json(path: Path) -> Any: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +@migrate_app.command("legacy") +def migrate_legacy_artifacts( + study_db: Path = typer.Option(..., "--study-db", help="Target canonical study DB"), + agents_file: Path | None = typer.Option( + None, "--agents-file", help="Legacy agents JSON" + ), + network_file: Path | None = typer.Option( + None, "--network-file", help="Legacy network JSON" + ), + population_spec: Path | None = typer.Option( + None, + "--population-spec", + help="Optional population spec YAML source used for sample provenance", + ), + population_id: str = typer.Option("default", "--population-id"), + network_id: str = typer.Option("default", "--network-id"), +): + """Ingest legacy `agents.json`/`network.json` into `study.db`.""" + if agents_file is None and network_file is None: + console.print( + "[red]✗[/red] Provide at least one of --agents-file or --network-file" + ) + raise typer.Exit(1) + + with open_study_db(study_db) as db: + if population_spec is not None: + if not population_spec.exists(): + console.print( + f"[red]✗[/red] population spec not found: {population_spec}" + ) + raise typer.Exit(1) + db.save_population_spec( + population_id=population_id, + spec_yaml=population_spec.read_text(encoding="utf-8"), + source_path=str(population_spec), + ) + + if agents_file is not None: + if not agents_file.exists(): + console.print(f"[red]✗[/red] agents file not found: {agents_file}") + raise typer.Exit(1) + agents_data = _load_json(agents_file) + if not isinstance(agents_data, list): + console.print("[red]✗[/red] agents JSON must be a list") + raise typer.Exit(1) + sample_run_id = db.save_sample_result( + population_id=population_id, + agents=agents_data, + meta={ + "source": "legacy_migration", + "source_file": str(agents_file), + }, + seed=None, + ) + console.print( + f"[green]✓[/green] Imported {len(agents_data)} agents " + f"(population_id={population_id}, sample_run_id={sample_run_id})" + ) + + if network_file is not None: + if not network_file.exists(): + console.print(f"[red]✗[/red] network file not found: {network_file}") + raise typer.Exit(1) + network_data = _load_json(network_file) + if not isinstance(network_data, dict): + console.print("[red]✗[/red] network JSON must be an object") + raise typer.Exit(1) + + raw_edges = network_data.get("edges", []) + if not isinstance(raw_edges, list): + console.print("[red]✗[/red] network.edges must be a list") + raise typer.Exit(1) + + network_run_id = db.save_network_result( + population_id=population_id, + network_id=network_id, + config=network_data.get("config", {}), + result_meta=network_data.get("meta", {}), + edges=raw_edges, + seed=None, + candidate_mode="legacy", + network_metrics=network_data.get("metrics"), + ) + console.print( + f"[green]✓[/green] Imported {len(raw_edges)} edges " + f"(network_id={network_id}, network_run_id={network_run_id})" + ) + + console.print(f"[green]✓[/green] Migration complete: {study_db}") + + +@migrate_app.command("scenario") +def migrate_scenario_yaml( + input_path: Path = typer.Option(..., "--input", help="Legacy scenario YAML"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB path"), + population_id: str = typer.Option("default", "--population-id"), + network_id: str = typer.Option("default", "--network-id"), + output: Path | None = typer.Option(None, "--output", "-o"), +): + """Rewrite a legacy scenario YAML to DB-first metadata fields.""" + if not input_path.exists(): + console.print(f"[red]✗[/red] Scenario file not found: {input_path}") + raise typer.Exit(1) + + with open(input_path, encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + + meta = data.get("meta") + if not isinstance(meta, dict): + console.print("[red]✗[/red] Invalid scenario YAML: missing meta object") + raise typer.Exit(1) + + had_legacy = "agents_file" in meta or "network_file" in meta + meta.pop("agents_file", None) + meta.pop("network_file", None) + meta["study_db"] = str(study_db) + meta["population_id"] = population_id + meta["network_id"] = network_id + data["meta"] = meta + + out = output or input_path.with_suffix(".db-first.yaml") + out.parent.mkdir(parents=True, exist_ok=True) + with open(out, "w", encoding="utf-8") as f: + yaml.safe_dump(data, f, sort_keys=False) + + if had_legacy: + console.print(f"[green]✓[/green] Migrated legacy scenario -> {out}") + else: + console.print(f"[green]✓[/green] Rewrote scenario metadata -> {out}") diff --git a/extropy/cli/commands/network.py b/extropy/cli/commands/network.py index 3444d7d..2e318cf 100644 --- a/extropy/cli/commands/network.py +++ b/extropy/cli/commands/network.py @@ -14,10 +14,16 @@ @app.command("network") def network_command( - agents_file: Path = typer.Argument( - ..., help="Agents JSON file to generate network from" + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + population_id: str = typer.Option( + "default", "--population-id", help="Population ID in study DB" + ), + network_id: str = typer.Option( + "default", "--network-id", help="Network ID to write in study DB" + ), + output: Path | None = typer.Option( + None, "--output", "-o", help="Optional JSON export path (non-canonical)" ), - output: Path = typer.Option(..., "--output", "-o", help="Output network JSON file"), population: Path | None = typer.Option( None, "--population", @@ -50,6 +56,65 @@ def network_command( no_metrics: bool = typer.Option( False, "--no-metrics", help="Skip computing node metrics (faster)" ), + candidate_mode: str = typer.Option( + "exact", + "--candidate-mode", + help="Similarity candidate mode: exact | blocked", + ), + candidate_pool_multiplier: float = typer.Option( + 12.0, + "--candidate-pool-multiplier", + help="Blocked mode candidate pool size as multiple of avg_degree", + ), + block_attr: list[str] | None = typer.Option( + None, + "--block-attr", + help="Blocking attribute (repeatable). If omitted, auto-selects top attributes", + ), + similarity_workers: int = typer.Option( + 1, + "--similarity-workers", + min=1, + help="Worker processes for similarity computation", + ), + similarity_chunk_size: int = typer.Option( + 64, + "--similarity-chunk-size", + min=8, + help="Row chunk size for similarity worker tasks", + ), + checkpoint: Path | None = typer.Option( + None, + "--checkpoint", + help="DB path for similarity checkpointing (must be the same as --study-db)", + ), + resume_checkpoint: bool = typer.Option( + False, + "--resume-checkpoint", + help="Resume similarity stage from checkpoint tables in --study-db", + ), + checkpoint_every: int = typer.Option( + 250, + "--checkpoint-every", + min=1, + help="Write checkpoint every N processed similarity rows", + ), + resource_mode: str = typer.Option( + "auto", + "--resource-mode", + help="Resource tuning mode: auto | manual", + ), + safe_auto_workers: bool = typer.Option( + True, + "--safe-auto-workers/--unsafe-auto-workers", + help="When auto mode is enabled, keep worker count conservative for laptops/VMs", + ), + max_memory_gb: float | None = typer.Option( + None, + "--max-memory-gb", + min=0.5, + help="Optional memory budget cap for auto resource tuning", + ), ): """ Generate a social network from sampled agents. @@ -65,37 +130,54 @@ def network_command( 4. None of the above → empty config (flat network, no similarity structure) Example: - extropy network agents.json -o network.json - extropy network agents.json -o network.json -p population.yaml - extropy network agents.json -o network.json -c network-config.yaml - extropy network agents.json -o network.json -p population.yaml --save-config my-config.yaml + extropy network --study-db study.db + extropy network --study-db study.db --population-id main --network-id main + extropy network --study-db study.db -p population.yaml -c network-config.yaml """ from ...population.network import ( generate_network, generate_network_with_metrics, - load_agents_json, NetworkConfig, generate_network_config, ) from ...core.models import PopulationSpec + from ...storage import open_study_db + from ...utils import ResourceGovernor start_time = time.time() console.print() + if ( + checkpoint is not None + and checkpoint.expanduser().resolve() != study_db.expanduser().resolve() + ): + console.print( + "[red]✗[/red] --checkpoint must point to the same canonical file as --study-db" + ) + raise typer.Exit(1) + checkpoint_db = study_db if (resume_checkpoint or checkpoint is not None) else None + # Load Agents - if not agents_file.exists(): - console.print(f"[red]✗[/red] Agents file not found: {agents_file}") + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") raise typer.Exit(1) with console.status("[cyan]Loading agents...[/cyan]"): try: - agents = load_agents_json(agents_file) + with open_study_db(study_db) as db: + agents = db.get_agents(population_id) except Exception as e: console.print(f"[red]✗[/red] Failed to load agents: {e}") raise typer.Exit(1) + if not agents: + console.print( + f"[red]✗[/red] No agents found for population_id '{population_id}' in {study_db}" + ) + raise typer.Exit(1) console.print( - f"[green]✓[/green] Loaded {len(agents)} agents from [bold]{agents_file}[/bold]" + f"[green]✓[/green] Loaded {len(agents)} agents from [bold]{study_db}[/bold] " + f"(population_id={population_id})" ) # ========================================================================= @@ -165,9 +247,48 @@ def network_command( "avg_degree": avg_degree, "rewire_prob": rewire_prob, "seed": seed if seed is not None else config.seed, + "candidate_mode": candidate_mode, + "candidate_pool_multiplier": candidate_pool_multiplier, + "blocking_attributes": block_attr or config.blocking_attributes, + "similarity_workers": similarity_workers, + "similarity_chunk_size": similarity_chunk_size, + "checkpoint_every_rows": checkpoint_every, + } + ) + + if resource_mode not in {"auto", "manual"}: + console.print("[red]✗[/red] --resource-mode must be 'auto' or 'manual'") + raise typer.Exit(1) + + governor = ResourceGovernor( + resource_mode=resource_mode, + safe_auto_workers=safe_auto_workers, + max_memory_gb=max_memory_gb, + ) + tuned_workers = governor.recommend_workers( + requested_workers=config.similarity_workers, + memory_per_worker_gb=0.75, + ) + tuned_chunk = governor.recommend_chunk_size( + requested_chunk_size=config.similarity_chunk_size, + min_chunk_size=8, + max_chunk_size=2048, + ) + + config = config.model_copy( + update={ + "similarity_workers": tuned_workers, + "similarity_chunk_size": tuned_chunk, } ) + if config.candidate_mode not in {"exact", "blocked"}: + console.print( + f"[red]✗[/red] Invalid --candidate-mode '{config.candidate_mode}' " + "(expected: exact | blocked)" + ) + raise typer.Exit(1) + # Save config if requested if save_config: config.to_yaml(save_config) @@ -175,6 +296,17 @@ def network_command( f"[green]✓[/green] Saved network config to [bold]{save_config}[/bold]" ) + console.print( + f"[dim]Mode: {config.candidate_mode} | workers={config.similarity_workers} " + f"| checkpoint={'on' if checkpoint_db else 'off'}[/dim]" + ) + if resource_mode == "auto": + snap = governor.snapshot() + console.print( + f"[dim]Auto resources: cpu={snap.cpu_count}, " + f"total_mem={snap.total_memory_gb:.1f}GB, budget={snap.memory_budget_gb:.1f}GB[/dim]" + ) + console.print() generation_start = time.time() current_stage = ["Initializing", 0, 0] @@ -192,9 +324,21 @@ def do_generation(): nonlocal result, generation_error try: if no_metrics: - result = generate_network(agents, config, on_progress) + result = generate_network( + agents, + config, + on_progress, + checkpoint_path=checkpoint_db, + resume_from_checkpoint=resume_checkpoint, + ) else: - result = generate_network_with_metrics(agents, config, on_progress) + result = generate_network_with_metrics( + agents, + config, + on_progress, + checkpoint_path=checkpoint_db, + resume_from_checkpoint=resume_checkpoint, + ) except Exception as e: generation_error = e finally: @@ -274,14 +418,38 @@ def do_generation(): pct = count / len(result.edges) * 100 if result.edges else 0 console.print(f" {edge_type}: {count} ({pct:.1f}%)") - # Save Output + # Save canonical output to study DB console.print() - with console.status(f"[cyan]Saving to {output}...[/cyan]"): - result.save_json(output) + with console.status(f"[cyan]Saving network to {study_db}...[/cyan]"): + with open_study_db(study_db) as db: + network_metrics = ( + result.network_metrics.model_dump(mode="json") + if result.network_metrics + else None + ) + db.save_network_result( + population_id=population_id, + network_id=network_id, + config=config.model_dump(mode="json"), + result_meta=result.meta, + edges=[e.to_dict() for e in result.edges], + seed=config.seed, + candidate_mode=config.candidate_mode, + network_metrics=network_metrics, + ) + + if output is not None: + with console.status(f"[cyan]Exporting JSON to {output}...[/cyan]"): + result.save_json(output) elapsed = time.time() - start_time console.print("═" * 60) - console.print(f"[green]✓[/green] Network saved to [bold]{output}[/bold]") + console.print( + f"[green]✓[/green] Network saved to [bold]{study_db}[/bold] " + f"(network_id={network_id})" + ) + if output is not None: + console.print(f"[dim]Exported JSON: {output}[/dim]") console.print(f"[dim]Total time: {format_elapsed(elapsed)}[/dim]") console.print("═" * 60) diff --git a/extropy/cli/commands/persona.py b/extropy/cli/commands/persona.py index 5ceb07e..ab501e0 100644 --- a/extropy/cli/commands/persona.py +++ b/extropy/cli/commands/persona.py @@ -4,12 +4,14 @@ import time from pathlib import Path from threading import Event, Thread +from typing import Any import typer from rich.live import Live from rich.spinner import Spinner from ...core.models import PopulationSpec +from ...storage import open_study_db from ..app import app, console from ..utils import ( format_elapsed, @@ -22,6 +24,16 @@ def persona_command( agents_file: Path = typer.Option( None, "--agents", "-a", help="Sampled agents JSON file (for population stats)" ), + study_db: Path | None = typer.Option( + None, + "--study-db", + help="Canonical study DB file (preferred; loads sampled agents by population id)", + ), + population_id: str = typer.Option( + "default", + "--population-id", + help="Population id when loading agents from --study-db", + ), output: Path = typer.Option( None, "--output", @@ -65,10 +77,11 @@ def persona_command( 3 = Generation error EXAMPLES: - extropy persona population.yaml --agents agents.json - extropy persona population.yaml -a agents.json -o persona_config.yaml - extropy persona population.yaml -a agents.json --agent 42 -y - extropy persona population.yaml -a agents.json --show # preview existing + extropy persona population.yaml --study-db study.db --population-id default + extropy persona population.yaml --study-db study.db -o persona_config.yaml + extropy persona population.yaml --study-db study.db --agent 42 -y + extropy persona population.yaml --study-db study.db --show + extropy persona population.yaml --agents agents.json # legacy input """ from ...population.persona import ( generate_persona_config, @@ -96,7 +109,11 @@ def persona_command( ) # Load Agents (optional but recommended) - agents = None + agents: list[dict[str, Any]] | None = None + if agents_file and study_db: + console.print("[red]✗[/red] Use either --agents or --study-db, not both") + raise typer.Exit(1) + if agents_file: if not agents_file.exists(): console.print(f"[red]✗[/red] Agents file not found: {agents_file}") @@ -119,9 +136,28 @@ def persona_command( raise typer.Exit(1) console.print(f"[green]✓[/green] Loaded {len(agents)} agents") + elif study_db: + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(2) + with console.status("[cyan]Loading agents from study DB...[/cyan]"): + try: + with open_study_db(study_db) as db: + agents = db.get_agents(population_id) + except Exception as e: + console.print(f"[red]✗[/red] Failed to load agents from study DB: {e}") + raise typer.Exit(1) + if not agents: + console.print( + f"[red]✗[/red] No agents found for population_id '{population_id}' in {study_db}" + ) + raise typer.Exit(1) + console.print( + f"[green]✓[/green] Loaded {len(agents)} agents from study DB population_id={population_id}" + ) else: console.print( - "[yellow]⚠[/yellow] No agents file - population stats will use defaults" + "[yellow]⚠[/yellow] No agent source provided (--study-db or --agents) - population stats will use defaults" ) # Handle --show mode: preview existing config without regenerating @@ -149,7 +185,9 @@ def persona_command( console.print() if not agents: - console.print("[red]✗[/red] Need --agents to preview personas") + console.print( + "[red]✗[/red] Need --study-db or --agents to preview personas" + ) raise typer.Exit(1) if agent_index >= len(agents): diff --git a/extropy/cli/commands/query.py b/extropy/cli/commands/query.py new file mode 100644 index 0000000..9bf158c --- /dev/null +++ b/extropy/cli/commands/query.py @@ -0,0 +1,77 @@ +"""Ad-hoc read-only query command.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import typer + +from ...storage import open_study_db, ReadOnlySQLRequest +from ..app import app, console + +query_app = typer.Typer(help="Read-only SQL query tools") +app.add_typer(query_app, name="query") + +_ALLOWED_PREFIXES = ("select", "with", "explain") +_DENYLIST_TOKENS = ( + " insert ", + " update ", + " delete ", + " alter ", + " drop ", + " create ", + " attach ", + " vacuum ", + " pragma ", + " replace ", + " truncate ", +) + + +@query_app.command("sql") +def query_sql( + study_db: Path = typer.Option(..., "--study-db"), + sql: str = typer.Option(..., "--sql", help="Read-only SQL statement"), + limit: int = typer.Option(1000, "--limit", min=1), + format: str = typer.Option("table", "--format", help="table|json|jsonl"), +): + req = ReadOnlySQLRequest(sql=sql, limit=limit) + normalized = req.sql.strip().lower() + if not normalized.startswith(_ALLOWED_PREFIXES): + console.print( + "[red]✗[/red] Only read-only SELECT/WITH/EXPLAIN queries are allowed" + ) + raise typer.Exit(1) + padded = f" {normalized} " + if ";" in req.sql.strip().rstrip(";"): + console.print("[red]✗[/red] Multi-statement SQL is not allowed") + raise typer.Exit(1) + if any(tok in padded for tok in _DENYLIST_TOKENS): + console.print("[red]✗[/red] Mutating SQL tokens are not allowed") + raise typer.Exit(1) + + with open_study_db(study_db) as db: + try: + rows = db.run_select(req.sql, limit=req.limit) + except Exception as e: + console.print(f"[red]✗[/red] Query failed: {e}") + raise typer.Exit(1) + + if format == "json": + console.print_json(data=rows) + return + if format == "jsonl": + for row in rows: + console.print(json.dumps(row, default=str)) + return + + if not rows: + console.print("[dim](no rows)[/dim]") + return + + columns = list(rows[0].keys()) + console.print(" | ".join(columns)) + console.print("-" * max(20, len(" | ".join(columns)))) + for row in rows: + console.print(" | ".join(str(row.get(c, "")) for c in columns)) diff --git a/extropy/cli/commands/report.py b/extropy/cli/commands/report.py new file mode 100644 index 0000000..91221d8 --- /dev/null +++ b/extropy/cli/commands/report.py @@ -0,0 +1,119 @@ +"""Reusable report generation commands.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +import typer + +from ..app import app, console + +report_app = typer.Typer(help="Generate reusable JSON reports") +app.add_typer(report_app, name="report") + + +@report_app.command("run") +def report_run( + study_db: Path = typer.Option(..., "--study-db"), + run_id: str | None = typer.Option(None, "--run-id"), + output: Path = typer.Option(..., "--output", "-o"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + if run_id: + cur.execute( + "SELECT run_id, population_id FROM simulation_runs WHERE run_id = ?", + (run_id,), + ) + else: + cur.execute( + """ + SELECT run_id, population_id + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ + ) + run_row = cur.fetchone() + if not run_row: + console.print("[yellow]No simulation runs found.[/yellow]") + raise typer.Exit(1) + resolved_run_id = str(run_row["run_id"]) + + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ?", + (resolved_run_id,), + ) + total = int(cur.fetchone()["cnt"]) + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ? AND aware = 1", + (resolved_run_id,), + ) + aware = int(cur.fetchone()["cnt"]) + cur.execute( + """ + SELECT COALESCE(private_position, position) AS position, COUNT(*) AS cnt + FROM agent_states + WHERE run_id = ? + AND COALESCE(private_position, position) IS NOT NULL + GROUP BY COALESCE(private_position, position) + """, + (resolved_run_id,), + ) + positions = {row["position"]: int(row["cnt"]) for row in cur.fetchall()} + finally: + conn.close() + + payload = { + "run_id": resolved_run_id, + "agent_count": total, + "aware_count": aware, + "aware_rate": (aware / total) if total else 0.0, + "positions": positions, + } + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(payload, indent=2), encoding="utf-8") + console.print(f"[green]✓[/green] Wrote run report: {output}") + + +@report_app.command("network") +def report_network( + study_db: Path = typer.Option(..., "--study-db"), + network_id: str = typer.Option("default", "--network-id"), + output: Path = typer.Option(..., "--output", "-o"), +): + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + try: + cur = conn.cursor() + cur.execute( + "SELECT COUNT(*) AS cnt, AVG(weight) AS avg_w FROM network_edges WHERE network_id = ?", + (network_id,), + ) + row = cur.fetchone() + edge_count = int(row["cnt"]) if row else 0 + avg_weight = float(row["avg_w"]) if row and row["avg_w"] is not None else 0.0 + + cur.execute( + "SELECT edge_type, COUNT(*) AS cnt FROM network_edges WHERE network_id = ? GROUP BY edge_type", + (network_id,), + ) + edge_types = {r["edge_type"]: int(r["cnt"]) for r in cur.fetchall()} + finally: + conn.close() + + payload = { + "network_id": network_id, + "edge_count": edge_count, + "avg_weight": avg_weight, + "edge_types": edge_types, + } + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(json.dumps(payload, indent=2), encoding="utf-8") + console.print(f"[green]✓[/green] Wrote network report: {output}") diff --git a/extropy/cli/commands/results.py b/extropy/cli/commands/results.py index 3011eee..1f1e037 100644 --- a/extropy/cli/commands/results.py +++ b/extropy/cli/commands/results.py @@ -1,5 +1,9 @@ -"""Results command for displaying simulation results.""" +"""Results command for DB-first simulation results.""" +from __future__ import annotations + +import json +import sqlite3 from pathlib import Path import typer @@ -9,7 +13,8 @@ @app.command("results") def results_command( - results_dir: Path = typer.Argument(..., help="Results directory from simulation"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + run_id: str | None = typer.Option(None, "--run-id", help="Simulation run id"), segment: str | None = typer.Option( None, "--segment", "-s", help="Attribute to segment by" ), @@ -18,43 +23,227 @@ def results_command( None, "--agent", "-a", help="Show single agent details" ), ): - """ - Display simulation results. + """Display simulation results from the canonical study DB.""" + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + + conn = sqlite3.connect(str(study_db)) + conn.row_factory = sqlite3.Row + + try: + cur = conn.cursor() + if run_id: + cur.execute( + """ + SELECT run_id, status, started_at, completed_at, stopped_reason, population_id + FROM simulation_runs + WHERE run_id = ? + """, + (run_id,), + ) + else: + cur.execute( + """ + SELECT run_id, status, started_at, completed_at, stopped_reason, population_id + FROM simulation_runs + ORDER BY started_at DESC + LIMIT 1 + """ + ) + run_row = cur.fetchone() + if not run_row: + console.print("[yellow]No simulation runs found in study DB.[/yellow]") + raise typer.Exit(0) + resolved_run_id = str(run_row["run_id"]) + population_id = str(run_row["population_id"]) + console.print( + f"[dim]run_id={resolved_run_id} status={run_row['status']} " + f"started_at={run_row['started_at']} completed_at={run_row['completed_at'] or '-'}[/dim]" + ) + if agent: + _display_agent(conn, resolved_run_id, population_id, agent) + return + if segment: + _display_segment(conn, resolved_run_id, population_id, segment) + return + if timeline: + _display_timeline(conn, resolved_run_id) + return + _display_summary(conn, resolved_run_id) + finally: + conn.close() + - Load and display results from a completed simulation run. +def _display_summary(conn: sqlite3.Connection, run_id: str) -> None: + cur = conn.cursor() + cur.execute("SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ?", (run_id,)) + total = int(cur.fetchone()["cnt"]) + if total == 0: + console.print("[yellow]No simulation state found in study DB.[/yellow]") + return + + cur.execute( + "SELECT COUNT(*) AS cnt FROM agent_states WHERE run_id = ? AND aware = 1", + (run_id,), + ) + aware = int(cur.fetchone()["cnt"]) - Example: - extropy results results/ # Summary view - extropy results results/ --segment age # Breakdown by age - extropy results results/ --timeline # Timeline view - extropy results results/ --agent agent_001 # Single agent - """ - from ...results import ( - load_results, - display_summary, - display_segment_breakdown, - display_timeline, - display_agent, + cur.execute( + """ + SELECT COALESCE(private_position, position) AS position, COUNT(*) AS cnt + FROM agent_states + WHERE run_id = ? + AND COALESCE(private_position, position) IS NOT NULL + GROUP BY COALESCE(private_position, position) + ORDER BY cnt DESC + """, + (run_id,), ) + rows = cur.fetchall() console.print() + console.print("[bold]Simulation Summary[/bold]") + console.print(f"Agents: {total}") + console.print(f"Aware: {aware} ({aware / total:.1%})") + console.print("Positions:") + for row in rows: + pct = int(row["cnt"]) / total + console.print(f" - {row['position']}: {row['cnt']} ({pct:.1%})") - if not results_dir.exists(): - console.print(f"[red]✗[/red] Results directory not found: {results_dir}") - raise typer.Exit(1) - try: - reader = load_results(results_dir) - except Exception as e: - console.print(f"[red]✗[/red] Failed to load results: {e}") - raise typer.Exit(1) +def _display_timeline(conn: sqlite3.Connection, run_id: str) -> None: + cur = conn.cursor() + cur.execute( + """ + SELECT timestep, new_exposures, agents_reasoned, shares_occurred, exposure_rate + FROM timestep_summaries + WHERE run_id = ? + ORDER BY timestep + """, + (run_id,), + ) + rows = cur.fetchall() + if not rows: + console.print("[yellow]No timestep summaries found.[/yellow]") + return + + console.print() + console.print("[bold]Timeline[/bold]") + for row in rows: + console.print( + f"t={row['timestep']:>3} | new_exp={row['new_exposures']:>5} | " + f"reasoned={row['agents_reasoned']:>5} | shares={row['shares_occurred']:>5} | " + f"exposure={float(row['exposure_rate']):.1%}" + ) + - # Dispatch to appropriate view - if agent: - display_agent(console, reader, agent) - elif segment: - display_segment_breakdown(console, reader, segment) - elif timeline: - display_timeline(console, reader) - else: - display_summary(console, reader) +def _display_segment( + conn: sqlite3.Connection, + run_id: str, + population_id: str, + attribute: str, +) -> None: + cur = conn.cursor() + cur.execute( + "SELECT agent_id, attrs_json FROM agents WHERE population_id = ?", + (population_id,), + ) + attr_by_agent: dict[str, str] = {} + for row in cur.fetchall(): + try: + attrs = json.loads(row["attrs_json"]) + except json.JSONDecodeError: + continue + attr_by_agent[str(row["agent_id"])] = str(attrs.get(attribute, "unknown")) + + if not attr_by_agent: + console.print("[yellow]No agent attribute records found.[/yellow]") + return + + cur.execute( + """ + SELECT agent_id, aware, COALESCE(private_position, position) AS position + FROM agent_states + WHERE run_id = ? + """, + (run_id,), + ) + groups: dict[str, dict[str, int]] = {} + for row in cur.fetchall(): + aid = str(row["agent_id"]) + key = attr_by_agent.get(aid, "unknown") + if key not in groups: + groups[key] = {"total": 0, "aware": 0} + groups[key]["total"] += 1 + if int(row["aware"]) == 1: + groups[key]["aware"] += 1 + + console.print() + console.print(f"[bold]Segment by {attribute}[/bold]") + for key, data in sorted(groups.items(), key=lambda x: x[1]["total"], reverse=True): + total = data["total"] + aware = data["aware"] + pct = aware / total if total else 0.0 + console.print(f" - {key}: {total} agents, aware={aware} ({pct:.1%})") + + +def _display_agent( + conn: sqlite3.Connection, + run_id: str, + population_id: str, + agent_id: str, +) -> None: + cur = conn.cursor() + cur.execute( + """ + SELECT * + FROM agent_states + WHERE run_id = ? AND agent_id = ? + """, + (run_id, agent_id), + ) + row = cur.fetchone() + if not row: + console.print( + f"[yellow]Agent not found in simulation state: {agent_id}[/yellow]" + ) + return + + cur.execute( + "SELECT attrs_json FROM agents WHERE population_id = ? AND agent_id = ? LIMIT 1", + (population_id, agent_id), + ) + attrs_row = cur.fetchone() + attrs = {} + if attrs_row: + try: + attrs = json.loads(attrs_row["attrs_json"]) + except json.JSONDecodeError: + attrs = {} + + console.print() + console.print(f"[bold]Agent {agent_id}[/bold]") + console.print(f"Aware: {bool(row['aware'])}") + console.print(f"Position: {row['private_position'] or row['position']}") + console.print( + f"Sentiment: {row['private_sentiment'] if row['private_sentiment'] is not None else row['sentiment']}" + ) + console.print( + f"Conviction: {row['private_conviction'] if row['private_conviction'] is not None else row['conviction']}" + ) + if row["public_statement"]: + console.print(f"Public statement: {row['public_statement']}") + if row["action_intent"]: + console.print(f"Action intent: {row['action_intent']}") + if row["raw_reasoning"]: + console.print() + console.print("[bold]Raw Reasoning[/bold]") + console.print(str(row["raw_reasoning"])) + if attrs: + console.print() + console.print("[bold]Attributes[/bold]") + for key in sorted(attrs.keys()): + if key.startswith("_"): + continue + console.print(f" - {key}: {attrs[key]}") diff --git a/extropy/cli/commands/sample.py b/extropy/cli/commands/sample.py index 5db1f20..bbea94c 100644 --- a/extropy/cli/commands/sample.py +++ b/extropy/cli/commands/sample.py @@ -22,8 +22,9 @@ def sample_command( spec_file: Path = typer.Argument( ..., help="Population spec YAML file to sample from" ), - output: Path = typer.Option( - ..., "--output", "-o", help="Output file path (.json or .db)" + study_db: Path = typer.Option(..., "--study-db", help="Canonical study database"), + population_id: str = typer.Option( + "default", "--population-id", help="Population identifier inside study DB" ), count: int | None = typer.Option( None, "--count", "-n", help="Number of agents (default: spec.meta.size)" @@ -31,9 +32,6 @@ def sample_command( seed: int | None = typer.Option( None, "--seed", help="Random seed for reproducibility" ), - format: str = typer.Option( - "json", "--format", "-f", help="Output format: json or sqlite" - ), report: bool = typer.Option( False, "--report", "-r", help="Show distribution summaries and stats" ), @@ -54,18 +52,16 @@ def sample_command( 3 = Sampling error EXAMPLES: - extropy sample surgeons.yaml -o agents.json - extropy sample surgeons.yaml -n 500 -o agents.json --seed 42 - extropy sample surgeons.yaml -n 1000 -o agents.db --format sqlite - extropy sample surgeons.yaml -o agents.json --report - extropy --json sample surgeons.yaml -o agents.json --report + extropy sample surgeons.yaml --study-db study.db + extropy sample surgeons.yaml --study-db study.db --population-id main --seed 42 + extropy sample surgeons.yaml --study-db study.db --count 1000 --report + extropy --json sample surgeons.yaml --study-db study.db --report """ from ...population.sampler import ( sample_population, - save_json, - save_sqlite, SamplingError, ) + from ...storage import open_study_db out = Output(console, json_mode=get_json_mode()) start_time = time.time() @@ -289,35 +285,48 @@ def on_progress(current: int, total: int): ) out.blank() - # Save Output + # Save to canonical DB out.blank() - output_format = format.lower() - - if output.suffix.lower() == ".db": - output_format = "sqlite" - elif output.suffix.lower() == ".json": - output_format = "json" - if not get_json_mode(): - with console.status(f"[cyan]Saving to {output_format}...[/cyan]"): - if output_format == "sqlite": - save_sqlite(result, output) - else: - save_json(result, output) + with console.status(f"[cyan]Saving to study DB: {study_db}...[/cyan]"): + with open_study_db(study_db) as db: + db.save_population_spec( + population_id=population_id, + spec_yaml=spec_file.read_text(encoding="utf-8"), + source_path=str(spec_file), + ) + sample_run_id = db.save_sample_result( + population_id=population_id, + agents=result.agents, + meta=result.meta, + seed=result.meta.get("seed"), + ) else: - if output_format == "sqlite": - save_sqlite(result, output) - else: - save_json(result, output) + with open_study_db(study_db) as db: + db.save_population_spec( + population_id=population_id, + spec_yaml=spec_file.read_text(encoding="utf-8"), + source_path=str(spec_file), + ) + sample_run_id = db.save_sample_result( + population_id=population_id, + agents=result.agents, + meta=result.meta, + seed=result.meta.get("seed"), + ) elapsed = time.time() - start_time - out.set_data("output_file", str(output)) - out.set_data("output_format", output_format) + out.set_data("study_db", str(study_db)) + out.set_data("population_id", population_id) + out.set_data("sample_run_id", sample_run_id) out.set_data("total_time_seconds", elapsed) out.divider() - out.success(f"Saved {len(result.agents)} agents to [bold]{output}[/bold]") + out.success( + f"Saved {len(result.agents)} agents to [bold]{study_db}[/bold] " + f"(population_id={population_id}, sample_run_id={sample_run_id})" + ) out.text(f"[dim]Total time: {format_elapsed(elapsed)}[/dim]") out.divider() diff --git a/extropy/cli/commands/scenario.py b/extropy/cli/commands/scenario.py index 0a23d52..6a5d99d 100644 --- a/extropy/cli/commands/scenario.py +++ b/extropy/cli/commands/scenario.py @@ -17,8 +17,13 @@ def scenario_command( population: Path = typer.Option( ..., "--population", "-p", help="Population spec YAML file" ), - agents: Path = typer.Option(..., "--agents", "-a", help="Sampled agents JSON file"), - network: Path = typer.Option(..., "--network", "-n", help="Network JSON file"), + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + population_id: str = typer.Option( + "default", "--population-id", help="Population ID in study DB" + ), + network_id: str = typer.Option( + "default", "--network-id", help="Network ID in study DB" + ), description: str | None = typer.Option( None, "--description", @@ -44,12 +49,12 @@ def scenario_command( - Outcome definitions (what to measure) Example: - extropy scenario -p population.yaml -a agents.json -n network.json - extropy scenario -p pop.yaml -a agents.json -n net.json -d "Custom description" -o custom.yaml + extropy scenario -p population.yaml --study-db study.db + extropy scenario -p pop.yaml --study-db study.db --population-id main --network-id main -d "Custom description" -o custom.yaml """ from ...core.models import PopulationSpec from ...scenario import create_scenario - from ...utils import make_relative_to + from ...storage import open_study_db start_time = time.time() console.print() @@ -59,13 +64,21 @@ def scenario_command( console.print(f"[red]✗[/red] Population spec not found: {population}") raise typer.Exit(1) - if not agents.exists(): - console.print(f"[red]✗[/red] Agents file not found: {agents}") + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") raise typer.Exit(1) - if not network.exists(): - console.print(f"[red]✗[/red] Network file not found: {network}") - raise typer.Exit(1) + with open_study_db(study_db) as db: + if db.get_agent_count(population_id) == 0: + console.print( + f"[red]✗[/red] No agents found for population_id '{population_id}' in {study_db}" + ) + raise typer.Exit(1) + if db.get_network_edge_count(network_id) == 0: + console.print( + f"[red]✗[/red] No network edges found for network_id '{network_id}' in {study_db}" + ) + raise typer.Exit(1) # Load population spec to get scenario description if not provided try: @@ -106,8 +119,9 @@ def run_pipeline(): result_spec, validation_result = create_scenario( description=scenario_desc, population_spec_path=population, - agents_path=agents, - network_path=network, + study_db_path=study_db, + population_id=population_id, + network_id=network_id, output_path=None, # Don't save yet on_progress=on_progress, ) @@ -239,10 +253,10 @@ def run_pipeline(): console.print("[dim]Cancelled.[/dim]") raise typer.Exit(0) - # Convert paths to be relative to output file before saving - result_spec.meta.population_spec = make_relative_to(population, output_path) - result_spec.meta.agents_file = make_relative_to(agents, output_path) - result_spec.meta.network_file = make_relative_to(network, output_path) + result_spec.meta.population_spec = str(population) + result_spec.meta.study_db = str(study_db) + result_spec.meta.population_id = population_id + result_spec.meta.network_id = network_id # Save to YAML result_spec.to_yaml(output_path) diff --git a/extropy/cli/commands/simulate.py b/extropy/cli/commands/simulate.py index 73f162f..3b7b883 100644 --- a/extropy/cli/commands/simulate.py +++ b/extropy/cli/commands/simulate.py @@ -101,21 +101,17 @@ def setup_logging(verbose: bool = False, debug: bool = False): def simulate_command( scenario_file: Path = typer.Argument(..., help="Scenario spec YAML file"), output: Path = typer.Option(..., "--output", "-o", help="Output results directory"), - model: str = typer.Option( + study_db: Path = typer.Option(..., "--study-db", help="Canonical study DB file"), + strong: str = typer.Option( "", - "--model", + "--strong", "-m", - help="LLM model for agent reasoning (empty = use config default)", + help="Strong model for Pass 1 (provider/model format)", ), - pivotal_model: str = typer.Option( + fast: str = typer.Option( "", - "--pivotal-model", - help="Model for pivotal/first-pass reasoning (default: same as --model)", - ), - routine_model: str = typer.Option( - "", - "--routine-model", - help="Cheap model for classification pass (default: provider cheap tier)", + "--fast", + help="Fast model for Pass 2 (provider/model format)", ), threshold: int = typer.Option( 3, "--threshold", "-t", help="Multi-touch threshold for re-reasoning" @@ -132,6 +128,55 @@ def simulate_command( chunk_size: int = typer.Option( 50, "--chunk-size", help="Agents per reasoning chunk for checkpointing" ), + checkpoint_every_chunks: int = typer.Option( + 1, + "--checkpoint-every-chunks", + min=1, + help="Persist simulation chunk checkpoints every N chunks", + ), + run_id: str | None = typer.Option( + None, + "--run-id", + help="Explicit run id (required with --resume)", + ), + resume: bool = typer.Option( + False, + "--resume", + help="Resume an existing run from study DB checkpoints", + ), + writer_queue_size: int = typer.Option( + 256, + "--writer-queue-size", + min=1, + help="Max reasoning chunks buffered before DB writer backpressure", + ), + db_write_batch_size: int = typer.Option( + 100, + "--db-write-batch-size", + min=1, + help="Number of chunks applied per DB writer transaction", + ), + retention_lite: bool = typer.Option( + False, + "--retention-lite", + help="Reduce retained payload volume (drops full raw reasoning text)", + ), + resource_mode: str = typer.Option( + "auto", + "--resource-mode", + help="Resource tuning mode: auto | manual", + ), + safe_auto_workers: bool = typer.Option( + True, + "--safe-auto-workers/--unsafe-auto-workers", + help="Conservative auto tuning for laptop/VM environments", + ), + max_memory_gb: float | None = typer.Option( + None, + "--max-memory-gb", + min=0.5, + help="Optional memory budget cap for auto resource tuning", + ), seed: int | None = typer.Option( None, "--seed", help="Random seed for reproducibility" ), @@ -157,12 +202,13 @@ def simulate_command( used automatically for embodied first-person personas. Example: - extropy simulate scenario.yaml -o results/ - extropy simulate scenario.yaml -o results/ --model gpt-5-nano --seed 42 - extropy simulate scenario.yaml -o results/ --persona population.persona.yaml + extropy simulate scenario.yaml --study-db study.db -o results/ + extropy simulate scenario.yaml --study-db study.db -o results/ --model gpt-5-nano --seed 42 + extropy simulate scenario.yaml --study-db study.db -o results/ --persona population.persona.yaml """ from ...simulation import run_simulation from ...simulation.progress import SimulationProgress + from ...utils import ResourceGovernor # Setup logging based on verbosity setup_logging(verbose=verbose, debug=debug) @@ -174,34 +220,33 @@ def simulate_command( if not scenario_file.exists(): console.print(f"[red]✗[/red] Scenario file not found: {scenario_file}") raise typer.Exit(1) + if not study_db.exists(): + console.print(f"[red]✗[/red] Study DB not found: {study_db}") + raise typer.Exit(1) + if resume and not run_id: + console.print("[red]✗[/red] --resume requires --run-id") + raise typer.Exit(1) + if resource_mode not in {"auto", "manual"}: + console.print("[red]✗[/red] --resource-mode must be 'auto' or 'manual'") + raise typer.Exit(1) from ...config import get_config config = get_config() # Resolve models from CLI args > config > defaults - effective_model = model or config.simulation.model - effective_pivotal = pivotal_model or config.simulation.pivotal_model - effective_routine = routine_model or config.simulation.routine_model + effective_strong = strong or config.resolve_sim_strong() + effective_fast = fast or config.resolve_sim_fast() effective_tier = rate_tier or config.simulation.rate_tier effective_rpm = rpm_override or config.simulation.rpm_override effective_tpm = tpm_override or config.simulation.tpm_override - display_model = effective_model or f"({config.simulation.provider} default)" - display_provider = config.simulation.provider - console.print(f"Simulating: [bold]{scenario_file}[/bold]") console.print(f"Output: {output}") + console.print(f"Study DB: {study_db}") console.print( - f"Provider: {display_provider} | Model: {display_model} | Threshold: {threshold}" + f"Strong: {effective_strong} | Fast: {effective_fast} | Threshold: {threshold}" ) - if effective_pivotal or effective_routine: - parts = [] - if effective_pivotal: - parts.append(f"Pivotal: {effective_pivotal}") - if effective_routine: - parts.append(f"Routine: {effective_routine}") - console.print(" | ".join(parts)) if effective_tier: console.print(f"Rate tier: {effective_tier}") if effective_rpm or effective_tpm: @@ -213,6 +258,22 @@ def simulate_command( console.print(f"Rate overrides: {' | '.join(parts)}") if seed: console.print(f"Seed: {seed}") + governor = ResourceGovernor( + resource_mode=resource_mode, + safe_auto_workers=safe_auto_workers, + max_memory_gb=max_memory_gb, + ) + tuned_chunk_size = governor.recommend_chunk_size( + requested_chunk_size=chunk_size, + min_chunk_size=8, + max_chunk_size=2000, + ) + if resource_mode == "auto": + snap = governor.snapshot() + console.print( + f"Resources(auto): cpu={snap.cpu_count} mem={snap.total_memory_gb:.1f}GB " + f"budget={snap.memory_budget_gb:.1f}GB chunk={tuned_chunk_size}" + ) if verbose or debug: console.print(f"Logging: {'DEBUG' if debug else 'VERBOSE'}") console.print() @@ -236,9 +297,9 @@ def on_progress(timestep: int, max_timesteps: int, status: str): result = run_simulation( scenario_path=scenario_file, output_dir=output, - model=effective_model, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + study_db_path=study_db, + strong=effective_strong, + fast=effective_fast, multi_touch_threshold=threshold, random_seed=seed, on_progress=on_progress, @@ -246,8 +307,15 @@ def on_progress(timestep: int, max_timesteps: int, status: str): rate_tier=effective_tier, rpm_override=effective_rpm, tpm_override=effective_tpm, - chunk_size=chunk_size, + chunk_size=tuned_chunk_size, progress=progress_state, + run_id=run_id, + resume=resume, + checkpoint_every_chunks=checkpoint_every_chunks, + retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, + resource_governor=governor, ) simulation_error = None except Exception as e: @@ -265,9 +333,9 @@ def do_simulation(): result = run_simulation( scenario_path=scenario_file, output_dir=output, - model=effective_model, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + study_db_path=study_db, + strong=effective_strong, + fast=effective_fast, multi_touch_threshold=threshold, random_seed=seed, on_progress=on_progress if not quiet else None, @@ -275,8 +343,15 @@ def do_simulation(): rate_tier=effective_tier, rpm_override=effective_rpm, tpm_override=effective_tpm, - chunk_size=chunk_size, + chunk_size=tuned_chunk_size, progress=progress_state, + run_id=run_id, + resume=resume, + checkpoint_every_chunks=checkpoint_every_chunks, + retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, + resource_governor=governor, ) except Exception as e: simulation_error = e diff --git a/extropy/cli/commands/validate.py b/extropy/cli/commands/validate.py index 0ca0fa3..d09177a 100644 --- a/extropy/cli/commands/validate.py +++ b/extropy/cli/commands/validate.py @@ -185,17 +185,13 @@ def _validate_scenario_spec(spec_file: Path, out: Output) -> int: f" [red]✗[/red] Population: {spec.meta.population_spec} (not found)" ) - agents_path = resolve_relative_to(spec.meta.agents_file, spec_file) - if agents_path.exists(): - out.text(f" [green]✓[/green] Agents: {spec.meta.agents_file}") + study_db_path = resolve_relative_to(spec.meta.study_db, spec_file) + if study_db_path.exists(): + out.text(f" [green]✓[/green] Study DB: {spec.meta.study_db}") else: - out.text(f" [red]✗[/red] Agents: {spec.meta.agents_file} (not found)") - - network_path = resolve_relative_to(spec.meta.network_file, spec_file) - if network_path.exists(): - out.text(f" [green]✓[/green] Network: {spec.meta.network_file}") - else: - out.text(f" [red]✗[/red] Network: {spec.meta.network_file} (not found)") + out.text(f" [red]✗[/red] Study DB: {spec.meta.study_db} (not found)") + out.text(f" [cyan]•[/cyan] population_id: {spec.meta.population_id}") + out.text(f" [cyan]•[/cyan] network_id: {spec.meta.network_id}") out.blank() diff --git a/extropy/config.py b/extropy/config.py index 55b7cb0..f2e2ca7 100644 --- a/extropy/config.py +++ b/extropy/config.py @@ -1,12 +1,14 @@ """Configuration management for Extropy. -Two-zone config system: -- pipeline: provider + models for phases 1-2 (spec, extend, sample, network, persona, scenario) -- simulation: provider + model for phase 3 (agent reasoning) +Two-tier config system: +- models: fast/strong model strings for pipeline phases 1-2 +- simulation: fast/strong model strings for phase 3 (agent reasoning) + +Model strings use "provider/model" format (e.g., "openai/gpt-5-mini"). Config resolution order (highest priority first): 1. Programmatic (ExtropyConfig constructed in code) -2. Environment variables (PIPELINE_PROVIDER, SIMULATION_MODEL, etc.) +2. Environment variables (MODELS_FAST, MODELS_STRONG, etc.) 3. Config file (~/.config/extropy/config.json, managed by `extropy config`) 4. Hardcoded defaults @@ -16,10 +18,11 @@ import json import logging import os -from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Any +from pydantic import BaseModel, ConfigDict, Field + logger = logging.getLogger(__name__) @@ -33,41 +36,85 @@ # ============================================================================= -# Two-zone config dataclasses +# Model string parsing +# ============================================================================= + + +def parse_model_string(model_string: str) -> tuple[str, str]: + """Parse a "provider/model" string into (provider, model) tuple. + + Examples: + "openai/gpt-5-mini" → ("openai", "gpt-5-mini") + "anthropic/claude-sonnet-4.5" → ("anthropic", "claude-sonnet-4.5") + "openrouter/anthropic/claude-sonnet-4.5" → ("openrouter", "anthropic/claude-sonnet-4.5") + + Raises: + ValueError: If the string doesn't contain a '/' separator. + """ + if "/" not in model_string: + raise ValueError( + f"Invalid model string: {model_string!r}. " + f"Expected format: 'provider/model' (e.g., 'openai/gpt-5-mini')" + ) + provider, _, model = model_string.partition("/") + if not provider or not model: + raise ValueError( + f"Invalid model string: {model_string!r}. " + f"Both provider and model must be non-empty." + ) + return provider, model + + +# ============================================================================= +# Two-tier config models # ============================================================================= -@dataclass -class PipelineConfig: - """Config for phases 1-2: spec, extend, sample, network, persona, scenario.""" +class ModelsConfig(BaseModel): + """Pipeline model configuration (phases 1-2). + + Uses "provider/model" format strings. + - fast: used for simple_call (cheap, fast tasks) + - strong: used for reasoning_call, agentic_research (complex tasks) + """ + + model_config = ConfigDict(populate_by_name=True) - provider: str = "openai" - model_simple: str = "" # empty = provider default - model_reasoning: str = "" # empty = provider default - model_research: str = "" # empty = provider default + fast: str = "openai/gpt-5-mini" + strong: str = "openai/gpt-5" -@dataclass -class SimZoneConfig: - """Config for phase 3: agent reasoning during simulation.""" +class SimulationConfig(BaseModel): + """Simulation model + tuning configuration (phase 3). + + Uses "provider/model" format strings. + - fast: used for Pass 2 (classification/routine) + - strong: used for Pass 1 (pivotal/role-play reasoning) + """ + + model_config = ConfigDict(populate_by_name=True) - provider: str = "openai" - model: str = "" # empty = provider default - pivotal_model: str = "" # model for pivotal reasoning (default: same as model) - routine_model: str = ( - "" # cheap model for classification (default: provider cheap tier) - ) + fast: str = "" # empty = same as models.fast + strong: str = "" # empty = same as models.strong max_concurrent: int = 50 - rate_tier: int | None = None # rate limit tier (1-4, None = Tier 1) - rpm_override: int | None = None # override RPM limit - tpm_override: int | None = None # override TPM limit - api_format: str = ( - "" # empty = auto (responses for openai, chat_completions for azure) - ) + rate_tier: int | None = None + rpm_override: int | None = None + tpm_override: int | None = None -@dataclass -class ExtropyConfig: +class CustomProviderConfig(BaseModel): + """Config for a custom OpenAI-compatible provider endpoint.""" + + base_url: str = "" + api_key_env: str = "" + + +# ============================================================================= +# Main config class +# ============================================================================= + + +class ExtropyConfig(BaseModel): """Top-level extropy configuration. Construct programmatically for package use, or load from config file for CLI use. @@ -75,8 +122,7 @@ class ExtropyConfig: Examples: # Package use — no files needed config = ExtropyConfig( - pipeline=PipelineConfig(provider="claude"), - simulation=SimZoneConfig(provider="openai", model="gpt-5-mini"), + models=ModelsConfig(fast="openai/gpt-5-mini", strong="anthropic/claude-sonnet-4.5"), ) # CLI use — loads from ~/.config/extropy/config.json @@ -84,15 +130,15 @@ class ExtropyConfig: # Override just simulation config = ExtropyConfig.load() - config.simulation.model = "gpt-5-nano" + config.simulation.strong = "openrouter/anthropic/claude-sonnet-4.5" """ - pipeline: PipelineConfig = field(default_factory=PipelineConfig) - simulation: SimZoneConfig = field(default_factory=SimZoneConfig) + model_config = ConfigDict(populate_by_name=True) - # Non-zone settings - db_path: str = "./storage/extropy.db" - default_population_size: int = 1000 + models: ModelsConfig = Field(default_factory=ModelsConfig) + simulation: SimulationConfig = Field(default_factory=SimulationConfig) + providers: dict[str, CustomProviderConfig] = Field(default_factory=dict) + show_cost: bool = False @classmethod def load(cls) -> "ExtropyConfig": @@ -102,7 +148,7 @@ def load(cls) -> "ExtropyConfig": """ config = cls() - # Layer 1: Load from config file if it exists + # Load from config file if it exists if CONFIG_FILE.exists(): try: with open(CONFIG_FILE) as f: @@ -111,27 +157,20 @@ def load(cls) -> "ExtropyConfig": except (json.JSONDecodeError, OSError) as exc: logger.warning("Failed to load config from %s: %s", CONFIG_FILE, exc) - # Layer 2: Env var overrides - if provider := os.environ.get("LLM_PROVIDER"): - # Legacy: single provider applied to both zones - config.pipeline.provider = provider - config.simulation.provider = provider - if val := os.environ.get("PIPELINE_PROVIDER"): - config.pipeline.provider = val - if val := os.environ.get("SIMULATION_PROVIDER"): - config.simulation.provider = val - if val := os.environ.get("MODEL_SIMPLE"): - config.pipeline.model_simple = val - if val := os.environ.get("MODEL_REASONING"): - config.pipeline.model_reasoning = val - if val := os.environ.get("MODEL_RESEARCH"): - config.pipeline.model_research = val - if val := os.environ.get("SIMULATION_MODEL"): - config.simulation.model = val - if val := os.environ.get("SIMULATION_PIVOTAL_MODEL"): - config.simulation.pivotal_model = val - if val := os.environ.get("SIMULATION_ROUTINE_MODEL"): - config.simulation.routine_model = val + # Env var overrides + if val := os.environ.get("MODELS_FAST"): + config.models.fast = val + if val := os.environ.get("MODELS_STRONG"): + config.models.strong = val + if val := os.environ.get("SIMULATION_FAST"): + config.simulation.fast = val + if val := os.environ.get("SIMULATION_STRONG"): + config.simulation.strong = val + if val := os.environ.get("SIMULATION_MAX_CONCURRENT"): + try: + config.simulation.max_concurrent = int(val) + except ValueError: + logger.warning("Invalid SIMULATION_MAX_CONCURRENT=%r, ignoring", val) if val := os.environ.get("SIMULATION_RATE_TIER"): try: config.simulation.rate_tier = int(val) @@ -147,38 +186,56 @@ def load(cls) -> "ExtropyConfig": config.simulation.tpm_override = int(val) except ValueError: logger.warning("Invalid SIMULATION_TPM_OVERRIDE=%r, ignoring", val) - if val := os.environ.get("SIMULATION_API_FORMAT"): - config.simulation.api_format = val - if val := os.environ.get("DB_PATH"): - config.db_path = val - if val := os.environ.get("DEFAULT_POPULATION_SIZE"): - try: - config.default_population_size = int(val) - except ValueError: - logger.warning("Invalid DEFAULT_POPULATION_SIZE=%r, ignoring", val) return config def save(self) -> None: """Save config to ~/.config/extropy/config.json.""" CONFIG_DIR.mkdir(parents=True, exist_ok=True) - data = asdict(self) - # Don't persist non-zone settings that are better as env vars - data.pop("db_path", None) - data.pop("default_population_size", None) + data: dict[str, Any] = { + "models": self.models.model_dump(), + "simulation": self.simulation.model_dump(), + } + if self.providers: + data["providers"] = { + name: cfg.model_dump() for name, cfg in self.providers.items() + } + if self.show_cost: + data["show_cost"] = True with open(CONFIG_FILE, "w") as f: json.dump(data, f, indent=2) def to_dict(self) -> dict[str, Any]: """Convert to dict for display.""" - return asdict(self) + result = { + "models": self.models.model_dump(), + "simulation": self.simulation.model_dump(), + } + if self.providers: + result["providers"] = { + name: cfg.model_dump() for name, cfg in self.providers.items() + } + return result - @property - def db_path_resolved(self) -> Path: - """Resolve database path.""" - path = Path(self.db_path) - path.parent.mkdir(parents=True, exist_ok=True) - return path + # ── Convenience resolution methods ── + + def resolve_pipeline_fast(self) -> str: + """Resolve the fast model string for pipeline use.""" + return self.models.fast + + def resolve_pipeline_strong(self) -> str: + """Resolve the strong model string for pipeline use.""" + return self.models.strong + + def resolve_sim_strong(self) -> str: + """Resolve the strong model string for simulation.""" + return self.simulation.strong or self.models.strong + + def resolve_sim_fast(self) -> str: + """Resolve the fast model string for simulation.""" + return self.simulation.fast or self.models.fast + + # ── Backward compat properties ── @property def cache_dir(self) -> Path: @@ -188,24 +245,34 @@ def cache_dir(self) -> Path: return path +# ============================================================================= +# Config dict application +# ============================================================================= + + def _apply_dict(config: ExtropyConfig, data: dict) -> None: - """Apply a dict of values onto an ExtropyConfig.""" - if "pipeline" in data and isinstance(data["pipeline"], dict): - for k, v in data["pipeline"].items(): - if hasattr(config.pipeline, k): - setattr(config.pipeline, k, v) + """Apply a dict of values onto an ExtropyConfig (v2 format).""" + if "models" in data and isinstance(data["models"], dict): + for k, v in data["models"].items(): + if hasattr(config.models, k): + setattr(config.models, k, v) if "simulation" in data and isinstance(data["simulation"], dict): for k, v in data["simulation"].items(): if hasattr(config.simulation, k): setattr(config.simulation, k, v) - if "db_path" in data: - config.db_path = data["db_path"] - if "default_population_size" in data: - config.default_population_size = int(data["default_population_size"]) + if "providers" in data and isinstance(data["providers"], dict): + for name, provider_data in data["providers"].items(): + if isinstance(provider_data, dict): + config.providers[name] = CustomProviderConfig( + base_url=provider_data.get("base_url", ""), + api_key_env=provider_data.get("api_key_env", ""), + ) + if "show_cost" in data: + config.show_cost = bool(data["show_cost"]) # ============================================================================= -# API key resolution (env vars + .env file) +# API key resolution # ============================================================================= _dotenv_loaded = False @@ -219,60 +286,49 @@ def _ensure_dotenv() -> None: try: from dotenv import find_dotenv, load_dotenv - # Resolve from current working directory first so CLI commands run - # from study repos consistently pick up that repo's `.env`. dotenv_path = find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path=dotenv_path, override=False) else: - # Fallback for environments where no discoverable .env exists. load_dotenv(override=False) except ImportError: - pass # python-dotenv not installed, skip + pass except Exception: - # Keep config loading resilient even if dotenv discovery has runtime issues. pass -def get_api_key(provider: str) -> str: - """Get API key for a provider from environment variables or .env file. +def get_api_key_for_provider( + provider_name: str, + custom_providers: dict[str, CustomProviderConfig] | None = None, +) -> str: + """Get API key for a provider. - Supports: - - openai: OPENAI_API_KEY - - claude: ANTHROPIC_API_KEY - - azure_openai: AZURE_OPENAI_API_KEY + Resolution order: + 1. Custom provider api_key_env override + 2. Convention: {PROVIDER_UPPER}_API_KEY - Returns empty string if not found (providers will raise on missing keys). + Special cases: + - "anthropic" → ANTHROPIC_API_KEY + - "azure" → AZURE_OPENAI_API_KEY + + Returns empty string if not found. """ _ensure_dotenv() - if provider == "openai": - return os.environ.get("OPENAI_API_KEY", "") - elif provider == "claude": - return os.environ.get("ANTHROPIC_API_KEY", "") - elif provider == "azure_openai": - return os.environ.get("AZURE_OPENAI_API_KEY", "") - return "" - - -def get_azure_config(provider: str) -> dict[str, str]: - """Get Azure-specific configuration from environment variables. - Args: - provider: 'azure_openai' + # Check custom provider override first + if custom_providers and provider_name in custom_providers: + custom = custom_providers[provider_name] + if custom.api_key_env: + return os.environ.get(custom.api_key_env, "") - Returns: - Dict of Azure config values (endpoint, api_version, deployment). - """ - _ensure_dotenv() - if provider == "azure_openai": - return { - "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT", ""), - "api_version": os.environ.get( - "AZURE_OPENAI_API_VERSION", "2025-03-01-preview" - ), - "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT", ""), - } - return {} + # Convention: {PROVIDER}_API_KEY + # Special cases for backward compat + key_map = { + "azure": "AZURE_OPENAI_API_KEY", + "azure_openai": "AZURE_OPENAI_API_KEY", + } + env_var = key_map.get(provider_name, f"{provider_name.upper()}_API_KEY") + return os.environ.get(env_var, "") # ============================================================================= @@ -298,8 +354,8 @@ def configure(config: ExtropyConfig) -> None: """Set the global ExtropyConfig programmatically. Use this when extropy is used as a package: - from extropy.config import configure, ExtropyConfig, PipelineConfig - configure(ExtropyConfig(pipeline=PipelineConfig(provider="claude"))) + from extropy.config import configure, ExtropyConfig, ModelsConfig + configure(ExtropyConfig(models=ModelsConfig(fast="openai/gpt-5-mini"))) """ global _config _config = config diff --git a/extropy/core/cost/__init__.py b/extropy/core/cost/__init__.py new file mode 100644 index 0000000..1377de8 --- /dev/null +++ b/extropy/core/cost/__init__.py @@ -0,0 +1,39 @@ +"""Cost tracking, pricing resolution, and persistent ledger. + +This package provides: +- CostTracker: Session-scoped accumulator (auto-records from providers) +- Pricing: Three-tier model pricing resolution (OpenRouter → cache → fallback) +- Ledger: Persistent cost history (~/.config/extropy/cost_ledger.db) +""" + +from .tracker import CostTracker, CallRecord, ModelUsage +from .pricing import ( + ModelPricing, + get_pricing, + resolve_default_model, + refresh_pricing, + get_cache_info, + FALLBACK_PRICING, + PROVIDER_DEFAULTS, +) +from .ledger import CostEntry, record_session, query_entries, query_totals + +__all__ = [ + # Tracker + "CostTracker", + "CallRecord", + "ModelUsage", + # Pricing + "ModelPricing", + "get_pricing", + "resolve_default_model", + "refresh_pricing", + "get_cache_info", + "FALLBACK_PRICING", + "PROVIDER_DEFAULTS", + # Ledger + "CostEntry", + "record_session", + "query_entries", + "query_totals", +] diff --git a/extropy/core/cost/ledger.py b/extropy/core/cost/ledger.py new file mode 100644 index 0000000..c6d7e08 --- /dev/null +++ b/extropy/core/cost/ledger.py @@ -0,0 +1,278 @@ +"""Persistent cost ledger. + +Appends session cost summaries to a local SQLite database. +Provides query methods for the `extropy cost` command. +""" + +import json +import logging +import sqlite3 +import time +from datetime import datetime +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +_LEDGER_DIR = Path.home() / ".config" / "extropy" +_LEDGER_FILE = _LEDGER_DIR / "cost_ledger.db" + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS cost_entries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp REAL NOT NULL, + date TEXT NOT NULL, + command TEXT NOT NULL, + scenario TEXT NOT NULL DEFAULT '', + total_calls INTEGER NOT NULL DEFAULT 0, + total_input_tokens INTEGER NOT NULL DEFAULT 0, + total_output_tokens INTEGER NOT NULL DEFAULT 0, + total_cost REAL, + models_json TEXT NOT NULL DEFAULT '{}', + elapsed_seconds REAL +); + +CREATE INDEX IF NOT EXISTS idx_cost_entries_date ON cost_entries(date); +CREATE INDEX IF NOT EXISTS idx_cost_entries_command ON cost_entries(command); +""" + + +class CostEntry(BaseModel): + """A single cost ledger entry.""" + + timestamp: float + date: str + command: str + scenario: str + total_calls: int + total_input_tokens: int + total_output_tokens: int + total_cost: float | None + models: dict[str, Any] + elapsed_seconds: float | None + + +def _get_connection() -> sqlite3.Connection: + """Get a connection to the ledger database, creating it if needed.""" + _LEDGER_DIR.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(_LEDGER_FILE)) + conn.row_factory = sqlite3.Row + conn.executescript(_SCHEMA) + return conn + + +def record_session(summary: dict[str, Any]) -> None: + """Append a session cost summary to the ledger. + + Args: + summary: Dict from CostTracker.summary() + """ + if summary.get("total_calls", 0) == 0: + return + + try: + conn = _get_connection() + try: + now = time.time() + conn.execute( + """ + INSERT INTO cost_entries + (timestamp, date, command, scenario, total_calls, + total_input_tokens, total_output_tokens, total_cost, + models_json, elapsed_seconds) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + now, + datetime.fromtimestamp(now).strftime("%Y-%m-%d"), + summary.get("command", ""), + summary.get("scenario", ""), + summary.get("total_calls", 0), + summary.get("total_input_tokens", 0), + summary.get("total_output_tokens", 0), + summary.get("total_cost"), + json.dumps(summary.get("by_model", {})), + summary.get("elapsed_seconds"), + ), + ) + conn.commit() + finally: + conn.close() + except (sqlite3.Error, OSError) as e: + logger.debug(f"Failed to record cost to ledger: {e}") + + +def query_entries( + days: int | None = 7, + command: str | None = None, + limit: int = 100, +) -> list[CostEntry]: + """Query cost ledger entries. + + Args: + days: Number of days to look back (None = all time) + command: Filter by command name (None = all commands) + limit: Max entries to return + + Returns: + List of CostEntry, newest first. + """ + try: + conn = _get_connection() + except (sqlite3.Error, OSError): + return [] + + try: + clauses = [] + params: list[Any] = [] + + if days is not None: + cutoff = time.time() - (days * 86400) + clauses.append("timestamp >= ?") + params.append(cutoff) + + if command: + clauses.append("command = ?") + params.append(command) + + where = f"WHERE {' AND '.join(clauses)}" if clauses else "" + params.append(limit) + + rows = conn.execute( + f""" + SELECT * FROM cost_entries + {where} + ORDER BY timestamp DESC + LIMIT ? + """, + params, + ).fetchall() + + entries = [] + for row in rows: + try: + models = json.loads(row["models_json"]) + except (json.JSONDecodeError, TypeError): + models = {} + + entries.append( + CostEntry( + timestamp=row["timestamp"], + date=row["date"], + command=row["command"], + scenario=row["scenario"], + total_calls=row["total_calls"], + total_input_tokens=row["total_input_tokens"], + total_output_tokens=row["total_output_tokens"], + total_cost=row["total_cost"], + models=models, + elapsed_seconds=row["elapsed_seconds"], + ) + ) + + return entries + finally: + conn.close() + + +def query_totals( + days: int | None = 7, + group_by: str | None = None, +) -> dict[str, Any]: + """Query aggregated cost totals. + + Args: + days: Number of days to look back (None = all time) + group_by: Group results by "command", "date", or "model" (None = totals only) + + Returns: + Dict with total and optional grouped breakdowns. + """ + try: + conn = _get_connection() + except (sqlite3.Error, OSError): + return {"total_cost": None, "total_calls": 0} + + try: + where = "" + params: list[Any] = [] + if days is not None: + cutoff = time.time() - (days * 86400) + where = "WHERE timestamp >= ?" + params.append(cutoff) + + # Overall totals + row = conn.execute( + f""" + SELECT + COUNT(*) as sessions, + SUM(total_calls) as calls, + SUM(total_input_tokens) as input_tokens, + SUM(total_output_tokens) as output_tokens, + SUM(total_cost) as cost + FROM cost_entries + {where} + """, + params, + ).fetchone() + + result: dict[str, Any] = { + "sessions": row["sessions"] or 0, + "total_calls": row["calls"] or 0, + "total_input_tokens": row["input_tokens"] or 0, + "total_output_tokens": row["output_tokens"] or 0, + "total_cost": round(row["cost"], 4) if row["cost"] else None, + } + + # Grouped breakdown + if group_by == "command": + rows = conn.execute( + f""" + SELECT command, + COUNT(*) as sessions, + SUM(total_calls) as calls, + SUM(total_cost) as cost + FROM cost_entries + {where} + GROUP BY command + ORDER BY cost DESC + """, + params, + ).fetchall() + result["by_command"] = { + r["command"]: { + "sessions": r["sessions"], + "calls": r["calls"] or 0, + "cost": round(r["cost"], 4) if r["cost"] else None, + } + for r in rows + } + + elif group_by == "date": + rows = conn.execute( + f""" + SELECT date, + COUNT(*) as sessions, + SUM(total_calls) as calls, + SUM(total_cost) as cost + FROM cost_entries + {where} + GROUP BY date + ORDER BY date DESC + """, + params, + ).fetchall() + result["by_date"] = { + r["date"]: { + "sessions": r["sessions"], + "calls": r["calls"] or 0, + "cost": round(r["cost"], 4) if r["cost"] else None, + } + for r in rows + } + + return result + finally: + conn.close() diff --git a/extropy/core/cost/pricing.py b/extropy/core/cost/pricing.py new file mode 100644 index 0000000..a87085e --- /dev/null +++ b/extropy/core/cost/pricing.py @@ -0,0 +1,369 @@ +"""Model pricing resolution for cost estimation and tracking. + +Three-tier pricing resolution: +1. OpenRouter API (free, no auth, covers 200+ models) → cached locally +2. Local cache file (~/.config/extropy/pricing_cache.json) with 24h TTL +3. Hardcoded fallback table for offline/known models + +Provides per-model input/output pricing (USD per million tokens) +and provider default model resolution without needing API keys. +""" + +import json +import logging +import time +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +# Cache location and TTL +_CACHE_DIR = Path.home() / ".config" / "extropy" +_CACHE_FILE = _CACHE_DIR / "pricing_cache.json" +_CACHE_TTL_SECONDS = 24 * 60 * 60 # 24 hours + + +class ModelPricing(BaseModel, frozen=True): + """Pricing for a single model (USD per million tokens).""" + + input_per_mtok: float + output_per_mtok: float + + +# ── Hardcoded fallback (Tier 3) ────────────────────────────────────────────── + +# Known model pricing (USD per million tokens) +# Sources: OpenAI and Anthropic pricing pages as of 2025 +FALLBACK_PRICING: dict[str, ModelPricing] = { + # OpenAI + "gpt-5": ModelPricing(input_per_mtok=2.50, output_per_mtok=10.00), + "gpt-5-mini": ModelPricing(input_per_mtok=0.30, output_per_mtok=1.50), + "gpt-5-nano": ModelPricing(input_per_mtok=0.10, output_per_mtok=0.40), + "gpt-5.2": ModelPricing(input_per_mtok=2.50, output_per_mtok=10.00), + # Azure-hosted models + "DeepSeek-V3.2": ModelPricing(input_per_mtok=0.80, output_per_mtok=2.00), + "Kimi-K2.5": ModelPricing(input_per_mtok=1.00, output_per_mtok=4.00), + # Claude + "claude-sonnet-4-5-20250929": ModelPricing( + input_per_mtok=3.00, output_per_mtok=15.00 + ), + "claude-sonnet-4-5-20250514": ModelPricing( + input_per_mtok=3.00, output_per_mtok=15.00 + ), + "claude-sonnet-4.5": ModelPricing(input_per_mtok=3.00, output_per_mtok=15.00), + "claude-sonnet-4": ModelPricing(input_per_mtok=3.00, output_per_mtok=15.00), + "claude-haiku-4-5-20251001": ModelPricing( + input_per_mtok=0.80, output_per_mtok=4.00 + ), + "claude-haiku-4.5": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + "claude-haiku-4": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + # DeepSeek (direct API) + "deepseek-chat": ModelPricing(input_per_mtok=0.14, output_per_mtok=0.28), + "deepseek-reasoner": ModelPricing(input_per_mtok=0.55, output_per_mtok=2.19), +} + +# Provider default models — 2-tier (fast/strong) +PROVIDER_DEFAULTS: dict[str, dict[str, str]] = { + "openai": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "openrouter": { + "fast": "openai/gpt-5-mini", + "strong": "openai/gpt-5", + }, + "deepseek": { + "fast": "deepseek-chat", + "strong": "deepseek-reasoner", + }, + "together": { + "fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + "groq": { + "fast": "llama-3.3-70b-versatile", + "strong": "llama-3.3-70b-versatile", + }, + # Legacy aliases + "claude": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure_openai": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, +} + + +# ── In-memory cache ────────────────────────────────────────────────────────── + +_memory_cache: dict[str, ModelPricing] = {} +_memory_cache_loaded: bool = False + + +# ── Tier 1: OpenRouter API ─────────────────────────────────────────────────── + + +def _fetch_openrouter_pricing() -> dict[str, ModelPricing] | None: + """Fetch pricing from OpenRouter API (no auth required). + + Returns: + Dict of model_id → ModelPricing, or None if fetch failed. + """ + try: + import urllib.request + import urllib.error + + url = "https://openrouter.ai/api/v1/models" + req = urllib.request.Request(url, headers={"User-Agent": "extropy"}) + with urllib.request.urlopen(req, timeout=10) as resp: + data = json.loads(resp.read().decode()) + + result: dict[str, ModelPricing] = {} + for model in data.get("data", []): + model_id = model.get("id", "") + pricing = model.get("pricing", {}) + + # OpenRouter returns pricing as string USD per token (not per MTok) + prompt_price = pricing.get("prompt") + completion_price = pricing.get("completion") + + if prompt_price is not None and completion_price is not None: + try: + input_per_tok = float(prompt_price) + output_per_tok = float(completion_price) + except (ValueError, TypeError): + continue + + # Skip free/zero-cost models + if input_per_tok == 0 and output_per_tok == 0: + continue + + # Convert per-token to per-million-tokens + result[model_id] = ModelPricing( + input_per_mtok=input_per_tok * 1_000_000, + output_per_mtok=output_per_tok * 1_000_000, + ) + + if result: + logger.debug(f"Fetched pricing for {len(result)} models from OpenRouter") + return result + + except Exception as e: + logger.debug(f"Failed to fetch OpenRouter pricing: {e}") + + return None + + +# ── Tier 2: Local cache file ───────────────────────────────────────────────── + + +def _load_cache() -> dict[str, ModelPricing] | None: + """Load pricing from local cache file if it exists and is fresh. + + Returns: + Dict of model_id → ModelPricing, or None if cache is stale/missing. + """ + if not _CACHE_FILE.exists(): + return None + + try: + with open(_CACHE_FILE) as f: + data = json.load(f) + + # Check TTL + cached_at = data.get("cached_at", 0) + if time.time() - cached_at > _CACHE_TTL_SECONDS: + logger.debug("Pricing cache expired") + return None + + result: dict[str, ModelPricing] = {} + for model_id, pricing in data.get("models", {}).items(): + result[model_id] = ModelPricing( + input_per_mtok=pricing["input_per_mtok"], + output_per_mtok=pricing["output_per_mtok"], + ) + + logger.debug(f"Loaded {len(result)} models from pricing cache") + return result + + except (json.JSONDecodeError, KeyError, OSError) as e: + logger.debug(f"Failed to load pricing cache: {e}") + return None + + +def _save_cache(pricing: dict[str, ModelPricing]) -> None: + """Save pricing to local cache file.""" + try: + _CACHE_DIR.mkdir(parents=True, exist_ok=True) + data = { + "cached_at": time.time(), + "models": {model_id: p.model_dump() for model_id, p in pricing.items()}, + } + with open(_CACHE_FILE, "w") as f: + json.dump(data, f, indent=2) + logger.debug(f"Saved pricing cache with {len(pricing)} models") + except OSError as e: + logger.debug(f"Failed to save pricing cache: {e}") + + +# ── Resolution logic ───────────────────────────────────────────────────────── + + +def _ensure_cache_loaded() -> None: + """Lazily load the pricing cache into memory (once per process).""" + global _memory_cache, _memory_cache_loaded + if _memory_cache_loaded: + return + + # Try local cache first (fast, no network) + cached = _load_cache() + if cached: + _memory_cache = cached + _memory_cache_loaded = True + return + + # Try OpenRouter API + fetched = _fetch_openrouter_pricing() + if fetched: + _memory_cache = fetched + _save_cache(fetched) + _memory_cache_loaded = True + return + + # No dynamic pricing available — will fall through to hardcoded + _memory_cache_loaded = True + + +def _normalize_model_id(model: str) -> list[str]: + """Generate candidate lookup keys for a model name. + + Handles the mapping between bare model names (used by providers) + and OpenRouter-style IDs (provider/model). + + Args: + model: Model name (e.g., "gpt-5-mini" or "openai/gpt-5-mini") + + Returns: + List of candidate keys to try, in priority order. + """ + candidates = [model] + + # If it's already a provider/model format, also try the bare model name + if "/" in model: + bare = model.rsplit("/", 1)[-1] + candidates.append(bare) + else: + # Try common provider prefixes for bare model names + if model.startswith("gpt-"): + candidates.append(f"openai/{model}") + elif model.startswith("claude-"): + candidates.append(f"anthropic/{model}") + elif model.startswith("deepseek-"): + candidates.append(f"deepseek/{model}") + elif model.startswith("llama-") or model.startswith("meta-llama/"): + candidates.append(f"meta-llama/{model}") + + return candidates + + +def get_pricing(model: str) -> ModelPricing | None: + """Get pricing for a model using three-tier resolution. + + Resolution order: + 1. OpenRouter API cache (refreshed every 24h) + 2. Local cache file + 3. Hardcoded fallback table + + Args: + model: Model name (bare like "gpt-5-mini" or qualified like "openai/gpt-5-mini") + + Returns: + ModelPricing or None if no pricing found. + """ + _ensure_cache_loaded() + + candidates = _normalize_model_id(model) + + # Try dynamic cache first + for candidate in candidates: + if candidate in _memory_cache: + return _memory_cache[candidate] + + # Fall back to hardcoded + for candidate in candidates: + if candidate in FALLBACK_PRICING: + return FALLBACK_PRICING[candidate] + + return None + + +def resolve_default_model(provider: str, tier: str = "strong") -> str: + """Resolve default model name for a provider without instantiating it. + + Args: + provider: Provider name ('openai', 'anthropic', etc.) + tier: 'fast' or 'strong' (also accepts legacy 'simple'/'reasoning') + + Returns: + Model name string + """ + # Map legacy tier names + tier_map = {"simple": "fast", "reasoning": "strong"} + tier = tier_map.get(tier, tier) + + defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["openai"]) + return defaults.get(tier, defaults["strong"]) + + +def refresh_pricing() -> bool: + """Force-refresh pricing from OpenRouter API. + + Returns: + True if refresh succeeded. + """ + global _memory_cache, _memory_cache_loaded + + fetched = _fetch_openrouter_pricing() + if fetched: + _memory_cache = fetched + _memory_cache_loaded = True + _save_cache(fetched) + return True + return False + + +def get_cache_info() -> dict[str, Any]: + """Get info about the pricing cache state (for diagnostics).""" + info: dict[str, Any] = { + "cache_file": str(_CACHE_FILE), + "cache_exists": _CACHE_FILE.exists(), + "memory_loaded": _memory_cache_loaded, + "memory_models": len(_memory_cache), + "fallback_models": len(FALLBACK_PRICING), + } + + if _CACHE_FILE.exists(): + try: + with open(_CACHE_FILE) as f: + data = json.load(f) + cached_at = data.get("cached_at", 0) + age_hours = (time.time() - cached_at) / 3600 + info["cache_age_hours"] = round(age_hours, 1) + info["cache_fresh"] = age_hours < (_CACHE_TTL_SECONDS / 3600) + info["cached_models"] = len(data.get("models", {})) + except (json.JSONDecodeError, OSError): + info["cache_corrupt"] = True + + return info diff --git a/extropy/core/cost/tracker.py b/extropy/core/cost/tracker.py new file mode 100644 index 0000000..58375e4 --- /dev/null +++ b/extropy/core/cost/tracker.py @@ -0,0 +1,288 @@ +"""Session-scoped cost accumulator. + +Automatically records token usage from every LLM provider call within +a CLI session. Providers push usage via CostTracker.record(); the CLI +reads the totals at exit via CostTracker.summary(). + +Thread-safe — simulation calls record() from async workers concurrently. +""" + +import logging +import threading +import time +from typing import Any + +from pydantic import BaseModel + +from ..providers.base import TokenUsage +from .pricing import get_pricing + +logger = logging.getLogger(__name__) + + +class CallRecord(BaseModel): + """A single LLM API call's token usage.""" + + model: str + input_tokens: int + output_tokens: int + timestamp: float + call_type: str = "" # "simple", "reasoning", "agentic_research", "async" + + +class ModelUsage(BaseModel): + """Accumulated usage for a single model.""" + + calls: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + +class CostTracker: + """Session-scoped cost accumulator. + + Singleton per process. Providers auto-record into this after each call. + The CLI reads summary/cost at session end. + + Thread-safe: uses a lock for mutation since simulation workers + call record() concurrently. + + Note: This is not a Pydantic model because it manages mutable state + with thread locks and singleton lifecycle — patterns that don't fit + Pydantic's immutable-validation model. + """ + + _instance: "CostTracker | None" = None + _lock_cls = threading.Lock() # Class-level lock for singleton creation + + def __init__(self) -> None: + self._records: list[CallRecord] = [] + self._by_model: dict[str, ModelUsage] = {} + self._lock = threading.Lock() + self._started_at = time.time() + self._command: str = "" # Set by CLI (e.g., "spec", "simulate") + self._scenario: str = "" # Set by CLI for ledger tagging + + @classmethod + def get(cls) -> "CostTracker": + """Get or create the singleton instance.""" + if cls._instance is None: + with cls._lock_cls: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + """Reset the singleton (for testing or new session).""" + with cls._lock_cls: + cls._instance = None + + def set_context(self, command: str = "", scenario: str = "") -> None: + """Set session context for ledger tagging. + + Args: + command: CLI command name (e.g., "spec", "simulate") + scenario: Scenario/population name for identification + """ + self._command = command + self._scenario = scenario + + def record( + self, + model: str, + usage: TokenUsage, + call_type: str = "", + ) -> None: + """Record token usage from a single LLM API call. + + Called automatically by provider base class after each call. + + Args: + model: Model name used for the call + usage: Token usage from the API response + call_type: Type of call ("simple", "reasoning", etc.) + """ + if usage.input_tokens == 0 and usage.output_tokens == 0: + return + + record = CallRecord( + model=model, + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + timestamp=time.time(), + call_type=call_type, + ) + + with self._lock: + self._records.append(record) + + if model not in self._by_model: + self._by_model[model] = ModelUsage() + + mu = self._by_model[model] + mu.calls += 1 + mu.input_tokens += usage.input_tokens + mu.output_tokens += usage.output_tokens + + @property + def total_calls(self) -> int: + """Total number of LLM calls recorded.""" + with self._lock: + return sum(mu.calls for mu in self._by_model.values()) + + @property + def total_input_tokens(self) -> int: + """Total input tokens across all models.""" + with self._lock: + return sum(mu.input_tokens for mu in self._by_model.values()) + + @property + def total_output_tokens(self) -> int: + """Total output tokens across all models.""" + with self._lock: + return sum(mu.output_tokens for mu in self._by_model.values()) + + def total_cost(self) -> float | None: + """Compute total USD cost from recorded usage. + + Returns: + Total cost in USD, or None if no pricing available for any model. + """ + with self._lock: + total = 0.0 + has_any_pricing = False + + for model, mu in self._by_model.items(): + pricing = get_pricing(model) + if pricing: + has_any_pricing = True + total += ( + mu.input_tokens * pricing.input_per_mtok + + mu.output_tokens * pricing.output_per_mtok + ) / 1_000_000 + + return total if has_any_pricing else None + + def cost_by_model(self) -> dict[str, dict[str, Any]]: + """Get cost breakdown by model. + + Returns: + Dict of model → {calls, input_tokens, output_tokens, cost} + """ + with self._lock: + result: dict[str, dict[str, Any]] = {} + for model, mu in self._by_model.items(): + pricing = get_pricing(model) + cost = None + if pricing: + cost = ( + mu.input_tokens * pricing.input_per_mtok + + mu.output_tokens * pricing.output_per_mtok + ) / 1_000_000 + + result[model] = { + "calls": mu.calls, + "input_tokens": mu.input_tokens, + "output_tokens": mu.output_tokens, + "cost": cost, + } + return result + + def summary(self) -> dict[str, Any]: + """Full session summary for export/display. + + Returns: + Dict with total and per-model breakdowns. + """ + with self._lock: + by_model = {} + total_cost = 0.0 + has_pricing = False + + for model, mu in self._by_model.items(): + pricing = get_pricing(model) + model_cost = None + if pricing: + has_pricing = True + model_cost = ( + mu.input_tokens * pricing.input_per_mtok + + mu.output_tokens * pricing.output_per_mtok + ) / 1_000_000 + total_cost += model_cost + + by_model[model] = { + "calls": mu.calls, + "input_tokens": mu.input_tokens, + "output_tokens": mu.output_tokens, + "cost": round(model_cost, 4) if model_cost is not None else None, + } + + total_in = sum(mu.input_tokens for mu in self._by_model.values()) + total_out = sum(mu.output_tokens for mu in self._by_model.values()) + + return { + "command": self._command, + "scenario": self._scenario, + "total_calls": sum(mu.calls for mu in self._by_model.values()), + "total_input_tokens": total_in, + "total_output_tokens": total_out, + "total_cost": round(total_cost, 4) if has_pricing else None, + "by_model": by_model, + "elapsed_seconds": round(time.time() - self._started_at, 1), + } + + def summary_line(self) -> str | None: + """One-line cost summary for CLI footer. + + Returns: + Formatted string like "$0.38 · openai/gpt-5 · 8 calls · 87k in / 12k out", + or None if no calls were recorded. + """ + with self._lock: + total_calls = sum(mu.calls for mu in self._by_model.values()) + if total_calls == 0: + return None + + total_in = sum(mu.input_tokens for mu in self._by_model.values()) + total_out = sum(mu.output_tokens for mu in self._by_model.values()) + models = list(self._by_model.keys()) + + cost = self.total_cost() + + parts = [] + + # Cost + if cost is not None: + parts.append(f"${cost:.2f}") + else: + parts.append("cost unknown") + + # Model(s) + if len(models) == 1: + parts.append(models[0]) + elif len(models) > 1: + parts.append(f"{len(models)} models") + + # Call count + parts.append(f"{total_calls} call{'s' if total_calls != 1 else ''}") + + # Token counts + parts.append(f"{_format_tokens(total_in)} in / {_format_tokens(total_out)} out") + + return " · ".join(parts) + + @property + def has_records(self) -> bool: + """Whether any calls have been recorded.""" + with self._lock: + return len(self._records) > 0 + + +def _format_tokens(n: int) -> str: + """Format token count for display (e.g., 87k, 1.5M).""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + elif n >= 1_000: + return f"{n / 1_000:.0f}k" + return str(n) diff --git a/extropy/core/llm.py b/extropy/core/llm.py index ecfa0a5..dfbc3db 100644 --- a/extropy/core/llm.py +++ b/extropy/core/llm.py @@ -1,20 +1,19 @@ """LLM clients for Extropy - Facade Layer. -This module provides a unified interface to LLM providers with two-zone routing: -- Pipeline (sync calls): simple_call, reasoning_call, agentic_research - → Uses the pipeline provider (configured for phases 1-2) -- Simulation (async calls): simple_call_async - → Uses the simulation provider (configured for phase 3) +This module provides a unified interface to LLM providers with two-tier routing: +- fast: simple_call → uses models.fast (cheap, fast tasks) +- strong: reasoning_call, agentic_research → uses models.strong (complex tasks) +- simulation: simple_call_async → uses simulation.strong/fast -Configure via `extropy config` CLI or programmatically via extropy.config.configure(). +Model strings use "provider/model" format. The provider is extracted to route +to the correct backend; the model name is passed through. -Each function supports retry with error feedback via the `previous_errors` parameter. -When validation fails, pass the error message back to let the LLM self-correct. +Configure via `extropy config` CLI or programmatically via extropy.config.configure(). """ -from .providers import get_pipeline_provider, get_simulation_provider +from .providers import get_provider from .providers.base import TokenUsage, ValidatorCallback, RetryCallback -from ..config import get_config +from ..config import get_config, parse_model_string __all__ = [ @@ -28,25 +27,14 @@ ] -def _get_pipeline_model_override(tier: str) -> str | None: - """Get pipeline model override from config if configured.""" +def _resolve_provider_and_model( + model_string: str, +) -> tuple: + """Resolve a "provider/model" string to (provider_instance, model_name).""" config = get_config() - pipeline = config.pipeline - if tier == "simple" and pipeline.model_simple: - return pipeline.model_simple - elif tier == "reasoning" and pipeline.model_reasoning: - return pipeline.model_reasoning - elif tier == "research" and pipeline.model_research: - return pipeline.model_research - return None - - -def _get_simulation_model_override() -> str | None: - """Get simulation model override from config if configured.""" - config = get_config() - if config.simulation.model: - return config.simulation.model - return None + provider_name, model_name = parse_model_string(model_string) + provider = get_provider(provider_name, config.providers) + return provider, model_name def simple_call( @@ -59,20 +47,21 @@ def simple_call( ) -> dict: """Simple LLM call with structured output, no reasoning, no web search. - Routed through the PIPELINE provider. + Uses the FAST tier (config.models.fast). Use for fast, cheap tasks: - Context sufficiency checks - Simple classification - Validation """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("simple") + config = get_config() + model_string = model or config.resolve_pipeline_fast() + provider, model_name = _resolve_provider_and_model(model_string) return provider.simple_call( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, log=log, max_tokens=max_tokens, ) @@ -87,18 +76,21 @@ async def simple_call_async( ) -> tuple[dict, TokenUsage]: """Async version of simple_call for concurrent API requests. - Routed through the SIMULATION provider. - Used for batch agent reasoning during simulation. + Model is passed explicitly from simulation caller (provider/model format). Returns (structured_data, token_usage) tuple. """ - provider = get_simulation_provider() - effective_model = model or _get_simulation_model_override() + if model: + provider, model_name = _resolve_provider_and_model(model) + else: + config = get_config() + model_string = config.resolve_sim_strong() + provider, model_name = _resolve_provider_and_model(model_string) return await provider.simple_call_async( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, max_tokens=max_tokens, ) @@ -117,20 +109,21 @@ def reasoning_call( ) -> dict: """LLM call with reasoning and structured output, but NO web search. - Routed through the PIPELINE provider. + Uses the STRONG tier (config.models.strong). Use for tasks that require reasoning but not external data: - Attribute selection/categorization - Schema design - Logical analysis """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("reasoning") + config = get_config() + model_string = model or config.resolve_pipeline_strong() + provider, model_name = _resolve_provider_and_model(model_string) return provider.reasoning_call( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, reasoning_effort=reasoning_effort, log=log, previous_errors=previous_errors, @@ -154,21 +147,17 @@ def agentic_research( ) -> tuple[dict, list[str]]: """Perform agentic research with web search and structured output. - Routed through the PIPELINE provider. - - The model will: - 1. Decide what to search for - 2. Search the web (possibly multiple times) - 3. Reason about the results - 4. Return structured data matching the schema + Uses the STRONG tier (config.models.strong). + Web search is a provider capability, not a tier distinction. """ - provider = get_pipeline_provider() - effective_model = model or _get_pipeline_model_override("research") + config = get_config() + model_string = model or config.resolve_pipeline_strong() + provider, model_name = _resolve_provider_and_model(model_string) return provider.agentic_research( prompt=prompt, response_schema=response_schema, schema_name=schema_name, - model=effective_model, + model=model_name, reasoning_effort=reasoning_effort, log=log, previous_errors=previous_errors, diff --git a/extropy/core/models/scenario.py b/extropy/core/models/scenario.py index 0288667..ab592df 100644 --- a/extropy/core/models/scenario.py +++ b/extropy/core/models/scenario.py @@ -268,8 +268,11 @@ class ScenarioMeta(BaseModel): name: str = Field(description="Short identifier for the scenario") description: str = Field(description="Full scenario description") population_spec: str = Field(description="Path to population YAML") - agents_file: str = Field(description="Path to sampled agents JSON") - network_file: str = Field(description="Path to network JSON") + study_db: str = Field(description="Path to canonical study DB") + population_id: str = Field( + default="default", description="Population ID in study DB" + ) + network_id: str = Field(default="default", description="Network ID in study DB") created_at: datetime = Field(default_factory=datetime.now) @@ -305,7 +308,32 @@ def from_yaml(cls, path: Path | str) -> "ScenarioSpec": with open(path) as f: data = yaml.safe_load(f) - return cls.model_validate(data) + if not isinstance(data, dict): + raise ValueError("Scenario YAML must parse to an object") + + meta = data.get("meta", {}) + if isinstance(meta, dict) and ("agents_file" in meta or "network_file" in meta): + raise ValueError( + "Legacy scenario schema detected (meta.agents_file/meta.network_file). " + "Migrate with: extropy migrate scenario --input " + f"{path} --study-db study.db --population-id default --network-id default" + ) + + try: + return cls.model_validate(data) + except Exception as e: + if isinstance(meta, dict) and ( + "study_db" not in meta + or "population_id" not in meta + or "network_id" not in meta + ): + raise ValueError( + "Scenario metadata must include meta.study_db, meta.population_id, " + "and meta.network_id. If this is an older scenario, run: " + "extropy migrate scenario --input " + f"{path} --study-db study.db --population-id default --network-id default" + ) from e + raise def summary(self) -> str: """Get a text summary of the scenario spec.""" diff --git a/extropy/core/models/simulation.py b/extropy/core/models/simulation.py index 5532686..88d5617 100644 --- a/extropy/core/models/simulation.py +++ b/extropy/core/models/simulation.py @@ -338,17 +338,13 @@ class SimulationRunConfig(BaseModel): scenario_path: str = Field(description="Path to scenario YAML") output_dir: str = Field(description="Directory for results output") - model: str = Field( + strong: str = Field( default="", - description="LLM model for agent reasoning (empty = use config default)", + description="Strong model for Pass 1 role-play reasoning (provider/model format, empty = config default)", ) - pivotal_model: str = Field( + fast: str = Field( default="", - description="Model for pivotal reasoning (default: same as model)", - ) - routine_model: str = Field( - default="", - description="Cheap model for routine reasoning + classification (default: provider cheap tier)", + description="Fast model for Pass 2 classification (provider/model format, empty = config default)", ) reasoning_effort: str = Field(default="low", description="Reasoning effort level") multi_touch_threshold: int = Field( @@ -362,6 +358,19 @@ class SimulationRunConfig(BaseModel): default=50, description="Agents per reasoning chunk for checkpointing" ) + # Backward compat aliases + @property + def model(self) -> str: + return self.strong + + @property + def pivotal_model(self) -> str: + return self.strong + + @property + def routine_model(self) -> str: + return self.fast + # ============================================================================= # Timestep Summary diff --git a/extropy/core/pricing.py b/extropy/core/pricing.py index 616a25d..d2de21a 100644 --- a/extropy/core/pricing.py +++ b/extropy/core/pricing.py @@ -40,21 +40,49 @@ class ModelPricing: ), "claude-haiku-4.5": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), "claude-haiku-4": ModelPricing(input_per_mtok=0.80, output_per_mtok=4.00), + # DeepSeek (direct API) + "deepseek-chat": ModelPricing(input_per_mtok=0.14, output_per_mtok=0.28), + "deepseek-reasoner": ModelPricing(input_per_mtok=0.55, output_per_mtok=2.19), } -# Provider default models (matches provider classes, no API key needed) +# Provider default models — 2-tier (fast/strong) PROVIDER_DEFAULTS: dict[str, dict[str, str]] = { "openai": { - "simple": "gpt-5-mini", - "reasoning": "gpt-5", + "fast": "gpt-5-mini", + "strong": "gpt-5", }, + "anthropic": { + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", + }, + "azure": { + "fast": "gpt-5-mini", + "strong": "gpt-5", + }, + "openrouter": { + "fast": "openai/gpt-5-mini", + "strong": "openai/gpt-5", + }, + "deepseek": { + "fast": "deepseek-chat", + "strong": "deepseek-reasoner", + }, + "together": { + "fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + "groq": { + "fast": "llama-3.3-70b-versatile", + "strong": "llama-3.3-70b-versatile", + }, + # Legacy aliases "claude": { - "simple": "claude-haiku-4-5-20251001", - "reasoning": "claude-sonnet-4-5-20250929", + "fast": "claude-haiku-4-5-20251001", + "strong": "claude-sonnet-4-5-20250929", }, "azure_openai": { - "simple": "gpt-5-mini", - "reasoning": "gpt-5", + "fast": "gpt-5-mini", + "strong": "gpt-5", }, } @@ -64,15 +92,19 @@ def get_pricing(model: str) -> ModelPricing | None: return MODEL_PRICING.get(model) -def resolve_default_model(provider: str, tier: str = "reasoning") -> str: +def resolve_default_model(provider: str, tier: str = "strong") -> str: """Resolve default model name for a provider without instantiating it. Args: - provider: Provider name ('openai' or 'claude') - tier: 'simple' or 'reasoning' + provider: Provider name ('openai', 'anthropic', etc.) + tier: 'fast' or 'strong' (also accepts legacy 'simple'/'reasoning') Returns: Model name string """ + # Map legacy tier names + tier_map = {"simple": "fast", "reasoning": "strong"} + tier = tier_map.get(tier, tier) + defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["openai"]) - return defaults.get(tier, defaults["reasoning"]) + return defaults.get(tier, defaults["strong"]) diff --git a/extropy/core/providers/__init__.py b/extropy/core/providers/__init__.py index 195f949..1fb5c39 100644 --- a/extropy/core/providers/__init__.py +++ b/extropy/core/providers/__init__.py @@ -1,97 +1,237 @@ -"""LLM Provider factory. +"""LLM Provider registry and factory. -Provides two-zone provider routing: -- Pipeline provider: used for phases 1-2 (spec, extend, persona, scenario) -- Simulation provider: used for phase 3 (agent reasoning) +Provides: +- BUILTIN_PROVIDERS: Registry of known provider names → factory info +- get_provider(): Create a provider instance from a provider name +- get_pipeline_provider() / get_simulation_provider(): Zone-based provider access The simulation provider is cached so its async client can be reused across batch calls and closed cleanly before the event loop shuts down. """ -from .base import LLMProvider -from ...config import get_config, get_api_key, get_azure_config - +import os -# Cached simulation provider — reused across batch calls so the async -# client isn't re-created per request, and can be closed cleanly. -_simulation_provider: LLMProvider | None = None +from .base import LLMProvider +from ...config import ( + get_config, + get_api_key_for_provider, + parse_model_string, + CustomProviderConfig, +) + + +# ============================================================================= +# Provider Registry +# ============================================================================= + +# Each entry: (module, class_name, default_kwargs) +# Lazy-imported to avoid loading all SDKs at startup. +_BUILTIN_REGISTRY: dict[str, dict] = { + "openai": { + "module": ".openai", + "class": "OpenAIProvider", + }, + "anthropic": { + "module": ".anthropic", + "class": "AnthropicProvider", + }, + "openrouter": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://openrouter.ai/api/v1", + "supports_search": True, + "provider_label": "openrouter", + "default_fast": "openai/gpt-5-mini", + "default_strong": "openai/gpt-5", + }, + }, + "azure": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "", # resolved from env + "supports_search": False, + "provider_label": "azure", + "default_fast": "gpt-5-mini", + "default_strong": "gpt-5", + }, + }, + "deepseek": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.deepseek.com/v1", + "supports_search": False, + "provider_label": "deepseek", + "default_fast": "deepseek-chat", + "default_strong": "deepseek-reasoner", + }, + }, + "together": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.together.xyz/v1", + "supports_search": False, + "provider_label": "together", + "default_fast": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "default_strong": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + }, + }, + "groq": { + "module": ".openai_compat", + "class": "OpenAICompatProvider", + "kwargs": { + "base_url": "https://api.groq.com/openai/v1", + "supports_search": False, + "provider_label": "groq", + "default_fast": "llama-3.3-70b-versatile", + "default_strong": "llama-3.3-70b-versatile", + }, + }, +} + + +def get_provider( + provider_name: str, + custom_providers: dict[str, CustomProviderConfig] | None = None, +) -> LLMProvider: + """Create a provider instance by name. + + Checks custom providers first, then built-in registry. + + Args: + provider_name: Provider name (e.g., "openai", "anthropic", "openrouter") + custom_providers: Optional custom provider configs from ExtropyConfig + + Returns: + LLMProvider instance + + Raises: + ValueError: If provider is unknown + """ + api_key = get_api_key_for_provider(provider_name, custom_providers) + # Check custom providers first + if custom_providers and provider_name in custom_providers: + from .openai_compat import OpenAICompatProvider -def _create_provider(provider_name: str) -> LLMProvider: - """Create a provider instance by name.""" - api_key = get_api_key(provider_name) - - if provider_name == "openai": - from .openai import OpenAIProvider + custom = custom_providers[provider_name] + return OpenAICompatProvider( + api_key=api_key, + base_url=custom.base_url, + supports_search=False, + provider_label=provider_name, + ) - return OpenAIProvider(api_key=api_key) - elif provider_name == "claude": - from .claude import ClaudeProvider + # Check built-in registry + if provider_name not in _BUILTIN_REGISTRY: + available = sorted( + set(list(_BUILTIN_REGISTRY.keys()) + list((custom_providers or {}).keys())) + ) + raise ValueError( + f"Unknown LLM provider: {provider_name!r}. " + f"Available: {', '.join(available)}" + ) - return ClaudeProvider(api_key=api_key) - elif provider_name == "azure_openai": - from .openai import OpenAIProvider + entry = _BUILTIN_REGISTRY[provider_name] - azure_cfg = get_azure_config(provider_name) - if not azure_cfg.get("azure_endpoint"): + # Special case: Azure needs endpoint from env + if provider_name == "azure": + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") + if not endpoint: raise ValueError( "AZURE_OPENAI_ENDPOINT not found. Set it as an environment variable.\n" " export AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/" ) - # Resolve api_format: config value > auto-default (chat_completions for Azure) + entry = dict(entry) + entry["kwargs"] = dict(entry.get("kwargs", {})) + entry["kwargs"]["base_url"] = endpoint + + # Lazy import + import importlib + + module = importlib.import_module(entry["module"], package=__package__) + cls = getattr(module, entry["class"]) + + kwargs = dict(entry.get("kwargs", {})) + kwargs["api_key"] = api_key + + return cls(**kwargs) + + +# ============================================================================= +# Zone-based provider access (backward compat) +# ============================================================================= + +# Cached providers — reused across calls for connection reuse +_cached_providers: dict[str, LLMProvider] = {} + + +def _get_or_create_provider(provider_name: str, cache_key: str = "") -> LLMProvider: + """Get or create a cached provider instance.""" + key = cache_key or provider_name + if key not in _cached_providers: config = get_config() - api_format = config.simulation.api_format or "chat_completions" - return OpenAIProvider( - api_key=api_key, - azure_endpoint=azure_cfg["azure_endpoint"], - api_version=azure_cfg.get("api_version", "2025-03-01-preview"), - azure_deployment=azure_cfg.get("azure_deployment", ""), - api_format=api_format, - ) - else: - raise ValueError( - f"Unknown LLM provider: {provider_name}. " - f"Valid options: 'openai', 'claude', 'azure_openai'" - ) + _cached_providers[key] = get_provider(provider_name, config.providers) + return _cached_providers[key] def get_pipeline_provider() -> LLMProvider: - """Get the provider for pipeline phases (spec, extend, persona, scenario).""" + """Get the provider for pipeline phases (spec, extend, persona, scenario). + + Uses the provider from models.fast (pipeline calls use both fast and strong, + but the provider is determined by the fast model string). + """ config = get_config() - return _create_provider(config.pipeline.provider) + provider, _ = parse_model_string(config.models.fast) + return _get_or_create_provider(provider, f"pipeline:{provider}") def get_simulation_provider() -> LLMProvider: """Get the cached provider for simulation phase (agent reasoning). - Caches the provider so the underlying async HTTP client is reused - across all calls in a batch, avoiding orphaned connections. + Uses the provider from the resolved simulation strong model. """ - global _simulation_provider config = get_config() - provider_name = config.simulation.provider - - if _simulation_provider is None: - _simulation_provider = _create_provider(provider_name) - - return _simulation_provider + strong_model = config.resolve_sim_strong() + provider, _ = parse_model_string(strong_model) + return _get_or_create_provider(provider, f"simulation:{provider}") async def close_simulation_provider() -> None: - """Close the cached simulation provider's async client. + """Close cached providers' async clients. Call this before the event loop shuts down to cleanly release HTTP connections and avoid 'Event loop is closed' errors. """ - global _simulation_provider - if _simulation_provider is not None: - await _simulation_provider.close_async() - _simulation_provider = None + for key, provider in list(_cached_providers.items()): + await provider.close_async() + _cached_providers.clear() + + +def reset_provider_cache() -> None: + """Reset the provider cache (for testing).""" + _cached_providers.clear() + + +# Legacy factory (kept for backward compat in tests) +def _create_provider(provider_name: str) -> LLMProvider: + """DEPRECATED: Use get_provider() instead.""" + # Map old names + name_map = {"claude": "anthropic", "azure_openai": "azure"} + canonical = name_map.get(provider_name, provider_name) + config = get_config() + return get_provider(canonical, config.providers) __all__ = [ "LLMProvider", + "get_provider", "get_pipeline_provider", "get_simulation_provider", "close_simulation_provider", + "reset_provider_cache", + "parse_model_string", ] diff --git a/extropy/core/providers/anthropic.py b/extropy/core/providers/anthropic.py new file mode 100644 index 0000000..d68ee01 --- /dev/null +++ b/extropy/core/providers/anthropic.py @@ -0,0 +1,388 @@ +"""Anthropic (Claude) LLM Provider implementation. + +Uses the tool use pattern for reliable structured output: +instead of asking Claude to output JSON in text, we define a tool +with the response schema. Claude "calls" the tool, returning structured +data guaranteed to match the schema. +""" + +import logging +import random +import time + +import anthropic + +from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback +from .logging import log_request_response, extract_error_summary + +_TRANSIENT_ANTHROPIC_ERRORS = ( + anthropic.APIConnectionError, + anthropic.InternalServerError, + anthropic.RateLimitError, +) +_MAX_API_RETRIES = 3 + + +logger = logging.getLogger(__name__) + + +def _clean_schema_for_tool(schema: dict) -> dict: + """Clean a JSON schema for use as a tool input_schema. + + Removes fields that aren't valid in tool input schemas + (like 'additionalProperties' in nested objects that Claude + doesn't support in tool definitions). + """ + cleaned = {} + for key, value in schema.items(): + if key == "additionalProperties": + continue + if isinstance(value, dict): + cleaned[key] = _clean_schema_for_tool(value) + elif isinstance(value, list): + cleaned[key] = [ + _clean_schema_for_tool(item) if isinstance(item, dict) else item + for item in value + ] + else: + cleaned[key] = value + return cleaned + + +def _make_structured_tool(schema_name: str, response_schema: dict) -> dict: + """Create a tool definition that forces structured output.""" + return { + "name": schema_name, + "description": ( + "Return your response as structured data. " + "You MUST call this tool with your complete response." + ), + "input_schema": _clean_schema_for_tool(response_schema), + } + + +def _extract_tool_input(response) -> dict | None: + """Extract tool_use input from a Claude response.""" + for block in response.content: + if block.type == "tool_use": + return block.input + return None + + +def _extract_usage(response) -> TokenUsage: + """Extract token usage from an Anthropic API response.""" + if not hasattr(response, "usage") or response.usage is None: + return TokenUsage() + return TokenUsage( + input_tokens=getattr(response.usage, "input_tokens", 0) or 0, + output_tokens=getattr(response.usage, "output_tokens", 0) or 0, + ) + + +class AnthropicProvider(LLMProvider): + """Anthropic (Claude) LLM provider. + + Uses the tool use pattern for structured output — Claude "calls" a tool + with the response data, guaranteeing valid JSON matching the schema. + """ + + provider_name = "anthropic" + + def __init__(self, api_key: str = "") -> None: + if not api_key: + raise ValueError( + "Anthropic API key not found. Set it via:\n" + " export ANTHROPIC_API_KEY=sk-ant-...\n" + "Get your key from: https://console.anthropic.com/settings/keys" + ) + super().__init__(api_key) + + def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry a sync API call on transient errors with exponential backoff.""" + for attempt in range(max_retries + 1): + try: + return fn() + except _TRANSIENT_ANTHROPIC_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + logger.warning( + f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " + f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" + ) + time.sleep(wait) + + async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry an async API call on transient errors with exponential backoff.""" + import asyncio + + for attempt in range(max_retries + 1): + try: + return await fn() + except _TRANSIENT_ANTHROPIC_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + logger.warning( + f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " + f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" + ) + await asyncio.sleep(wait) + + @property + def default_fast_model(self) -> str: + return "claude-haiku-4-5-20251001" + + @property + def default_strong_model(self) -> str: + return "claude-sonnet-4-5-20250929" + + def _get_client(self) -> anthropic.Anthropic: + return anthropic.Anthropic(api_key=self._api_key) + + def _get_async_client(self) -> anthropic.AsyncAnthropic: + if self._cached_async_client is None: + self._cached_async_client = anthropic.AsyncAnthropic(api_key=self._api_key) + return self._cached_async_client + + def simple_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ) -> dict: + model = model or self.default_simple_model + client = self._get_client() + tool = _make_structured_tool(schema_name, response_schema) + + # Acquire rate limit capacity before making the call + self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) + + logger.info( + f"[Claude] simple_call starting - model={model}, schema={schema_name}" + ) + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=max_tokens or 4096, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": prompt}], + ) + ) + + structured_data = _extract_tool_input(response) + + # Record token usage + usage = _extract_usage(response) + self._record_usage(model, usage, call_type="simple") + + if log: + log_request_response( + function_name="simple_call", + request={"model": model, "prompt_length": len(prompt)}, + response=response, + provider="claude", + ) + + return structured_data or {} + + async def simple_call_async( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + max_tokens: int | None = None, + ) -> tuple[dict, TokenUsage]: + model = model or self.default_simple_model + client = self._get_async_client() + tool = _make_structured_tool(schema_name, response_schema) + + response = await self._with_retry_async( + lambda: client.messages.create( + model=model, + max_tokens=max_tokens or 4096, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": prompt}], + ) + ) + + # Extract and record token usage + usage = _extract_usage(response) + self._record_usage(model, usage, call_type="async") + + return _extract_tool_input(response) or {}, usage + + def reasoning_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> dict: + """Claude reasoning call with tool-based structured output.""" + model = model or self.default_reasoning_model + client = self._get_client() + tool = _make_structured_tool(schema_name, response_schema) + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + def _call(ep: str) -> dict: + # Acquire rate limit capacity before each API call + self._acquire_rate_limit(ep, model, max_output=16384) + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=16384, + tools=[tool], + tool_choice={"type": "tool", "name": schema_name}, + messages=[{"role": "user", "content": ep}], + ) + ) + structured_data = _extract_tool_input(response) + + # Record token usage + ru = _extract_usage(response) + self._record_usage(model, ru, call_type="reasoning") + + if log: + log_request_response( + function_name="reasoning_call", + request={"model": model, "prompt_length": len(ep)}, + response=response, + provider="claude", + ) + return structured_data or {} + + return self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + def agentic_research( + self, + prompt: str, + response_schema: dict, + schema_name: str = "research_data", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> tuple[dict, list[str]]: + """Claude agentic research with web search + tool-based structured output. + + Uses web_search tool for research and a structured output tool for the response. + Claude first searches, then calls the output tool with results. + """ + model = model or self.default_research_model + client = self._get_client() + output_tool = _make_structured_tool(schema_name, response_schema) + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + all_sources: list[str] = [] + + def _call(ep: str) -> dict: + research_prompt = ( + f"{ep}\n\n" + f"After researching, call the '{schema_name}' tool with your structured findings." + ) + + # Acquire rate limit capacity before each API call + self._acquire_rate_limit(research_prompt, model, max_output=16384) + + logger.info(f"[Claude] agentic_research - model={model}") + + response = self._with_retry( + lambda: client.messages.create( + model=model, + max_tokens=16384, + tools=[ + { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5, + }, + output_tool, + ], + messages=[{"role": "user", "content": research_prompt}], + ) + ) + + structured_data = None + sources: list[str] = [] + + for block in response.content: + if block.type == "web_search_tool_result": + if hasattr(block, "content") and block.content: + for res in block.content: + if hasattr(res, "url"): + sources.append(res.url) + + if block.type == "tool_use" and block.name == schema_name: + structured_data = block.input + + if block.type == "text": + if hasattr(block, "citations") and block.citations: + for citation in block.citations: + if hasattr(citation, "url"): + sources.append(citation.url) + + all_sources.extend(sources) + logger.info(f"[Claude] Web search completed, found {len(sources)} sources") + + # Record token usage + ru = _extract_usage(response) + self._record_usage(model, ru, call_type="agentic_research") + + if log: + log_request_response( + function_name="agentic_research", + request={"model": model, "prompt_length": len(research_prompt)}, + response=response, + provider="claude", + sources=list(set(sources)), + ) + + return structured_data or {} + + result = self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + return result, list(set(all_sources)) + + +# Backward compat alias +ClaudeProvider = AnthropicProvider diff --git a/extropy/core/providers/base.py b/extropy/core/providers/base.py index f33fbcf..d796ca7 100644 --- a/extropy/core/providers/base.py +++ b/extropy/core/providers/base.py @@ -1,15 +1,18 @@ """Abstract base class for LLM providers.""" +import logging from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Callable, TYPE_CHECKING +from pydantic import BaseModel + if TYPE_CHECKING: from ..rate_limiter import RateLimiter +logger = logging.getLogger(__name__) + -@dataclass -class TokenUsage: +class TokenUsage(BaseModel): """Token usage from a single LLM API call.""" input_tokens: int = 0 @@ -34,6 +37,8 @@ class LLMProvider(ABC): All providers must implement these methods with the same signatures to ensure drop-in compatibility. + Automatically records token usage into CostTracker after each call. + Args: api_key: API key or access token for the provider. """ @@ -89,6 +94,22 @@ def _acquire_rate_limit( estimated_output_tokens=max_output, ) + def _record_usage(self, model: str, usage: TokenUsage, call_type: str = "") -> None: + """Record token usage into the session CostTracker. + + Called after each API call. Safe to call even if no CostTracker + is active (e.g., in tests or library use without CLI). + """ + if usage.input_tokens == 0 and usage.output_tokens == 0: + return + try: + from ..cost.tracker import CostTracker + + CostTracker.get().record(model=model, usage=usage, call_type=call_type) + except Exception: + # Never let cost tracking break actual LLM calls + pass + async def close_async(self) -> None: """Close the cached async client to release connections cleanly. @@ -101,21 +122,28 @@ async def close_async(self) -> None: @property @abstractmethod - def default_simple_model(self) -> str: - """Default model for simple_call (fast, cheap).""" + def default_fast_model(self) -> str: + """Default model for fast/cheap calls (simple_call, Pass 2).""" ... @property @abstractmethod - def default_reasoning_model(self) -> str: - """Default model for reasoning_call (balanced).""" + def default_strong_model(self) -> str: + """Default model for strong/reasoning calls (reasoning_call, agentic_research, Pass 1).""" ... + # Backward-compat aliases (read-only) + @property + def default_simple_model(self) -> str: + return self.default_fast_model + + @property + def default_reasoning_model(self) -> str: + return self.default_strong_model + @property - @abstractmethod def default_research_model(self) -> str: - """Default model for agentic_research (with web search).""" - ... + return self.default_strong_model @abstractmethod def simple_call( diff --git a/extropy/core/providers/claude.py b/extropy/core/providers/claude.py index c06691a..15aa201 100644 --- a/extropy/core/providers/claude.py +++ b/extropy/core/providers/claude.py @@ -1,370 +1,8 @@ -"""Claude (Anthropic) LLM Provider implementation. +"""DEPRECATED: Use extropy.core.providers.anthropic instead. -Uses the tool use pattern for reliable structured output: -instead of asking Claude to output JSON in text, we define a tool -with the response schema. Claude "calls" the tool, returning structured -data guaranteed to match the schema. +This module re-exports AnthropicProvider as ClaudeProvider for backward compatibility. """ -import logging -import random -import time +from .anthropic import AnthropicProvider, ClaudeProvider # noqa: F401 -import anthropic - -from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback -from .logging import log_request_response, extract_error_summary - -_TRANSIENT_ANTHROPIC_ERRORS = ( - anthropic.APIConnectionError, - anthropic.InternalServerError, - anthropic.RateLimitError, -) -_MAX_API_RETRIES = 3 - - -logger = logging.getLogger(__name__) - - -def _clean_schema_for_tool(schema: dict) -> dict: - """Clean a JSON schema for use as a tool input_schema. - - Removes fields that aren't valid in tool input schemas - (like 'additionalProperties' in nested objects that Claude - doesn't support in tool definitions). - """ - cleaned = {} - for key, value in schema.items(): - if key == "additionalProperties": - continue - if isinstance(value, dict): - cleaned[key] = _clean_schema_for_tool(value) - elif isinstance(value, list): - cleaned[key] = [ - _clean_schema_for_tool(item) if isinstance(item, dict) else item - for item in value - ] - else: - cleaned[key] = value - return cleaned - - -def _make_structured_tool(schema_name: str, response_schema: dict) -> dict: - """Create a tool definition that forces structured output.""" - return { - "name": schema_name, - "description": ( - "Return your response as structured data. " - "You MUST call this tool with your complete response." - ), - "input_schema": _clean_schema_for_tool(response_schema), - } - - -def _extract_tool_input(response) -> dict | None: - """Extract tool_use input from a Claude response.""" - for block in response.content: - if block.type == "tool_use": - return block.input - return None - - -class ClaudeProvider(LLMProvider): - """Claude (Anthropic) LLM provider. - - Uses the tool use pattern for structured output — Claude "calls" a tool - with the response data, guaranteeing valid JSON matching the schema. - - """ - - provider_name = "anthropic" - - def __init__(self, api_key: str = "") -> None: - if not api_key: - raise ValueError( - "Anthropic API key not found. Set it via:\n" - " export ANTHROPIC_API_KEY=sk-ant-...\n" - "Get your key from: https://console.anthropic.com/settings/keys" - ) - super().__init__(api_key) - - def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): - """Retry a sync API call on transient errors with exponential backoff.""" - for attempt in range(max_retries + 1): - try: - return fn() - except _TRANSIENT_ANTHROPIC_ERRORS as e: - if attempt == max_retries: - raise - wait = (2**attempt) + random.random() - logger.warning( - f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " - f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" - ) - time.sleep(wait) - - async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): - """Retry an async API call on transient errors with exponential backoff.""" - import asyncio - - for attempt in range(max_retries + 1): - try: - return await fn() - except _TRANSIENT_ANTHROPIC_ERRORS as e: - if attempt == max_retries: - raise - wait = (2**attempt) + random.random() - logger.warning( - f"[Claude] Transient error (attempt {attempt + 1}/{max_retries + 1}): " - f"{type(e).__name__}: {e}. Retrying in {wait:.1f}s" - ) - await asyncio.sleep(wait) - - @property - def default_simple_model(self) -> str: - return "claude-haiku-4-5-20251001" - - @property - def default_reasoning_model(self) -> str: - return "claude-sonnet-4-5-20250929" - - @property - def default_research_model(self) -> str: - return "claude-sonnet-4-5-20250929" - - def _get_client(self) -> anthropic.Anthropic: - return anthropic.Anthropic(api_key=self._api_key) - - def _get_async_client(self) -> anthropic.AsyncAnthropic: - if self._cached_async_client is None: - self._cached_async_client = anthropic.AsyncAnthropic(api_key=self._api_key) - return self._cached_async_client - - def simple_call( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - log: bool = True, - max_tokens: int | None = None, - ) -> dict: - model = model or self.default_simple_model - client = self._get_client() - tool = _make_structured_tool(schema_name, response_schema) - - # Acquire rate limit capacity before making the call - self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) - - logger.info( - f"[Claude] simple_call starting - model={model}, schema={schema_name}" - ) - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=max_tokens or 4096, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": prompt}], - ) - ) - - structured_data = _extract_tool_input(response) - - if log: - log_request_response( - function_name="simple_call", - request={"model": model, "prompt_length": len(prompt)}, - response=response, - provider="claude", - ) - - return structured_data or {} - - async def simple_call_async( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - max_tokens: int | None = None, - ) -> tuple[dict, TokenUsage]: - model = model or self.default_simple_model - client = self._get_async_client() - tool = _make_structured_tool(schema_name, response_schema) - - response = await self._with_retry_async( - lambda: client.messages.create( - model=model, - max_tokens=max_tokens or 4096, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": prompt}], - ) - ) - - # Extract token usage - usage = TokenUsage() - if hasattr(response, "usage") and response.usage is not None: - usage = TokenUsage( - input_tokens=getattr(response.usage, "input_tokens", 0) or 0, - output_tokens=getattr(response.usage, "output_tokens", 0) or 0, - ) - - return _extract_tool_input(response) or {}, usage - - def reasoning_call( - self, - prompt: str, - response_schema: dict, - schema_name: str = "response", - model: str | None = None, - reasoning_effort: str = "low", - log: bool = True, - previous_errors: str | None = None, - validator: ValidatorCallback | None = None, - max_retries: int = 2, - on_retry: RetryCallback | None = None, - ) -> dict: - """Claude reasoning call with tool-based structured output.""" - model = model or self.default_reasoning_model - client = self._get_client() - tool = _make_structured_tool(schema_name, response_schema) - - effective_prompt = prompt - if previous_errors: - effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" - - def _call(ep: str) -> dict: - # Acquire rate limit capacity before each API call - self._acquire_rate_limit(ep, model, max_output=16384) - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=16384, - tools=[tool], - tool_choice={"type": "tool", "name": schema_name}, - messages=[{"role": "user", "content": ep}], - ) - ) - structured_data = _extract_tool_input(response) - if log: - log_request_response( - function_name="reasoning_call", - request={"model": model, "prompt_length": len(ep)}, - response=response, - provider="claude", - ) - return structured_data or {} - - return self._retry_with_validation( - call_fn=_call, - prompt=prompt, - validator=validator, - max_retries=max_retries, - on_retry=on_retry, - extract_error_summary_fn=extract_error_summary, - initial_prompt=effective_prompt if previous_errors else None, - ) - - def agentic_research( - self, - prompt: str, - response_schema: dict, - schema_name: str = "research_data", - model: str | None = None, - reasoning_effort: str = "low", - log: bool = True, - previous_errors: str | None = None, - validator: ValidatorCallback | None = None, - max_retries: int = 2, - on_retry: RetryCallback | None = None, - ) -> tuple[dict, list[str]]: - """Claude agentic research with web search + tool-based structured output. - - Uses web_search tool for research and a structured output tool for the response. - Claude first searches, then calls the output tool with results. - """ - model = model or self.default_research_model - client = self._get_client() - output_tool = _make_structured_tool(schema_name, response_schema) - - effective_prompt = prompt - if previous_errors: - effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" - - all_sources: list[str] = [] - - def _call(ep: str) -> dict: - research_prompt = ( - f"{ep}\n\n" - f"After researching, call the '{schema_name}' tool with your structured findings." - ) - - # Acquire rate limit capacity before each API call - self._acquire_rate_limit(research_prompt, model, max_output=16384) - - logger.info(f"[Claude] agentic_research - model={model}") - - response = self._with_retry( - lambda: client.messages.create( - model=model, - max_tokens=16384, - tools=[ - { - "type": "web_search_20250305", - "name": "web_search", - "max_uses": 5, - }, - output_tool, - ], - messages=[{"role": "user", "content": research_prompt}], - ) - ) - - structured_data = None - sources: list[str] = [] - - for block in response.content: - if block.type == "web_search_tool_result": - if hasattr(block, "content") and block.content: - for res in block.content: - if hasattr(res, "url"): - sources.append(res.url) - - if block.type == "tool_use" and block.name == schema_name: - structured_data = block.input - - if block.type == "text": - if hasattr(block, "citations") and block.citations: - for citation in block.citations: - if hasattr(citation, "url"): - sources.append(citation.url) - - all_sources.extend(sources) - logger.info(f"[Claude] Web search completed, found {len(sources)} sources") - - if log: - log_request_response( - function_name="agentic_research", - request={"model": model, "prompt_length": len(research_prompt)}, - response=response, - provider="claude", - sources=list(set(sources)), - ) - - return structured_data or {} - - result = self._retry_with_validation( - call_fn=_call, - prompt=prompt, - validator=validator, - max_retries=max_retries, - on_retry=on_retry, - extract_error_summary_fn=extract_error_summary, - initial_prompt=effective_prompt if previous_errors else None, - ) - - return result, list(set(all_sources)) +__all__ = ["ClaudeProvider", "AnthropicProvider"] diff --git a/extropy/core/providers/openai.py b/extropy/core/providers/openai.py index 871ad18..45d5599 100644 --- a/extropy/core/providers/openai.py +++ b/extropy/core/providers/openai.py @@ -110,6 +110,20 @@ def _extract_chat_completions_text(response) -> str | None: return content return None + def _extract_usage(self, response, use_chat: bool = False) -> TokenUsage: + """Extract token usage from an OpenAI API response.""" + if not hasattr(response, "usage") or response.usage is None: + return TokenUsage() + if use_chat: + return TokenUsage( + input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, + output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, + ) + return TokenUsage( + input_tokens=getattr(response.usage, "input_tokens", 0) or 0, + output_tokens=getattr(response.usage, "output_tokens", 0) or 0, + ) + def _build_responses_params( self, model: str, @@ -193,15 +207,11 @@ async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): await asyncio.sleep(wait) @property - def default_simple_model(self) -> str: + def default_fast_model(self) -> str: return "gpt-5-mini" @property - def default_reasoning_model(self) -> str: - return "gpt-5" - - @property - def default_research_model(self) -> str: + def default_strong_model(self) -> str: return "gpt-5" def _get_client(self) -> OpenAI: @@ -278,6 +288,10 @@ def simple_call( raw_text = self._extract_output_text(response) structured_data = json.loads(raw_text) if raw_text else None + # Extract and record token usage + usage = self._extract_usage(response, use_chat=use_chat) + self._record_usage(model, usage, call_type="simple") + if log: log_request_response( function_name="simple_call", @@ -326,19 +340,9 @@ async def simple_call_async( raw_text = self._extract_output_text(response) structured_data = json.loads(raw_text) if raw_text else None - # Extract token usage - usage = TokenUsage() - if hasattr(response, "usage") and response.usage is not None: - if use_chat: - usage = TokenUsage( - input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, - output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, - ) - else: - usage = TokenUsage( - input_tokens=getattr(response.usage, "input_tokens", 0) or 0, - output_tokens=getattr(response.usage, "output_tokens", 0) or 0, - ) + # Extract and record token usage + usage = self._extract_usage(response, use_chat=use_chat) + self._record_usage(model, usage, call_type="async") return structured_data or {}, usage @@ -384,6 +388,11 @@ def _call(ep: str) -> dict: ) raw_text = self._extract_output_text(response) structured_data = json.loads(raw_text) if raw_text else None + + # Record token usage + usage = self._extract_usage(response) + self._record_usage(model, usage, call_type="reasoning") + if log: log_request_response( function_name="reasoning_call", @@ -480,6 +489,10 @@ def _call(ep: str) -> dict: all_sources.extend(sources) + # Record token usage + usage = self._extract_usage(response) + self._record_usage(model, usage, call_type="agentic_research") + if log: log_request_response( function_name="agentic_research", diff --git a/extropy/core/providers/openai_compat.py b/extropy/core/providers/openai_compat.py new file mode 100644 index 0000000..4772987 --- /dev/null +++ b/extropy/core/providers/openai_compat.py @@ -0,0 +1,352 @@ +"""OpenAI-compatible LLM Provider for third-party endpoints. + +Supports any provider that implements the OpenAI Chat Completions API: +- OpenRouter, DeepSeek, Together, Groq, Azure OpenAI, etc. + +Uses `openai.OpenAI(base_url=...)` with Chat Completions API for all calls. +Supports `json_schema` response format for structured output. +For agentic_research, appends `:online` to model name if provider supports search, +and parses `url_citation` annotations for sources. +""" + +import json +import logging +import random +import time + +import openai +from openai import OpenAI, AsyncOpenAI + +from .base import LLMProvider, TokenUsage, ValidatorCallback, RetryCallback +from .logging import log_request_response, extract_error_summary + +_TRANSIENT_ERRORS = ( + openai.APIConnectionError, + openai.InternalServerError, + openai.RateLimitError, +) +_MAX_API_RETRIES = 3 + +logger = logging.getLogger(__name__) + + +class OpenAICompatProvider(LLMProvider): + """OpenAI-compatible provider for third-party endpoints. + + Uses the Chat Completions API with json_schema response format. + """ + + def __init__( + self, + api_key: str = "", + *, + base_url: str = "", + supports_search: bool = False, + provider_label: str = "openai_compat", + default_fast: str = "gpt-5-mini", + default_strong: str = "gpt-5", + ) -> None: + if not api_key: + raise ValueError( + f"API key not found for {provider_label}. " + f"Set it as an environment variable." + ) + super().__init__(api_key) + self._base_url = base_url + self._supports_search = supports_search + self.provider_name = provider_label + self._default_fast = default_fast + self._default_strong = default_strong + + @property + def default_fast_model(self) -> str: + return self._default_fast + + @property + def default_strong_model(self) -> str: + return self._default_strong + + def _get_client(self) -> OpenAI: + kwargs: dict = {"api_key": self._api_key} + if self._base_url: + kwargs["base_url"] = self._base_url + return OpenAI(**kwargs) + + def _get_async_client(self) -> AsyncOpenAI: + if self._cached_async_client is None: + kwargs: dict = {"api_key": self._api_key} + if self._base_url: + kwargs["base_url"] = self._base_url + self._cached_async_client = AsyncOpenAI(**kwargs) + return self._cached_async_client + + def _build_params( + self, + model: str, + prompt: str, + schema: dict, + schema_name: str, + max_tokens: int | None, + ) -> dict: + """Build Chat Completions API request parameters.""" + params: dict = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": schema_name, + "strict": True, + "schema": schema, + }, + }, + } + if max_tokens is not None: + params["max_tokens"] = max_tokens + return params + + @staticmethod + def _extract_text(response) -> str | None: + """Extract text from Chat Completions response.""" + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + if content: + return content + return None + + @staticmethod + def _extract_sources(response) -> list[str]: + """Extract citation URLs from response annotations.""" + sources: list[str] = [] + if not response.choices: + return sources + message = response.choices[0].message + if hasattr(message, "annotations") and message.annotations: + for annotation in message.annotations: + if hasattr(annotation, "type") and annotation.type == "url_citation": + if hasattr(annotation, "url"): + sources.append(annotation.url) + return sources + + def _with_retry(self, fn, max_retries: int = _MAX_API_RETRIES): + """Retry on transient errors with exponential backoff.""" + for attempt in range(max_retries + 1): + try: + return fn() + except _TRANSIENT_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + lbl = self.provider_name + att = f"{attempt + 1}/{max_retries + 1}" + logger.warning( + f"[{lbl}] Transient error ({att}): " + f"{type(e).__name__}: {e}. " + f"Retrying in {wait:.1f}s" + ) + time.sleep(wait) + + async def _with_retry_async(self, fn, max_retries: int = _MAX_API_RETRIES): + """Async retry on transient errors.""" + import asyncio + + for attempt in range(max_retries + 1): + try: + return await fn() + except _TRANSIENT_ERRORS as e: + if attempt == max_retries: + raise + wait = (2**attempt) + random.random() + lbl = self.provider_name + att = f"{attempt + 1}/{max_retries + 1}" + logger.warning( + f"[{lbl}] Transient error ({att}): " + f"{type(e).__name__}: {e}. " + f"Retrying in {wait:.1f}s" + ) + await asyncio.sleep(wait) + + def simple_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + log: bool = True, + max_tokens: int | None = None, + ) -> dict: + model = model or self.default_fast_model + client = self._get_client() + + self._acquire_rate_limit(prompt, model, max_output=max_tokens or 4096) + + params = self._build_params( + model, + prompt, + response_schema, + schema_name, + max_tokens, + ) + lbl = self.provider_name + logger.info(f"[{lbl}] simple_call model={model} schema={schema_name}") + + api_start = time.time() + response = self._with_retry(lambda: client.chat.completions.create(**params)) + api_elapsed = time.time() - api_start + logger.info(f"[{self.provider_name}] API response in {api_elapsed:.2f}s") + + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + + if log: + log_request_response( + function_name="simple_call", + request=params, + response=response, + provider=self.provider_name, + ) + + return structured_data or {} + + async def simple_call_async( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + max_tokens: int | None = None, + ) -> tuple[dict, TokenUsage]: + model = model or self.default_fast_model + client = self._get_async_client() + + params = self._build_params( + model, + prompt, + response_schema, + schema_name, + max_tokens, + ) + + response = await self._with_retry_async( + lambda: client.chat.completions.create(**params) + ) + + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + + usage = TokenUsage() + if hasattr(response, "usage") and response.usage is not None: + usage = TokenUsage( + input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, + output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, + ) + + return structured_data or {}, usage + + def reasoning_call( + self, + prompt: str, + response_schema: dict, + schema_name: str = "response", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> dict: + model = model or self.default_strong_model + client = self._get_client() + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + def _call(ep: str) -> dict: + self._acquire_rate_limit(ep, model, max_output=16384) + params = self._build_params(model, ep, response_schema, schema_name, None) + response = self._with_retry( + lambda: client.chat.completions.create(**params) + ) + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + if log: + log_request_response( + function_name="reasoning_call", + request=params, + response=response, + provider=self.provider_name, + ) + return structured_data or {} + + return self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + def agentic_research( + self, + prompt: str, + response_schema: dict, + schema_name: str = "research_data", + model: str | None = None, + reasoning_effort: str = "low", + log: bool = True, + previous_errors: str | None = None, + validator: ValidatorCallback | None = None, + max_retries: int = 2, + on_retry: RetryCallback | None = None, + ) -> tuple[dict, list[str]]: + model = model or self.default_strong_model + client = self._get_client() + + effective_prompt = prompt + if previous_errors: + effective_prompt = f"{previous_errors}\n\n---\n\n{prompt}" + + # For providers that support search, append :online suffix + search_model = f"{model}:online" if self._supports_search else model + + all_sources: list[str] = [] + + def _call(ep: str) -> dict: + self._acquire_rate_limit(ep, model, max_output=16384) + params = self._build_params( + search_model, ep, response_schema, schema_name, None + ) + response = self._with_retry( + lambda: client.chat.completions.create(**params) + ) + raw_text = self._extract_text(response) + structured_data = json.loads(raw_text) if raw_text else None + sources = self._extract_sources(response) + all_sources.extend(sources) + + if log: + log_request_response( + function_name="agentic_research", + request=params, + response=response, + provider=self.provider_name, + sources=list(set(sources)), + ) + + return structured_data or {} + + result = self._retry_with_validation( + call_fn=_call, + prompt=prompt, + validator=validator, + max_retries=max_retries, + on_retry=on_retry, + extract_error_summary_fn=extract_error_summary, + initial_prompt=effective_prompt if previous_errors else None, + ) + + return result, list(set(all_sources)) diff --git a/extropy/core/rate_limiter.py b/extropy/core/rate_limiter.py index 256cdb1..8ec6063 100644 --- a/extropy/core/rate_limiter.py +++ b/extropy/core/rate_limiter.py @@ -486,10 +486,11 @@ def stats(self) -> dict: class DualRateLimiter: - """Manages separate rate limiters for pivotal (Pass 1) and routine (Pass 2) models. + """Manages separate rate limiters for strong (Pass 1) and fast (Pass 2) models. - When pivotal and routine models are the same, uses a single shared limiter. + When strong and fast models are the same, uses a single shared limiter. When they differ, uses independent limiters since API limits are per-model. + Supports mixed providers (e.g., strong=anthropic, fast=openai). """ def __init__( @@ -499,51 +500,79 @@ def __init__( ): self.pivotal = pivotal self.routine = routine + # Aliases for new naming convention + self.strong = pivotal + self.fast = routine @classmethod def create( cls, - provider: str, + provider: str = "", pivotal_model: str = "", routine_model: str = "", tier: int | None = None, rpm_override: int | None = None, tpm_override: int | None = None, + *, + strong_model_string: str = "", + fast_model_string: str = "", ) -> "DualRateLimiter": """Create dual rate limiter for two-pass reasoning. - If both models are the same (or routine is empty), a single - shared limiter is used for both passes. + Accepts either: + - Legacy: provider + pivotal_model + routine_model (single provider) + - New: strong_model_string + fast_model_string (provider/model format, mixed providers) Args: - provider: Provider name - pivotal_model: Model for Pass 1 (role-play reasoning) - routine_model: Model for Pass 2 (classification) + provider: Provider name (legacy, used if model strings not provided) + pivotal_model: Model for Pass 1 (legacy) + routine_model: Model for Pass 2 (legacy) tier: Rate limit tier (1-4) - rpm_override: Override RPM (applies to pivotal limiter) - tpm_override: Override TPM (applies to pivotal limiter) + rpm_override: Override RPM + tpm_override: Override TPM + strong_model_string: "provider/model" for strong/pivotal (new) + fast_model_string: "provider/model" for fast/routine (new) Returns: DualRateLimiter instance """ + # Resolve strong limiter + if strong_model_string and "/" in strong_model_string: + from ..config import parse_model_string + + strong_provider, strong_model = parse_model_string(strong_model_string) + else: + strong_provider = provider + strong_model = pivotal_model + pivotal_limiter = RateLimiter.for_provider( - provider=provider, - model=pivotal_model, + provider=strong_provider, + model=strong_model, tier=tier, rpm_override=rpm_override, tpm_override=tpm_override, ) - # If routine model is the same as pivotal (or not specified), share the limiter - effective_routine = routine_model or pivotal_model - if effective_routine == pivotal_model or not effective_routine: + # Resolve fast limiter + if fast_model_string and "/" in fast_model_string: + from ..config import parse_model_string + + fast_provider, fast_model = parse_model_string(fast_model_string) + else: + fast_provider = provider + fast_model = routine_model + + # If same provider+model, share the limiter + effective_fast_model = fast_model or strong_model + if fast_provider == strong_provider and effective_fast_model == strong_model: + return cls(pivotal=pivotal_limiter, routine=pivotal_limiter) + + if not effective_fast_model and not fast_provider: return cls(pivotal=pivotal_limiter, routine=pivotal_limiter) - # Different models — create separate limiter for routine - # Overrides apply to both (on Azure, limits are per-resource not per-model) routine_limiter = RateLimiter.for_provider( - provider=provider, - model=effective_routine, + provider=fast_provider or strong_provider, + model=effective_fast_model, tier=tier, rpm_override=rpm_override, tpm_override=tpm_override, diff --git a/extropy/core/rate_limits.py b/extropy/core/rate_limits.py index d081e60..ec53025 100644 --- a/extropy/core/rate_limits.py +++ b/extropy/core/rate_limits.py @@ -85,11 +85,32 @@ }, } -# Map "claude" provider name to anthropic profiles +# Provider aliases — map alternate names to canonical profiles RATE_LIMIT_PROFILES["claude"] = RATE_LIMIT_PROFILES["anthropic"] - -# Azure OpenAI uses the same rate limit profiles as standard OpenAI RATE_LIMIT_PROFILES["azure_openai"] = RATE_LIMIT_PROFILES["openai"] +RATE_LIMIT_PROFILES["azure"] = RATE_LIMIT_PROFILES["openai"] + +# Third-party providers — conservative defaults +# These providers typically have per-key limits; adjust via rate_tier/rpm_override. +_THIRD_PARTY_DEFAULT = { + "default": { + 1: {"rpm": 60, "tpm": 100_000}, + 2: {"rpm": 200, "tpm": 500_000}, + 3: {"rpm": 500, "tpm": 1_000_000}, + 4: {"rpm": 1_000, "tpm": 2_000_000}, + }, +} +RATE_LIMIT_PROFILES["openrouter"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["deepseek"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["together"] = _THIRD_PARTY_DEFAULT +RATE_LIMIT_PROFILES["groq"] = { + "default": { + 1: {"rpm": 30, "tpm": 15_000}, + 2: {"rpm": 60, "tpm": 50_000}, + 3: {"rpm": 200, "tpm": 100_000}, + 4: {"rpm": 500, "tpm": 500_000}, + }, +} def get_limits( diff --git a/extropy/population/network/__init__.py b/extropy/population/network/__init__.py index 7772d0a..e6f03c6 100644 --- a/extropy/population/network/__init__.py +++ b/extropy/population/network/__init__.py @@ -7,7 +7,8 @@ Usage: from extropy.network import generate_network, NetworkConfig, NetworkResult - # Load agents from JSON + # Agents are typically loaded from study.db via CLI, then passed here. + # (load_agents_json is kept for explicit import/export workflows.) agents = load_agents_json("agents.json") # Generate network with default config (flat — no similarity structure) diff --git a/extropy/population/network/config.py b/extropy/population/network/config.py index f411c2d..08a4024 100644 --- a/extropy/population/network/config.py +++ b/extropy/population/network/config.py @@ -103,8 +103,18 @@ class NetworkConfig(BaseModel): Attributes: avg_degree: Target average degree (connections per agent) rewire_prob: Watts-Strogatz rewiring probability + similarity_store_threshold: Minimum similarity retained in sparse matrix similarity_threshold: Sigmoid threshold for edge probability similarity_steepness: Sigmoid steepness for edge probability + candidate_mode: Similarity candidate strategy. + - "exact": all-pairs (highest fidelity, slowest) + - "blocked": block-based candidate pruning (near-equivalent, much faster) + candidate_pool_multiplier: Candidate pool size per node as a multiple of avg_degree + min_candidate_pool: Lower bound for candidate pool size per node in blocked mode + blocking_attributes: Attributes used for blocking. Auto-selected if empty. + similarity_workers: Worker processes for similarity stage (1 = serial) + similarity_chunk_size: Row chunk size per worker task + checkpoint_every_rows: Save similarity checkpoint every N rows triadic_closure_prob: Probability of closing open triads (A-B, B-C -> A-C). Higher values create more realistic clustering. Default 0.4. target_clustering: Target clustering coefficient (0.3-0.5 is realistic). @@ -123,8 +133,16 @@ class NetworkConfig(BaseModel): avg_degree: float = 20.0 rewire_prob: float = 0.05 + similarity_store_threshold: float = 0.05 similarity_threshold: float = 0.3 similarity_steepness: float = 10.0 + candidate_mode: Literal["exact", "blocked"] = "exact" + candidate_pool_multiplier: float = 12.0 + min_candidate_pool: int = 80 + blocking_attributes: list[str] = Field(default_factory=list) + similarity_workers: int = 1 + similarity_chunk_size: int = 64 + checkpoint_every_rows: int = 250 triadic_closure_prob: float = 0.6 target_clustering: float = 0.35 target_modularity: float = 0.55 # Target modularity (0.4-0.7 range) diff --git a/extropy/population/network/generator.py b/extropy/population/network/generator.py index 09fa31c..f8ca719 100644 --- a/extropy/population/network/generator.py +++ b/extropy/population/network/generator.py @@ -5,12 +5,16 @@ import json import logging +import hashlib +import multiprocessing as mp import random +from concurrent.futures import ProcessPoolExecutor, as_completed from datetime import datetime from pathlib import Path from typing import Any from ...core.models import Edge, NetworkResult +from ...storage import open_study_db from ...utils.callbacks import NetworkProgressCallback from ...utils.eval_safe import ConditionError, eval_condition from .config import NetworkConfig, InfluenceFactorConfig @@ -21,6 +25,395 @@ logger = logging.getLogger(__name__) +_SIM_WORKER_AGENTS: list[dict[str, Any]] | None = None +_SIM_WORKER_ATTRIBUTE_WEIGHTS = None +_SIM_WORKER_ORDINAL_LEVELS: dict[str, dict[str, int]] | None = None +_SIM_WORKER_THRESHOLD: float = 0.05 +_SIM_WORKER_CANDIDATE_MAP: list[list[int]] | None = None + + +def _choose_blocking_attributes(config: NetworkConfig) -> list[str]: + """Choose blocking attributes for candidate pruning.""" + if config.blocking_attributes: + return list(config.blocking_attributes) + + weighted = sorted( + config.attribute_weights.items(), + key=lambda x: x[1].weight, + reverse=True, + ) + preferred = [ + attr for attr, cfg in weighted if cfg.match_type in {"exact", "within_n"} + ] + + if preferred: + return preferred[:3] + + return [attr for attr, _ in weighted[:2]] + + +def _build_blocked_candidate_map( + agents: list[dict[str, Any]], + config: NetworkConfig, + seed: int, +) -> tuple[list[list[int]] | None, list[str]]: + """Build per-agent candidate lists for blocked similarity mode.""" + attrs = _choose_blocking_attributes(config) + n = len(agents) + + if not attrs or n <= 1: + return None, attrs + + blocks: dict[str, dict[Any, list[int]]] = {attr: {} for attr in attrs} + + for idx, agent in enumerate(agents): + for attr in attrs: + val = agent.get(attr) + if val is None: + continue + blocks[attr].setdefault(val, []).append(idx) + + target_pool = max( + config.min_candidate_pool, + int(config.avg_degree * config.candidate_pool_multiplier), + ) + target_pool = max(1, min(n - 1, target_pool)) + + candidate_map: list[list[int]] = [[] for _ in range(n)] + + for i, agent in enumerate(agents): + scores: dict[int, int] = {} + + for attr in attrs: + val = agent.get(attr) + if val is None: + continue + for j in blocks[attr].get(val, []): + if j == i: + continue + scores[j] = scores.get(j, 0) + 1 + + ranked = sorted(scores.items(), key=lambda x: (-x[1], x[0])) + chosen = [j for j, _ in ranked[:target_pool]] + + if len(chosen) < target_pool: + rng = random.Random(seed + (i + 1) * 7919) + seen = set(chosen) + seen.add(i) + while len(chosen) < target_pool and len(seen) < n: + j = rng.randrange(n) + if j in seen: + continue + seen.add(j) + chosen.append(j) + + candidate_map[i] = sorted(chosen) + + return candidate_map, attrs + + +def _similarity_checkpoint_signature( + n: int, + seed: int, + config: NetworkConfig, + blocking_attrs: list[str], +) -> dict[str, Any]: + """Build a minimal signature to validate checkpoint compatibility.""" + return { + "n": n, + "seed": seed, + "candidate_mode": config.candidate_mode, + "threshold": config.similarity_store_threshold, + "candidate_pool_multiplier": config.candidate_pool_multiplier, + "min_candidate_pool": config.min_candidate_pool, + "blocking_attributes": blocking_attrs, + } + + +def _similarity_checkpoint_job_id(signature: dict[str, Any]) -> str: + raw = json.dumps(signature, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24] + + +def _load_similarity_checkpoint( + checkpoint_db: Path, + expected_signature: dict[str, Any], +) -> tuple[dict[tuple[int, int], float], int, set[int]]: + """Load checkpoint and validate compatibility with current run settings.""" + job_id = _similarity_checkpoint_job_id(expected_signature) + with open_study_db(checkpoint_db) as db: + signature = db.get_network_similarity_job_signature(job_id) + if signature is None: + raise ValueError(f"Checkpoint not found in study DB: job_id={job_id}") + if signature != expected_signature: + raise ValueError( + "Checkpoint settings do not match current run. " + "Delete checkpoint or run with matching config." + ) + + done_chunks = db.list_completed_similarity_chunks(job_id) + done_starts = {start for start, _ in done_chunks} + similarities = db.load_similarity_pairs(job_id) + + contiguous_rows = 0 + for start, end in done_chunks: + if start != contiguous_rows: + break + contiguous_rows = end + + return similarities, max(0, contiguous_rows), done_starts + + +def _init_similarity_worker( + agents: list[dict[str, Any]], + attribute_weights, + ordinal_levels: dict[str, dict[str, int]] | None, + threshold: float, + candidate_map: list[list[int]] | None, +) -> None: + """Initialize process-local globals for similarity workers.""" + global _SIM_WORKER_AGENTS + global _SIM_WORKER_ATTRIBUTE_WEIGHTS + global _SIM_WORKER_ORDINAL_LEVELS + global _SIM_WORKER_THRESHOLD + global _SIM_WORKER_CANDIDATE_MAP + + _SIM_WORKER_AGENTS = agents + _SIM_WORKER_ATTRIBUTE_WEIGHTS = attribute_weights + _SIM_WORKER_ORDINAL_LEVELS = ordinal_levels + _SIM_WORKER_THRESHOLD = threshold + _SIM_WORKER_CANDIDATE_MAP = candidate_map + + +def _compute_similarity_chunk( + task: tuple[int, int], +) -> tuple[int, list[tuple[int, int, float]]]: + """Compute similarities for a chunk of row indices in a worker process.""" + start, end = task + if _SIM_WORKER_AGENTS is None: + raise RuntimeError("Similarity worker not initialized") + + n = len(_SIM_WORKER_AGENTS) + rows: list[tuple[int, int, float]] = [] + + for i in range(start, min(end, n)): + if _SIM_WORKER_CANDIDATE_MAP is None: + candidates = range(i + 1, n) + else: + candidates = _SIM_WORKER_CANDIDATE_MAP[i] + + for j in candidates: + if j <= i: + continue + sim = compute_similarity( + _SIM_WORKER_AGENTS[i], + _SIM_WORKER_AGENTS[j], + _SIM_WORKER_ATTRIBUTE_WEIGHTS, + _SIM_WORKER_ORDINAL_LEVELS, + ) + if sim >= _SIM_WORKER_THRESHOLD: + rows.append((i, j, sim)) + + return end, rows + + +def _compute_similarities_parallel( + agents: list[dict[str, Any]], + config: NetworkConfig, + candidate_map: list[list[int]] | None, + on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | None = None, + checkpoint_signature: dict[str, Any] | None = None, + initial_similarities: dict[tuple[int, int], float] | None = None, + completed_rows: int = 0, + completed_chunk_starts: set[int] | None = None, + checkpoint_job_id: str | None = None, +) -> dict[tuple[int, int], float]: + """Compute sparse similarities with process parallelism.""" + n = len(agents) + similarities: dict[tuple[int, int], float] = dict(initial_similarities or {}) + + chunk_size = max(8, config.similarity_chunk_size) + tasks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)] + task_ends = {start: end for start, end in tasks} + completed_starts: set[int] = set(completed_chunk_starts or set()) + for start, end in tasks: + if end <= completed_rows: + completed_starts.add(start) + pending_tasks = [(s, e) for s, e in tasks if s not in completed_starts] + workers = max(1, config.similarity_workers) + + completed_row_count = sum((e - s) for s, e in tasks if s in completed_starts) + if on_progress and completed_row_count > 0: + on_progress("Computing similarities", min(completed_row_count, n), n) + + try: + ctx = mp.get_context("spawn") + with ProcessPoolExecutor( + max_workers=workers, + mp_context=ctx, + initializer=_init_similarity_worker, + initargs=( + agents, + config.attribute_weights, + config.ordinal_levels, + config.similarity_store_threshold, + candidate_map, + ), + ) as ex: + futures = { + ex.submit(_compute_similarity_chunk, task): task + for task in pending_tasks + } + pending_results: dict[int, list[tuple[int, int, float]]] = {} + sorted_starts = [start for start, _ in tasks] + next_commit_idx = 0 + + for fut in as_completed(futures): + task_start, _task_end = futures[fut] + _row_end, local_rows = fut.result() + pending_results[task_start] = local_rows + + # Deterministic merge: commit completed chunks in chunk_start order. + while next_commit_idx < len(sorted_starts): + current_start = sorted_starts[next_commit_idx] + current_end = task_ends[current_start] + if current_start in completed_starts: + next_commit_idx += 1 + continue + if current_start not in pending_results: + break + + chunk_rows = pending_results.pop(current_start) + for i, j, sim in chunk_rows: + similarities[(i, j)] = sim + completed_starts.add(current_start) + completed_row_count += current_end - current_start + completed_rows = max(completed_rows, current_end) + + if ( + checkpoint_path is not None + and checkpoint_signature is not None + and checkpoint_job_id is not None + ): + with open_study_db(checkpoint_path) as db: + db.save_similarity_chunk_rows( + job_id=checkpoint_job_id, + chunk_start=current_start, + chunk_end=current_end, + rows=chunk_rows, + ) + + if on_progress: + on_progress( + "Computing similarities", min(completed_row_count, n), n + ) + next_commit_idx += 1 + + except Exception as e: + downgraded_config = config.model_copy( + update={ + "similarity_workers": 1, + "similarity_chunk_size": max(8, config.similarity_chunk_size // 2), + } + ) + logger.warning( + "Parallel similarity failed (%s). Falling back to serial mode " + "(chunk_size %d -> %d).", + e, + config.similarity_chunk_size, + downgraded_config.similarity_chunk_size, + ) + return _compute_similarities_serial( + agents=agents, + config=downgraded_config, + candidate_map=candidate_map, + on_progress=on_progress, + checkpoint_path=checkpoint_path, + initial_similarities=similarities, + start_row=completed_rows, + checkpoint_signature=checkpoint_signature, + completed_chunk_starts=completed_starts, + checkpoint_job_id=checkpoint_job_id, + ) + + return similarities + + +def _compute_similarities_serial( + agents: list[dict[str, Any]], + config: NetworkConfig, + candidate_map: list[list[int]] | None = None, + on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | None = None, + initial_similarities: dict[tuple[int, int], float] | None = None, + start_row: int = 0, + checkpoint_signature: dict[str, Any] | None = None, + completed_chunk_starts: set[int] | None = None, + checkpoint_job_id: str | None = None, +) -> dict[tuple[int, int], float]: + """Compute sparse similarities serially, with optional checkpointing.""" + n = len(agents) + threshold = config.similarity_store_threshold + similarities = dict(initial_similarities or {}) + checkpoint_every = max(1, config.checkpoint_every_rows) + chunk_size = max(8, config.similarity_chunk_size) + tasks = [(i, min(i + chunk_size, n)) for i in range(0, n, chunk_size)] + completed_starts: set[int] = set(completed_chunk_starts or set()) + for start, end in tasks: + if end <= start_row: + completed_starts.add(start) + completed_row_count = sum((e - s) for s, e in tasks if s in completed_starts) + + for chunk_idx, (start, end) in enumerate(tasks): + if start in completed_starts: + continue + + local_rows: list[tuple[int, int, float]] = [] + for i in range(start, end): + if candidate_map is None: + candidates = range(i + 1, n) + else: + candidates = candidate_map[i] + + for j in candidates: + if j <= i: + continue + sim = compute_similarity( + agents[i], + agents[j], + config.attribute_weights, + config.ordinal_levels, + ) + if sim >= threshold: + similarities[(i, j)] = sim + local_rows.append((i, j, sim)) + + completed_starts.add(start) + completed_row_count += end - start + + if ( + checkpoint_path is not None + and checkpoint_signature is not None + and checkpoint_job_id is not None + ): + if ( + completed_row_count % checkpoint_every == 0 + or chunk_idx == len(tasks) - 1 + ): + with open_study_db(checkpoint_path) as db: + db.save_similarity_chunk_rows( + job_id=checkpoint_job_id, + chunk_start=start, + chunk_end=end, + rows=local_rows, + ) + + if on_progress: + on_progress("Computing similarities", min(completed_row_count, n), n) + + return similarities + def _eval_edge_condition( condition: str, @@ -429,6 +822,7 @@ def _triadic_closure( edge_set: set[tuple[str, str]], config: NetworkConfig, rng: random.Random, + similarities: dict[tuple[int, int], float] | None = None, communities: list[int] | None = None, target_clustering: float = 0.35, max_edge_increase: float = 1.5, @@ -480,9 +874,17 @@ def _triadic_closure( # Score triads by similarity and community membership triad_with_score = [] for a, c, b in open_triads: - sim = compute_similarity( - agents[a], agents[c], config.attribute_weights, config.ordinal_levels - ) + pair = (min(a, c), max(a, c)) + sim = similarities.get(pair) if similarities is not None else None + if sim is None: + sim = compute_similarity( + agents[a], + agents[c], + config.attribute_weights, + config.ordinal_levels, + ) + if similarities is not None: + similarities[pair] = sim same_community = ( communities is not None and communities[a] == communities[c] ) @@ -691,6 +1093,7 @@ def _generate_network_single_pass( edge_set, config, rng, + similarities=similarities, communities=communities, target_clustering=config.target_clustering, max_edge_increase=2.5, # Allow up to 2.5x edges for better clustering @@ -718,6 +1121,8 @@ def generate_network( agents: list[dict[str, Any]], config: NetworkConfig | None = None, on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | str | None = None, + resume_from_checkpoint: bool = False, ) -> NetworkResult: """Generate a social network from sampled agents. @@ -746,27 +1151,108 @@ def generate_network( n = len(agents) agent_ids = [a.get("_id", f"agent_{i}") for i, a in enumerate(agents)] - - if on_progress: - on_progress("Computing similarities", 0, n) + checkpoint_file = Path(checkpoint_path) if checkpoint_path else None + if checkpoint_file is not None and checkpoint_file.suffix.lower() != ".db": + raise ValueError( + "Network checkpoints are DB-only now. Use --study-db (or --checkpoint )." + ) # Step 1: Compute degree factors degree_factors = [compute_degree_factor(a, config) for a in agents] - # Step 2: Compute similarity matrix (sparse) - similarities: dict[tuple[int, int], float] = {} - threshold = 0.05 + # Step 2: Build similarity candidates (exact/blocked) + candidate_map: list[list[int]] | None = None + blocking_attrs: list[str] = [] + candidate_mode = config.candidate_mode - for i in range(n): - for j in range(i + 1, n): - sim = compute_similarity( - agents[i], agents[j], config.attribute_weights, config.ordinal_levels + if config.candidate_mode == "blocked": + if on_progress: + on_progress("Preparing candidate blocks", 0, n) + candidate_map, blocking_attrs = _build_blocked_candidate_map( + agents, config, seed + ) + if on_progress: + on_progress("Preparing candidate blocks", n, n) + if candidate_map is None: + logger.warning( + "Blocked candidate mode could not be initialized. Falling back to exact mode." ) - if sim >= threshold: - similarities[(i, j)] = sim + candidate_mode = "exact" + + if on_progress: + on_progress("Computing similarities", 0, n) + + checkpoint_signature = _similarity_checkpoint_signature( + n=n, + seed=seed, + config=config, + blocking_attrs=blocking_attrs, + ) + checkpoint_job_id: str | None = None + if checkpoint_file is not None: + checkpoint_job_id = _similarity_checkpoint_job_id(checkpoint_signature) + if not resume_from_checkpoint: + with open_study_db(checkpoint_file) as db: + db.init_network_similarity_job( + network_run_id=f"checkpoint:{checkpoint_job_id}", + signature=checkpoint_signature, + job_id=checkpoint_job_id, + ) + db.mark_similarity_job_running(checkpoint_job_id) + + similarities: dict[tuple[int, int], float] + start_row = 0 + completed_chunk_starts: set[int] = set() - if on_progress and i % 50 == 0: - on_progress("Computing similarities", i, n) + if resume_from_checkpoint and checkpoint_file is None: + raise ValueError("--resume-checkpoint requires a checkpoint DB path") + + if resume_from_checkpoint: + if checkpoint_file is None or not checkpoint_file.exists(): + raise ValueError(f"Checkpoint not found: {checkpoint_file}") + similarities, start_row, completed_chunk_starts = _load_similarity_checkpoint( + checkpoint_file, checkpoint_signature + ) + if checkpoint_job_id and checkpoint_file is not None: + with open_study_db(checkpoint_file) as db: + db.mark_similarity_job_running(checkpoint_job_id) + if on_progress: + on_progress("Computing similarities", min(start_row, n), n) + else: + similarities = {} + + use_parallel_similarity = config.similarity_workers > 1 + + if use_parallel_similarity: + similarities = _compute_similarities_parallel( + agents=agents, + config=config, + candidate_map=candidate_map if candidate_mode == "blocked" else None, + on_progress=on_progress, + checkpoint_path=checkpoint_file, + checkpoint_signature=checkpoint_signature, + initial_similarities=similarities, + completed_rows=start_row, + completed_chunk_starts=completed_chunk_starts, + checkpoint_job_id=checkpoint_job_id, + ) + else: + similarities = _compute_similarities_serial( + agents=agents, + config=config, + candidate_map=candidate_map if candidate_mode == "blocked" else None, + on_progress=on_progress, + checkpoint_path=checkpoint_file, + initial_similarities=similarities, + start_row=start_row, + checkpoint_signature=checkpoint_signature, + completed_chunk_starts=completed_chunk_starts, + checkpoint_job_id=checkpoint_job_id, + ) + + if checkpoint_job_id and checkpoint_file is not None: + with open_study_db(checkpoint_file) as db: + db.mark_similarity_job_complete(checkpoint_job_id) if on_progress: on_progress("Computing similarities", n, n) @@ -935,6 +1421,10 @@ def generate_network( "rewired_count": rewired_count, "algorithm": "adaptive_calibration", "seed": seed, + "candidate_mode": candidate_mode, + "similarity_pairs": len(similarities), + "blocking_attributes": blocking_attrs if candidate_mode == "blocked" else [], + "resumed_from_checkpoint": resume_from_checkpoint, "config": { "avg_degree_target": config.avg_degree, "rewire_prob": config.rewire_prob, @@ -951,6 +1441,8 @@ def generate_network_with_metrics( agents: list[dict[str, Any]], config: NetworkConfig | None = None, on_progress: NetworkProgressCallback | None = None, + checkpoint_path: Path | str | None = None, + resume_from_checkpoint: bool = False, ) -> NetworkResult: """Generate network and compute all metrics. @@ -960,7 +1452,13 @@ def generate_network_with_metrics( """ from .metrics import compute_network_metrics, compute_node_metrics - result = generate_network(agents, config, on_progress) + result = generate_network( + agents, + config, + on_progress, + checkpoint_path=checkpoint_path, + resume_from_checkpoint=resume_from_checkpoint, + ) agent_ids = [a.get("_id", f"agent_{i}") for i, a in enumerate(agents)] diff --git a/extropy/scenario/__init__.py b/extropy/scenario/__init__.py index 6dd793d..31c8893 100644 --- a/extropy/scenario/__init__.py +++ b/extropy/scenario/__init__.py @@ -16,8 +16,9 @@ >>> spec, result = create_scenario( ... "Netflix announces $3 price increase", ... "population.yaml", - ... "agents.json", - ... "network.json", + ... study_db_path="study.db", + ... population_id="default", + ... network_id="default", ... "scenario.yaml" ... ) >>> result.valid diff --git a/extropy/scenario/compiler.py b/extropy/scenario/compiler.py index 7be6831..bb108f2 100644 --- a/extropy/scenario/compiler.py +++ b/extropy/scenario/compiler.py @@ -8,7 +8,6 @@ 5. Assemble and validate ScenarioSpec """ -import json import re from datetime import datetime from pathlib import Path @@ -26,7 +25,8 @@ from .interaction import determine_interaction_model from .outcomes import define_outcomes from ..utils.callbacks import StepProgressCallback -from .validator import validate_scenario, get_agent_count +from .validator import validate_scenario +from ..storage import open_study_db def _generate_scenario_name(description: str) -> str: @@ -57,43 +57,38 @@ def _determine_simulation_config(population_size: int) -> SimulationConfig: ) -def _load_network_summary(network_path: Path) -> dict | None: - """Load network summary for exposure generation.""" - if not network_path.exists(): - return None +def _load_network_summary(network_data: dict[str, object]) -> dict[str, object]: + """Build network summary for exposure generation from network payload.""" + edge_types = set() + node_count = 0 - try: - with open(network_path) as f: - network = json.load(f) + meta = network_data.get("meta") + if isinstance(meta, dict): + raw_count = meta.get("node_count") + if isinstance(raw_count, int): + node_count = raw_count - # Extract summary information - edge_types = set() - node_count = 0 + edges = network_data.get("edges") + if isinstance(edges, list): + for edge in edges: + if not isinstance(edge, dict): + continue + edge_type = edge.get("edge_type") or edge.get("type") + if isinstance(edge_type, str): + edge_types.add(edge_type) - if "meta" in network: - node_count = network["meta"].get("node_count", 0) - - if "edges" in network: - for edge in network["edges"]: - # Check both 'edge_type' and 'type' fields (different network formats) - if "edge_type" in edge: - edge_types.add(edge["edge_type"]) - elif "type" in edge: - edge_types.add(edge["type"]) - - return { - "node_count": node_count, - "edge_types": list(edge_types), - } - except (json.JSONDecodeError, KeyError, TypeError): - return None + return { + "node_count": node_count, + "edge_types": list(edge_types), + } def create_scenario( description: str, population_spec_path: str | Path, - agents_path: str | Path, - network_path: str | Path, + study_db_path: str | Path, + population_id: str = "default", + network_id: str = "default", output_path: str | Path | None = None, on_progress: StepProgressCallback | None = None, ) -> tuple[ScenarioSpec, ValidationResult]: @@ -114,8 +109,9 @@ def create_scenario( Args: description: Natural language scenario description population_spec_path: Path to population YAML file - agents_path: Path to agents JSON file - network_path: Path to network JSON file + study_db_path: Path to canonical study DB + population_id: Population ID in study DB + network_id: Network ID in study DB output_path: Optional path to save scenario YAML on_progress: Optional callback(step, status) for progress updates @@ -130,16 +126,16 @@ def create_scenario( >>> spec, result = create_scenario( ... "Netflix announces $3 price increase", ... "population.yaml", - ... "agents.json", - ... "network.json", + ... "study.db", + ... "default", + ... "default", ... "scenario.yaml" ... ) >>> result.valid True """ population_spec_path = Path(population_spec_path) - agents_path = Path(agents_path) - network_path = Path(network_path) + study_db_path = Path(study_db_path) def progress(step: str, status: str): if on_progress: @@ -157,7 +153,13 @@ def progress(step: str, status: str): population_spec = PopulationSpec.from_yaml(population_spec_path) # Load network summary for exposure generation - network_summary = _load_network_summary(network_path) + with open_study_db(study_db_path) as db: + network = db.get_network(network_id) + if not network.get("edges"): + raise FileNotFoundError( + f"Network '{network_id}' not found in study DB: {study_db_path}" + ) + network_summary = _load_network_summary(network) # ========================================================================= # Step 1: Parse scenario description @@ -220,8 +222,9 @@ def progress(step: str, status: str): name=scenario_name, description=description, population_spec=str(population_spec_path), - agents_file=str(agents_path), - network_file=str(network_path), + study_db=str(study_db_path), + population_id=population_id, + network_id=network_id, created_at=datetime.now(), ) @@ -240,18 +243,9 @@ def progress(step: str, status: str): # Validate # ========================================================================= - # Note: We validate agent count consistency, which requires loading the file. - # We use get_agent_count() to do this safely/robustly. - agent_count = get_agent_count(agents_path) - - # Load network for validation (needed for edge type reference validation) - network = None - if network_path.exists(): - try: - with open(network_path) as f: - network = json.load(f) - except (json.JSONDecodeError, OSError): - pass + with open_study_db(study_db_path) as db: + agent_count = db.get_agent_count(population_id) + network = db.get_network(network_id) validation_result = validate_scenario(spec, population_spec, agent_count, network) @@ -268,8 +262,9 @@ def progress(step: str, status: str): def compile_scenario_from_files( description: str, population_spec_path: str | Path, - agents_path: str | Path, - network_path: str | Path, + study_db_path: str | Path, + population_id: str = "default", + network_id: str = "default", ) -> ScenarioSpec: """ Convenience function to create a scenario spec. @@ -279,8 +274,9 @@ def compile_scenario_from_files( Args: description: Natural language scenario description population_spec_path: Path to population YAML file - agents_path: Path to agents JSON file - network_path: Path to network JSON file + study_db_path: Path to canonical study DB + population_id: Population ID in study DB + network_id: Network ID in study DB Returns: ScenarioSpec @@ -292,8 +288,9 @@ def compile_scenario_from_files( spec, result = create_scenario( description, population_spec_path, - agents_path, - network_path, + study_db_path, + population_id, + network_id, ) if not result.valid: diff --git a/extropy/scenario/validator.py b/extropy/scenario/validator.py index 13ab389..cdc8ac4 100644 --- a/extropy/scenario/validator.py +++ b/extropy/scenario/validator.py @@ -20,6 +20,7 @@ extract_names_from_expression, validate_expression_syntax, ) +from ..storage import open_study_db # Helper functions to create ValidationIssue with appropriate severity @@ -411,12 +412,10 @@ def validate_scenario( base_file = Path(spec_file) population_path = resolve_relative_to(spec.meta.population_spec, base_file) - agents_path = resolve_relative_to(spec.meta.agents_file, base_file) - network_path = resolve_relative_to(spec.meta.network_file, base_file) + study_db_path = resolve_relative_to(spec.meta.study_db, base_file) else: population_path = Path(spec.meta.population_spec) - agents_path = Path(spec.meta.agents_file) - network_path = Path(spec.meta.network_file) + study_db_path = Path(spec.meta.study_db) if not population_path.exists(): errors.append( @@ -428,22 +427,12 @@ def validate_scenario( ) ) - if not agents_path.exists(): + if not study_db_path.exists(): errors.append( ValidationError( category="file_reference", - location="meta.agents_file", - message=f"Agents file not found: {spec.meta.agents_file}", - suggestion="Check the file path", - ) - ) - - if not network_path.exists(): - errors.append( - ValidationError( - category="file_reference", - location="meta.network_file", - message=f"Network file not found: {spec.meta.network_file}", + location="meta.study_db", + message=f"Study DB not found: {spec.meta.study_db}", suggestion="Check the file path", ) ) @@ -462,6 +451,38 @@ def validate_scenario( ) ) + # Validate IDs inside study DB when available. + if study_db_path.exists(): + try: + with open_study_db(study_db_path) as db: + if db.get_agent_count(spec.meta.population_id) == 0: + errors.append( + ValidationError( + category="file_reference", + location="meta.population_id", + message=f"Population ID not found in study DB: {spec.meta.population_id}", + suggestion="Run `extropy sample ... --study-db ... --population-id ...` first", + ) + ) + if db.get_network_edge_count(spec.meta.network_id) == 0: + errors.append( + ValidationError( + category="file_reference", + location="meta.network_id", + message=f"Network ID not found in study DB: {spec.meta.network_id}", + suggestion="Run `extropy network ... --study-db ... --network-id ...` first", + ) + ) + except Exception: + errors.append( + ValidationError( + category="file_reference", + location="meta.study_db", + message=f"Failed to read study DB: {spec.meta.study_db}", + suggestion="Check that the file is a valid SQLite study DB", + ) + ) + return ValidationResult(issues=[*errors, *warnings]) @@ -541,15 +562,12 @@ def load_and_validate_scenario( except Exception: pass # Will be caught as validation error - agents_path = resolve_relative_to(spec.meta.agents_file, scenario_path) - if agents_path.exists(): - agent_count = get_agent_count(agents_path) - - network_path = resolve_relative_to(spec.meta.network_file, scenario_path) - if network_path.exists(): + study_db_path = resolve_relative_to(spec.meta.study_db, scenario_path) + if study_db_path.exists(): try: - with open(network_path) as f: - network = json.load(f) + with open_study_db(study_db_path) as db: + agent_count = db.get_agent_count(spec.meta.population_id) + network = db.get_network(spec.meta.network_id) except Exception: pass diff --git a/extropy/simulation/__init__.py b/extropy/simulation/__init__.py index 1f22a33..1d3ecb8 100644 --- a/extropy/simulation/__init__.py +++ b/extropy/simulation/__init__.py @@ -25,11 +25,8 @@ Output: Results directory containing: - - simulation.db: SQLite database with all state - - timeline.jsonl: Streaming event log - - agent_states.json: Final state per agent + - study.db: Canonical SQLite database with simulation state/checkpoints - by_timestep.json: Metrics over time - - outcome_distributions.json: Final outcome distributions - meta.json: Run configuration """ diff --git a/extropy/simulation/engine.py b/extropy/simulation/engine.py index 29f4053..56ae3b4 100644 --- a/extropy/simulation/engine.py +++ b/extropy/simulation/engine.py @@ -14,8 +14,12 @@ import json import logging +import queue import random +import sqlite3 +import threading import time +import uuid from datetime import datetime from pathlib import Path from typing import Any @@ -36,16 +40,16 @@ float_to_conviction, ) from ..core.rate_limiter import DualRateLimiter -from ..population.network import load_agents_json from ..population.persona import PersonaConfig +from ..storage import open_study_db from .progress import SimulationProgress from .state import StateManager from .persona import generate_persona from .reasoning import batch_reason_agents, create_reasoning_context from .propagation import apply_seed_exposures, propagate_through_network from .stopping import evaluate_stopping_conditions -from .timeline import TimelineManager from ..utils.callbacks import TimestepProgressCallback +from ..utils.resource_governor import ResourceGovernor from .aggregation import ( compute_timestep_summary, compute_final_aggregates, @@ -65,12 +69,29 @@ _PRIVATE_FLIP_CONVICTION = CONVICTION_MAP[ConvictionLevel.FIRM] +class _StateTimelineAdapter: + """Timeline adapter that persists events into StateManager timeline table.""" + + def __init__(self, state_manager: StateManager): + self.state_manager = state_manager + + def log_event(self, event: SimulationEvent) -> None: + self.state_manager.log_event(event) + + def flush(self) -> None: + return + + def close(self) -> None: + return + + class SimulationSummary: """Summary of a completed simulation run.""" def __init__( self, scenario_name: str, + run_id: str | None, population_size: int, total_timesteps: int, stopped_reason: str | None, @@ -83,6 +104,7 @@ def __init__( completed_at: datetime, ): self.scenario_name = scenario_name + self.run_id = run_id self.population_size = population_size self.total_timesteps = total_timesteps self.stopped_reason = stopped_reason @@ -98,6 +120,7 @@ def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "scenario_name": self.scenario_name, + "run_id": self.run_id, "population_size": self.population_size, "total_timesteps": self.total_timesteps, "stopped_reason": self.stopped_reason, @@ -128,6 +151,13 @@ def __init__( persona_config: PersonaConfig | None = None, rate_limiter: DualRateLimiter | None = None, chunk_size: int = 50, + state_db_path: Path | str | None = None, + run_id: str | None = None, + checkpoint_every_chunks: int = 1, + retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, + resource_governor: ResourceGovernor | None = None, ): """Initialize simulation engine. @@ -149,6 +179,19 @@ def __init__( self.persona_config = persona_config self.rate_limiter = rate_limiter self.chunk_size = chunk_size + self.run_id = run_id or f"run_{uuid.uuid4().hex[:12]}" + self.checkpoint_every_chunks = max(1, checkpoint_every_chunks) + self.retention_lite = retention_lite + self.writer_queue_size = max(1, writer_queue_size) + self.db_write_batch_size = max(1, db_write_batch_size) + self.resource_governor = resource_governor + self.reasoning_max_concurrency = 50 + if self.resource_governor is not None: + self.reasoning_max_concurrency = self.resource_governor.recommend_workers( + requested_workers=50, + memory_per_worker_gb=0.2, + ) + self._last_guardrail_timestep = -1 # Build agent map for quick lookup self.agent_map = {a.get("_id", str(i)): a for i, a in enumerate(agents)} @@ -174,13 +217,18 @@ def __init__( self.output_dir.mkdir(parents=True, exist_ok=True) # Initialize state manager + state_db_file = ( + Path(state_db_path) if state_db_path else self.output_dir / "study.db" + ) self.state_manager = StateManager( - self.output_dir / "simulation.db", + state_db_file, agents, + run_id=self.run_id, ) + self.study_db = open_study_db(state_db_file) # Initialize timeline manager - self.timeline = TimelineManager(self.output_dir / "timeline.jsonl") + self.timeline = _StateTimelineAdapter(self.state_manager) # Pre-generate personas for all agents # Extract decision-relevant attributes from outcome config (trait salience) @@ -258,6 +306,52 @@ def set_progress_state(self, progress: SimulationProgress) -> None: """ self._progress = progress + def _apply_runtime_guardrails(self, timestep: int) -> None: + """Downshift runtime knobs when process memory nears configured budget.""" + if ( + self.resource_governor is None + or self.resource_governor.resource_mode != "auto" + ): + return + + ratio = self.resource_governor.memory_pressure_ratio() + if ratio < 0.85: + return + + factor = 0.5 if ratio >= 0.98 else 0.75 + old_concurrency = self.reasoning_max_concurrency + old_batch = self.db_write_batch_size + old_queue = self.writer_queue_size + + self.reasoning_max_concurrency = self.resource_governor.downshift_int( + self.reasoning_max_concurrency, factor=factor, minimum=1 + ) + self.db_write_batch_size = self.resource_governor.downshift_int( + self.db_write_batch_size, factor=factor, minimum=1 + ) + self.writer_queue_size = self.resource_governor.downshift_int( + self.writer_queue_size, factor=factor, minimum=4 + ) + + changed = ( + old_concurrency != self.reasoning_max_concurrency + or old_batch != self.db_write_batch_size + or old_queue != self.writer_queue_size + ) + if changed and timestep != self._last_guardrail_timestep: + self._last_guardrail_timestep = timestep + logger.warning( + "[RESOURCE] Memory pressure %.2fx budget; " + "reasoning_concurrency %d->%d, writer_batch %d->%d, writer_queue %d->%d", + ratio, + old_concurrency, + self.reasoning_max_concurrency, + old_batch, + self.db_write_batch_size, + old_queue, + self.writer_queue_size, + ) + def _report_progress(self, timestep: int, status: str) -> None: """Report progress to callback.""" if self._on_progress: @@ -320,7 +414,7 @@ def run(self) -> SimulationSummary: """Execute the full simulation. Supports automatic resume: if the output directory contains a - simulation.db with partial progress, the engine picks up where + study.db with partial progress, the engine picks up where it left off. Returns: @@ -379,6 +473,7 @@ def run(self) -> SimulationSummary: self._export_results() finally: self.state_manager.close() + self.study_db.close() return summary @@ -491,6 +586,7 @@ def _reason_agents(self, timestep: int) -> tuple[int, int, int]: Returns: Tuple of (agents_reasoned, state_changes, shares_occurred). """ + self._apply_runtime_guardrails(timestep) agents_to_reason = self.state_manager.get_agents_to_reason( timestep, self.config.multi_touch_threshold, @@ -547,12 +643,96 @@ def _on_agent_done(agent_id: str, result: Any) -> None: context = self._build_reasoning_context(agent_id, old_state) contexts.append(context) - # Split into chunks - total_reasoned = 0 - total_changes = 0 - total_shares = 0 + completed_chunks = self.study_db.get_completed_simulation_chunks( + self.run_id, timestep + ) + totals = {"reasoned": 0, "changes": 0, "shares": 0} + work_queue: queue.Queue[tuple[int, list[tuple[str, Any]], bool] | object] = ( + queue.Queue(maxsize=self.writer_queue_size) + ) + sentinel = object() + writer_error: list[Exception] = [] + + def _writer_loop() -> None: + chunks_since_checkpoint = 0 + pending_chunks: list[tuple[int, list[tuple[str, Any]], bool]] = [] + + def _flush_pending() -> None: + nonlocal chunks_since_checkpoint + if not pending_chunks: + return + + with self.state_manager.transaction(): + for chunk_index, chunk_results, _is_last_chunk in pending_chunks: + reasoned, changes, shares = self._process_reasoning_chunk( + timestep, chunk_results, old_states + ) + totals["reasoned"] += reasoned + totals["changes"] += changes + totals["shares"] += shares + + for chunk_index, _chunk_results, is_last_chunk in pending_chunks: + self.study_db.save_simulation_checkpoint( + run_id=self.run_id, + timestep=timestep, + chunk_index=chunk_index, + status="done", + ) + chunks_since_checkpoint += 1 + if ( + chunks_since_checkpoint >= self.checkpoint_every_chunks + or is_last_chunk + ): + self.study_db.set_run_metadata( + self.run_id, + "last_checkpoint", + f"{timestep}:{chunk_index}", + ) + chunks_since_checkpoint = 0 + + pending_chunks.clear() + + try: + while True: + item = work_queue.get() + try: + if item is sentinel: + _flush_pending() + break + + chunk_index, chunk_results, is_last_chunk = item + if chunk_index in completed_chunks: + continue + pending_chunks.append( + (chunk_index, chunk_results, is_last_chunk) + ) + if ( + len(pending_chunks) >= self.db_write_batch_size + or is_last_chunk + ): + _flush_pending() + finally: + work_queue.task_done() + except Exception as e: # pragma: no cover - surfaced to caller + writer_error.append(e) + + writer_thread = threading.Thread( + target=_writer_loop, + name=f"sim-writer-{self.run_id}-{timestep}", + daemon=True, + ) + writer_thread.start() for chunk_start in range(0, len(contexts), self.chunk_size): + if writer_error: + break + self._apply_runtime_guardrails(timestep) + chunk_index = chunk_start // self.chunk_size + if chunk_index in completed_chunks: + logger.info( + f"[TIMESTEP {timestep}] Skipping completed chunk {chunk_index}" + ) + continue chunk_contexts = contexts[chunk_start : chunk_start + self.chunk_size] reasoning_start = time.time() @@ -560,13 +740,13 @@ def _on_agent_done(agent_id: str, result: Any) -> None: chunk_contexts, self.scenario, self.config, + max_concurrency=self.reasoning_max_concurrency, rate_limiter=self.rate_limiter, on_agent_done=_on_agent_done, ) reasoning_elapsed = time.time() - reasoning_start self.total_reasoning_calls += len(chunk_results) - # Accumulate token usage self.pivotal_input_tokens += chunk_usage.pivotal_input_tokens self.pivotal_output_tokens += chunk_usage.pivotal_output_tokens self.routine_input_tokens += chunk_usage.routine_input_tokens @@ -578,18 +758,27 @@ def _on_agent_done(agent_id: str, result: Any) -> None: if chunk_results else f"[TIMESTEP {timestep}] Chunk empty" ) - - # Process and commit this chunk - with self.state_manager.transaction(): - reasoned, changes, shares = self._process_reasoning_chunk( - timestep, chunk_results, old_states - ) - - total_reasoned += reasoned - total_changes += changes - total_shares += shares - - return total_reasoned, total_changes, total_shares + is_last_chunk = chunk_start + self.chunk_size >= len(contexts) + work_queue.put((chunk_index, chunk_results, is_last_chunk)) + + work_queue.put(sentinel) + while work_queue.unfinished_tasks > 0: + if writer_error: + while True: + try: + work_queue.get_nowait() + work_queue.task_done() + except queue.Empty: + break + break + time.sleep(0.01) + + work_queue.join() + writer_thread.join(timeout=1) + if writer_error: + raise writer_error[0] + + return totals["reasoned"], totals["changes"], totals["shares"] def _process_reasoning_chunk( self, @@ -778,7 +967,7 @@ def _process_reasoning_chunk( private_outcomes=private_outcomes, committed=is_committed, outcomes=private_outcomes, - raw_reasoning=response.reasoning, + raw_reasoning=None if self.retention_lite else response.reasoning, updated_at=timestep, ) @@ -813,6 +1002,9 @@ def _process_reasoning_chunk( "public_conviction": new_state.public_conviction, "private_conviction": new_state.private_conviction, "will_share": new_state.will_share, + "raw_reasoning": None + if self.retention_lite + else response.reasoning, }, ) ) @@ -1054,6 +1246,7 @@ def _finalize( return SimulationSummary( scenario_name=self.scenario.meta.name, + run_id=self.run_id, population_size=len(self.agents), total_timesteps=final_timestep + 1, stopped_reason=stopped_reason, @@ -1062,7 +1255,7 @@ def _finalize( final_exposure_rate=self.state_manager.get_exposure_rate(), outcome_distributions=outcome_dists, runtime_seconds=runtime, - model_used=self.config.model, + model_used=self.config.strong, completed_at=datetime.now(), ) @@ -1072,7 +1265,7 @@ def _compute_cost(self) -> dict[str, Any]: Returns: Cost dictionary with token counts and estimated USD. """ - from ..core.pricing import get_pricing, resolve_default_model + from ..core.pricing import get_pricing from ..config import get_config cost: dict[str, Any] = { @@ -1087,19 +1280,14 @@ def _compute_cost(self) -> dict[str, Any]: # Resolve effective model names for pricing config = get_config() - provider = config.simulation.provider - pivotal_model = ( - self.config.pivotal_model - or self.config.model - or config.simulation.pivotal_model - or config.simulation.model - or resolve_default_model(provider, "reasoning") - ) - routine_model = ( - self.config.routine_model - or config.simulation.routine_model - or resolve_default_model(provider, "simple") - ) + from ..config import parse_model_string + + strong_model_str = self.config.strong or config.resolve_sim_strong() + fast_model_str = self.config.fast or config.resolve_sim_fast() + + # Strip provider prefix for pricing lookup (pricing is keyed by bare model name) + _, pivotal_model = parse_model_string(strong_model_str) + _, routine_model = parse_model_string(fast_model_str) cost["pivotal_model"] = pivotal_model cost["routine_model"] = routine_model @@ -1141,7 +1329,7 @@ def _compute_cost(self) -> dict[str, Any]: return cost def _export_results(self) -> None: - """Export all results to output directory.""" + """Export compact default artifacts to output directory.""" # Export summary summaries = self.state_manager.get_timestep_summaries() timeline_agg = compute_timeline_aggregates(summaries) @@ -1149,48 +1337,13 @@ def _export_results(self) -> None: with open(self.output_dir / "by_timestep.json", "w") as f: json.dump(timeline_agg, f, indent=2) - # Export final agent states - final_states = self.state_manager.export_final_states() - - # Merge with agent attributes - agent_results = [] - for state in final_states: - agent_id = state["agent_id"] - agent = self.agent_map.get(agent_id, {}) - - agent_results.append( - { - "agent_id": agent_id, - "attributes": { - k: v for k, v in agent.items() if not k.startswith("_") - }, - "final_state": state, - "reasoning_count": ( - 1 if state["last_reasoning_timestep"] >= 0 else 0 - ), - } - ) - - with open(self.output_dir / "agent_states.json", "w") as f: - json.dump(agent_results, f, indent=2) - - # Export outcome distributions - outcome_dists = compute_outcome_distributions( - self.state_manager, - self.scenario.outcomes.suggested_outcomes, - ) - - with open(self.output_dir / "outcome_distributions.json", "w") as f: - json.dump(outcome_dists, f, indent=2) - # Export meta information meta = { "scenario_name": self.scenario.meta.name, "scenario_path": self.config.scenario_path, "population_size": len(self.agents), - "model": self.config.model, - "pivotal_model": self.config.pivotal_model, - "routine_model": self.config.routine_model, + "strong_model": self.config.strong, + "fast_model": self.config.fast, "seed": self.seed, "multi_touch_threshold": self.config.multi_touch_threshold, "completed_at": datetime.now().isoformat(), @@ -1209,9 +1362,9 @@ def _export_results(self) -> None: def run_simulation( scenario_path: str | Path, output_dir: str | Path, - model: str = "", - pivotal_model: str = "", - routine_model: str = "", + study_db_path: str | Path | None = None, + strong: str = "", + fast: str = "", multi_touch_threshold: int = 3, random_seed: int | None = None, on_progress: TimestepProgressCallback | None = None, @@ -1221,6 +1374,13 @@ def run_simulation( tpm_override: int | None = None, chunk_size: int = 50, progress: SimulationProgress | None = None, + run_id: str | None = None, + resume: bool = False, + checkpoint_every_chunks: int = 1, + retention_lite: bool = False, + writer_queue_size: int = 256, + db_write_batch_size: int = 100, + resource_governor: ResourceGovernor | None = None, ) -> SimulationSummary: """Run a simulation from a scenario file. @@ -1229,9 +1389,8 @@ def run_simulation( Args: scenario_path: Path to scenario YAML file output_dir: Directory for results output - model: LLM model for agent reasoning - pivotal_model: Model for pivotal reasoning (default: same as model) - routine_model: Cheap model for routine + classification + strong: Strong model for Pass 1 reasoning (provider/model format) + fast: Fast model for Pass 2 classification (provider/model format) multi_touch_threshold: Re-reason after N new exposures random_seed: Random seed for reproducibility on_progress: Progress callback(timestep, max, status) @@ -1241,12 +1400,48 @@ def run_simulation( tpm_override: Override TPM limit chunk_size: Agents per reasoning chunk for checkpointing progress: Optional SimulationProgress for live display tracking + run_id: Optional run identifier for resume and bookkeeping + resume: Resume a prior run from DB checkpoints + checkpoint_every_chunks: Mark simulation checkpoint every N chunks + retention_lite: Reduce payload volume by dropping full raw reasoning text + writer_queue_size: Maximum buffered chunks waiting for DB writer + db_write_batch_size: Number of chunks applied per DB writer transaction + resource_governor: Optional governor for runtime downshift guardrails Returns: SimulationSummary with results """ scenario_path = Path(scenario_path) output_dir = Path(output_dir) + if resume and not run_id: + raise ValueError("--resume requires --run-id") + + def _reset_runtime_tables(path: Path, run_key: str) -> None: + conn = sqlite3.connect(str(path)) + try: + cur = conn.cursor() + statements = [ + "DELETE FROM agent_states WHERE run_id = ?", + "DELETE FROM exposures WHERE run_id = ?", + "DELETE FROM memory_traces WHERE run_id = ?", + "DELETE FROM timeline WHERE run_id = ?", + "DELETE FROM timestep_summaries WHERE run_id = ?", + "DELETE FROM shared_to WHERE run_id = ?", + "DELETE FROM simulation_metadata WHERE run_id = ?", + ] + for sql in statements: + try: + cur.execute(sql, (run_key,)) + except sqlite3.OperationalError: + # Legacy table shape fallback. + table = sql.split()[2] + cur.execute(f"DELETE FROM {table}") + conn.commit() + except sqlite3.OperationalError: + # First run on this DB may not have simulation tables yet. + pass + finally: + conn.close() # Load scenario scenario = ScenarioSpec.from_yaml(scenario_path) @@ -1257,18 +1452,63 @@ def run_simulation( pop_path = scenario_path.parent / pop_path population_spec = PopulationSpec.from_yaml(pop_path) - # Load agents - agents_path = Path(scenario.meta.agents_file) - if not agents_path.is_absolute(): - agents_path = scenario_path.parent / agents_path - agents = load_agents_json(agents_path) + # Resolve canonical study DB + if study_db_path is None: + if not getattr(scenario.meta, "study_db", None): + raise ValueError( + "Legacy scenario format detected. Rebuild scenario with --study-db." + ) + study_db_resolved = Path(scenario.meta.study_db) + if not study_db_resolved.is_absolute(): + study_db_resolved = scenario_path.parent / study_db_resolved + else: + study_db_resolved = Path(study_db_path) + + if not study_db_resolved.exists(): + raise FileNotFoundError(f"Study DB not found: {study_db_resolved}") + + resolved_run_id = ( + run_id + or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + ) + + with open_study_db(study_db_resolved) as db: + agents = db.get_agents(scenario.meta.population_id) + if not agents: + raise ValueError( + f"No agents for population_id '{scenario.meta.population_id}' in {study_db_resolved}" + ) + network = db.get_network(scenario.meta.network_id) + if not network.get("edges"): + raise ValueError( + f"No network edges for network_id '{scenario.meta.network_id}' in {study_db_resolved}" + ) + db.create_simulation_run( + run_id=resolved_run_id, + scenario_name=scenario.meta.name, + population_id=scenario.meta.population_id, + network_id=scenario.meta.network_id, + config={ + "scenario_path": str(scenario_path), + "output_dir": str(output_dir), + "strong": strong, + "fast": fast, + "multi_touch_threshold": multi_touch_threshold, + "chunk_size": chunk_size, + "checkpoint_every_chunks": checkpoint_every_chunks, + "retention_lite": retention_lite, + "writer_queue_size": writer_queue_size, + "db_write_batch_size": db_write_batch_size, + "resume": resume, + }, + seed=random_seed, + status="running", + ) + db.set_run_metadata(resolved_run_id, "output_dir", str(output_dir)) + db.set_run_metadata(resolved_run_id, "study_db", str(study_db_resolved)) - # Load network - network_path = Path(scenario.meta.network_file) - if not network_path.is_absolute(): - network_path = scenario_path.parent / network_path - with open(network_path) as f: - network = json.load(f) + if not resume: + _reset_runtime_tables(study_db_resolved, resolved_run_id) # Load persona config if provided persona_config = None @@ -1288,26 +1528,22 @@ def run_simulation( config = SimulationRunConfig( scenario_path=str(scenario_path), output_dir=str(output_dir), - model=model, - pivotal_model=pivotal_model, - routine_model=routine_model, + strong=strong, + fast=fast, multi_touch_threshold=multi_touch_threshold, random_seed=random_seed, ) - # Create dual rate limiter (separate limiters for pivotal and routine models) + # Resolve effective model strings for rate limiting from ..config import get_config entropy_config = get_config() - provider = entropy_config.simulation.provider - effective_model = model or entropy_config.simulation.model or "" - effective_pivotal = pivotal_model or effective_model - effective_routine = routine_model or entropy_config.simulation.routine_model or "" + effective_strong = strong or entropy_config.resolve_sim_strong() + effective_fast = fast or entropy_config.resolve_sim_fast() rate_limiter = DualRateLimiter.create( - provider=provider, - pivotal_model=effective_pivotal, - routine_model=effective_routine, + strong_model_string=effective_strong, + fast_model_string=effective_fast, tier=rate_tier, rpm_override=rpm_override, tpm_override=tpm_override, @@ -1323,6 +1559,13 @@ def run_simulation( persona_config=persona_config, rate_limiter=rate_limiter, chunk_size=chunk_size, + state_db_path=study_db_resolved, + run_id=resolved_run_id, + checkpoint_every_chunks=checkpoint_every_chunks, + retention_lite=retention_lite, + writer_queue_size=writer_queue_size, + db_write_batch_size=db_write_batch_size, + resource_governor=resource_governor, ) if on_progress: @@ -1331,4 +1574,23 @@ def run_simulation( if progress: engine.set_progress_state(progress) - return engine.run() + try: + summary = engine.run() + except Exception as e: + with open_study_db(study_db_resolved) as db: + db.update_simulation_run( + run_id=resolved_run_id, + status="failed", + stopped_reason=str(e), + ) + raise + + final_status = "stopped" if summary.stopped_reason else "completed" + with open_study_db(study_db_resolved) as db: + db.update_simulation_run( + run_id=resolved_run_id, + status=final_status, + stopped_reason=summary.stopped_reason, + ) + + return summary diff --git a/extropy/simulation/estimator.py b/extropy/simulation/estimator.py index 376245a..5a4a103 100644 --- a/extropy/simulation/estimator.py +++ b/extropy/simulation/estimator.py @@ -9,7 +9,7 @@ from typing import Any from ..core.models import ScenarioSpec, PopulationSpec -from ..core.pricing import ModelPricing, get_pricing, resolve_default_model +from ..core.pricing import ModelPricing, get_pricing from ..utils.eval_safe import eval_condition, ConditionError @@ -138,9 +138,8 @@ def estimate_simulation_cost( population_spec: PopulationSpec, agents: list[dict[str, Any]], network: dict[str, Any], - provider: str = "openai", - pivotal_model: str = "", - routine_model: str = "", + strong_model: str = "", + fast_model: str = "", multi_touch_threshold: int = 3, ) -> CostEstimate: """Estimate the cost of running a simulation. @@ -153,9 +152,8 @@ def estimate_simulation_cost( population_spec: Population specification agents: List of agent dictionaries network: Network data dict - provider: LLM provider name - pivotal_model: Model for Pass 1 (empty = provider default) - routine_model: Model for Pass 2 (empty = provider cheap tier) + strong_model: Model for Pass 1 (provider/model format, empty = config default) + fast_model: Model for Pass 2 (provider/model format, empty = config default) multi_touch_threshold: Re-reasoning threshold Returns: @@ -167,9 +165,14 @@ def estimate_simulation_cost( share_prob = scenario.spread.share_probability will_share_rate = 0.55 # accounts for conviction-gated sharing - # Resolve models - eff_pivotal = pivotal_model or resolve_default_model(provider, "reasoning") - eff_routine = routine_model or resolve_default_model(provider, "simple") + # Resolve models — strip provider prefix for pricing lookup + from ..config import get_config, parse_model_string + + config = get_config() + eff_strong_str = strong_model or config.resolve_sim_strong() + eff_fast_str = fast_model or config.resolve_sim_fast() + _, eff_pivotal = parse_model_string(eff_strong_str) + _, eff_routine = parse_model_string(eff_fast_str) # Pre-compute seed exposure schedule: timestep -> expected new seed exposures seed_schedule: dict[int, float] = {} diff --git a/extropy/simulation/reasoning.py b/extropy/simulation/reasoning.py index bf3ad52..ef9f824 100644 --- a/extropy/simulation/reasoning.py +++ b/extropy/simulation/reasoning.py @@ -455,8 +455,8 @@ async def _reason_agent_two_pass_async( position_outcome = _get_primary_position_outcome(scenario) # Determine models - main_model = config.model or None # None = provider default - classify_model = config.routine_model or None # None = provider default (cheap) + main_model = config.strong or None # None = provider default + classify_model = config.fast or None # None = provider default (cheap) # === Pass 1: Role-play === pass1_usage = TokenUsage() @@ -687,7 +687,7 @@ def reason_agent( if pass2_schema: pass2_prompt = build_pass2_prompt(reasoning, scenario) - classify_model = config.routine_model or None + classify_model = config.fast or None for attempt in range(config.max_retries): try: @@ -776,69 +776,90 @@ def batch_reason_agents( logger.info(f"[BATCH] Starting two-pass async reasoning for {total} agents") async def run_all(): - # Always use a semaphore to cap concurrent tasks. - # When rate limiter is available, size it from max_safe_concurrent. if rate_limiter: - concurrency = rate_limiter.max_safe_concurrent - # Stagger interval: spread launches across the RPM window - # e.g. 500 RPM → 8.3 req/s → 120ms between launches + target_concurrency = min( + max(1, rate_limiter.max_safe_concurrent), + max(1, max_concurrency), + ) stagger_interval = 60.0 / rate_limiter.pivotal.rpm logger.info( - f"[BATCH] Concurrency cap: {concurrency}, " + f"[BATCH] Concurrency cap: {target_concurrency}, " f"stagger: {stagger_interval * 1000:.0f}ms between launches" ) else: - concurrency = max_concurrency + target_concurrency = max(1, max_concurrency) stagger_interval = 0.0 - semaphore = asyncio.Semaphore(concurrency) completed = [0] + adaptive_concurrency = target_concurrency async def reason_with_pacing( + idx: int, ctx: ReasoningContext, - ) -> tuple[str, ReasoningResponse | None]: - async with semaphore: - start = time.time() - result = await _reason_agent_two_pass_async( - ctx, scenario, config, rate_limiter - ) - elapsed = time.time() - start - completed[0] += 1 - - if result: - logger.info( - f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} done in {elapsed:.2f}s " - f"(position={result.position}, sentiment={result.sentiment}, " - f"conviction={float_to_conviction(result.conviction)})" - ) - else: - logger.warning( - f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} FAILED" - ) - - if on_agent_done: - on_agent_done(ctx.agent_id, result) - - return (ctx.agent_id, result) + ) -> tuple[int, str, ReasoningResponse | None, float]: + start = time.time() + result = await _reason_agent_two_pass_async( + ctx, scenario, config, rate_limiter + ) + elapsed = time.time() - start + completed[0] += 1 - # Stagger task launches to avoid burst of requests hitting API at once. - # Each task is created with a small delay so they don't all enter - # the semaphore simultaneously. - tasks = [] - for i, ctx in enumerate(contexts): - tasks.append(asyncio.create_task(reason_with_pacing(ctx))) - if stagger_interval > 0 and i < concurrency - 1: - # Only stagger the first batch — after that the semaphore - # naturally gates as tasks complete and new ones enter - await asyncio.sleep(stagger_interval) + if result: + logger.info( + f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} done in {elapsed:.2f}s " + f"(position={result.position}, sentiment={result.sentiment}, " + f"conviction={float_to_conviction(result.conviction)})" + ) + else: + logger.warning(f"[BATCH] {completed[0]}/{total}: {ctx.agent_id} FAILED") + + if on_agent_done: + on_agent_done(ctx.agent_id, result) + + return (idx, ctx.agent_id, result, elapsed) + + results: list[tuple[str, ReasoningResponse | None] | None] = [None] * total + next_idx = 0 + while next_idx < total: + batch_end = min(total, next_idx + adaptive_concurrency) + batch_contexts = contexts[next_idx:batch_end] + tasks = [] + for local_offset, ctx in enumerate(batch_contexts): + idx = next_idx + local_offset + tasks.append(asyncio.create_task(reason_with_pacing(idx, ctx))) + if stagger_interval > 0 and local_offset < len(batch_contexts) - 1: + await asyncio.sleep(stagger_interval) + + batch_results = await asyncio.gather(*tasks) + latencies: list[float] = [] + failures = 0 + for idx, agent_id, result, elapsed in batch_results: + results[idx] = (agent_id, result) + latencies.append(elapsed) + if result is None: + failures += 1 + + # Adaptive concurrency control: + # - high error rate or high latency => downshift + # - clean/fast batches => cautiously upshift + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + fail_rate = failures / len(batch_results) if batch_results else 0.0 + if fail_rate >= 0.2 or avg_latency >= 20.0: + adaptive_concurrency = max(1, int(adaptive_concurrency * 0.7)) + elif fail_rate == 0 and avg_latency <= 8.0: + adaptive_concurrency = min(target_concurrency, adaptive_concurrency + 1) - results = await asyncio.gather(*tasks) + logger.info( + f"[BATCH] Adaptive concurrency={adaptive_concurrency} " + f"(avg_latency={avg_latency:.2f}s, fail_rate={fail_rate:.0%})" + ) + next_idx = batch_end # Close the async HTTP client before the event loop shuts down. # Without this, orphaned httpx connections produce "Event loop is # closed" errors during garbage collection. await close_simulation_provider() - return results + return [r for r in results if r is not None] batch_start = time.time() results = asyncio.run(run_all()) diff --git a/extropy/simulation/state.py b/extropy/simulation/state.py index 94a4350..22cbd4b 100644 --- a/extropy/simulation/state.py +++ b/extropy/simulation/state.py @@ -27,16 +27,23 @@ class StateManager: for frequently accessed data. """ - def __init__(self, db_path: Path | str, agents: list[dict[str, Any]] | None = None): + def __init__( + self, + db_path: Path | str, + agents: list[dict[str, Any]] | None = None, + run_id: str = "default", + ): """Initialize state manager with database path. Args: db_path: Path to SQLite database file agents: Optional list of agents to initialize + run_id: Simulation run scope key """ self.db_path = Path(db_path) + self.run_id = run_id self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path)) + self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) self.conn.row_factory = sqlite3.Row self.conn.execute("PRAGMA foreign_keys = ON") @@ -54,7 +61,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS agent_states ( - agent_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, aware INTEGER DEFAULT 0, exposure_count INTEGER DEFAULT 0, last_reasoning_timestep INTEGER DEFAULT -1, @@ -73,7 +81,8 @@ def _create_schema(self) -> None: private_conviction REAL, private_outcomes_json TEXT, raw_reasoning TEXT, - updated_at INTEGER DEFAULT 0 + updated_at INTEGER DEFAULT 0, + PRIMARY KEY (run_id, agent_id) ) """ ) @@ -82,6 +91,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS exposures ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, @@ -89,7 +99,7 @@ def _create_schema(self) -> None: source_agent_id TEXT, content TEXT, credibility REAL, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -98,13 +108,14 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS memory_traces ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT, timestep INTEGER, sentiment REAL, conviction REAL, summary TEXT, - FOREIGN KEY (agent_id) REFERENCES agent_states(agent_id) + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) ) """ ) @@ -113,6 +124,7 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timeline ( + run_id TEXT NOT NULL, id INTEGER PRIMARY KEY AUTOINCREMENT, timestep INTEGER, event_type TEXT, @@ -127,7 +139,8 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS timestep_summaries ( - timestep INTEGER PRIMARY KEY, + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, new_exposures INTEGER, agents_reasoned INTEGER, shares_occurred INTEGER, @@ -136,7 +149,8 @@ def _create_schema(self) -> None: position_distribution_json TEXT, average_sentiment REAL, average_conviction REAL, - sentiment_variance REAL + sentiment_variance REAL, + PRIMARY KEY (run_id, timestep) ) """ ) @@ -145,37 +159,37 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_agent - ON exposures(agent_id) + ON exposures(run_id, agent_id) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_exposures_timestep - ON exposures(timestep) + ON exposures(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timeline_timestep - ON timeline(timestep) + ON timeline(run_id, timestep) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_aware - ON agent_states(aware) + ON agent_states(run_id, aware) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_agent_states_will_share - ON agent_states(will_share) + ON agent_states(run_id, will_share) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_memory_traces_agent - ON memory_traces(agent_id) + ON memory_traces(run_id, agent_id) """ ) @@ -183,18 +197,19 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS shared_to ( + run_id TEXT NOT NULL, source_agent_id TEXT, target_agent_id TEXT, timestep INTEGER, position TEXT, - PRIMARY KEY (source_agent_id, target_agent_id) + PRIMARY KEY (run_id, source_agent_id, target_agent_id) ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_shared_to_source - ON shared_to(source_agent_id) + ON shared_to(run_id, source_agent_id) """ ) @@ -202,8 +217,11 @@ def _create_schema(self) -> None: cursor.execute( """ CREATE TABLE IF NOT EXISTS simulation_metadata ( - key TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + key TEXT NOT NULL, value TEXT + , + PRIMARY KEY (run_id, key) ) """ ) @@ -215,6 +233,13 @@ def _upgrade_schema(self) -> None: cursor = self.conn.cursor() migrations = [ + ("agent_states", "run_id", "TEXT DEFAULT 'default'"), + ("exposures", "run_id", "TEXT DEFAULT 'default'"), + ("memory_traces", "run_id", "TEXT DEFAULT 'default'"), + ("timeline", "run_id", "TEXT DEFAULT 'default'"), + ("timestep_summaries", "run_id", "TEXT DEFAULT 'default'"), + ("shared_to", "run_id", "TEXT DEFAULT 'default'"), + ("simulation_metadata", "run_id", "TEXT DEFAULT 'default'"), ("agent_states", "conviction", "REAL"), ("agent_states", "public_statement", "TEXT"), ("timestep_summaries", "average_conviction", "REAL"), @@ -237,6 +262,28 @@ def _upgrade_schema(self) -> None: # Column already exists pass + cursor.execute( + "UPDATE agent_states SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE exposures SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE memory_traces SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE timeline SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE timestep_summaries SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE shared_to SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + cursor.execute( + "UPDATE simulation_metadata SET run_id = COALESCE(run_id, 'default') WHERE run_id IS NULL" + ) + self.conn.commit() @contextmanager @@ -264,10 +311,10 @@ def initialize_agents(self, agents: list[dict[str, Any]]) -> None: agent_id = agent.get("_id", str(agent.get("id", ""))) cursor.execute( """ - INSERT OR IGNORE INTO agent_states (agent_id) - VALUES (?) + INSERT OR IGNORE INTO agent_states (run_id, agent_id) + VALUES (?, ?) """, - (agent_id,), + (self.run_id, agent_id), ) self.conn.commit() @@ -286,9 +333,9 @@ def get_agent_state(self, agent_id: str) -> AgentState: # Get main state cursor.execute( """ - SELECT * FROM agent_states WHERE agent_id = ? + SELECT * FROM agent_states WHERE run_id = ? AND agent_id = ? """, - (agent_id,), + (self.run_id, agent_id), ) row = cursor.fetchone() @@ -298,9 +345,11 @@ def get_agent_state(self, agent_id: str) -> AgentState: # Get exposure history cursor.execute( """ - SELECT * FROM exposures WHERE agent_id = ? ORDER BY timestep + SELECT * FROM exposures + WHERE run_id = ? AND agent_id = ? + ORDER BY timestep """, - (agent_id,), + (self.run_id, agent_id), ) exposure_rows = cursor.fetchall() @@ -386,35 +435,44 @@ def get_agent_state(self, agent_id: str) -> AgentState: def get_unaware_agents(self) -> list[str]: """Get IDs of agents who haven't been exposed yet.""" cursor = self.conn.cursor() - cursor.execute("SELECT agent_id FROM agent_states WHERE aware = 0") + cursor.execute( + "SELECT agent_id FROM agent_states WHERE run_id = ? AND aware = 0", + (self.run_id,), + ) return [row["agent_id"] for row in cursor.fetchall()] def get_aware_agents(self) -> list[str]: """Get IDs of agents who are aware of the event.""" cursor = self.conn.cursor() - cursor.execute("SELECT agent_id FROM agent_states WHERE aware = 1") + cursor.execute( + "SELECT agent_id FROM agent_states WHERE run_id = ? AND aware = 1", + (self.run_id,), + ) return [row["agent_id"] for row in cursor.fetchall()] def get_sharers(self) -> list[str]: """Get IDs of agents who will share.""" cursor = self.conn.cursor() cursor.execute( - "SELECT agent_id FROM agent_states WHERE aware = 1 AND will_share = 1" + "SELECT agent_id FROM agent_states WHERE run_id = ? AND aware = 1 AND will_share = 1", + (self.run_id,), ) return [row["agent_id"] for row in cursor.fetchall()] def get_all_agent_ids(self) -> list[str]: """Get all agent IDs in the database.""" cursor = self.conn.cursor() - cursor.execute("SELECT agent_id FROM agent_states") + cursor.execute( + "SELECT agent_id FROM agent_states WHERE run_id = ?", (self.run_id,) + ) return [row["agent_id"] for row in cursor.fetchall()] def get_network_hop_depth(self, agent_id: str) -> int | None: """Get the minimum network hop depth from a seed exposure for an agent.""" cursor = self.conn.cursor() cursor.execute( - "SELECT network_hop_depth FROM agent_states WHERE agent_id = ?", - (agent_id,), + "SELECT network_hop_depth FROM agent_states WHERE run_id = ? AND agent_id = ?", + (self.run_id, agent_id), ) row = cursor.fetchone() if not row: @@ -443,8 +501,9 @@ def get_agents_to_reason(self, timestep: int, threshold: int) -> list[str]: cursor.execute( """ SELECT agent_id FROM agent_states - WHERE aware = 1 AND last_reasoning_timestep < 0 - """ + WHERE run_id = ? AND aware = 1 AND last_reasoning_timestep < 0 + """, + (self.run_id,), ) never_reasoned = [row["agent_id"] for row in cursor.fetchall()] @@ -456,14 +515,17 @@ def get_agents_to_reason(self, timestep: int, threshold: int) -> list[str]: COUNT(DISTINCT e.source_agent_id) as unique_sources FROM agent_states s JOIN exposures e - ON e.agent_id = s.agent_id + ON e.run_id = s.run_id + AND e.agent_id = s.agent_id AND e.timestep > s.last_reasoning_timestep AND e.source_agent_id IS NOT NULL - WHERE s.aware = 1 + WHERE s.run_id = ? + AND s.aware = 1 AND s.last_reasoning_timestep >= 0 AND s.committed = 0 GROUP BY s.agent_id - """ + """, + (self.run_id,), ) multi_touch = [] @@ -490,10 +552,10 @@ def record_share( cursor.execute( """ INSERT OR REPLACE INTO shared_to - (source_agent_id, target_agent_id, timestep, position) - VALUES (?, ?, ?, ?) + (run_id, source_agent_id, target_agent_id, timestep, position) + VALUES (?, ?, ?, ?, ?) """, - (source_id, target_id, timestep, position), + (self.run_id, source_id, target_id, timestep, position), ) def get_unshared_neighbors( @@ -518,10 +580,11 @@ def get_unshared_neighbors( f""" SELECT target_agent_id, position FROM shared_to - WHERE source_agent_id = ? + WHERE run_id = ? + AND source_agent_id = ? AND target_agent_id IN ({placeholders}) """, - [source_id] + neighbor_ids, + [self.run_id, source_id] + neighbor_ids, ) already_shared = { @@ -547,8 +610,8 @@ def save_metadata(self, key: str, value: str) -> None: """ cursor = self.conn.cursor() cursor.execute( - "INSERT OR REPLACE INTO simulation_metadata (key, value) VALUES (?, ?)", - (key, value), + "INSERT OR REPLACE INTO simulation_metadata (run_id, key, value) VALUES (?, ?, ?)", + (self.run_id, key, value), ) self.conn.commit() @@ -562,7 +625,10 @@ def get_metadata(self, key: str) -> str | None: Value string or None if not found """ cursor = self.conn.cursor() - cursor.execute("SELECT value FROM simulation_metadata WHERE key = ?", (key,)) + cursor.execute( + "SELECT value FROM simulation_metadata WHERE run_id = ? AND key = ?", + (self.run_id, key), + ) row = cursor.fetchone() return row["value"] if row else None @@ -573,7 +639,10 @@ def delete_metadata(self, key: str) -> None: key: Metadata key to delete """ cursor = self.conn.cursor() - cursor.execute("DELETE FROM simulation_metadata WHERE key = ?", (key,)) + cursor.execute( + "DELETE FROM simulation_metadata WHERE run_id = ? AND key = ?", + (self.run_id, key), + ) self.conn.commit() def get_last_completed_timestep(self) -> int: @@ -583,7 +652,10 @@ def get_last_completed_timestep(self) -> int: Max timestep from timestep_summaries, or -1 if none exist. """ cursor = self.conn.cursor() - cursor.execute("SELECT MAX(timestep) as max_ts FROM timestep_summaries") + cursor.execute( + "SELECT MAX(timestep) as max_ts FROM timestep_summaries WHERE run_id = ?", + (self.run_id,), + ) row = cursor.fetchone() if row and row["max_ts"] is not None: return row["max_ts"] @@ -625,8 +697,8 @@ def get_agents_already_reasoned_this_timestep(self, timestep: int) -> set[str]: """ cursor = self.conn.cursor() cursor.execute( - "SELECT agent_id FROM agent_states WHERE last_reasoning_timestep = ?", - (timestep,), + "SELECT agent_id FROM agent_states WHERE run_id = ? AND last_reasoning_timestep = ?", + (self.run_id, timestep), ) return {row["agent_id"] for row in cursor.fetchall()} @@ -642,10 +714,19 @@ def record_exposure(self, agent_id: str, exposure: ExposureRecord) -> None: # Insert exposure record cursor.execute( """ - INSERT INTO exposures (agent_id, timestep, channel, source_agent_id, content, credibility) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO exposures ( + run_id, + agent_id, + timestep, + channel, + source_agent_id, + content, + credibility + ) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( + self.run_id, agent_id, exposure.timestep, exposure.channel, @@ -675,13 +756,15 @@ def record_exposure(self, agent_id: str, exposure: ExposureRecord) -> None: ELSE MIN(network_hop_depth, ?) END, updated_at = ? - WHERE agent_id = ? + WHERE run_id = ? + AND agent_id = ? """, ( new_hop_depth, new_hop_depth, new_hop_depth, exposure.timestep, + self.run_id, agent_id, ), ) @@ -724,7 +807,8 @@ def apply_conviction_decay( ELSE will_share END, updated_at = ? - WHERE aware = 1 + WHERE run_id = ? + AND aware = 1 AND conviction IS NOT NULL AND conviction > ? AND last_reasoning_timestep < ? @@ -738,6 +822,7 @@ def apply_conviction_decay( decay_multiplier, sharing_threshold, timestep, + self.run_id, sharing_threshold, timestep, ), @@ -783,7 +868,8 @@ def update_agent_state( raw_reasoning = ?, last_reasoning_timestep = ?, updated_at = ? - WHERE agent_id = ? + WHERE run_id = ? + AND agent_id = ? """, ( state.position, @@ -808,6 +894,7 @@ def update_agent_state( state.raw_reasoning, timestep, timestep, + self.run_id, agent_id, ), ) @@ -851,7 +938,8 @@ def batch_update_states( raw_reasoning = ?, last_reasoning_timestep = ?, updated_at = ? - WHERE agent_id = ? + WHERE run_id = ? + AND agent_id = ? """, ( state.position, @@ -876,6 +964,7 @@ def batch_update_states( state.raw_reasoning, timestep, timestep, + self.run_id, agent_id, ), ) @@ -895,10 +984,11 @@ def save_memory_entry(self, agent_id: str, entry: MemoryEntry) -> None: # Insert new entry cursor.execute( """ - INSERT INTO memory_traces (agent_id, timestep, sentiment, conviction, summary) - VALUES (?, ?, ?, ?, ?) + INSERT INTO memory_traces (run_id, agent_id, timestep, sentiment, conviction, summary) + VALUES (?, ?, ?, ?, ?, ?) """, ( + self.run_id, agent_id, entry.timestep, entry.sentiment, @@ -913,12 +1003,12 @@ def save_memory_entry(self, agent_id: str, entry: MemoryEntry) -> None: DELETE FROM memory_traces WHERE id NOT IN ( SELECT id FROM memory_traces - WHERE agent_id = ? + WHERE run_id = ? AND agent_id = ? ORDER BY timestep DESC LIMIT 3 - ) AND agent_id = ? + ) AND run_id = ? AND agent_id = ? """, - (agent_id, agent_id), + (self.run_id, agent_id, self.run_id, agent_id), ) def get_memory_traces(self, agent_id: str) -> list[MemoryEntry]: @@ -934,10 +1024,10 @@ def get_memory_traces(self, agent_id: str) -> list[MemoryEntry]: cursor.execute( """ SELECT * FROM memory_traces - WHERE agent_id = ? + WHERE run_id = ? AND agent_id = ? ORDER BY timestep ASC """, - (agent_id,), + (self.run_id, agent_id), ) return [ @@ -960,10 +1050,18 @@ def log_event(self, event: SimulationEvent) -> None: cursor.execute( """ - INSERT INTO timeline (timestep, event_type, agent_id, details_json, wall_timestamp) - VALUES (?, ?, ?, ?, ?) + INSERT INTO timeline ( + run_id, + timestep, + event_type, + agent_id, + details_json, + wall_timestamp + ) + VALUES (?, ?, ?, ?, ?, ?) """, ( + self.run_id, event.timestep, event.event_type.value, event.agent_id, @@ -976,13 +1074,19 @@ def get_exposure_rate(self) -> float: """Get fraction of population that is aware.""" cursor = self.conn.cursor() - cursor.execute("SELECT COUNT(*) as total FROM agent_states") + cursor.execute( + "SELECT COUNT(*) as total FROM agent_states WHERE run_id = ?", + (self.run_id,), + ) total = cursor.fetchone()["total"] if total == 0: return 0.0 - cursor.execute("SELECT COUNT(*) as aware FROM agent_states WHERE aware = 1") + cursor.execute( + "SELECT COUNT(*) as aware FROM agent_states WHERE run_id = ? AND aware = 1", + (self.run_id,), + ) aware = cursor.fetchone()["aware"] return aware / total @@ -995,9 +1099,11 @@ def get_position_distribution(self) -> dict[str, int]: """ SELECT COALESCE(private_position, position) as position, COUNT(*) as cnt FROM agent_states - WHERE COALESCE(private_position, position) IS NOT NULL + WHERE run_id = ? + AND COALESCE(private_position, position) IS NOT NULL GROUP BY COALESCE(private_position, position) - """ + """, + (self.run_id,), ) return {row["position"]: row["cnt"] for row in cursor.fetchall()} @@ -1010,8 +1116,10 @@ def get_average_sentiment(self) -> float | None: """ SELECT AVG(COALESCE(private_sentiment, sentiment)) as avg_sentiment FROM agent_states - WHERE COALESCE(private_sentiment, sentiment) IS NOT NULL - """ + WHERE run_id = ? + AND COALESCE(private_sentiment, sentiment) IS NOT NULL + """, + (self.run_id,), ) row = cursor.fetchone() @@ -1025,8 +1133,10 @@ def get_average_conviction(self) -> float | None: """ SELECT AVG(COALESCE(private_conviction, conviction)) as avg_conviction FROM agent_states - WHERE COALESCE(private_conviction, conviction) IS NOT NULL - """ + WHERE run_id = ? + AND COALESCE(private_conviction, conviction) IS NOT NULL + """, + (self.run_id,), ) row = cursor.fetchone() @@ -1040,8 +1150,10 @@ def get_sentiment_variance(self) -> float | None: """ SELECT AVG(COALESCE(private_sentiment, sentiment)) as mean_s, COUNT(*) as cnt FROM agent_states - WHERE COALESCE(private_sentiment, sentiment) IS NOT NULL - """ + WHERE run_id = ? + AND COALESCE(private_sentiment, sentiment) IS NOT NULL + """, + (self.run_id,), ) row = cursor.fetchone() @@ -1056,9 +1168,10 @@ def get_sentiment_variance(self) -> float | None: * (COALESCE(private_sentiment, sentiment) - ?) ) as variance FROM agent_states - WHERE COALESCE(private_sentiment, sentiment) IS NOT NULL + WHERE run_id = ? + AND COALESCE(private_sentiment, sentiment) IS NOT NULL """, - (mean, mean), + (mean, mean, self.run_id), ) var_row = cursor.fetchone() return var_row["variance"] if var_row else None @@ -1074,12 +1187,13 @@ def save_timestep_summary(self, summary: TimestepSummary) -> None: cursor.execute( """ INSERT OR REPLACE INTO timestep_summaries - (timestep, new_exposures, agents_reasoned, shares_occurred, + (run_id, timestep, new_exposures, agents_reasoned, shares_occurred, state_changes, exposure_rate, position_distribution_json, average_sentiment, average_conviction, sentiment_variance) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( + self.run_id, summary.timestep, summary.new_exposures, summary.agents_reasoned, @@ -1099,8 +1213,11 @@ def get_timestep_summaries(self) -> list[TimestepSummary]: cursor.execute( """ - SELECT * FROM timestep_summaries ORDER BY timestep - """ + SELECT * FROM timestep_summaries + WHERE run_id = ? + ORDER BY timestep + """, + (self.run_id,), ) summaries = [] @@ -1137,7 +1254,7 @@ def export_final_states(self) -> list[dict[str, Any]]: """ cursor = self.conn.cursor() - cursor.execute("SELECT * FROM agent_states") + cursor.execute("SELECT * FROM agent_states WHERE run_id = ?", (self.run_id,)) agent_rows = cursor.fetchall() states = [] @@ -1145,8 +1262,10 @@ def export_final_states(self) -> list[dict[str, Any]]: """ SELECT agent_id, COUNT(*) as cnt FROM exposures + WHERE run_id = ? GROUP BY agent_id - """ + """, + (self.run_id,), ) exposure_counts = {row["agent_id"]: row["cnt"] for row in cursor.fetchall()} @@ -1227,7 +1346,10 @@ def export_timeline(self) -> list[dict[str, Any]]: """ cursor = self.conn.cursor() - cursor.execute("SELECT * FROM timeline ORDER BY timestep, id") + cursor.execute( + "SELECT * FROM timeline WHERE run_id = ? ORDER BY timestep, id", + (self.run_id,), + ) events = [] for row in cursor.fetchall(): @@ -1253,7 +1375,10 @@ def export_timeline(self) -> list[dict[str, Any]]: def get_population_count(self) -> int: """Get total number of agents.""" cursor = self.conn.cursor() - cursor.execute("SELECT COUNT(*) as cnt FROM agent_states") + cursor.execute( + "SELECT COUNT(*) as cnt FROM agent_states WHERE run_id = ?", + (self.run_id,), + ) return cursor.fetchone()["cnt"] def close(self) -> None: diff --git a/extropy/storage/__init__.py b/extropy/storage/__init__.py new file mode 100644 index 0000000..a1b35ab --- /dev/null +++ b/extropy/storage/__init__.py @@ -0,0 +1,18 @@ +"""Storage layer for canonical study database.""" + +from .study_db import StudyDB, open_study_db +from .schemas import ( + AgentDBRecord, + NetworkEdgeDBRecord, + ChatMessagePayload, + ReadOnlySQLRequest, +) + +__all__ = [ + "StudyDB", + "open_study_db", + "AgentDBRecord", + "NetworkEdgeDBRecord", + "ChatMessagePayload", + "ReadOnlySQLRequest", +] diff --git a/extropy/storage/schemas.py b/extropy/storage/schemas.py new file mode 100644 index 0000000..198e122 --- /dev/null +++ b/extropy/storage/schemas.py @@ -0,0 +1,46 @@ +"""Pydantic schemas for canonical study DB payloads.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field, ConfigDict + + +class AgentDBRecord(BaseModel): + """Validated representation of an agent row in the study DB.""" + + population_id: str + agent_id: str + attrs_json: dict[str, Any] + sample_run_id: str + + +class NetworkEdgeDBRecord(BaseModel): + """Validated representation of a network edge row in the study DB.""" + + network_id: str + source_id: str + target_id: str + weight: float + edge_type: str + influence_st: float | None = None + influence_ts: float | None = None + + +class ChatMessagePayload(BaseModel): + """Validated chat message payload persisted in chat_messages.""" + + role: str = Field(min_length=1) + content: str = Field(min_length=1) + citations: dict[str, Any] = Field(default_factory=dict) + token_usage: dict[str, Any] = Field(default_factory=dict) + + +class ReadOnlySQLRequest(BaseModel): + """Read-only SQL request contract for query CLI.""" + + model_config = ConfigDict(str_strip_whitespace=True) + + sql: str = Field(min_length=1) + limit: int = Field(default=1000, ge=1) diff --git a/extropy/storage/study_db.py b/extropy/storage/study_db.py new file mode 100644 index 0000000..627d288 --- /dev/null +++ b/extropy/storage/study_db.py @@ -0,0 +1,885 @@ +"""Canonical study database storage for Extropy. + +This module provides the schema and helper operations for ``study.db``. +""" + +from __future__ import annotations + +import json +import sqlite3 +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any + +from .schemas import AgentDBRecord, NetworkEdgeDBRecord, ChatMessagePayload + + +def _now_iso() -> str: + return datetime.now().isoformat() + + +def _dumps(data: Any) -> str: + return json.dumps(data, default=str) + + +class StudyDB: + """SQLite-backed canonical study store.""" + + def __init__(self, path: Path | str): + self.path = Path(path) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.conn = sqlite3.connect(str(self.path), check_same_thread=False) + self.conn.row_factory = sqlite3.Row + self._set_pragmas() + self.init_schema() + + def _set_pragmas(self) -> None: + cursor = self.conn.cursor() + cursor.execute("PRAGMA foreign_keys = ON") + cursor.execute("PRAGMA journal_mode = WAL") + cursor.execute("PRAGMA synchronous = NORMAL") + cursor.execute("PRAGMA temp_store = MEMORY") + self.conn.commit() + + def init_schema(self) -> None: + """Create canonical schema and indexes.""" + cursor = self.conn.cursor() + cursor.executescript( + """ + CREATE TABLE IF NOT EXISTS study_meta ( + key TEXT PRIMARY KEY, + value TEXT + ); + + CREATE TABLE IF NOT EXISTS population_specs ( + population_id TEXT PRIMARY KEY, + spec_yaml TEXT NOT NULL, + source_path TEXT, + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS sample_runs ( + sample_run_id TEXT PRIMARY KEY, + population_id TEXT NOT NULL, + seed INTEGER, + count INTEGER, + created_at TEXT NOT NULL, + meta_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS agents ( + population_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + attrs_json TEXT NOT NULL, + sample_run_id TEXT NOT NULL, + PRIMARY KEY (population_id, agent_id) + ); + + CREATE TABLE IF NOT EXISTS network_runs ( + network_run_id TEXT PRIMARY KEY, + population_id TEXT NOT NULL, + network_id TEXT NOT NULL, + config_json TEXT NOT NULL, + seed INTEGER, + candidate_mode TEXT NOT NULL, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + completed_at TEXT, + meta_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS network_edges ( + network_id TEXT NOT NULL, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + weight REAL NOT NULL, + edge_type TEXT NOT NULL, + influence_st REAL, + influence_ts REAL, + PRIMARY KEY (network_id, source_id, target_id) + ); + + CREATE TABLE IF NOT EXISTS network_metrics ( + network_id TEXT PRIMARY KEY, + metrics_json TEXT NOT NULL, + computed_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS network_similarity_jobs ( + job_id TEXT PRIMARY KEY, + network_run_id TEXT NOT NULL, + signature_json TEXT NOT NULL, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS network_similarity_chunks ( + job_id TEXT NOT NULL, + chunk_start INTEGER NOT NULL, + chunk_end INTEGER NOT NULL, + status TEXT NOT NULL, + pair_count INTEGER NOT NULL DEFAULT 0, + updated_at TEXT NOT NULL, + PRIMARY KEY (job_id, chunk_start) + ); + + CREATE TABLE IF NOT EXISTS network_similarity_pairs ( + job_id TEXT NOT NULL, + i INTEGER NOT NULL, + j INTEGER NOT NULL, + sim REAL NOT NULL, + PRIMARY KEY (job_id, i, j) + ) WITHOUT ROWID; + + CREATE TABLE IF NOT EXISTS simulation_runs ( + run_id TEXT PRIMARY KEY, + scenario_name TEXT, + population_id TEXT NOT NULL, + network_id TEXT NOT NULL, + config_json TEXT NOT NULL, + seed INTEGER, + status TEXT NOT NULL, + started_at TEXT NOT NULL, + completed_at TEXT, + stopped_reason TEXT + ); + + CREATE TABLE IF NOT EXISTS agent_states ( + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + aware INTEGER DEFAULT 0, + exposure_count INTEGER DEFAULT 0, + last_reasoning_timestep INTEGER DEFAULT -1, + position TEXT, + sentiment REAL, + conviction REAL, + public_statement TEXT, + action_intent TEXT, + will_share INTEGER DEFAULT 0, + outcomes_json TEXT, + public_position TEXT, + public_sentiment REAL, + public_conviction REAL, + private_position TEXT, + private_sentiment REAL, + private_conviction REAL, + private_outcomes_json TEXT, + raw_reasoning TEXT, + committed INTEGER DEFAULT 0, + network_hop_depth INTEGER, + updated_at INTEGER DEFAULT 0, + PRIMARY KEY (run_id, agent_id) + ); + + CREATE TABLE IF NOT EXISTS exposures ( + run_id TEXT NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT, + timestep INTEGER, + channel TEXT, + source_agent_id TEXT, + content TEXT, + credibility REAL, + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) + ); + + CREATE TABLE IF NOT EXISTS memory_traces ( + run_id TEXT NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT, + timestep INTEGER, + sentiment REAL, + conviction REAL, + summary TEXT, + FOREIGN KEY (run_id, agent_id) REFERENCES agent_states(run_id, agent_id) + ); + + CREATE TABLE IF NOT EXISTS timeline ( + run_id TEXT NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestep INTEGER, + event_type TEXT, + agent_id TEXT, + details_json TEXT, + wall_timestamp TEXT + ); + + CREATE TABLE IF NOT EXISTS timestep_summaries ( + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, + new_exposures INTEGER, + agents_reasoned INTEGER, + shares_occurred INTEGER, + state_changes INTEGER, + exposure_rate REAL, + position_distribution_json TEXT, + average_sentiment REAL, + average_conviction REAL, + sentiment_variance REAL, + PRIMARY KEY (run_id, timestep) + ); + + CREATE TABLE IF NOT EXISTS shared_to ( + run_id TEXT NOT NULL, + source_agent_id TEXT, + target_agent_id TEXT, + timestep INTEGER, + position TEXT, + PRIMARY KEY (run_id, source_agent_id, target_agent_id) + ); + + CREATE TABLE IF NOT EXISTS simulation_metadata ( + run_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT + , + PRIMARY KEY (run_id, key) + ); + + CREATE TABLE IF NOT EXISTS run_metadata ( + run_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + PRIMARY KEY (run_id, key) + ); + + CREATE TABLE IF NOT EXISTS simulation_checkpoints ( + run_id TEXT NOT NULL, + timestep INTEGER NOT NULL, + chunk_index INTEGER NOT NULL, + status TEXT NOT NULL, + updated_at TEXT NOT NULL, + PRIMARY KEY (run_id, timestep, chunk_index) + ); + + CREATE TABLE IF NOT EXISTS chat_sessions ( + session_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + mode TEXT NOT NULL, + created_at TEXT NOT NULL, + closed_at TEXT, + meta_json TEXT + ); + + CREATE TABLE IF NOT EXISTS chat_messages ( + session_id TEXT NOT NULL, + turn_index INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + citations_json TEXT, + token_usage_json TEXT, + created_at TEXT NOT NULL, + PRIMARY KEY (session_id, turn_index) + ); + + CREATE TABLE IF NOT EXISTS chat_artifacts ( + session_id TEXT NOT NULL, + key TEXT NOT NULL, + value_json TEXT NOT NULL, + PRIMARY KEY (session_id, key) + ); + + CREATE INDEX IF NOT EXISTS idx_agents_population ON agents(population_id); + CREATE INDEX IF NOT EXISTS idx_network_edges_src ON network_edges(network_id, source_id); + CREATE INDEX IF NOT EXISTS idx_network_edges_tgt ON network_edges(network_id, target_id); + CREATE INDEX IF NOT EXISTS idx_net_sim_chunks_status ON network_similarity_chunks(job_id, status); + CREATE INDEX IF NOT EXISTS idx_sim_ckpt ON simulation_checkpoints(run_id, timestep, chunk_index); + CREATE INDEX IF NOT EXISTS idx_chat_session_agent ON chat_sessions(run_id, agent_id); + CREATE INDEX IF NOT EXISTS idx_agent_states_aware ON agent_states(run_id, aware); + CREATE INDEX IF NOT EXISTS idx_agent_states_will_share ON agent_states(run_id, will_share); + CREATE INDEX IF NOT EXISTS idx_agent_states_last_reasoning ON agent_states(run_id, last_reasoning_timestep); + CREATE INDEX IF NOT EXISTS idx_agent_states_run_awws + ON agent_states(run_id, aware, will_share, last_reasoning_timestep); + CREATE INDEX IF NOT EXISTS idx_exposures_agent_timestep ON exposures(run_id, agent_id, timestep); + CREATE INDEX IF NOT EXISTS idx_timeline_timestep ON timeline(run_id, timestep); + CREATE INDEX IF NOT EXISTS idx_shared_to_source ON shared_to(run_id, source_agent_id); + """ + ) + self.conn.commit() + + def close(self) -> None: + self.conn.close() + + def __enter__(self) -> "StudyDB": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.close() + return False + + def save_population_spec( + self, + population_id: str, + spec_yaml: str, + source_path: str | None, + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO population_specs (population_id, spec_yaml, source_path, created_at) + VALUES (?, ?, ?, ?) + """, + (population_id, spec_yaml, source_path, _now_iso()), + ) + self.conn.commit() + + def get_population_spec_yaml(self, population_id: str) -> str | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT spec_yaml FROM population_specs WHERE population_id = ?", + (population_id,), + ) + row = cursor.fetchone() + return str(row["spec_yaml"]) if row else None + + def save_sample_result( + self, + population_id: str, + agents: list[dict[str, Any]], + meta: dict[str, Any], + seed: int | None = None, + sample_run_id: str | None = None, + ) -> str: + run_id = sample_run_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + + cursor.execute( + """ + INSERT OR REPLACE INTO sample_runs + (sample_run_id, population_id, seed, count, created_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?) + """, + (run_id, population_id, seed, len(agents), _now_iso(), _dumps(meta)), + ) + + cursor.execute("DELETE FROM agents WHERE population_id = ?", (population_id,)) + + rows = [] + for i, agent in enumerate(agents): + agent_id = str(agent.get("_id", f"agent_{i}")) + row_agent = dict(agent) + row_agent["_id"] = agent_id + rec = AgentDBRecord( + population_id=population_id, + agent_id=agent_id, + attrs_json=row_agent, + sample_run_id=run_id, + ) + rows.append( + ( + rec.population_id, + rec.agent_id, + _dumps(rec.attrs_json), + rec.sample_run_id, + ) + ) + + cursor.executemany( + """ + INSERT INTO agents (population_id, agent_id, attrs_json, sample_run_id) + VALUES (?, ?, ?, ?) + """, + rows, + ) + self.conn.commit() + return run_id + + def get_agents(self, population_id: str) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT attrs_json + FROM agents + WHERE population_id = ? + ORDER BY agent_id + """, + (population_id,), + ) + agents = [] + for row in cursor.fetchall(): + try: + agents.append(json.loads(row["attrs_json"])) + except json.JSONDecodeError: + continue + return agents + + def get_agent_count(self, population_id: str) -> int: + cursor = self.conn.cursor() + cursor.execute( + "SELECT COUNT(*) AS cnt FROM agents WHERE population_id = ?", + (population_id,), + ) + row = cursor.fetchone() + return int(row["cnt"]) if row else 0 + + def save_network_result( + self, + population_id: str, + network_id: str, + config: dict[str, Any], + result_meta: dict[str, Any], + edges: list[dict[str, Any]], + seed: int | None, + candidate_mode: str, + network_metrics: dict[str, Any] | None = None, + network_run_id: str | None = None, + ) -> str: + run_id = network_run_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + now = _now_iso() + + cursor.execute( + """ + INSERT OR REPLACE INTO network_runs + (network_run_id, population_id, network_id, config_json, seed, candidate_mode, + status, created_at, completed_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + population_id, + network_id, + _dumps(config), + seed, + candidate_mode, + "completed", + now, + now, + _dumps(result_meta), + ), + ) + + cursor.execute("DELETE FROM network_edges WHERE network_id = ?", (network_id,)) + + rows = [] + for edge in edges: + infl = edge.get("influence_weight") or {} + rec = NetworkEdgeDBRecord( + network_id=network_id, + source_id=str(edge.get("source", "")), + target_id=str(edge.get("target", "")), + weight=float(edge.get("weight", 0.0)), + edge_type=str(edge.get("type", edge.get("edge_type", "unknown"))), + influence_st=float(infl.get("source_to_target", edge.get("weight", 0.0))), + influence_ts=float(infl.get("target_to_source", edge.get("weight", 0.0))), + ) + rows.append( + ( + rec.network_id, + rec.source_id, + rec.target_id, + rec.weight, + rec.edge_type, + rec.influence_st, + rec.influence_ts, + ) + ) + + cursor.executemany( + """ + INSERT INTO network_edges + (network_id, source_id, target_id, weight, edge_type, influence_st, influence_ts) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + + if network_metrics is not None: + cursor.execute( + """ + INSERT OR REPLACE INTO network_metrics (network_id, metrics_json, computed_at) + VALUES (?, ?, ?) + """, + (network_id, _dumps(network_metrics), now), + ) + + self.conn.commit() + return run_id + + def get_network(self, network_id: str) -> dict[str, Any]: + cursor = self.conn.cursor() + + cursor.execute( + "SELECT meta_json FROM network_runs WHERE network_id = ? ORDER BY completed_at DESC LIMIT 1", + (network_id,), + ) + run_row = cursor.fetchone() + meta = {} + if run_row: + try: + meta = json.loads(run_row["meta_json"]) + except json.JSONDecodeError: + meta = {} + + cursor.execute( + """ + SELECT source_id, target_id, weight, edge_type, influence_st, influence_ts + FROM network_edges + WHERE network_id = ? + ORDER BY source_id, target_id + """, + (network_id,), + ) + edges = [] + for row in cursor.fetchall(): + edges.append( + { + "source": row["source_id"], + "target": row["target_id"], + "weight": row["weight"], + "type": row["edge_type"], + "bidirectional": True, + "influence_weight": { + "source_to_target": row["influence_st"], + "target_to_source": row["influence_ts"], + }, + } + ) + + return {"meta": meta, "edges": edges} + + def get_network_edge_count(self, network_id: str) -> int: + cursor = self.conn.cursor() + cursor.execute( + "SELECT COUNT(*) AS cnt FROM network_edges WHERE network_id = ?", + (network_id,), + ) + row = cursor.fetchone() + return int(row["cnt"]) if row else 0 + + def init_network_similarity_job( + self, + network_run_id: str, + signature: dict[str, Any], + job_id: str | None = None, + ) -> str: + job = job_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + now = _now_iso() + cursor.execute( + """ + INSERT OR REPLACE INTO network_similarity_jobs + (job_id, network_run_id, signature_json, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + (job, network_run_id, _dumps(signature), "running", now, now), + ) + self.conn.commit() + return job + + def get_network_similarity_job_signature(self, job_id: str) -> dict[str, Any] | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT signature_json FROM network_similarity_jobs WHERE job_id = ?", + (job_id,), + ) + row = cursor.fetchone() + if not row: + return None + try: + return json.loads(row["signature_json"]) + except json.JSONDecodeError: + return None + + def get_completed_similarity_chunks(self, job_id: str) -> set[int]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT chunk_start + FROM network_similarity_chunks + WHERE job_id = ? AND status = 'done' + """, + (job_id,), + ) + return {int(row["chunk_start"]) for row in cursor.fetchall()} + + def list_completed_similarity_chunks(self, job_id: str) -> list[tuple[int, int]]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT chunk_start, chunk_end + FROM network_similarity_chunks + WHERE job_id = ? AND status = 'done' + ORDER BY chunk_start + """, + (job_id,), + ) + return [ + (int(row["chunk_start"]), int(row["chunk_end"])) for row in cursor.fetchall() + ] + + def save_similarity_chunk_rows( + self, + job_id: str, + chunk_start: int, + chunk_end: int, + rows: list[tuple[int, int, float]], + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO network_similarity_chunks + (job_id, chunk_start, chunk_end, status, pair_count, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + (job_id, chunk_start, chunk_end, "running", len(rows), _now_iso()), + ) + if rows: + cursor.executemany( + """ + INSERT OR REPLACE INTO network_similarity_pairs (job_id, i, j, sim) + VALUES (?, ?, ?, ?) + """, + [(job_id, i, j, sim) for i, j, sim in rows], + ) + cursor.execute( + """ + UPDATE network_similarity_chunks + SET status = 'done', updated_at = ? + WHERE job_id = ? AND chunk_start = ? + """, + (_now_iso(), job_id, chunk_start), + ) + cursor.execute( + """ + UPDATE network_similarity_jobs + SET updated_at = ? + WHERE job_id = ? + """, + (_now_iso(), job_id), + ) + self.conn.commit() + + def mark_similarity_job_running(self, job_id: str) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + UPDATE network_similarity_jobs + SET status = 'running', updated_at = ? + WHERE job_id = ? + """, + (_now_iso(), job_id), + ) + self.conn.commit() + + def load_similarity_pairs(self, job_id: str) -> dict[tuple[int, int], float]: + cursor = self.conn.cursor() + cursor.execute( + "SELECT i, j, sim FROM network_similarity_pairs WHERE job_id = ?", + (job_id,), + ) + return {(int(row["i"]), int(row["j"])): float(row["sim"]) for row in cursor.fetchall()} + + def mark_similarity_job_complete(self, job_id: str, drop_pairs: bool = False) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + UPDATE network_similarity_jobs + SET status = 'completed', updated_at = ? + WHERE job_id = ? + """, + (_now_iso(), job_id), + ) + if drop_pairs: + cursor.execute("DELETE FROM network_similarity_pairs WHERE job_id = ?", (job_id,)) + self.conn.commit() + + def create_simulation_run( + self, + run_id: str, + scenario_name: str, + population_id: str, + network_id: str, + config: dict[str, Any], + seed: int | None, + status: str = "running", + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO simulation_runs + (run_id, scenario_name, population_id, network_id, config_json, seed, status, started_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (run_id, scenario_name, population_id, network_id, _dumps(config), seed, status, _now_iso()), + ) + self.conn.commit() + + def update_simulation_run( + self, + run_id: str, + status: str, + stopped_reason: str | None = None, + ) -> None: + cursor = self.conn.cursor() + completed_at = _now_iso() if status in {"completed", "failed", "stopped"} else None + cursor.execute( + """ + UPDATE simulation_runs + SET status = ?, stopped_reason = ?, completed_at = COALESCE(?, completed_at) + WHERE run_id = ? + """, + (status, stopped_reason, completed_at, run_id), + ) + self.conn.commit() + + def set_run_metadata(self, run_id: str, key: str, value: str) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO run_metadata (run_id, key, value) + VALUES (?, ?, ?) + """, + (run_id, key, value), + ) + self.conn.commit() + + def get_run_metadata(self, run_id: str, key: str) -> str | None: + cursor = self.conn.cursor() + cursor.execute( + "SELECT value FROM run_metadata WHERE run_id = ? AND key = ?", + (run_id, key), + ) + row = cursor.fetchone() + return str(row["value"]) if row else None + + def save_simulation_checkpoint( + self, + run_id: str, + timestep: int, + chunk_index: int, + status: str, + ) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO simulation_checkpoints + (run_id, timestep, chunk_index, status, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + (run_id, timestep, chunk_index, status, _now_iso()), + ) + self.conn.commit() + + def get_completed_simulation_chunks(self, run_id: str, timestep: int) -> set[int]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT chunk_index + FROM simulation_checkpoints + WHERE run_id = ? AND timestep = ? AND status = 'done' + """, + (run_id, timestep), + ) + return {int(row["chunk_index"]) for row in cursor.fetchall()} + + def create_chat_session( + self, + run_id: str, + agent_id: str, + mode: str, + meta: dict[str, Any] | None = None, + session_id: str | None = None, + ) -> str: + sid = session_id or str(uuid.uuid4()) + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO chat_sessions + (session_id, run_id, agent_id, mode, created_at, meta_json) + VALUES (?, ?, ?, ?, ?, ?) + """, + (sid, run_id, agent_id, mode, _now_iso(), _dumps(meta or {})), + ) + self.conn.commit() + return sid + + def append_chat_message( + self, + session_id: str, + role: str, + content: str, + citations: dict[str, Any] | None = None, + token_usage: dict[str, Any] | None = None, + ) -> int: + payload = ChatMessagePayload( + role=role, + content=content, + citations=citations or {}, + token_usage=token_usage or {}, + ) + + cursor = self.conn.cursor() + cursor.execute( + "SELECT COALESCE(MAX(turn_index), -1) AS max_turn FROM chat_messages WHERE session_id = ?", + (session_id,), + ) + turn = int(cursor.fetchone()["max_turn"]) + 1 + cursor.execute( + """ + INSERT INTO chat_messages + (session_id, turn_index, role, content, citations_json, token_usage_json, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + session_id, + turn, + payload.role, + payload.content, + _dumps(payload.citations), + _dumps(payload.token_usage), + _now_iso(), + ), + ) + self.conn.commit() + return turn + + def get_chat_messages(self, session_id: str) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + cursor.execute( + """ + SELECT turn_index, role, content, citations_json, token_usage_json, created_at + FROM chat_messages + WHERE session_id = ? + ORDER BY turn_index + """, + (session_id,), + ) + out: list[dict[str, Any]] = [] + for row in cursor.fetchall(): + out.append( + { + "turn_index": int(row["turn_index"]), + "role": row["role"], + "content": row["content"], + "citations": json.loads(row["citations_json"] or "{}"), + "token_usage": json.loads(row["token_usage_json"] or "{}"), + "created_at": row["created_at"], + } + ) + return out + + def run_select( + self, + query: str, + params: tuple[Any, ...] = (), + limit: int | None = None, + ) -> list[dict[str, Any]]: + cursor = self.conn.cursor() + sql = query.strip().rstrip(";") + if limit is not None and " limit " not in sql.lower(): + sql = f"{sql} LIMIT {int(limit)}" + cursor.execute(sql, params) + cols = [d[0] for d in cursor.description or []] + rows = [] + for row in cursor.fetchall(): + rows.append({k: row[idx] for idx, k in enumerate(cols)}) + return rows + + +def open_study_db(path: Path | str) -> StudyDB: + """Open ``study.db`` and ensure schema exists.""" + return StudyDB(path) diff --git a/extropy/utils/__init__.py b/extropy/utils/__init__.py index 347c833..6ddb4a0 100644 --- a/extropy/utils/__init__.py +++ b/extropy/utils/__init__.py @@ -40,6 +40,10 @@ resolve_relative_to, make_relative_to, ) +from .resource_governor import ( + ResourceGovernor, + ResourceSnapshot, +) __all__ = [ # Graphs @@ -69,4 +73,7 @@ # Paths "resolve_relative_to", "make_relative_to", + # Resource governor + "ResourceGovernor", + "ResourceSnapshot", ] diff --git a/extropy/utils/resource_governor.py b/extropy/utils/resource_governor.py new file mode 100644 index 0000000..f485143 --- /dev/null +++ b/extropy/utils/resource_governor.py @@ -0,0 +1,124 @@ +"""Resource auto-tuning helpers for CPU/memory constrained environments.""" + +from __future__ import annotations + +import os +import platform +import resource +import subprocess +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ResourceSnapshot: + cpu_count: int + total_memory_gb: float + memory_budget_gb: float + + +class ResourceGovernor: + """Computes safe worker/chunk recommendations from local machine resources.""" + + def __init__( + self, + resource_mode: str = "auto", + safe_auto_workers: bool = True, + max_memory_gb: float | None = None, + ): + self.resource_mode = resource_mode + self.safe_auto_workers = safe_auto_workers + self.max_memory_gb = max_memory_gb + + @staticmethod + def _detect_total_memory_gb() -> float: + # Linux and many Unix systems + try: + page_size = os.sysconf("SC_PAGE_SIZE") + phys_pages = os.sysconf("SC_PHYS_PAGES") + if page_size > 0 and phys_pages > 0: + return (page_size * phys_pages) / (1024**3) + except (ValueError, OSError, AttributeError): + pass + + # macOS fallback + if platform.system().lower() == "darwin": + try: + out = subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True) + return int(out.strip()) / (1024**3) + except Exception: + pass + + # Conservative fallback + return 8.0 + + def snapshot(self) -> ResourceSnapshot: + cpu_count = max(1, os.cpu_count() or 1) + total_mem = self._detect_total_memory_gb() + capped = min(total_mem, self.max_memory_gb) if self.max_memory_gb else total_mem + budget = max(1.0, capped * 0.80) + return ResourceSnapshot( + cpu_count=cpu_count, + total_memory_gb=round(total_mem, 2), + memory_budget_gb=round(budget, 2), + ) + + @staticmethod + def _current_process_memory_gb() -> float: + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + system = platform.system().lower() + # Linux reports KB, macOS reports bytes. + if system == "darwin": + return float(usage) / (1024**3) + return float(usage) / (1024**2) + + def memory_pressure_ratio(self) -> float: + snap = self.snapshot() + current = self._current_process_memory_gb() + budget = max(0.1, snap.memory_budget_gb) + return current / budget + + @staticmethod + def downshift_int(current: int, factor: float, minimum: int = 1) -> int: + return max(minimum, int(max(1, current) * factor)) + + def recommend_workers( + self, + requested_workers: int, + memory_per_worker_gb: float, + ) -> int: + requested_workers = max(1, int(requested_workers)) + if self.resource_mode != "auto": + return requested_workers + + snap = self.snapshot() + cpu_cap = ( + max(1, snap.cpu_count - 1) if self.safe_auto_workers else snap.cpu_count + ) + mem_cap = max(1, int(snap.memory_budget_gb / max(0.1, memory_per_worker_gb))) + + if self.safe_auto_workers: + cpu_cap = min(cpu_cap, 8) + + return max(1, min(requested_workers, cpu_cap, mem_cap)) + + def recommend_chunk_size( + self, + requested_chunk_size: int, + min_chunk_size: int = 8, + max_chunk_size: int = 4096, + ) -> int: + requested_chunk_size = max(min_chunk_size, int(requested_chunk_size)) + if self.resource_mode != "auto": + return min(max_chunk_size, requested_chunk_size) + + snap = self.snapshot() + if snap.memory_budget_gb <= 4: + tuned = min(requested_chunk_size, 32) + elif snap.memory_budget_gb <= 8: + tuned = min(requested_chunk_size, 64) + elif snap.memory_budget_gb <= 16: + tuned = min(requested_chunk_size, 128) + else: + tuned = requested_chunk_size + + return max(min_chunk_size, min(max_chunk_size, tuned)) diff --git a/tests/test_cli.py b/tests/test_cli.py index d38079a..2b81c4e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,11 +1,16 @@ """CLI smoke tests using typer's CliRunner.""" +import json +import sqlite3 +from types import SimpleNamespace from pathlib import Path from typer.testing import CliRunner from extropy.cli.app import app from extropy.cli.commands.validate import _is_scenario_file +from extropy.population.network.config import NetworkConfig +from extropy.storage import open_study_db runner = CliRunner() @@ -16,7 +21,7 @@ class TestConfigCommand: def test_config_show(self): result = runner.invoke(app, ["config", "show"]) assert result.exit_code == 0 - assert "Pipeline" in result.output + assert "Models" in result.output assert "Simulation" in result.output def test_config_set_invalid_key(self): @@ -59,3 +64,301 @@ def test_version_output(self): result = runner.invoke(app, ["--version"]) assert result.exit_code == 0 assert "extropy" in result.output + + +class TestNetworkCommand: + """Smoke tests for the network command options.""" + + def test_network_command_supports_fast_mode_and_checkpoint(self, tmp_path): + study_db = tmp_path / "study.db" + config_path = tmp_path / "network-config.yaml" + output_path = tmp_path / "network.json" + + agents = [ + {"_id": "a0", "role": "x", "team": "alpha"}, + {"_id": "a1", "role": "x", "team": "alpha"}, + {"_id": "a2", "role": "y", "team": "beta"}, + {"_id": "a3", "role": "y", "team": "beta"}, + ] + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", agents=agents, meta={"source": "test"} + ) + + NetworkConfig(seed=42, avg_degree=2.0).to_yaml(config_path) + + result = runner.invoke( + app, + [ + "network", + "--study-db", + str(study_db), + "-o", + str(output_path), + "-c", + str(config_path), + "--no-metrics", + "--candidate-mode", + "blocked", + "--candidate-pool-multiplier", + "4.0", + "--block-attr", + "role", + "--similarity-workers", + "1", + "--similarity-chunk-size", + "8", + "--checkpoint", + str(study_db), + "--checkpoint-every", + "1", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + with open_study_db(study_db) as db: + rows = db.run_select("SELECT COUNT(*) AS cnt FROM network_similarity_jobs") + assert rows and int(rows[0]["cnt"]) >= 1 + + def test_network_resume_requires_checkpoint(self): + result = runner.invoke( + app, + [ + "network", + "--study-db", + "study.db", + "-o", + "network.json", + "--resume-checkpoint", + ], + ) + assert result.exit_code == 1 + assert "Study DB not found" in result.output + + def test_network_checkpoint_must_match_study_db(self, tmp_path): + study_db = tmp_path / "study.db" + other_db = tmp_path / "other.db" + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", agents=[{"_id": "a0"}], meta={} + ) + + result = runner.invoke( + app, + [ + "network", + "--study-db", + str(study_db), + "--checkpoint", + str(other_db), + ], + ) + assert result.exit_code == 1 + assert ( + "--checkpoint must point to the same canonical file as --study-db" + in result.output + ) + + +def _seed_run_scoped_state(study_db: Path) -> None: + agents = [ + {"_id": "a0", "team": "alpha"}, + {"_id": "a1", "team": "beta"}, + ] + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", agents=agents, meta={"source": "test"} + ) + db.create_simulation_run( + run_id="run_old", + scenario_name="s", + population_id="default", + network_id="default", + config={}, + seed=1, + status="completed", + ) + db.create_simulation_run( + run_id="run_new", + scenario_name="s", + population_id="default", + network_id="default", + config={}, + seed=2, + status="running", + ) + + conn = sqlite3.connect(str(study_db)) + cur = conn.cursor() + cur.execute( + """ + INSERT INTO agent_states (run_id, agent_id, aware, position, private_position, updated_at) + VALUES ('run_old', 'a0', 1, 'old_pos', 'old_pos', 0) + """ + ) + cur.execute( + """ + INSERT INTO agent_states (run_id, agent_id, aware, position, private_position, updated_at) + VALUES ('run_new', 'a0', 1, 'new_pos', 'new_pos', 0) + """ + ) + cur.execute( + """ + INSERT INTO timestep_summaries ( + run_id, timestep, new_exposures, agents_reasoned, shares_occurred, + state_changes, exposure_rate, position_distribution_json + ) + VALUES ('run_new', 0, 1, 1, 0, 1, 0.5, '{}') + """ + ) + conn.commit() + conn.close() + + +class TestRunScopedCliReads: + def test_results_defaults_to_latest_run(self, tmp_path): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + + result = runner.invoke(app, ["results", "--study-db", str(study_db)]) + assert result.exit_code == 0 + assert "run_id=run_new" in result.output + assert "new_pos" in result.output + assert "old_pos" not in result.output + + def test_export_states_defaults_to_latest_run(self, tmp_path): + study_db = tmp_path / "study.db" + out = tmp_path / "states.jsonl" + _seed_run_scoped_state(study_db) + + result = runner.invoke( + app, + ["export", "states", "--study-db", str(study_db), "--to", str(out)], + ) + assert result.exit_code == 0 + rows = [ + json.loads(line) for line in out.read_text(encoding="utf-8").splitlines() + ] + assert len(rows) == 1 + assert rows[0]["run_id"] == "run_new" + assert rows[0]["private_position"] == "new_pos" + + def test_chat_ask_reads_state_for_requested_run(self, tmp_path): + study_db = tmp_path / "study.db" + _seed_run_scoped_state(study_db) + + result = runner.invoke( + app, + [ + "chat", + "ask", + "--study-db", + str(study_db), + "--run-id", + "run_old", + "--agent-id", + "a0", + "--prompt", + "what is my stance", + "--json", + ], + ) + assert result.exit_code == 0 + payload = json.loads(result.stdout.strip()) + assert payload["session_id"] + assert "old_pos" in payload["assistant_text"] + assert "new_pos" not in payload["assistant_text"] + + +class TestPersonaCommand: + def test_persona_show_loads_agents_from_study_db(self, tmp_path, monkeypatch): + import extropy.cli.commands.persona as persona_cmd + import extropy.population.persona as persona_pkg + + class DummyPopulationSpec: + @classmethod + def from_yaml(cls, _path): + return SimpleNamespace( + meta=SimpleNamespace(description="test population"), + attributes=[{"name": "age"}], + ) + + class DummyPersonaConfig: + @classmethod + def from_file(cls, _path): + return object() + + monkeypatch.setattr(persona_cmd, "PopulationSpec", DummyPopulationSpec) + monkeypatch.setattr(persona_pkg, "PersonaConfig", DummyPersonaConfig) + monkeypatch.setattr( + persona_pkg, + "preview_persona", + lambda _agent, _config, max_width=80: "I am a test persona.", + ) + + spec_file = tmp_path / "population.yaml" + spec_file.write_text("meta: {}\n", encoding="utf-8") + persona_file = spec_file.with_suffix(".persona.yaml") + persona_file.write_text("dummy: true\n", encoding="utf-8") + + study_db = tmp_path / "study.db" + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=[{"_id": "a0", "age": 30}, {"_id": "a1", "age": 41}], + meta={"source": "test"}, + ) + + result = runner.invoke( + app, + [ + "persona", + str(spec_file), + "--study-db", + str(study_db), + "--population-id", + "default", + "--show", + ], + ) + assert result.exit_code == 0 + assert "Loaded 2 agents from study DB population_id=default" in result.output + assert "Persona for Agent a0" in result.output + + def test_persona_rejects_agents_and_study_db_together(self, tmp_path, monkeypatch): + import extropy.cli.commands.persona as persona_cmd + + monkeypatch.setattr( + persona_cmd.PopulationSpec, + "from_yaml", + classmethod( + lambda cls, _path: SimpleNamespace( + meta=SimpleNamespace(description="test population"), + attributes=[{"name": "age"}], + ) + ), + ) + spec_file = tmp_path / "population.yaml" + spec_file.write_text("meta: {}\n", encoding="utf-8") + agents_file = tmp_path / "agents.json" + agents_file.write_text("[]\n", encoding="utf-8") + study_db = tmp_path / "study.db" + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", agents=[{"_id": "a0"}], meta={} + ) + + result = runner.invoke( + app, + [ + "persona", + str(spec_file), + "--agents", + str(agents_file), + "--study-db", + str(study_db), + ], + ) + assert result.exit_code == 1 + assert "Use either --agents or --study-db, not both" in result.output diff --git a/tests/test_compiler.py b/tests/test_compiler.py index c6413ea..eec5597 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -3,7 +3,6 @@ Tests the 5-step compilation pipeline and auto-configuration logic. """ -import json from unittest.mock import patch import pytest @@ -17,6 +16,7 @@ _determine_simulation_config, create_scenario, ) +from extropy.storage import open_study_db class TestGenerateScenarioName: @@ -85,31 +85,39 @@ def mock_files(self, minimal_population_spec, tmp_path): pop_path = tmp_path / "population.yaml" minimal_population_spec.to_yaml(pop_path) - # Create agents JSON agents = [ {"_id": f"agent_{i:03d}", "age": 30 + i, "gender": "male"} for i in range(10) ] - agents_path = tmp_path / "agents.json" - agents_path.write_text(json.dumps(agents)) - - # Create network JSON - network = { - "meta": {"node_count": 10}, - "nodes": [{"id": f"agent_{i:03d}"} for i in range(10)], - "edges": [ - { - "source": f"agent_{i:03d}", - "target": f"agent_{(i + 1) % 10:03d}", - "type": "colleague", - } - for i in range(10) - ], - } - network_path = tmp_path / "network.json" - network_path.write_text(json.dumps(network)) + edges = [ + { + "source": f"agent_{i:03d}", + "target": f"agent_{(i + 1) % 10:03d}", + "weight": 1.0, + "type": "colleague", + "influence_weight": {"source_to_target": 1.0, "target_to_source": 1.0}, + } + for i in range(10) + ] - return pop_path, agents_path, network_path + study_db = tmp_path / "study.db" + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=agents, + meta={"source": "test_fixture"}, + ) + db.save_network_result( + population_id="default", + network_id="default", + config={}, + result_meta={"node_count": 10}, + edges=edges, + seed=42, + candidate_mode="test", + ) + + return pop_path, study_db @patch("extropy.scenario.compiler.parse_scenario") @patch("extropy.scenario.compiler.generate_seed_exposure") @@ -137,7 +145,7 @@ def test_creates_valid_scenario( OutcomeType, ) - pop_path, agents_path, network_path = mock_files + pop_path, study_db = mock_files # Configure mocks mock_parse.return_value = Event( @@ -188,8 +196,9 @@ def test_creates_valid_scenario( spec, validation_result = create_scenario( description="Test product launch scenario", population_spec_path=pop_path, - agents_path=agents_path, - network_path=network_path, + study_db_path=study_db, + population_id="default", + network_id="default", ) assert spec.meta.name is not None @@ -223,7 +232,7 @@ def test_progress_callback_called( OutcomeType, ) - pop_path, agents_path, network_path = mock_files + pop_path, study_db = mock_files mock_parse.return_value = Event( type=EventType.PRODUCT_LAUNCH, @@ -270,8 +279,9 @@ def on_progress(step, status): create_scenario( description="Test", population_spec_path=pop_path, - agents_path=agents_path, - network_path=network_path, + study_db_path=study_db, + population_id="default", + network_id="default", on_progress=on_progress, ) diff --git a/tests/test_engine.py b/tests/test_engine.py index 2375e42..690c65c 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,6 +20,7 @@ SimulationRunConfig, ) from extropy.simulation.progress import SimulationProgress +from extropy.utils.resource_governor import ResourceGovernor from extropy.core.models.scenario import ( Event, EventType, @@ -49,8 +50,9 @@ def minimal_scenario(): name="test_scenario", description="Test scenario", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( @@ -854,6 +856,66 @@ def test_agents_already_reasoned( already_4 = engine.state_manager.get_agents_already_reasoned_this_timestep(4) assert "a0" not in already_4 + def test_chunk_checkpoints_written_with_writer_pipeline( + self, + minimal_scenario, + simple_agents, + simple_network, + minimal_pop_spec, + tmp_path, + ): + """Writer pipeline should persist per-chunk checkpoints and last checkpoint marker.""" + config = SimulationRunConfig( + scenario_path="test.yaml", + output_dir=str(tmp_path / "output"), + chunk_size=1, + ) + engine = SimulationEngine( + scenario=minimal_scenario, + population_spec=minimal_pop_spec, + agents=simple_agents, + network=simple_network, + config=config, + chunk_size=1, + checkpoint_every_chunks=2, + writer_queue_size=2, + db_write_batch_size=2, + ) + + for aid in ["a0", "a1", "a2"]: + exposure = ExposureRecord( + timestep=0, channel="broadcast", content="Test", credibility=0.9 + ) + engine.state_manager.record_exposure(aid, exposure) + + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): + response = _make_reasoning_response() + results = [] + for ctx in contexts: + if on_agent_done: + on_agent_done(ctx.agent_id, response) + results.append((ctx.agent_id, response)) + return results, BatchTokenUsage() + + with patch( + "extropy.simulation.engine.batch_reason_agents", side_effect=fake_batch + ): + reasoned, _, _ = engine._reason_agents(0) + + assert reasoned == 3 + completed = engine.study_db.get_completed_simulation_chunks(engine.run_id, 0) + assert completed == {0, 1, 2} + assert ( + engine.study_db.get_run_metadata(engine.run_id, "last_checkpoint") == "0:2" + ) + class TestResumeLogic: """Test engine resume/checkpoint logic.""" @@ -1138,7 +1200,14 @@ def test_progress_state_updated( response_a0 = _make_reasoning_response(position="adopt", conviction=0.5) response_a1 = _make_reasoning_response(position="reject", conviction=0.7) - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): results = [] for ctx in contexts: if ctx.agent_id == "a0": @@ -1201,7 +1270,14 @@ def test_on_agent_done_callback_passed( received_kwargs = {} - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): received_kwargs["on_agent_done"] = on_agent_done resp = _make_reasoning_response() return [(ctx.agent_id, resp) for ctx in contexts], BatchTokenUsage() @@ -1396,7 +1472,14 @@ def test_tokens_accumulate_across_chunks( call_count = [0] - def fake_batch(contexts, scenario, cfg, rate_limiter=None, on_agent_done=None): + def fake_batch( + contexts, + scenario, + cfg, + max_concurrency=50, + rate_limiter=None, + on_agent_done=None, + ): call_count[0] += 1 resp = _make_reasoning_response() results = [(ctx.agent_id, resp) for ctx in contexts] @@ -1483,8 +1566,8 @@ def test_cost_unknown_model_returns_null_usd( config = SimulationRunConfig( scenario_path="test.yaml", output_dir=str(tmp_path / "output"), - model="unknown-model-xyz", - routine_model="unknown-model-abc", + strong="unknown-provider/unknown-model-xyz", + fast="unknown-provider/unknown-model-abc", ) engine = SimulationEngine( scenario=minimal_scenario, @@ -1505,3 +1588,75 @@ def test_cost_unknown_model_returns_null_usd( meta = json.load(f) assert meta["cost"]["estimated_usd"] is None + + def test_export_results_keeps_compact_default_artifacts( + self, + minimal_scenario, + simple_agents, + simple_network, + minimal_pop_spec, + tmp_path, + ): + """Default export should keep compact summaries and skip large JSON dumps.""" + config = SimulationRunConfig( + scenario_path="test.yaml", + output_dir=str(tmp_path / "output"), + ) + engine = SimulationEngine( + scenario=minimal_scenario, + population_spec=minimal_pop_spec, + agents=simple_agents, + network=simple_network, + config=config, + ) + + engine._export_results() + + assert (tmp_path / "output" / "meta.json").exists() + assert (tmp_path / "output" / "by_timestep.json").exists() + assert not (tmp_path / "output" / "agent_states.json").exists() + assert not (tmp_path / "output" / "outcome_distributions.json").exists() + + def test_runtime_guardrails_downshift_under_pressure( + self, + minimal_scenario, + simple_agents, + simple_network, + minimal_pop_spec, + tmp_path, + ): + """Runtime memory pressure should downshift concurrency/write knobs.""" + + class HighPressureGovernor(ResourceGovernor): + def memory_pressure_ratio(self) -> float: + return 1.1 + + config = SimulationRunConfig( + scenario_path="test.yaml", + output_dir=str(tmp_path / "output"), + ) + governor = HighPressureGovernor(resource_mode="auto") + engine = SimulationEngine( + scenario=minimal_scenario, + population_spec=minimal_pop_spec, + agents=simple_agents, + network=simple_network, + config=config, + writer_queue_size=64, + db_write_batch_size=16, + resource_governor=governor, + ) + before = ( + engine.reasoning_max_concurrency, + engine.db_write_batch_size, + engine.writer_queue_size, + ) + engine._apply_runtime_guardrails(timestep=0) + after = ( + engine.reasoning_max_concurrency, + engine.db_write_batch_size, + engine.writer_queue_size, + ) + assert after[0] < before[0] + assert after[1] < before[1] + assert after[2] < before[2] diff --git a/tests/test_estimator.py b/tests/test_estimator.py index e4c114e..a40a053 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -140,8 +140,9 @@ def small_scenario() -> ScenarioSpec: name="test_scenario", description="Test scenario for estimation", population_spec="pop.yaml", - agents_file="agents.json", - network_file="network.json", + study_db="study.db", + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, @@ -322,7 +323,8 @@ def test_basic_estimate( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", + strong_model="openai/gpt-5", + fast_model="openai/gpt-5-mini", ) assert est.population_size == 10 @@ -369,7 +371,8 @@ def test_model_resolution_openai( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", + strong_model="openai/gpt-5", + fast_model="openai/gpt-5-mini", ) assert est.pivotal_model == "gpt-5" assert est.routine_model == "gpt-5-mini" @@ -382,7 +385,8 @@ def test_model_resolution_claude( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="claude", + strong_model="anthropic/claude-sonnet-4-5-20250929", + fast_model="anthropic/claude-haiku-4-5-20251001", ) assert est.pivotal_model == "claude-sonnet-4-5-20250929" assert est.routine_model == "claude-haiku-4-5-20251001" @@ -395,9 +399,8 @@ def test_explicit_model_override( population_spec=small_pop_spec, agents=small_agents, network=small_network, - provider="openai", - pivotal_model="gpt-5-mini", - routine_model="gpt-5-mini", + strong_model="openai/gpt-5-mini", + fast_model="openai/gpt-5-mini", ) assert est.pivotal_model == "gpt-5-mini" assert est.routine_model == "gpt-5-mini" @@ -410,8 +413,8 @@ def test_unknown_model_pricing_none( population_spec=small_pop_spec, agents=small_agents, network=small_network, - pivotal_model="unknown-model-x", - routine_model="unknown-model-y", + strong_model="openai/unknown-model-x", + fast_model="openai/unknown-model-y", ) assert est.pivotal_pricing is None assert est.routine_pricing is None diff --git a/tests/test_integration_timestep.py b/tests/test_integration_timestep.py index c88874f..fe21005 100644 --- a/tests/test_integration_timestep.py +++ b/tests/test_integration_timestep.py @@ -77,8 +77,9 @@ def _make_scenario( name="test_scenario", description="Test scenario", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_network.py b/tests/test_network.py index 7564cff..749f023 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -158,8 +158,14 @@ def test_default_config(self): assert config.avg_degree == 20.0 assert config.rewire_prob == 0.05 + assert config.similarity_store_threshold == 0.05 assert config.similarity_threshold == 0.3 assert config.similarity_steepness == 10.0 + assert config.candidate_mode == "exact" + assert config.candidate_pool_multiplier == 12.0 + assert config.min_candidate_pool == 80 + assert config.similarity_workers == 1 + assert config.checkpoint_every_rows == 250 assert config.seed is None def test_custom_config(self): @@ -704,6 +710,111 @@ def on_progress(stage, current, total): stages = set(call[0] for call in progress_calls) assert "Computing similarities" in stages + def test_generate_network_blocked_mode_reproducibility(self, sample_agents): + """Blocked candidate mode should remain deterministic with fixed seed.""" + config = REFERENCE_NETWORK_CONFIG.model_copy( + update={ + "seed": 42, + "candidate_mode": "blocked", + "candidate_pool_multiplier": 8.0, + "blocking_attributes": ["employer_type", "federal_state"], + } + ) + + result1 = generate_network(sample_agents, config) + result2 = generate_network(sample_agents, config) + + edges1 = {(e.source, e.target) for e in result1.edges} + edges2 = {(e.source, e.target) for e in result2.edges} + + assert result1.meta["candidate_mode"] == "blocked" + assert result2.meta["candidate_mode"] == "blocked" + assert edges1 == edges2 + + def test_generate_network_resume_from_checkpoint_matches_fresh(self, sample_agents): + """Resuming from a saved similarity checkpoint should match a fresh run.""" + import sqlite3 + + config = REFERENCE_NETWORK_CONFIG.model_copy( + update={ + "seed": 42, + "similarity_chunk_size": 8, + "checkpoint_every_rows": 1, + } + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "study.db" + + # Build and persist checkpoint from a full run. + result_checkpointed = generate_network( + sample_agents, + config, + checkpoint_path=checkpoint_path, + ) + assert checkpoint_path.exists() + + # Simulate interruption by dropping the latter half of completed chunks. + conn = sqlite3.connect(str(checkpoint_path)) + cur = conn.cursor() + cur.execute( + "SELECT job_id FROM network_similarity_jobs ORDER BY created_at DESC LIMIT 1" + ) + job_id = cur.fetchone()[0] + cutoff = max(8, (len(sample_agents) // 2)) + cur.execute( + """ + SELECT MIN(chunk_start) + FROM network_similarity_chunks + WHERE job_id = ? AND chunk_start >= ? + """, + (job_id, cutoff), + ) + drop_start = cur.fetchone()[0] + if drop_start is None: + drop_start = 0 + cur.execute( + "DELETE FROM network_similarity_chunks WHERE job_id = ? AND chunk_start >= ?", + (job_id, drop_start), + ) + cur.execute( + "DELETE FROM network_similarity_pairs WHERE job_id = ? AND i >= ?", + (job_id, drop_start), + ) + cur.execute( + "UPDATE network_similarity_jobs SET status = 'running' WHERE job_id = ?", + (job_id,), + ) + conn.commit() + conn.close() + + resumed = generate_network( + sample_agents, + config, + checkpoint_path=checkpoint_path, + resume_from_checkpoint=True, + ) + fresh = generate_network(sample_agents, config) + + resumed_edges = {(e.source, e.target) for e in resumed.edges} + fresh_edges = {(e.source, e.target) for e in fresh.edges} + + assert resumed.meta["resumed_from_checkpoint"] is True + assert resumed_edges == fresh_edges + assert len(resumed.edges) == len(result_checkpointed.edges) + + def test_generate_network_checkpoint_requires_db_path(self, sample_agents): + """Checkpoint path must be a SQLite DB path.""" + config = REFERENCE_NETWORK_CONFIG.model_copy(update={"seed": 42}) + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "network-similarity.pkl" + with pytest.raises(ValueError, match="DB-only"): + generate_network( + sample_agents, + config, + checkpoint_path=checkpoint_path, + ) + class TestGenerateNetworkWithMetrics: """Tests for network generation with metrics.""" diff --git a/tests/test_propagation.py b/tests/test_propagation.py index a52075b..b0d9051 100644 --- a/tests/test_propagation.py +++ b/tests/test_propagation.py @@ -72,8 +72,9 @@ def _make_scenario( name="test", description="Test scenario", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_providers.py b/tests/test_providers.py index 291d846..752eb77 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -380,9 +380,8 @@ def test_no_validator_returns_immediately(self): """With no validator, first result is returned.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -421,9 +420,8 @@ def test_initial_prompt_used_on_first_call(self): """When initial_prompt is provided, it should be used for the first call.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -463,9 +461,8 @@ def test_validation_retries_use_base_prompt_not_initial(self): """Validation retries should use prompt, not initial_prompt.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -528,9 +525,8 @@ def test_validator_succeeds_on_first_attempt_with_initial_prompt(self): """When validator passes on first try with initial_prompt, no retries occur.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -574,9 +570,8 @@ def test_on_retry_callback_invoked_correctly(self): """Test that on_retry callback is invoked with correct parameters.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -640,9 +635,8 @@ def test_no_initial_prompt_defaults_to_prompt(self): """When initial_prompt is None, prompt is used for first call.""" class ConcreteProvider(LLMProvider): - default_simple_model = "test" - default_reasoning_model = "test" - default_research_model = "test" + default_fast_model = "test" + default_strong_model = "test" def simple_call(self, *a, **kw): return {} @@ -813,12 +807,10 @@ class TestProviderFactoryAzure: ) def test_create_azure_openai_provider(self): from extropy.core.providers import _create_provider + from extropy.core.providers.openai_compat import OpenAICompatProvider provider = _create_provider("azure_openai") - assert isinstance(provider, OpenAIProvider) - assert provider._is_azure is True - assert provider._azure_endpoint == "https://my-resource.openai.azure.com" - assert provider._azure_deployment == "my-deployment" + assert isinstance(provider, OpenAICompatProvider) @patch.dict( "os.environ", diff --git a/tests/test_reasoning_prompts.py b/tests/test_reasoning_prompts.py index f043dba..d8feb18 100644 --- a/tests/test_reasoning_prompts.py +++ b/tests/test_reasoning_prompts.py @@ -48,8 +48,9 @@ def _make_scenario(**overrides): name="test", description="Test", population_spec="test.yaml", - agents_file="test.json", - network_file="test_network.json", + study_db="study.db", + population_id="default", + network_id="default", created_at=datetime(2024, 1, 1), ), event=Event( diff --git a/tests/test_resource_governor.py b/tests/test_resource_governor.py new file mode 100644 index 0000000..8a1eaa5 --- /dev/null +++ b/tests/test_resource_governor.py @@ -0,0 +1,25 @@ +"""Tests for resource auto-tuning and runtime guardrails.""" + +from extropy.utils.resource_governor import ResourceGovernor + + +def test_downshift_int_respects_minimum(): + assert ResourceGovernor.downshift_int(100, factor=0.5, minimum=1) == 50 + assert ResourceGovernor.downshift_int(2, factor=0.1, minimum=4) == 4 + + +def test_memory_pressure_ratio_uses_budget(monkeypatch): + governor = ResourceGovernor(resource_mode="auto", max_memory_gb=8.0) + monkeypatch.setattr( + governor, + "_detect_total_memory_gb", + lambda: 8.0, + ) + monkeypatch.setattr( + governor, + "_current_process_memory_gb", + lambda: 3.2, + ) + + # Budget is 80% of capped memory => 6.4 GB, so ratio should be 0.5. + assert governor.memory_pressure_ratio() == 0.5 diff --git a/tests/test_scenario.py b/tests/test_scenario.py index 15111fb..274b4af 100644 --- a/tests/test_scenario.py +++ b/tests/test_scenario.py @@ -342,8 +342,9 @@ def test_scenario_meta_creation(self): name="ai_tool_announcement", description="Hospital announces new AI diagnostic tool", population_spec="surgeons.yaml", - agents_file="agents.json", - network_file="network.json", + study_db="study.db", + population_id="default", + network_id="default", ) assert meta.name == "ai_tool_announcement" assert meta.population_spec == "surgeons.yaml" @@ -361,8 +362,9 @@ def sample_scenario_spec(self): name="test_scenario", description="Test scenario", population_spec="pop.yaml", - agents_file="agents.json", - network_file="network.json", + study_db="study.db", + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, @@ -596,8 +598,9 @@ def test_full_scenario_with_all_features(self): name="ai_tool_full_scenario", description="Hospital announces mandatory AI diagnostic tool", population_spec="german_surgeons.yaml", - agents_file="agents_500.json", - network_file="network_500.json", + study_db="study.db", + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, diff --git a/tests/test_scenario_validator.py b/tests/test_scenario_validator.py index 3aa6d57..7f23cde 100644 --- a/tests/test_scenario_validator.py +++ b/tests/test_scenario_validator.py @@ -1,6 +1,5 @@ """Tests for scenario validation behavior.""" -import json from pathlib import Path from extropy.core.models.scenario import ( @@ -21,20 +20,21 @@ SpreadConfig, ) from extropy.scenario.validator import load_and_validate_scenario, validate_scenario +from extropy.storage import open_study_db def _make_scenario_spec( population_path: str, - agents_path: str, - network_path: str, + study_db_path: str, ) -> ScenarioSpec: return ScenarioSpec( meta=ScenarioMeta( name="test_scenario", description="Validation test scenario", population_spec=population_path, - agents_file=agents_path, - network_file=network_path, + study_db=study_db_path, + population_id="default", + network_id="default", ), event=Event( type=EventType.ANNOUNCEMENT, @@ -83,17 +83,19 @@ def _make_scenario_spec( def test_validate_scenario_surfaces_errors(tmp_path: Path): """Validation should preserve and return discovered errors.""" population_path = tmp_path / "population.yaml" - agents_path = tmp_path / "agents.json" - network_path = tmp_path / "network.json" + study_db = tmp_path / "study.db" population_path.write_text("placeholder: true\n") - agents_path.write_text("[]\n") - network_path.write_text('{"meta": {}, "edges": []}\n') + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=[], + meta={"source": "test"}, + ) spec = _make_scenario_spec( str(population_path), - str(agents_path), - str(network_path), + str(study_db), ) spec.seed_exposure.rules[0].channel = "missing_channel" @@ -110,15 +112,38 @@ def test_load_and_validate_scenario_resolves_relative_paths( ): """Relative file references should resolve against scenario file location.""" population_path = tmp_path / "population.yaml" - agents_path = tmp_path / "agents.json" - network_path = tmp_path / "network.json" + study_db = tmp_path / "study.db" scenario_path = tmp_path / "scenario.yaml" minimal_population_spec.to_yaml(population_path) - agents_path.write_text('[{"_id": "agent_0", "age": 35, "gender": "male"}]\n') - network_path.write_text(json.dumps({"meta": {"node_count": 1}, "edges": []})) + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", + agents=[{"_id": "agent_0", "age": 35, "gender": "male"}], + meta={"source": "test"}, + ) + db.save_network_result( + population_id="default", + network_id="default", + config={}, + result_meta={"node_count": 1}, + edges=[ + { + "source": "agent_0", + "target": "agent_0", + "weight": 1.0, + "type": "self", + "influence_weight": { + "source_to_target": 1.0, + "target_to_source": 1.0, + }, + } + ], + seed=None, + candidate_mode="test", + ) - spec = _make_scenario_spec("population.yaml", "agents.json", "network.json") + spec = _make_scenario_spec("population.yaml", "study.db") spec.to_yaml(scenario_path) _, result = load_and_validate_scenario(scenario_path) @@ -133,17 +158,17 @@ def test_load_and_validate_scenario_resolves_relative_paths( def test_validate_scenario_allows_edge_weight_in_spread_modifier(tmp_path: Path): """edge_weight should be treated as a valid spread modifier reference.""" population_path = tmp_path / "population.yaml" - agents_path = tmp_path / "agents.json" - network_path = tmp_path / "network.json" + study_db = tmp_path / "study.db" population_path.write_text("placeholder: true\n") - agents_path.write_text("[]\n") - network_path.write_text('{"meta": {}, "edges": []}\n') + with open_study_db(study_db) as db: + db.save_sample_result( + population_id="default", agents=[], meta={"source": "test"} + ) spec = _make_scenario_spec( str(population_path), - str(agents_path), - str(network_path), + str(study_db), ) spec.spread.share_modifiers = [ SpreadModifier(when="edge_weight > 0.7", multiply=1.1, add=0.0) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index a5866d3..aecec95 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -402,6 +402,34 @@ def test_get_population_count(self, temp_db, agents): count = manager.get_population_count() assert count == 3 + def test_run_scope_isolation(self, temp_db, agents): + """Different run_id views should not leak state into each other.""" + exposure = ExposureRecord( + timestep=0, + channel="email", + content="Scoped exposure", + credibility=0.9, + ) + + with StateManager(temp_db, agents=agents, run_id="run_a") as run_a: + run_a.record_exposure("agent_000", exposure) + assert run_a.get_exposure_rate() == pytest.approx(1 / 3, abs=0.01) + assert run_a.get_checkpoint_timestep() is None + run_a.mark_timestep_started(2) + assert run_a.get_checkpoint_timestep() == 2 + + with StateManager(temp_db, agents=agents, run_id="run_b") as run_b: + assert run_b.get_exposure_rate() == 0.0 + assert run_b.get_agent_state("agent_000").aware is False + assert run_b.get_checkpoint_timestep() is None + run_b.record_exposure("agent_001", exposure) + assert run_b.get_exposure_rate() == pytest.approx(1 / 3, abs=0.01) + + with StateManager(temp_db, agents=agents, run_id="run_a") as run_a_again: + assert run_a_again.get_agent_state("agent_000").aware is True + assert run_a_again.get_agent_state("agent_001").aware is False + assert run_a_again.get_checkpoint_timestep() == 2 + class TestPersonaGeneration: """Tests for persona generation."""