diff --git a/.gitignore b/.gitignore index df1a13b..f10dba1 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,16 @@ -/logs \ No newline at end of file +/logs +*.pyc +__pycache__/ +*.pyo +*.egg-info +/build +/datasets +/results +/BEHAVIOR-1K +.worktrees/ +/docs/plans +yolo11n.pt +.tmp* +/scripts +.sisyphus/ +.claude/ diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 0000000..5f568e2 --- /dev/null +++ b/AGENT.md @@ -0,0 +1,629 @@ +# AGENT.md — AI-Assisted Development Guide + +This document helps AI coding agents (and human developers) understand the EASI library and contribute new benchmarks effectively. + +## How the Library Works + +EASI evaluates embodied AI agents in interactive simulators. The core loop: + +``` +for each episode in dataset: + sim.reset(task.format_reset_config(episode)) + while not done and step < max_steps: + action = agent.act(observation, instruction) + observation = sim.step(action) + result = task.evaluate_episode(episode, trajectory) +summary = task.aggregate_results(all_records) +``` + +Four components plug together: + +| Component | Role | Lives in | +|---|---|---| +| **Task** | Defines episodes, action space, success metrics | `easi/tasks//` | +| **Simulator** | Runs the 3D environment in a subprocess | `easi/simulators//` | +| **Bridge** | Wraps the simulator's Python API for IPC | `easi/tasks//bridge.py` or `easi/simulators///bridge.py` | +| **Agent** | Decides actions (DummyAgent or ReActAgent+LLM) | `easi/agents/` | + +--- + +## Adding a New Benchmark (Step-by-Step) + +This is the most common contribution. Follow the existing task structure exactly. + +### Prerequisites + +- The simulator is already integrated (check `easi env list`). If not, see "Adding a New Simulator" below. +- The benchmark's dataset is on HuggingFace (or you have local episodes). +- You have the benchmark's source code to reference for environment setup and evaluation logic. + +### Step 1: Create the Task Folder + +``` +easi/tasks// +├── __init__.py # Empty +├── task.py # Task class (required) +├── bridge.py # Bridge script (if task needs custom env wrapping) +├── actions.py # Action space definitions (if static) +├── prompts.py # PromptBuilder for LLM interaction (optional) +├── _base.yaml # Base config (or split configs) +├── _base.yaml # Shared config for multi-split tasks +├── config/ # Few-shot examples, etc. (optional) +└── vendor/ # Vendored benchmark env code (optional) + └── __init__.py +``` + +Use `easi task scaffold ` to generate boilerplate, then customize. + +### Step 2: Define the Task YAML + +Every task needs at least one `.yaml` config file. The registry auto-discovers all `easi/tasks/*/*.yaml` files (excluding files without a `name` key). + +**Minimal config (single-split task):** + +```yaml +name: my_benchmark +display_name: "My Benchmark" +description: "Description of the benchmark" +simulator: "ai2thor:v5_0_0" +task_class: "easi.tasks.my_benchmark.task.MyBenchmarkTask" +max_steps: 50 +dataset: + source: huggingface + repo_id: "username/my-benchmark-dataset" + subset: null + split: "test" +``` + +**Multi-split task (recommended for benchmarks with difficulty splits):** + +```yaml +# _base.yaml — shared config, NOT registered as a task (no `name` key) +display_name: "My Benchmark" +simulator: "ai2thor:v5_0_0" +task_class: "easi.tasks.my_benchmark.task.MyBenchmarkTask" +max_steps: 50 +dataset: + source: huggingface + repo_id: "username/my-benchmark-dataset" +simulator_configs: + screen_height: 500 + screen_width: 500 + additional_deps: + - "gym" +agent: + prompt_builder: "easi.tasks.my_benchmark.prompts.MyPromptBuilder" + prompt_builder_kwargs: + n_shot: 3 + generation_kwargs: + temperature: 0 + max_tokens: 2048 +``` + +```yaml +# my_benchmark_base.yaml — registered as task "my_benchmark_base" +extends: _base.yaml +name: my_benchmark_base +display_name: "My Benchmark Base Split" +dataset: + split: "base" +``` + +```yaml +# my_benchmark_hard.yaml — registered as task "my_benchmark_hard" +extends: _base.yaml +name: my_benchmark_hard +display_name: "My Benchmark Hard Split" +dataset: + split: "hard" +``` + +**YAML fields reference:** + +| Field | Required | Description | +|---|---|---| +| `name` | Yes | Task name used in CLI (`easi start `) | +| `display_name` | No | Human-readable name | +| `description` | No | Task description | +| `simulator` | Yes | Simulator key, e.g. `"ai2thor:v5_0_0"` or `"dummy:v1"` | +| `task_class` | Yes | Dotted import path to your Task class | +| `max_steps` | No | Max steps per episode (default: 500) | +| `dataset.source` | Yes | `"huggingface"` or `"local"` | +| `dataset.repo_id` | HF only | HuggingFace repo ID | +| `dataset.split` | HF only | Dataset split name | +| `dataset.subset` | No | Dataset subset (auto-detected if single) | +| `dataset.zip_files` | No | List of zip files to extract after download | +| `simulator_configs` | No | Dict passed to bridge as `simulator_kwargs` | +| `simulator_configs.additional_deps` | No | Extra pip packages for the simulator env | +| `simulator_configs.env_vars` | No | Environment variables for bridge subprocess | +| `agent.prompt_builder` | No | Dotted path to PromptBuilder class | +| `agent.prompt_builder_kwargs` | No | Kwargs passed to PromptBuilder constructor | +| `agent.generation_kwargs` | No | LLM generation defaults (temperature, max_tokens, etc.) | +| `extends` | No | Relative path to base YAML for template inheritance | + +### Step 3: Implement the Task Class + +Subclass `BaseTask` and implement 3 abstract methods: + +```python +"""My benchmark task for EASI.""" +from __future__ import annotations + +from pathlib import Path + +from easi.core.base_task import BaseTask +from easi.core.episode import StepResult + +class MyBenchmarkTask(BaseTask): + + def get_task_yaml_path(self) -> Path: + """Return path to the default YAML config.""" + return Path(__file__).parent / "my_benchmark_base.yaml" + + def format_reset_config(self, episode: dict) -> dict: + """Map a dataset row to simulator reset kwargs. + + The returned dict is passed to bridge.reset(reset_config). + Include everything the bridge needs to initialize the episode. + """ + return { + "episode_id": episode.get("id", "unknown"), + "scene": episode["scene"], + "instruction": episode["instruction"], + # Add all fields your bridge needs + } + + def evaluate_episode( + self, episode: dict, trajectory: list[StepResult] + ) -> dict[str, float]: + """Score a completed episode. + + Args: + episode: The raw dataset row dict. + trajectory: List of StepResult from the agent-simulator loop. + Each StepResult has: observation, reward, done, info. + + Returns: + Dict of metric_name -> float. These are saved to result.json + and passed to aggregate_results(). + """ + if not trajectory: + return {"task_success": 0.0, "num_steps": 0.0} + + last_step = trajectory[-1] + return { + "task_success": last_step.info.get("task_success", 0.0), + "num_steps": float(len(trajectory)), + } +``` + +**Optional overrides:** + +```python + # Static action space (if not dynamic per-episode) + def _build_action_space(self) -> list[str]: + return ["MoveForward", "TurnLeft", "TurnRight", "Stop"] + + # Custom bridge script (if task needs special env wrapping) + def get_bridge_script_path(self) -> Path: + return Path(__file__).parent / "bridge.py" + + # Extract instruction from episode (if field name differs) + def get_instruction(self, episode: dict) -> str: + return episode.get("task_description", self.name) + + # Dynamic action space per episode (e.g., EB-Alfred) + def on_episode_reset(self, observation, agent) -> None: + new_actions = observation.metadata.get("action_space", "").split(",") + if new_actions and hasattr(agent, "update_action_space"): + agent.update_action_space(new_actions) + + # Custom cross-episode aggregation + def aggregate_results(self, records): + """Custom aggregation with access to trajectories and episode data. + + Args: + records: list[EpisodeRecord], each with: + - record.episode: raw dataset row dict + - record.trajectory: list[StepResult] + - record.episode_results: dict from evaluate_episode() + """ + n = len(records) + successes = sum(r.episode_results.get("task_success", 0) for r in records) + return { + "success_rate": round(successes / n, 4) if n else 0.0, + "avg_steps": round( + sum(r.episode_results.get("num_steps", 0) for r in records) / n, 2 + ) if n else 0.0, + } + + # Built-in episodes for testing without dataset download + def _get_builtin_episodes(self) -> list[dict]: + return [{"id": 0, "scene": "TestScene", "instruction": "test"}] +``` + +### Step 4: Implement the Bridge (if needed) + +If your benchmark uses a vendored environment that differs from the simulator's default bridge, create a task-specific bridge. The bridge runs as a **subprocess** in the simulator's conda env. + +```python +"""My benchmark bridge — wraps vendored env via BaseBridge. + +This script runs inside the simulator's conda env (e.g., Python 3.10). +""" +from __future__ import annotations + +import sys +from pathlib import Path + +# Ensure repo root is importable +_repo_root = Path(__file__).resolve().parents[3] +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +from easi.simulators.base_bridge import BaseBridge + + +class MyBenchmarkBridge(BaseBridge): + """Wraps vendored MyEnv via BaseBridge.""" + + def _create_env(self, reset_config, simulator_kwargs): + """Create the environment. Called once on first reset.""" + from easi.tasks.my_benchmark.vendor.my_env import MyEnv + resolution = simulator_kwargs.get("screen_height", 500) + return MyEnv(resolution=resolution) + + def _on_reset(self, env, reset_config): + """Reset with episode data. Return observation.""" + return env.reset(scene=reset_config["scene"]) + + def _on_step(self, env, action_text): + """Execute action. Return (obs, reward, done, info) tuple.""" + return env.step(action_text) + + def _extract_image(self, obs): + """Extract RGB numpy array (H, W, 3) from observation.""" + return obs["rgb"] # np.ndarray + + def _extract_info(self, info): + """Filter info dict to JSON-serializable values.""" + return { + "task_success": float(info.get("success", 0.0)), + "feedback": str(info.get("feedback", "")), + } + + +if __name__ == "__main__": + MyBenchmarkBridge.main() +``` + +**BaseBridge hooks:** + +| Method | Default | Override when | +|---|---|---| +| `_create_env(reset_config, simulator_kwargs)` | `raise NotImplementedError` | Always (required) | +| `_extract_image(obs)` | `raise NotImplementedError` | Always (required) | +| `_on_reset(env, reset_config)` | `env.reset()` | Env needs episode data passed to reset | +| `_on_step(env, action_text)` | `env.step(action_text)` | Action needs translation (text → int, etc.) | +| `_extract_info(info)` | Filters to scalar values | You want specific keys in result.json | + +### Step 5: Implement the PromptBuilder (optional) + +For LLM-powered evaluation, create a task-specific PromptBuilder: + +```python +"""Prompt builder for My Benchmark.""" +from __future__ import annotations + +import json + +from easi.agents.prompt_builder import validate_action_name, _encode_image_base64 +from easi.core.episode import Action +from easi.core.memory import AgentMemory + + +class MyPromptBuilder: + """Builds prompts for My Benchmark's ReAct agent.""" + + def __init__(self, n_shot=3, use_feedback=True): + self.n_shot = n_shot + self.use_feedback = use_feedback + + def build_messages(self, memory: AgentMemory) -> list[dict]: + """Build LLM messages from agent memory. + + Args: + memory: AgentMemory with task_description, action_space, + current_observation, steps (history), action_history. + + Returns: + List of message dicts: [{"role": "system", "content": [...]}, ...] + """ + messages = [] + + # System message with instructions + system_text = f"You are an agent. Task: {memory.task_description}\n" + system_text += f"Actions: {', '.join(memory.action_space)}" + messages.append({"role": "system", "content": [{"type": "text", "text": system_text}]}) + + # Current observation (with image) + user_content = [] + if memory.current_observation and memory.current_observation.rgb_path: + img_b64 = _encode_image_base64(memory.current_observation.rgb_path) + user_content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + }) + user_content.append({"type": "text", "text": "What action should you take?"}) + messages.append({"role": "user", "content": user_content}) + + return messages + + def parse_response(self, llm_response: str, memory: AgentMemory) -> list[Action]: + """Parse LLM text response into validated Action objects. + + Returns: + List of Action objects. Empty list = parsing failed. + """ + try: + data = json.loads(llm_response) + except json.JSONDecodeError: + return [] + + plan = data.get("executable_plan", []) + actions = [] + for entry in plan: + name = entry.get("action_name", "") + validated = validate_action_name(name, memory.action_space) + if validated: + actions.append(Action(action_name=validated)) + else: + break # Stop at first invalid action + return actions +``` + +### Step 6: Write Tests + +Follow the pattern in existing test files. All tests run offline (no simulator, no LLM). + +```python +"""Tests for My Benchmark task (offline, no simulator needed).""" +import pytest +from pathlib import Path + +from easi.core.episode import Observation, StepResult + + +class TestMyBenchmarkTask: + @pytest.fixture + def task(self): + from easi.tasks.my_benchmark.task import MyBenchmarkTask + return MyBenchmarkTask() + + def test_name(self, task): + assert task.name == "my_benchmark_base" + + def test_simulator_key(self, task): + assert task.simulator_key == "ai2thor:v5_0_0" + + def test_action_space(self, task): + assert len(task.action_space) > 0 + + def test_max_steps(self, task): + assert task.max_steps == 50 + + def test_format_reset_config(self, task): + episode = {"id": 0, "scene": "TestScene", "instruction": "test"} + config = task.format_reset_config(episode) + assert "scene" in config + + def test_evaluate_episode(self, task): + obs = Observation(rgb_path="/tmp/fake.png") + trajectory = [ + StepResult(observation=obs, reward=0.0, done=True, + info={"task_success": 1.0}), + ] + result = task.evaluate_episode({"id": 0}, trajectory) + assert "task_success" in result + + def test_evaluate_empty_trajectory(self, task): + result = task.evaluate_episode({}, []) + assert result["task_success"] == 0.0 + + def test_bridge_script_path(self, task): + path = task.get_bridge_script_path() + if path is not None: + assert path.exists() + + def test_registry_discovers_task(self): + from easi.tasks.registry import list_tasks + tasks = list_tasks() + assert "my_benchmark_base" in tasks +``` + +Run tests: `pytest tests/test_my_benchmark.py -v` + +### Step 7: Verify + +```bash +# All existing tests still pass +pytest tests/ -v --timeout=60 + +# Registry discovers your task +easi task list | grep my_benchmark + +# Task info looks correct +easi task info my_benchmark_base + +# Dummy agent smoke test (no LLM needed) +easi start my_benchmark_base --agent dummy --max-episodes 1 +``` + +--- + +## Adding a New Simulator + +Less common. Only needed when a benchmark uses a simulator not yet in EASI. + +### Structure + +``` +easi/simulators// +├── __init__.py +├── manifest.yaml # Declares name, versions, classes +└── / + ├── __init__.py + ├── simulator.py # Subclass of BaseSimulator + ├── env_manager.py # Subclass of BaseEnvironmentManager + ├── bridge.py # Default bridge script + ├── conda_env.yaml # Conda environment spec + └── requirements.txt # Pip dependencies +``` + +### manifest.yaml + +```yaml +name: my_sim +display_name: "My Simulator" +default_version: "v1_0_0" +versions: + v1_0_0: + description: "My Simulator 1.0.0" + simulator_class: "easi.simulators.my_sim.v1_0_0.simulator.MySimSimulator" + env_manager_class: "easi.simulators.my_sim.v1_0_0.env_manager.MySimEnvManager" + python_version: "3.10" +``` + +### simulator.py + +```python +from pathlib import Path +from easi.core.base_simulator import BaseSimulator + +class MySimSimulator(BaseSimulator): + @property + def name(self) -> str: + return "my_sim" + + @property + def version(self) -> str: + return "v1_0_0" + + def _get_bridge_script_path(self) -> Path: + return Path(__file__).parent / "bridge.py" +``` + +### env_manager.py + +```python +from pathlib import Path +from easi.core.base_env_manager import BaseEnvironmentManager + +class MySimEnvManager(BaseEnvironmentManager): + @property + def simulator_name(self) -> str: + return "my_sim" + + @property + def version(self) -> str: + return "v1_0_0" + + @property + def needs_display(self) -> bool: + return True # Set True if simulator needs X11/Xvfb + + def get_conda_env_yaml_path(self) -> Path: + return Path(__file__).parent / "conda_env.yaml" + + def get_requirements_txt_path(self) -> Path: + return Path(__file__).parent / "requirements.txt" + + def get_system_deps(self) -> list[str]: + return ["conda"] # Add "xvfb" if needs_display is True + + def get_validation_import(self) -> str: + return "from my_sim import Controller; print('ok')" +``` + +### Verify + +```bash +easi env list # Should show my_sim +easi env install my_sim # Install conda env +easi env check my_sim # Validate +easi sim test my_sim # Smoke test bridge +``` + +--- + +## Vendoring Benchmark Code + +When integrating an external benchmark (e.g., from EmbodiedBench), vendor only the environment code you need: + +1. **Create `vendor/` directory** in your task folder +2. **Copy only the env class** (not the full benchmark runner/evaluator) +3. **Remove external dependencies** the benchmark used for logging, dataset loading, gym registration — EASI handles all of these +4. **Adapt the interface**: + - `reset(episode)` accepts an episode dict (from EASI's dataset) + - `step(action)` returns `(obs, reward, done, info)` tuple + - Remove internal image saving (bridge handles this) + - Remove internal logging (EASI's logger handles this) + +--- + +## Key Conventions + +### Logging + +```python +from easi.utils.logging import get_logger +logger = get_logger(__name__) +# Use logger.info(), logger.warning(), logger.error() +# Use logger.trace() for verbose debug output +# NEVER use print() +``` + +### Imports + +```python +from __future__ import annotations # Always first import +``` + +### Testing + +- All tests run offline (mock simulators, no LLM calls) +- Use `Observation(rgb_path="/tmp/fake.png")` for test observations +- Use `StepResult(observation=obs, done=True, info={...})` for test trajectories +- Test file naming: `tests/test_.py` + +### summary.json Structure + +```json +{ + "num_episodes": 100, + "model": "gpt-4o", + "backend": "openai", + "llm_usage": {"total_calls": 500, "total_tokens": 150000}, + "metrics": { + "success_rate": 0.73, + "avg_num_steps": 24.3, + "avg_task_success": 0.73 + } +} +``` + +Metrics (from `task.aggregate_results()`) are nested under `"metrics"`. Run metadata stays at the top level. + +--- + +## Existing Benchmarks Reference + +| Task | Simulator | Splits | Action Type | Max Steps | +|---|---|---|---|---| +| `dummy_task` | `dummy:v1` | 1 | 4 discrete text | 100 | +| `ebalfred_*` | `ai2thor:v2_1_0` | 6 | ~133 skill text | 50 | +| `ebnavigation_*` | `ai2thor:v5_0_0` | 5 | 8 discrete int | 20 | +| `ebhabitat_*` | `habitat_sim:v0_3_0` | 4 | 4 discrete text | varies | +| `ebmanipulation_*` | `coppeliasim:v4_1_0` | 4 | continuous params | varies | + +Use the closest existing task as a template when adding a new one. The `dummy_task` is the simplest reference; `ebalfred` is the most complete. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ba9bbf2 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,196 @@ +# CLAUDE.md + +This file provides guidance to Claude Code when working with this repository. + +## Project Overview + +EASI is a unified evaluation framework for embodied AI agents. It has two layers: + +1. **Static benchmarks** — VLMEvalKit and lmms-eval submodules for VLM evaluation (image Q&A, spatial reasoning). These are mature and rarely modified. +2. **Embodied agent evaluation** (`easi/` library) — The active development focus. Subprocess-isolated simulators, multi-split tasks, and LLM-powered agents for interactive benchmarks (EB-Alfred, EB-Navigation, EB-Habitat, EB-Manipulation, HAZARD). + +Most development work happens in the `easi/` library. + +## Quick Reference + +```bash +# Setup +pip install -e ".[dev]" + +# Run tests (540 tests, ~4-5min) +pytest tests/ -v --timeout=60 + +# CLI +easi task list # List all tasks +easi env list # List all simulators +easi env install ai2thor:v2_1_0 # Install simulator env +easi sim test dummy # Smoke test simulator +easi start dummy_task --agent dummy # Run evaluation (no LLM) + +# Real evaluation +easi start ebalfred_base --agent react --backend openai --model gpt-4o +easi start ebalfred_base --agent react --backend openai --model gpt-4o --num-parallel 4 +easi start --resume ./logs/ebalfred_base/ +``` + +## Architecture + +``` +easi/ +├── core/ # Abstract base classes + dataclasses +│ ├── base_task.py # BaseTask — task interface +│ ├── base_simulator.py # BaseSimulator — simulator interface +│ ├── base_agent.py # BaseAgent — agent interface +│ ├── base_env_manager.py # BaseEnvironmentManager — conda env setup +│ ├── episode.py # Observation, Action, StepResult, EpisodeRecord +│ ├── memory.py # AgentMemory — shared agent/prompt state +│ ├── protocols.py # Runtime-checkable Protocol interfaces +│ └── exceptions.py # EASIError hierarchy +│ +├── agents/ # Agent implementations +│ ├── dummy_agent.py # Random action picker (testing) +│ ├── react_agent.py # ReAct agent with multi-action buffering +│ └── prompt_builder.py # PromptBuilder protocol + DefaultPromptBuilder +│ +├── simulators/ # Simulator implementations (subprocess-isolated) +│ ├── base_bridge.py # BaseBridge — Gym-like env wrapper for IPC +│ ├── subprocess_runner.py # SubprocessRunner — process lifecycle +│ ├── registry.py # Auto-discovery via manifest.yaml +│ ├── dummy/v1/ # In-memory testing simulator +│ ├── ai2thor/v2_1_0/ # AI2-THOR 2.1.0 (EB-Alfred, Python 3.8) +│ ├── ai2thor/v5_0_0/ # AI2-THOR 5.0.0 (EB-Navigation, Python 3.10) +│ ├── habitat_sim/v0_3_0/ # Habitat-Sim 0.3.0 (EB-Habitat, Python 3.9) +│ ├── coppeliasim/v4_1_0/ # CoppeliaSim 4.1.0 (EB-Manipulation, Python 3.10) +│ └── tdw/v1_11_23/ # ThreeDWorld 1.11.23 (HAZARD, Python 3.10) +│ +├── tasks/ # Benchmark task definitions +│ ├── registry.py # Auto-discovery via *.yaml glob +│ ├── yaml_utils.py # Template inheritance (extends) +│ ├── dataset.py # HuggingFace + local dataset loading +│ ├── scaffold.py # Task boilerplate generator +│ ├── dummy_task/ # 3-episode testing task +│ ├── ebalfred/ # EB-Alfred (6 splits) +│ ├── ebnavigation/ # EB-Navigation (5 splits) +│ ├── ebhabitat/ # EB-Habitat (4 splits) +│ └── ebmanipulation/ # EB-Manipulation (4 splits) +│ +├── evaluation/ # Evaluation orchestration +│ ├── runner.py # EvaluationRunner (sequential) +│ ├── parallel_runner.py # ParallelRunner (thread-pool, any backend) +│ └── metrics.py # default_aggregate + legacy aggregate_metrics +│ +├── llm/ # LLM client infrastructure +│ ├── client.py # LLMClient (LiteLLM wrapper, any backend) +│ ├── api_client.py # LLMApiClient (legacy OpenAI-only) +│ ├── server_manager.py # vLLM server lifecycle +│ ├── dummy_server.py # Dummy LLM server for testing +│ └── utils.py # Backend config (parse, validate, split kwargs) +│ +├── communication/ # Filesystem IPC (parent <-> bridge subprocess) +│ ├── filesystem.py # Atomic JSON read/write, command/response +│ └── schemas.py # Command/response schemas +│ +├── utils/ # Shared utilities +│ ├── logging.py # Centralized logging (TRACE/DEBUG/INFO/WARNING/ERROR) +│ ├── import_utils.py # Dynamic class importing +│ ├── json_repair.py # LLM response JSON repair +│ └── ... # paths, locking, system_deps, spinner +│ +└── cli.py # CLI entry point (easi command) +``` + +## Key Patterns + +### Subprocess Isolation +Each simulator runs in its own conda environment (potentially different Python version). The bridge script communicates with the parent process via filesystem IPC (atomic JSON files in a temp directory). This enables Python 3.8 for AI2-THOR v2.1 while the host runs Python 3.10+. + +### Multi-Split Tasks +Each task folder can have multiple YAML configs. The task registry discovers all `*.yaml` files and registers each as a separate task (e.g., `ebalfred_base`, `ebalfred_spatial`). Split YAMLs use template inheritance via `extends: _base.yaml`. + +### Pluggable Metrics +Two-phase metric system: +- **Per-episode**: `task.evaluate_episode(episode, trajectory) -> dict` (always user-defined) +- **Cross-episode**: `task.aggregate_results(records: list[EpisodeRecord]) -> dict` (optional override, default averages all numeric keys) + +Metrics are nested under `summary["metrics"]` in summary.json, separated from run metadata. + +### ReAct Agent + PromptBuilder +The agent uses a PromptBuilder protocol for task-specific prompts. The builder constructs messages from AgentMemory and parses LLM responses into validated Actions. Multi-action buffering: LLM returns a plan, agent executes one action per step, clears buffer on failure. + +### Auto-Discovery +- **Simulators**: Discovered via `easi/simulators/*/manifest.yaml` +- **Tasks**: Discovered via `easi/tasks/*/*.yaml` +- Both use dotted import paths to load classes dynamically + +## CLI Commands + +| Command | Description | +|---|---| +| `easi env list` | List available simulators | +| `easi env install ` | Install simulator conda env | +| `easi env check ` | Verify environment is ready | +| `easi task list` | List available tasks | +| `easi task info ` | Show task details | +| `easi task download ` | Download task dataset | +| `easi task scaffold ` | Generate new task boilerplate | +| `easi sim test ` | Smoke test a simulator bridge | +| `easi start ` | Run evaluation | +| `easi llm-server` | Start dummy LLM server | + +### Key `easi start` Options + +```bash +easi start \ + --agent {dummy|react} \ + --backend {vllm|openai|anthropic|gemini} \ + --model \ + --num-parallel \ # Thread-pool parallelism + --max-episodes \ + --resume \ + --output-dir ./logs \ + --llm-kwargs '{"temperature": 0.7}' \ + --vllm-instances \ # Number of vLLM server instances (default: 1) + --vllm-gpus 0,1,2,3 \ # GPUs for vLLM (split across instances) + --sim-gpus 4,5 # GPUs for simulator rendering +``` + +**Parallel vLLM example** (2 instances with TP=2, 8 workers, simulators on separate GPUs): +```bash +easi start ebalfred_base \ + --agent react --backend vllm \ + --model Qwen/Qwen2.5-VL-72B-Instruct \ + --num-parallel 8 --vllm-instances 2 \ + --vllm-gpus 0,1,2,3 --sim-gpus 4,5 \ + --llm-kwargs '{"tensor_parallel_size": 2}' +``` + +## Output Structure + +``` +logs//_/ + config.json # CLI options + resolved config + summary.json # {"num_episodes": N, "metrics": {...}, "model": "...", ...} + episodes/ + 000_/ + result.json # Per-episode metrics + trajectory.jsonl # Action log (one JSON line per step) + step_0000.png # Observation images +``` + +## Testing + +```bash +pytest tests/ -v --timeout=60 # Full suite (540 tests) +pytest tests/test_metrics.py -v # Specific file +``` + +All tests run offline without simulators or LLMs. Tests mock subprocess bridges and use DummyTask + DummyAgent. + +## Logging Convention + +```python +from easi.utils.logging import get_logger +logger = get_logger(__name__) +``` + +Use `logger.info()` for user-facing messages, `logger.trace()` for detailed debug output. Never use `print()`. diff --git a/debug_mirror_one_episode.sh b/debug_mirror_one_episode.sh new file mode 100644 index 0000000..b30d72e --- /dev/null +++ b/debug_mirror_one_episode.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash +# Debug helper: run one LHPR-VLN mirror-prompt episode end-to-end and +# capture the exact flipped PNGs the builder ships to the LLM. The +# capture happens inside LHPRVLNMirrorSFTPromptBuilder (guarded by the +# MIRROR_DEBUG_DIR env var), so what lands on disk *is* what the model +# saw — no post-processing. +# +# Usage: +# bash debug_mirror_one_episode.sh [split] +# +# split = unseen_val_filtered_sft_mirror (default, smallest split) +# = unseen_test_filtered_sft_mirror +# +# Env overrides: +# SIM_GPUS (default: 0) — habitat_sim render GPU(s) +# LLM_GPUS (default: 1,2) — vLLM GPU(s) +# TP (default: 2) — tensor_parallel_size +# +# After the run: +# debug_logs//_/episodes/000_ep_0/step_*.png +# original sim-rendered front/left/right frames. +# debug_logs//_/episodes/000_ep_0/mirror/step_*.png +# flipped + slot-swapped frames the mirror builder served to the LLM. +# debug_logs//_/episodes/000_ep_0/mirror/step_NNNN_history_*.png +# historical front-views the builder sampled (also flipped). + +set -euo pipefail + +MODEL="${1:-}" +SPLIT="${2:-unseen_val_filtered_sft_mirror}" +TASK="lhpr_vln_${SPLIT}" +OUTPUT_DIR="./debug_logs" +SIM_GPUS="${SIM_GPUS:-0}" +LLM_GPUS="${LLM_GPUS:-1,2}" +TP="${TP:-2}" + +if [ -z "$MODEL" ]; then + echo "Error: model path required." + echo "Usage: $0 [split]" + exit 1 +fi + +REPO_ROOT="$(cd "$(dirname "$0")" && pwd)" +cd "$REPO_ROOT" + +# shellcheck disable=SC1091 +source .venv/bin/activate + +# Staging dir for the builder's per-step dumps. Uses a timestamp so repeat +# runs don't collide. Moved into the final episode dir after easi finishes. +TS="$(date +%Y%m%d_%H%M%S)" +MIRROR_STAGE="$REPO_ROOT/$OUTPUT_DIR/_mirror_stage_${TS}" +mkdir -p "$MIRROR_STAGE" +export MIRROR_DEBUG_DIR="$MIRROR_STAGE" + +echo "=== Mirror debug run ===" +echo " model: $MODEL" +echo " task: $TASK" +echo " staging: $MIRROR_STAGE" +echo " sim_gpus: $SIM_GPUS llm_gpus: $LLM_GPUS tp: $TP" +echo "" + +easi start "$TASK" \ + --agent react --backend vllm \ + --model "$MODEL" \ + --episodes :1 \ + --num-parallel 1 \ + --sim-gpus "$SIM_GPUS" \ + --llm-gpus "$LLM_GPUS" \ + --llm-kwargs "{\"tensor_parallel_size\": $TP, \"trust_remote_code\": true, \"startup_timeout\": 900, \"skip_special_tokens\": false}" \ + --output-dir "$OUTPUT_DIR" \ + --verbosity TRACE + +RUN_DIR="$(ls -td "$OUTPUT_DIR/$TASK"/*/ | head -1)" +EP_DIR="$(ls -d "$RUN_DIR"episodes/*/ | head -1)" + +# Move the staged mirror PNGs into the episode dir so they sit next to the +# originals for easy A/B. +MIRROR_DIR="${EP_DIR}mirror" +mkdir -p "$MIRROR_DIR" +if compgen -G "$MIRROR_STAGE/*.png" > /dev/null; then + mv "$MIRROR_STAGE"/*.png "$MIRROR_DIR/" +fi +rmdir "$MIRROR_STAGE" 2>/dev/null || true + +FRAME_COUNT="$(find "$MIRROR_DIR" -maxdepth 1 -name '*.png' | wc -l)" + +echo "" +echo "=== Done ===" +echo " run dir: $RUN_DIR" +echo " originals: ${EP_DIR}step_*.png" +echo " mirror: ${MIRROR_DIR}/ ($FRAME_COUNT frames)" +echo "" +echo "Sanity checks:" +echo " jq . '${EP_DIR}result.json'" +echo " head '${EP_DIR}trajectory.jsonl' | jq ." +echo "" +echo "Look for a step where llm_response contains <|left|> and action is" +echo "turn_right (or vice versa). That confirms the remap fired." diff --git a/docs/cli-reference.md b/docs/cli-reference.md new file mode 100644 index 0000000..ed30fb2 --- /dev/null +++ b/docs/cli-reference.md @@ -0,0 +1,720 @@ +# EASI CLI Reference + +Complete reference for the `easi` command-line interface. + +## Global Options + +All commands support: + +``` +--verbosity {TRACE,DEBUG,INFO,WARNING,ERROR} + Set logging verbosity (default: INFO) +``` + +--- + +## `easi start` — Run Evaluation + +Execute evaluation on one or more tasks with an agent and LLM backend. + +``` +easi start [TASK ...] [options] +``` + +### Task Selection + +| Argument | Description | +|---|---| +| `TASK` | Task name(s) as positional arguments (e.g., `ebalfred_base`) | +| `--tasks TASKS` | Comma-separated task names (overrides positional args) | + +### Agent + +| Option | Description | +|---|---| +| `--agent {dummy,react}` | **Required.** Agent type to use | + +- `dummy` — Random action picker (no LLM needed) +- `react` — ReAct agent with multi-action buffering (requires LLM backend) + +### LLM Backend + +| Option | Description | +|---|---| +| `--backend {vllm,custom,openai,anthropic,gemini,dummy}` | LLM backend (required for `react` agent) | +| `--model MODEL` | Model identifier (HuggingFace ID for `vllm`, registry name for `custom`) | +| `--llm-url URL` | LLM server base URL (for external servers) | +| `--port PORT` | Port for local LLM server (default: 8080) | +| `--llm-kwargs JSON` | Extra LLM/server kwargs as JSON string | +| `--max-retries N` | Max retry attempts on transient LLM errors (default: 3) | + +**Model identifiers by backend:** + +| Backend | Example `--model` values | Description | +|---|---|---| +| `vllm` | `Qwen/Qwen2.5-VL-72B-Instruct` | vLLM-supported HuggingFace models | +| `custom` | `qwen3_vl`, `echo` | Custom model registry name (see `easi model list`) | +| `openai` | `gpt-4o`, `gpt-5.2-2025-12-11` | OpenAI API models | +| `anthropic` | `claude-sonnet-4-20250514` | Anthropic API models | +| `gemini` | `gemini-2.0-flash` | Google Gemini API models | + +### Execution Control + +| Option | Description | +|---|---| +| `--num-parallel N` | Parallel simulator instances (default: 1). Works with any backend. | +| `--episodes FILTER` | Episode filter: IDs (`2,5,7`), ranges (`10:20`), or `:N` for first N (default: all) | +| `--seed SEED` | Random seed for agent reproducibility | +| `--render-platform PLATFORM` | Rendering platform override (default: simulator's preference). See [Render Platforms](#render-platforms). | + +### GPU Allocation (Local Backends) + +These options apply to local LLM backends (`vllm` and `custom`). + +| Option | Description | +|---|---| +| `--llm-instances N` | Number of LLM server instances to start (default: 1). Each runs on a subset of `--llm-gpus`. | +| `--llm-gpus IDS` | Comma-separated GPU IDs for LLM inference (e.g., `0,1,2,3`). GPUs are split evenly across instances. | +| `--sim-gpus IDS` | Comma-separated GPU IDs for simulator rendering (e.g., `4,5`). Sets `CUDA_VISIBLE_DEVICES` for simulator subprocesses. | + +**Notes:** +- `--llm-gpus` is required when `--llm-instances > 1`. +- `--llm-gpus` and `--sim-gpus` must not overlap. +- GPU IDs are validated against hardware at startup (via `nvidia-smi`). +- All LLM instances start in parallel (processes spawned first, then health-checked concurrently). +- Workers are assigned to LLM instances via round-robin (e.g., 8 workers across 2 instances → 4 workers per instance). +- These options are ignored with a warning if `--backend` is not `vllm` or `custom`. +- Local backends use a 600s default timeout (vs 120s for API backends) to handle request queueing when workers outnumber server instances. + +### Data & Output + +| Option | Description | +|---|---| +| `--output-dir PATH` | Base output directory (default: `./logs`) | +| `--data-dir PATH` | Dataset cache directory (default: `./datasets`) | +| `--refresh-data` | Delete cached dataset and re-download | + +### Resume + +| Option | Description | +|---|---| +| `--resume DIR` | Resume from a previous run directory (contains `config.json`) | + +When resuming, completed episodes are skipped and evaluation continues from the next episode. The task name is loaded from the saved config. New CLI arguments override saved values. + +### Examples + +```bash +# Quick test with dummy agent (no LLM) +easi start dummy_task --agent dummy + +# OpenAI API +easi start ebalfred_base --agent react --backend openai --model gpt-4o + +# Anthropic API +easi start ebalfred_base --agent react --backend anthropic --model claude-sonnet-4-20250514 + +# vLLM (auto-starts server) +easi start ebalfred_base --agent react --backend vllm \ + --model Qwen/Qwen2.5-VL-72B-Instruct --port 8080 + +# vLLM (external server) +easi start ebalfred_base --agent react --backend vllm \ + --model Qwen/Qwen2.5-VL-72B-Instruct --llm-url http://localhost:8000 + +# Custom generation kwargs +easi start ebalfred_base --agent react --backend openai --model gpt-4o \ + --llm-kwargs '{"temperature": 0.7, "max_tokens": 500}' + +# Limit episodes +easi start ebalfred_base --agent dummy --episodes :5 --seed 42 + +# Parallel evaluation (API backend) +easi start ebalfred_base --agent react --backend openai --model gpt-4o \ + --num-parallel 4 + +# Parallel evaluation with local vLLM (1 instance, all GPUs) +easi start ebalfred_base --agent react --backend vllm \ + --model Qwen/Qwen2.5-VL-7B-Instruct --num-parallel 8 + +# Parallel vLLM with 2 instances (TP=2 each) + separate sim GPUs +easi start ebalfred_base --agent react --backend vllm \ + --model Qwen/Qwen2.5-VL-72B-Instruct \ + --num-parallel 8 --llm-instances 2 \ + --llm-gpus 0,1,2,3 --sim-gpus 4,5 \ + --llm-kwargs '{"tensor_parallel_size": 2}' + +# External multi-URL vLLM (pre-started servers, no auto-management) +easi start ebalfred_base --agent react --backend vllm \ + --model Qwen/Qwen2.5-VL-72B-Instruct --num-parallel 8 \ + --llm-url http://localhost:8000/v1,http://localhost:8001/v1 + +# Custom model server (auto-starts, single instance) +easi start ebalfred_base --agent react --backend custom \ + --model qwen3_vl \ + --llm-kwargs '{"model_path": "Qwen/Qwen3-VL-8B-Instruct"}' + +# Custom model with parallel workers and 2 server instances +easi start ebalfred_base --agent react --backend custom \ + --model qwen3_vl --num-parallel 8 \ + --llm-instances 2 --llm-gpus 0,1,2,3 --sim-gpus 4,5 \ + --llm-kwargs '{"model_path": "Qwen/Qwen3-VL-8B-Instruct"}' + +# Custom model with generation kwargs +easi start ebalfred_base --agent react --backend custom \ + --model qwen3_vl \ + --llm-kwargs '{"model_path": "Qwen/Qwen3-VL-8B-Instruct", "temperature": 0.7, "max_tokens": 2048}' + +# Multiple tasks +easi start ebalfred_base ebnavigation_base --agent react \ + --backend openai --model gpt-4o + +# Multiple tasks (CSV form) +easi start --tasks ebalfred_base,ebnavigation_base --agent react \ + --backend openai --model gpt-4o + +# Resume a previous run +easi start --resume ./logs/ebalfred_base/20260215_093045_gpt-4o + +# Force dataset re-download +easi start ebalfred_base --agent dummy --refresh-data + +# Override render platform (e.g., force native display) +easi start ebmanipulation_base --agent react --backend openai --model gpt-4o \ + --render-platform native + +# Verbose logging +easi start ebalfred_base --agent dummy --verbosity TRACE +``` + +### Output Structure + +``` +//_[_]/ +├── config.json # CLI options + resolved configuration +├── summary.json # Aggregated metrics +└── episodes/ + ├── 000_/ + │ ├── result.json # Per-episode metrics + │ ├── trajectory.jsonl # Action log (one JSON line per step) + │ ├── step_0000.png # Observation images + │ └── ... + └── 001_/ + └── ... +``` + +**`summary.json` format:** +```json +{ + "num_episodes": 10, + "model": "gpt-4o", + "agent": "react", + "metrics": { + "success_rate": 0.7, + "avg_steps": 12.3 + } +} +``` + +### Notes + +- `--num-parallel > 1` works with any backend. It uses a thread pool with one simulator per thread. +- When using `--backend vllm` or `--backend custom` without `--llm-url`, local server(s) are auto-started and stopped after evaluation. +- `--resume` cannot be combined with multiple tasks. +- `--llm-kwargs` is split into server kwargs (e.g., `tensor_parallel_size`, `dtype`, `model_path`) and generation kwargs (e.g., `temperature`, `max_tokens`). Server kwargs are passed to the server process; generation kwargs are sent per-request. +- For `--backend custom`, `model_path` in `--llm-kwargs` specifies the HuggingFace model ID or local path to weights. The `--model` flag selects which custom model class to use from the registry. + +--- + +## `easi env` — Manage Simulator Environments + +### `easi env list` + +List all available simulators and their versions. + +```bash +easi env list +``` + +Output shows each simulator as `name:version`, with the default version marked. + +--- + +### `easi env install ` + +Install a simulator environment (creates a conda env with required dependencies). + +``` +easi env install [--reinstall] [--with-task-deps TASK] +``` + +| Argument | Description | +|---|---| +| `simulator` | Simulator key (e.g., `ai2thor:v2_1_0`, `tdw:v1_11_23`) | +| `--reinstall` | Remove existing environment and install from scratch | +| `--with-task-deps TASK` | Also install additional dependencies from a specific task | + +**Examples:** + +```bash +# Install AI2-THOR v2.1.0 +easi env install ai2thor:v2_1_0 + +# Reinstall from scratch +easi env install ai2thor:v2_1_0 --reinstall + +# Install with task-specific dependencies +easi env install ai2thor:v2_1_0 --with-task-deps ebalfred_base +``` + +The created conda environment is named `easi__` (e.g., `easi_ai2thor_v2_1_0`). + +--- + +### `easi env check ` + +Check if a simulator environment is ready for use. + +```bash +easi env check ai2thor:v2_1_0 +``` + +Reports missing system dependencies, the Python executable path, and whether the environment is ready. + +--- + +## `easi task` — Manage Tasks + +### `easi task list` + +List all available tasks discovered in the registry. + +```bash +easi task list +``` + +Output format: `task_name -- display_name (simulator: simulator_key)` + +--- + +### `easi task info ` + +Display detailed information about a specific task. + +```bash +easi task info ebalfred_base +``` + +Shows task name, description, simulator key, and max steps. + +--- + +### `easi task download ` + +Download and cache the task dataset locally. + +``` +easi task download [--refresh-data] +``` + +| Argument | Description | +|---|---| +| `task` | Task name (e.g., `ebalfred_base`) | +| `--refresh-data` | Delete cached dataset and re-download from source | + +**Examples:** + +```bash +easi task download ebalfred_base +easi task download ebalfred_base --refresh-data +``` + +--- + +### `easi task scaffold ` + +Generate boilerplate code for a new benchmark task. + +``` +easi task scaffold [--simulator SIM] [--max-steps N] +``` + +| Argument | Description | +|---|---| +| `name` | Task name in snake_case (e.g., `my_benchmark`) | +| `--simulator SIM` | Simulator key to use (default: `dummy:v1`) | +| `--max-steps N` | Maximum steps per episode (default: 50) | + +**Example:** + +```bash +easi task scaffold my_benchmark --simulator ai2thor:v2_1_0 --max-steps 100 +``` + +Creates: +- `easi/tasks/my_benchmark/bridge.py` +- `easi/tasks/my_benchmark/task.py` +- `easi/tasks/my_benchmark/my_benchmark.yaml` +- `tests/test_my_benchmark.py` + +--- + +## `easi sim` — Control Simulators + +### `easi sim test ` + +Run a smoke test on a simulator (reset + N steps). + +``` +easi sim test [--steps N] [--timeout SECONDS] [--render-platform PLATFORM] +``` + +| Argument | Description | +|---|---| +| `simulator` | Simulator key (e.g., `dummy`, `ai2thor:v5_0_0`) | +| `--steps N` | Number of steps to execute (default: 5) | +| `--timeout SECONDS` | Bridge startup timeout (default: 200.0) | +| `--render-platform PLATFORM` | Rendering platform override (default: simulator's preference). See [Render Platforms](#render-platforms). | + +**Examples:** + +```bash +easi sim test dummy +easi sim test ai2thor:v5_0_0 --steps 10 +easi sim test ai2thor:v2_1_0 --steps 3 --timeout 300 +easi sim test coppeliasim:v4_1_0 --render-platform native +``` + +Executes `MoveAhead` for each step and reports observations and rewards. + +--- + +## `easi model` — Manage Custom Models + +### `easi model list` + +List all custom models discovered in the registry. + +```bash +easi model list +``` + +Output shows each model name and its display name. + +--- + +### `easi model info ` + +Display detailed information about a custom model. + +```bash +easi model info qwen3_vl +``` + +Shows model name, display name, description, model class, and default kwargs. + +--- + +### Custom Model Overview + +Custom models allow running model architectures not supported by vLLM. Each model is defined by: + +1. **A Python class** extending `BaseModelServer` with `load()`, `generate()`, and `unload()` methods +2. **A `manifest.yaml`** file for auto-discovery by the registry + +Models live in `easi/llm/models//` and are auto-discovered at startup. + +**Built-in custom models:** + +| Name | Description | +|---|---| +| `echo` | Echoes input back (testing) | +| `qwen3_vl` | Qwen3-VL vision-language model (8B, 72B, etc.) | + +**Installation:** + +Custom models require additional dependencies not included in the base install: + +```bash +pip install -e ".[custom-models]" +``` + +This installs `torch`, `transformers`, `accelerate`, `fastapi`, `uvicorn`, and `Pillow`. + +**How it works:** + +When you run `--backend custom --model `: +1. The registry looks up the model class from `easi/llm/models//manifest.yaml` +2. A FastAPI HTTP server is started as a subprocess, loading the model +3. The server exposes an OpenAI-compatible `/v1/chat/completions` endpoint +4. LiteLLM connects to it transparently via the `openai/` prefix +5. Manifest `default_kwargs` (e.g., `dtype`, `attn_implementation`) are merged with CLI `--llm-kwargs` + +**Adding a new custom model:** + +Create a directory under `easi/llm/models/` with: + +``` +easi/llm/models/my_model/ +├── __init__.py +├── manifest.yaml +└── model.py +``` + +`manifest.yaml`: +```yaml +name: my_model +display_name: "My Custom Model" +description: "Description of the model" +model_class: "easi.llm.models.my_model.model.MyModel" +default_kwargs: + dtype: "bfloat16" +``` + +`model.py`: +```python +from easi.llm.models.base_model_server import BaseModelServer + +class MyModel(BaseModelServer): + def load(self, model_path: str, device: str, **kwargs) -> None: + # Load model weights + ... + + def generate(self, messages: list[dict], **kwargs) -> str: + # messages are in OpenAI format (with image_url for vision) + # Return generated text + ... + + def unload(self) -> None: + # Release GPU memory + ... +``` + +Helper utilities are available in `easi.llm.models.helpers`: +- `extract_images(messages)` — Extract PIL Images from base64 image_url entries +- `extract_text_only(messages)` — Concatenate all text content +- `extract_by_role(messages)` — Group text by role + +--- + +## `easi ps` — Show EASI Processes + +List all running EASI-related processes (LLM servers, simulator bridges) and optionally kill them. + +``` +easi ps [--kill] +``` + +| Option | Description | +|---|---| +| `--kill` | Send SIGTERM (then SIGKILL) to all found EASI processes | + +**Detected process types:** + +| Type | Description | +|---|---| +| `http_server` | Custom model server (`easi.llm.models.http_server`) | +| `api_server` | vLLM server (`vllm.entrypoints.openai.api_server`) | +| `dummy_server` | Dummy LLM server (`easi.llm.dummy_server`) | +| `bridge` | Simulator bridge subprocess | + +**Output includes:** +- PID, status, CPU%, MEM%, process type, and command +- `[ZOMBIE]` tag for zombie processes +- GPU memory held by EASI processes (via `nvidia-smi`) + +**Examples:** + +```bash +# List all EASI processes +easi ps + +# Kill all orphaned EASI processes (e.g., after Ctrl+C) +easi ps --kill +``` + +--- + +## `easi llm-server` — Dummy LLM Server + +Start a minimal OpenAI-compatible dummy LLM server for testing. + +``` +easi llm-server [--host HOST] [--port PORT] [--mode MODE] [--action-space ACTION ...] +``` + +| Option | Description | +|---|---| +| `--host HOST` | Server host (default: `127.0.0.1`) | +| `--port PORT` | Server port (default: `8000`) | +| `--mode {fixed,random}` | Response mode: `fixed` returns first action, `random` returns random action | +| `--action-space ACTION ...` | Space-separated action names (default: `MoveAhead TurnLeft TurnRight Stop`) | + +**Examples:** + +```bash +# Default dummy server +easi llm-server + +# Custom port and fixed mode +easi llm-server --port 8080 --mode fixed + +# Custom action space +easi llm-server --mode random --action-space Forward Backward TurnLeft TurnRight +``` + +**Endpoints:** +- `POST /v1/chat/completions` — OpenAI-compatible chat completion +- `GET /health` — Health check + +Use with `easi start`: +```bash +# Terminal 1: start dummy server +easi llm-server --port 8000 + +# Terminal 2: run evaluation against it +easi start ebalfred_base --agent react --backend openai \ + --model dummy --llm-url http://localhost:8000 +``` + +--- + +## Render Platforms + +Render platforms control how a simulator gets a display for rendering. Each simulator declares a default platform and a set of supported platforms in its manifest. Use `--render-platform` to override. + +### Built-in Platforms + +| Platform | Description | +|---|---| +| `auto` | Use native display if `DISPLAY` is set, fall back to xvfb | +| `native` | Require an existing `DISPLAY` (fails if none) | +| `xvfb` | Wrap with `xvfb-run` (virtual X11 framebuffer) | +| `egl` | GPU-accelerated headless rendering via EGL (no X11) | +| `headless` | No display at all (simulator has native headless support) | +| `xorg` | Auto-managed Xorg server per GPU (GPU-accelerated X11, defaults to GPU 0, use `--sim-gpus` to specify) | + +### Custom Platforms + +Some simulators register custom render platform classes in their manifest that extend the built-in platforms with simulator-specific environment variables. For example, CoppeliaSim defines custom `auto`, `native`, and `xvfb` platforms that set `QT_QPA_PLATFORM_PLUGIN_PATH` and control the `COPPELIASIM_HEADLESS` flag. + +Custom platforms are resolved automatically — when you pass `--render-platform xvfb` for a CoppeliaSim task, the CoppeliaSim-specific xvfb platform is used instead of the generic one. + +### Platform Defaults by Simulator + +| Simulator | Default | Supported | +|---|---|---| +| `dummy:v1` | `headless` | `headless` | +| `ai2thor:v2_1_0` | `auto` | `auto`, `native`, `xvfb`, `xorg` | +| `ai2thor:v5_0_0` | `auto` | `auto`, `native`, `xvfb`, `xorg` | +| `habitat_sim:v0_3_0` | `auto` | `auto`, `native`, `xvfb`, `egl`, `xorg` | +| `coppeliasim:v4_1_0` | `auto` | `auto`, `native`, `xvfb`, `xorg` | +| `tdw:v1_11_23` | `auto` | `auto`, `native`, `xvfb`, `xorg` | + +--- + +## Environment Variables + +The CLI itself does not use environment variables, but the LLM backends require API keys: + +| Variable | Backend | +|---|---| +| `OPENAI_API_KEY` | `openai` | +| `ANTHROPIC_API_KEY` | `anthropic` | +| `GOOGLE_API_KEY` | `gemini` | + +These are handled by the underlying LiteLLM client. The `vllm` and `custom` backends do not require API keys (a dummy key is used automatically for the local OpenAI-compatible server). + +--- + +## Available Simulators + +| Key | Description | +|---|---| +| `dummy:v1` | In-memory testing simulator (no external deps) | +| `ai2thor:v2_1_0` | AI2-THOR 2.1.0 (EB-Alfred, Python 3.8) | +| `ai2thor:v5_0_0` | AI2-THOR 5.0.0 (EB-Navigation, Python 3.10) | +| `habitat_sim:v0_3_0` | Habitat-Sim 0.3.0 (EB-Habitat, Python 3.9) | +| `coppeliasim:v4_1_0` | CoppeliaSim 4.1.0 (EB-Manipulation, Python 3.10) | +| `tdw:v1_11_23` | ThreeDWorld 1.11.23 (HAZARD, Python 3.10) | + +--- + +## Available Tasks + +| Task | Simulator | Description | +|---|---|---| +| `dummy_task` | `dummy:v1` | 3-episode testing task | +| `ebalfred_base` | `ai2thor:v2_1_0` | EB-Alfred base split | +| `ebalfred_spatial` | `ai2thor:v2_1_0` | EB-Alfred spatial reasoning | +| `ebalfred_commonsense` | `ai2thor:v2_1_0` | EB-Alfred commonsense reasoning | +| `ebalfred_complex` | `ai2thor:v2_1_0` | EB-Alfred complex tasks | +| `ebalfred_long_horizon` | `ai2thor:v2_1_0` | EB-Alfred long-horizon tasks | +| `ebalfred_image` | `ai2thor:v2_1_0` | EB-Alfred image understanding | +| `ebnavigation_base` | `ai2thor:v5_0_0` | EB-Navigation base split | +| `ebnavigation_spatial` | `ai2thor:v5_0_0` | EB-Navigation spatial | +| `ebnavigation_commonsense` | `ai2thor:v5_0_0` | EB-Navigation commonsense | +| `ebnavigation_complex` | `ai2thor:v5_0_0` | EB-Navigation complex | +| `ebnavigation_image` | `ai2thor:v5_0_0` | EB-Navigation image | +| `ebhabitat_base` | `habitat_sim:v0_3_0` | EB-Habitat base split | +| `ebhabitat_spatial` | `habitat_sim:v0_3_0` | EB-Habitat spatial | +| `ebhabitat_commonsense` | `habitat_sim:v0_3_0` | EB-Habitat commonsense | +| `ebhabitat_complex` | `habitat_sim:v0_3_0` | EB-Habitat complex | +| `ebmanipulation_base` | `coppeliasim:v4_1_0` | EB-Manipulation base split | +| `ebmanipulation_spatial` | `coppeliasim:v4_1_0` | EB-Manipulation spatial | +| `ebmanipulation_commonsense` | `coppeliasim:v4_1_0` | EB-Manipulation commonsense | +| `ebmanipulation_complex` | `coppeliasim:v4_1_0` | EB-Manipulation complex | +| `hazard_fire` | `tdw:v1_11_23` | HAZARD fire scenario | +| `hazard_flood` | `tdw:v1_11_23` | HAZARD flood scenario | +| `hazard_wind` | `tdw:v1_11_23` | HAZARD wind scenario | + +--- + +## Workflow Examples + +### First-Time Setup and Evaluation + +```bash +# 1. Install simulator +easi env install ai2thor:v2_1_0 --with-task-deps ebalfred_base + +# 2. Verify environment +easi env check ai2thor:v2_1_0 + +# 3. Smoke test the simulator +easi sim test ai2thor:v2_1_0 + +# 4. Download dataset +easi task download ebalfred_base + +# 5. Run evaluation +easi start ebalfred_base --agent react --backend openai --model gpt-4o +``` + +### Creating a New Benchmark + +```bash +# 1. Scaffold the task +easi task scaffold my_benchmark --simulator ai2thor:v2_1_0 --max-steps 100 + +# 2. Edit the generated files: +# easi/tasks/my_benchmark/bridge.py — implement _create_env(), _extract_image() +# easi/tasks/my_benchmark/task.py — implement format_reset_config() +# easi/tasks/my_benchmark/my_benchmark.yaml — configure dataset source + +# 3. Run tests +pytest tests/test_my_benchmark.py -v + +# 4. Test with dummy agent +easi start my_benchmark --agent dummy +``` + +### Batch Evaluation Across Tasks + +```bash +# Run all EB-Alfred splits +easi start --tasks ebalfred_base,ebalfred_spatial,ebalfred_commonsense \ + --agent react --backend openai --model gpt-4o --num-parallel 4 + +# Results saved to ./logs/// for each task +``` diff --git a/docs/easi-prompt-format-reference.md b/docs/easi-prompt-format-reference.md new file mode 100644 index 0000000..2af819b --- /dev/null +++ b/docs/easi-prompt-format-reference.md @@ -0,0 +1,528 @@ +# EASI Standard Prompt Format Reference + +This document defines the standard prompt format for EASI benchmarks that do not provide their own prompt format. EmbodiedBench benchmarks (EB-Alfred, EB-Navigation, EB-Habitat, EB-Manipulation) retain their original published formats for reproducibility. + +## Scope + +**Applies to:** New benchmarks and benchmarks without a published prompt format (VLN-CE R2R, VLN-CE RxR, ManipulaTHOR, AI2-THOR Rearrangement, HAZARD text-plan variants, all future tasks). + +**Does not apply to:** EmbodiedBench benchmarks (retain original format), HAZARD multiple-choice format (fundamentally different paradigm). + +--- + +## System Prompt Structure + +The system prompt uses markdown sections in this fixed order. Required sections must always be present. Optional sections are included only when the benchmark needs them. + +``` +## Role and Environment [REQUIRED] +## Observation Description [OPTIONAL] +## Available Actions [REQUIRED] +## Strategy [OPTIONAL] +## Guidelines [REQUIRED] +## Response Format [REQUIRED] +``` + +### Role and Environment (Required) + +One paragraph establishing who the agent is and what environment it operates in. Keep it concise — 2-3 sentences. + +``` +You are a robot navigating in a 3D indoor environment. You observe the +environment through a front-facing camera and must follow natural language +instructions to navigate to a goal location. +``` + +### Observation Description (Optional) + +Describes what each piece of environment feedback means. Include this section only when the benchmark provides dynamic feedback (geodesic distances, object states, GPS coordinates, etc.) that the LLM needs context to interpret. + +``` +## Observation Description +- **Distance to goal**: Geodesic (shortest walkable path) distance in meters + to the goal location. Decreases as you get closer. +- **Held object**: Name of the object currently being held, or "none". +``` + +Do NOT describe the image observation here — the LLM can see the image directly. Only describe non-visual feedback that appears as text. + +### Available Actions (Required) + +List all actions the agent can take, with a brief description of what each does. Include any validity constraints. + +``` +## Available Actions +- move_forward: Move forward by 0.25 meters +- turn_left: Turn left by 30 degrees +- turn_right: Turn right by 30 degrees +- look_up: Tilt camera up by 30 degrees +- look_down: Tilt camera down by 30 degrees +- stop: Stop and end navigation (use ONLY when you believe you have reached + the destination) +``` + +For benchmarks with parameterized actions, include the parameter format: + +``` +- find : Navigate to the named receptacle +- pick_up : Pick up the named object (must be nearby, hands empty) +``` + +### Strategy (Optional) + +Benchmark-specific tactical advice. Include this when the task has non-obvious strategies that improve performance. Keep it actionable. + +``` +## Strategy +1. Follow the instruction step by step, matching landmarks mentioned +2. Use move_forward to advance and turn_left/turn_right to change direction +3. Use stop ONLY when confident you have reached the described destination +``` + +### Guidelines (Required) + +Universal rules that apply regardless of benchmark. Always include these core guidelines, adding benchmark-specific ones as needed: + +``` +## Guidelines +1. Always output at least one action in executable_plan. +2. Only use actions from the Available Actions list. +3. If previous actions failed, reason about why and try a different approach. +4. Do not repeatedly execute the same action sequence. +5. Keep your plan efficient and concise. +``` + +### Response Format (Required) + +Always use the standard 4-field JSON format: + +``` +## Response Format +Output a JSON object with exactly these 4 fields: +{ + "visual_state_description": "Describe what you see in the current image", + "reasoning_and_reflection": "Reason about your situation, reflect on + history and feedback", + "language_plan": "Describe your next plan in natural language", + "executable_plan": [{"action": ""}] +} + +You may include multiple actions in executable_plan. Actions execute +sequentially. +``` + +--- + +## Response Format Specification + +### JSON Schema + +All EASI prompt builders (within scope) must use this response format: + +```json +{ + "visual_state_description": "string", + "reasoning_and_reflection": "string", + "language_plan": "string", + "executable_plan": [ + {"action": "action_name"}, + {"action": "action_name"} + ] +} +``` + +### Field Definitions + +| Field | Type | Purpose | +|-------|------|---------| +| `visual_state_description` | string | Describe what the agent observes in the current image | +| `reasoning_and_reflection` | string | Reason about current state, reflect on history and feedback, explain why previous actions may have failed | +| `language_plan` | string | Natural language description of the planned actions | +| `executable_plan` | array | Ordered list of actions to execute | + +### Action Entry Format + +Each action in `executable_plan` is an object with an `action` field: + +```json +{"action": "move_forward"} +``` + +Do NOT use `action_id` — numeric IDs are an internal concept. The LLM should always reference actions by name. + +For parameterized actions, include a `params` field: + +```json +{"action": "find", "params": {"target": "Cabinet_2"}} +``` + +### Parsing Rules + +1. Apply `fix_json()` before parsing (handles common LLM JSON errors) +2. Accept both `{"action": "name"}` and `{"action_name": "name"}` +3. Validate each action name against `memory.action_space` +4. On first invalid action, stop parsing (don't skip — the plan is ordered) +5. On complete parse failure, return empty action list (agent will re-prompt) + +--- + +## Action History + +Action history provides the LLM with context about what happened in previous steps. It is a text section embedded in the user message. + +### Format + +``` +## Action History (last N steps) +Step 0: move_forward -> Distance to goal: 8.2m +Step 1: turn_left -> Distance to goal: 8.2m +Step 2: move_forward -> Distance to goal: 7.9m +``` + +Each entry: `Step {i}: {action_name} -> {feedback}` + +If feedback is disabled (`use_feedback: false`), omit the feedback portion: + +``` +Step 0: move_forward +Step 1: turn_left +Step 2: move_forward +``` + +### Configuration + +| YAML Parameter | Type | Default | Description | +|----------------|------|---------|-------------| +| `action_history_len` | int | 20 | Maximum entries to include. 0 = disabled. | +| `use_feedback` | bool | true | Include environment feedback in each entry. | + +### Data Source + +Action history comes from `memory.action_history`, which returns `list[tuple[str, str]]` — pairs of `(action_name, feedback_string)`. + +### Truncation + +When history exceeds `action_history_len`, keep only the most recent entries: + +```python +history = memory.action_history[-self.action_history_len:] +``` + +--- + +## Chat History + +Chat history provides the LLM with its own previous responses, enabling it to maintain reasoning continuity across steps. It is a text section embedded in the user message, parallel to action history. + +### Format + +``` +## Chat History (last N responses) +[Step 0 Response] +{"visual_state_description": "I see a hallway...", "reasoning_and_reflection": "I need to...", "language_plan": "Move forward...", "executable_plan": [{"action": "move_forward"}]} + +[Step 1 Response] +{"visual_state_description": "I see a door...", "reasoning_and_reflection": "The door matches...", "language_plan": "Turn right...", "executable_plan": [{"action": "turn_right"}]} +``` + +Each entry is the full LLM JSON response from that step, preceded by a `[Step N Response]` header. + +### Configuration + +| YAML Parameter | Type | Default | Description | +|----------------|------|---------|-------------| +| `chat_history` | bool | false | Enable chat history section. | +| `message_window_len` | int | 5 | Maximum responses to include. | + +### Data Source + +Chat history comes from `memory.steps`, which contains `StepRecord` objects with `llm_response` fields: + +```python +if self.chat_history: + responses = [ + s.llm_response for s in memory.steps + if s.llm_response is not None + ][-self.message_window_len:] +``` + +### Interaction with Action History + +When both are enabled, action history and chat history appear as separate sections in the user message. Action history provides a compact summary; chat history provides the full reasoning. They are complementary, not redundant. + +Recommended default: `action_history_len: 20, chat_history: false`. Enable chat history only for benchmarks where maintaining reasoning continuity significantly improves performance. + +--- + +## Image Handling + +### Encoding + +All images are encoded as base64 data URLs: + +```python +data:image/png;base64,{base64_encoded_data} +``` + +### Position in Message + +Images appear BEFORE text in the user message content array: + +```python +content = [] +# Images first +content.append({"type": "image_url", "image_url": {"url": img_url}}) +# Text after +content.append({"type": "text", "text": prompt_text}) +``` + +### Multiple Images + +When a benchmark provides multiple images (e.g., RGB + depth, current + goal), label them: + +```python +# In the text portion: +"(Image 1: Current view, Image 2: Goal state)" +``` + +### Image Sources + +| Source | Field | When to Use | +|--------|-------|-------------| +| Primary RGB | `observation.rgb_path` | Always (every benchmark has this) | +| Depth | `observation.metadata["depth_path"]` | When depth sensing is relevant | +| Goal/reference | `observation.metadata["goal_rgb_path"]` | For rearrangement/comparison tasks | +| Multi-view | `observation.metadata["{view}_rgb_path"]` | For panoramic or multi-camera setups | + +--- + +## Environment Feedback + +Environment feedback is benchmark-specific dynamic information from the simulator, delivered via `observation.metadata`. The standard defines where feedback appears and how to toggle it, not what it contains. + +### Where Feedback Appears + +1. **In action history entries**: `"Step N: action -> {feedback_text}"` +2. **As a dedicated section** (optional): For rich contextual feedback that applies to the current state, not just the last action. + +``` +## Environment Feedback +Distance to goal: 5.3m +``` + +### Toggle + +Controlled by `use_feedback: true/false` in YAML config. When false, omit feedback from action history entries and omit the Environment Feedback section. + +### Common Feedback Patterns + +| Pattern | Example | Used By | +|---------|---------|---------| +| Distance to goal | `"Distance to goal: 5.3m"` | VLN-CE R2R/RxR | +| Action success/failure | `"success"` / `"fail: object not reachable"` | EB-Alfred, HAZARD | +| Object states | `"Holding: Apple_1"` | Rearrangement, ManipulaTHOR | +| Spatial info | `"Position: (3.2, 0.1, -1.5), Rotation: 90°"` | Rearrangement, ManipulaTHOR | + +--- + +## YAML Configuration Standard + +Every prompt builder should accept these common kwargs. Benchmark-specific kwargs can be added below them. + +```yaml +agent: + prompt_builder: "easi.tasks..prompts." + prompt_builder_kwargs: + # Standard kwargs (all prompt builders should support these) + use_feedback: true # Include environment feedback + action_history_len: 20 # Max action history entries (0 = disabled) + chat_history: false # Include previous model responses + message_window_len: 5 # Max chat history entries (when chat_history: true) + + # Benchmark-specific kwargs (examples) + # use_geo_distance: true # VLN-CE: show geodesic distance + # n_shot: 3 # Few-shot examples count + # use_depth: false # Depth image toggle + generation_kwargs: + temperature: 0 + max_tokens: 4096 + top_p: 0.95 +``` + +--- + +## User Message Assembly Order + +The user message is assembled in this fixed order: + +``` +[Image(s)] <- base64 encoded, before text +[Text content, assembled as:] + ## Task <- instruction / task description + ## Observation Description <- only if defined (from system prompt context) + ## Environment Feedback <- current-step feedback (if use_feedback) + ## Action History (last N steps) <- if action_history_len > 0 and has history + ## Chat History (last N responses) <- if chat_history and has history + [Response format reminder] <- brief reminder of JSON format +``` + +On the first turn, action history and chat history are empty and omitted. + +--- + +## Message Structure + +Always exactly 2 messages: + +```python +messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": [image_parts..., text_part]}, +] +``` + +The system message contains the static prompt (role, actions, strategy, guidelines, response format). The user message contains the dynamic per-step content (images, task instruction, feedback, history). + +This is simpler than multi-turn conversation and works consistently across all LLM backends. + +--- + +## Complete Example: Navigation Benchmark + +### System Prompt + +``` +## Role and Environment +You are a robot navigating in a 3D indoor environment. You observe the +environment through a front-facing camera and must follow natural language +instructions to navigate to a goal location. + +## Observation Description +- **Distance to goal**: Geodesic distance in meters to the goal. Decreases + as you approach the destination. + +## Available Actions +- move_forward: Move forward by 0.25 meters +- turn_left: Turn left by 15 degrees +- turn_right: Turn right by 15 degrees +- stop: Stop and end navigation (use ONLY when you believe you have reached + the destination described in the instruction) + +## Strategy +1. Carefully read the navigation instruction +2. Observe your surroundings in the image +3. Follow the instruction step by step, matching landmarks and directions +4. Use stop ONLY when confident you have reached the described destination + +## Guidelines +1. Always output at least one action in executable_plan. +2. Only use actions from the Available Actions list. +3. If previous actions failed, reason about why and try a different approach. +4. Do not repeatedly execute the same action sequence. +5. Keep your plan efficient and concise. + +## Response Format +Output a JSON object with exactly these 4 fields: +{ + "visual_state_description": "Describe what you see in the current image", + "reasoning_and_reflection": "Reason about your situation and history", + "language_plan": "Describe your next plan in natural language", + "executable_plan": [{"action": ""}] +} + +You may include multiple actions in executable_plan. Actions execute +sequentially. +``` + +### User Message (Step 5) + +``` +[Image: base64 encoded current view] + +## Task +Walk down the hallway and turn right into the bedroom. + +## Environment Feedback +Distance to goal: 5.3m + +## Action History (last 5 steps) +Step 0: move_forward -> Distance to goal: 8.2m +Step 1: move_forward -> Distance to goal: 7.9m +Step 2: move_forward -> Distance to goal: 7.6m +Step 3: turn_right -> Distance to goal: 7.6m +Step 4: move_forward -> Distance to goal: 7.3m + +Respond with the JSON format specified above. +``` + +--- + +## Complete Example: Object Manipulation Benchmark + +### System Prompt + +``` +## Role and Environment +You are a robotic arm in an indoor environment. You can pick up, place, and +manipulate objects on a table using discrete actions. + +## Observation Description +- **Held object**: The object currently in the gripper, or "none". +- **Nearby objects**: Objects within interaction range and their positions. + +## Available Actions +- move_to : Move the arm to the named object +- pick_up : Grasp the named object (must be nearby, gripper empty) +- place_on : Place held object on the named receptacle +- open : Open a closed receptacle +- close : Close an open receptacle +- done: Signal task completion + +## Strategy +1. Locate the target object before attempting to pick it up +2. Ensure your gripper is empty before picking up a new object +3. Navigate to the destination before placing an object + +## Guidelines +1. Always output at least one action in executable_plan. +2. Only use actions from the Available Actions list. +3. If previous actions failed, reason about why and try a different approach. +4. Do not repeatedly execute the same action sequence. +5. Keep your plan efficient and concise. + +## Response Format +Output a JSON object with exactly these 4 fields: +{ + "visual_state_description": "Describe what you see in the current image", + "reasoning_and_reflection": "Reason about your situation and history", + "language_plan": "Describe your next plan in natural language", + "executable_plan": [{"action": ""}] +} +``` + +### User Message (Step 3, with chat history enabled) + +``` +[Image: base64 encoded current view] + +## Task +Pick up the apple and place it in the bowl. + +## Environment Feedback +Held object: none +Nearby objects: Apple_1 (0.3m), Bowl_2 (1.2m) + +## Action History (last 3 steps) +Step 0: move_to Apple_1 -> success +Step 1: pick_up Apple_1 -> fail: object not reachable +Step 2: move_to Apple_1 -> success + +## Chat History (last 2 responses) +[Step 1 Response] +{"visual_state_description": "I see a red apple on the counter...", "reasoning_and_reflection": "I moved to the apple successfully...", "language_plan": "Pick up the apple", "executable_plan": [{"action": "pick_up Apple_1"}]} + +[Step 2 Response] +{"visual_state_description": "The apple is still on the counter...", "reasoning_and_reflection": "Pickup failed, I need to get closer...", "language_plan": "Move closer then pick up", "executable_plan": [{"action": "move_to Apple_1"}]} + +Respond with the JSON format specified above. +``` diff --git a/docs/superpowers/plans/2026-03-12-reverie-ce.md b/docs/superpowers/plans/2026-03-12-reverie-ce.md new file mode 100644 index 0000000..30608f1 --- /dev/null +++ b/docs/superpowers/plans/2026-03-12-reverie-ce.md @@ -0,0 +1,665 @@ +# REVERIE-CE Integration Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add REVERIE-CE as a navigation-only task in EASI, reusing the VLN-CE R2R infrastructure. + +**Architecture:** REVERIE-CE is a thin task layer inheriting from VLN-CE R2R. Same simulator (`habitat_sim:v0_1_7`), same action space (4 discrete), same metrics (SR/SPL/NE/NDTW/SDTW). Only the prompt builder and dataset are new. + +**Tech Stack:** Python 3.10+ (host), Habitat-Sim 0.1.7 (bridge subprocess, Python 3.8), HuggingFace datasets, Matterport3D scenes. + +**Spec:** `docs/superpowers/specs/2026-03-12-reverie-ce-design.md` + +--- + +## File Structure + +| Action | File | Responsibility | +|--------|------|---------------| +| Create | `easi/tasks/reverie_ce/__init__.py` | Package marker | +| Create | `easi/tasks/reverie_ce/task.py` | `ReverieCETask` — inherits `VLNCETask`, overrides paths | +| Create | `easi/tasks/reverie_ce/bridge.py` | `ReverieCEBridge` — inherits `VLNCEBridge`, standalone entry point | +| Create | `easi/tasks/reverie_ce/prompts.py` | `ReverieCEPromptBuilder` — high-level instruction prompt | +| Create | `easi/tasks/reverie_ce/actions.py` | Re-exports `get_action_space` from `vlnce_r2r.actions` | +| Create | `easi/tasks/reverie_ce/_base.yaml` | Task config pointing to REVERIE-CE HuggingFace repo | +| Create | `easi/tasks/reverie_ce/reverie_ce_val_unseen.yaml` | Val unseen split | +| Create | `easi/tasks/reverie_ce/reverie_ce_test.yaml` | Test split | +| Create | `tests/test_reverie_ce_task.py` | Unit tests for task, prompt builder, action space | +| None | `easi/tasks/vlnce_r2r/` | Imported, not modified | +| None | `easi/simulators/habitat_sim/` | Unchanged | + +--- + +## Chunk 1: Data Preparation + +### Task 1: Reformat Dynam3D Data into EASI HuggingFace Repo + +This is a one-time manual/scripted step done outside the EASI codebase. It prepares the dataset that the task will consume. + +**Files:** +- External: HuggingFace repo `oscarqjh/REVERIE-CE_easi` + +- [ ] **Step 1: Download Dynam3D REVERIE-CE data** + +Download from HuggingFace `MrZihanWang/Dynam3D`: +```bash +# Download the REVERIE-CE specific files +huggingface-cli download MrZihanWang/Dynam3D \ + data/datasets/reverie_training_data \ + data/datasets/reverie_val_unseen_data.json \ + data/datasets/reverie_test_data.json \ + data/datasets/reverie_val_unseen_gt.json \ + data/datasets/reverie_test_gt.json \ + --local-dir ./dynam3d_download +``` + +- [ ] **Step 2: Write a conversion script to reshape into JSONL** + +Create a temporary script (not committed to EASI) that: +1. Reads the per-scene JSON training files from `reverie_training_data/` +2. Reads the val/test single JSON files +3. Normalises each episode into the VLN-CE R2R JSONL format: + +```python +# Expected output format per line: +{ + "episode_id": str(item["episode_id"]), + "scene_id": item["scene_id"].replace("mp3d/", "").replace(".glb", "").split("/")[-1], + "instruction": item["instruction"]["instruction_text"], + "start_position": item["start_position"], + "start_rotation": item["start_rotation"], + "goal_position": item["goals"][0]["position"], + "geodesic_distance": item["info"]["geodesic_distance"], + "gt_locations": item["reference_path"] +} +``` + +Note: `scene_id` in VLN-CE R2R format is just the scan name (e.g. `cV4RVeZvu5T`), not the full path. The bridge constructs the full path from `data_dir + mp3d/ + scene_id + scene_id.glb`. + +4. Writes `train.jsonl`, `val_unseen.jsonl`, `test.jsonl` +5. Copies ground truth files as-is + +- [ ] **Step 3: Verify episode counts match Dynam3D source** + +```bash +wc -l data/train.jsonl data/val_unseen.jsonl data/test.jsonl +``` + +- [ ] **Step 4: Upload to HuggingFace** + +```bash +huggingface-cli upload oscarqjh/REVERIE-CE_easi ./reverie_ce_easi/ \ + --repo-type dataset +``` + +Include `mp3d_scenes.zip` (same Matterport3D scenes as R2R — can be copied from the R2R dataset repo). + +--- + +## Chunk 2: Task Implementation + +### Task 2: Create Action Space Module + +**Files:** +- Create: `easi/tasks/reverie_ce/actions.py` +- Test: `tests/test_reverie_ce_task.py` + +- [ ] **Step 1: Write failing test** + +```python +# tests/test_reverie_ce_task.py +class TestActionSpace: + def test_has_four_actions(self): + from easi.tasks.reverie_ce.actions import get_action_space + actions = get_action_space() + assert len(actions) == 4 + assert "move_forward" in actions + assert "turn_left" in actions + assert "turn_right" in actions + assert "stop" in actions +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestActionSpace -v` +Expected: FAIL (module not found) + +- [ ] **Step 3: Create the module** + +```python +# easi/tasks/reverie_ce/__init__.py +# (empty) + +# easi/tasks/reverie_ce/actions.py +"""REVERIE-CE action space — same as VLN-CE R2R.""" +from easi.tasks.vlnce_r2r.actions import get_action_space # noqa: F401 +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestActionSpace -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add easi/tasks/reverie_ce/__init__.py easi/tasks/reverie_ce/actions.py tests/test_reverie_ce_task.py +git commit -m "feat(reverie-ce): add action space module" +``` + +--- + +### Task 3: Create Task Class + +**Files:** +- Create: `easi/tasks/reverie_ce/task.py` +- Test: `tests/test_reverie_ce_task.py` + +- [ ] **Step 1: Write failing tests** + +```python +# Add to tests/test_reverie_ce_task.py +import json +import pytest +from unittest.mock import MagicMock +from easi.core.episode import EpisodeRecord, Observation, StepResult + + +class TestReverieCETask: + @pytest.fixture + def task(self): + from easi.tasks.reverie_ce.task import ReverieCETask + mock_config = { + "name": "reverie_ce_val_unseen", + "display_name": "REVERIE-CE Val Unseen", + "simulator": "habitat_sim:v0_1_7", + "task_class": "easi.tasks.reverie_ce.task.ReverieCETask", + "max_steps": 500, + "dataset": {"source": "huggingface", "repo_id": "oscarqjh/REVERIE-CE_easi", "split": "val_unseen"}, + "simulator_configs": {}, + "agent": {"prompt_builder": "easi.tasks.reverie_ce.prompts.ReverieCEPromptBuilder"}, + } + task = ReverieCETask.__new__(ReverieCETask) + task._config = mock_config + task._yaml_path = None + task._action_space = None + return task + + def test_action_space(self, task): + actions = task._build_action_space() + assert actions == ["move_forward", "turn_left", "turn_right", "stop"] + + def test_format_reset_config(self, task): + episode = { + "episode_id": "50001", + "scene_id": "cV4RVeZvu5T", + "instruction": "Go to the laundry room and get the cushion", + "start_position": [1.0, 0.5, -2.0], + "start_rotation": [0, 0.707, 0, 0.707], + "goal_position": [4.5, 0.5, 1.2], + "geodesic_distance": 10.5, + "gt_locations": [[1.0, 0.5, -2.0], [2.0, 0.5, -1.0]], + "_data_dir": "/data/reverie_ce", + } + config = task.format_reset_config(episode) + assert config["scene_id"] == "cV4RVeZvu5T" + assert config["data_dir"] == "/data/reverie_ce" + assert json.loads(config["start_position"]) == [1.0, 0.5, -2.0] + + def test_evaluate_episode_success(self, task): + info = { + "success": 1.0, "oracle_success": 1.0, "spl": 0.8, + "navigation_error": 1.5, "ndtw": 0.9, "sdtw": 0.85, + "path_length": 8.0, + } + obs = Observation(rgb_path="/tmp/step.png") + step = StepResult(observation=obs, done=True, info=info) + result = task.evaluate_episode({}, [step]) + assert result["success"] == 1.0 + assert result["spl"] == 0.8 + + def test_evaluate_episode_empty(self, task): + result = task.evaluate_episode({}, []) + assert result["success"] is None + assert result["path_length"] == 0.0 + + def test_aggregate_results(self, task): + records = [ + EpisodeRecord(episode={}, trajectory=[], episode_results={ + "success": 1.0, "oracle_success": 1.0, "spl": 0.8, + "navigation_error": 1.5, "ndtw": 0.9, "sdtw": 0.85, + "path_length": 8.0, "steps_taken": 30.0, + }), + EpisodeRecord(episode={}, trajectory=[], episode_results={ + "success": 0.0, "oracle_success": 0.0, "spl": 0.0, + "navigation_error": 6.0, "ndtw": 0.3, "sdtw": 0.0, + "path_length": 12.0, "steps_taken": 50.0, + }), + ] + summary = task.aggregate_results(records) + assert summary["num_episodes"] == 2 + assert summary["SR"] == 0.5 + assert summary["SPL"] == 0.4 + + def test_bridge_script_path(self, task): + path = task.get_bridge_script_path() + assert path.name == "bridge.py" + assert "reverie_ce" in str(path) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestReverieCETask -v` +Expected: FAIL (ReverieCETask not found) + +- [ ] **Step 3: Implement the task class** + +```python +# easi/tasks/reverie_ce/task.py +"""REVERIE-CE task for EASI. + +Navigation-only evaluation of REVERIE in continuous environments. +Inherits from VLNCETask — same metrics, same bridge protocol. +""" +from __future__ import annotations + +from pathlib import Path + +from easi.tasks.vlnce_r2r.task import VLNCETask +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +class ReverieCETask(VLNCETask): + + def get_task_yaml_path(self) -> Path: + return Path(__file__).parent / "_base.yaml" + + def get_bridge_script_path(self) -> Path: + return Path(__file__).parent / "bridge.py" +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestReverieCETask -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add easi/tasks/reverie_ce/task.py tests/test_reverie_ce_task.py +git commit -m "feat(reverie-ce): add ReverieCETask inheriting VLNCETask" +``` + +--- + +### Task 4: Create Bridge + +**Files:** +- Create: `easi/tasks/reverie_ce/bridge.py` + +The bridge must be a standalone script (runs in Python 3.8 subprocess). It inherits from `VLNCEBridge` but provides its own `__main__` entry point. + +- [ ] **Step 1: Create the bridge module** + +```python +# easi/tasks/reverie_ce/bridge.py +"""REVERIE-CE bridge — inherits VLN-CE R2R bridge. + +Runs inside the easi_habitat_sim_v0_1_7 conda env (Python 3.8). +Inherits all reset/step/extract logic from VLNCEBridge. + +Usage: + python bridge.py --workspace /tmp/easi_xxx [--simulator-kwargs '{}'] +""" +from __future__ import annotations + +import sys +from pathlib import Path + +_repo_root = Path(__file__).resolve().parents[3] +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +from easi.tasks.vlnce_r2r.bridge import VLNCEBridge # noqa: E402 + + +class ReverieCEBridge(VLNCEBridge): + """Bridge for REVERIE-CE. Identical to VLN-CE R2R for now.""" + pass + + +if __name__ == "__main__": + ReverieCEBridge.main() +``` + +- [ ] **Step 2: Verify bridge script path resolves correctly** + +The test from Task 3 (`test_bridge_script_path`) already validates this. + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestReverieCETask::test_bridge_script_path -v` +Expected: PASS + +- [ ] **Step 3: Commit** + +```bash +git add easi/tasks/reverie_ce/bridge.py +git commit -m "feat(reverie-ce): add bridge inheriting VLNCEBridge" +``` + +--- + +### Task 5: Create Prompt Builder + +**Files:** +- Create: `easi/tasks/reverie_ce/prompts.py` +- Test: `tests/test_reverie_ce_task.py` + +- [ ] **Step 1: Write failing tests** + +```python +# Add to tests/test_reverie_ce_task.py +class TestReverieCEPromptBuilder: + @pytest.fixture + def mock_encode(self): + # Must patch in vlnce_r2r.prompts where the function is actually + # called (super().build_messages() resolves it there, not in + # reverie_ce.prompts). + import easi.tasks.vlnce_r2r.prompts as prompts_mod + original = prompts_mod._encode_image_base64 + prompts_mod._encode_image_base64 = lambda x: "data:image/png;base64,AAAA" + yield + prompts_mod._encode_image_base64 = original + + def _make_memory(self, action_history=None): + memory = MagicMock() + memory.task_description = "Go to the laundry room and bring me the blue cushion" + memory.action_space = ["move_forward", "turn_left", "turn_right", "stop"] + memory.current_observation = Observation( + rgb_path="/tmp/test.png", + metadata={"geo_distance": "5.3"}, + ) + memory.action_history = action_history or [] + memory.steps = [] + return memory + + def test_system_prompt_mentions_high_level(self, mock_encode): + from easi.tasks.reverie_ce.prompts import ReverieCEPromptBuilder + builder = ReverieCEPromptBuilder() + memory = self._make_memory() + messages = builder.build_messages(memory) + system_msg = messages[0]["content"] + assert "high-level" in system_msg.lower() or "described location" in system_msg.lower() + + def test_build_messages_has_image(self, mock_encode): + from easi.tasks.reverie_ce.prompts import ReverieCEPromptBuilder + builder = ReverieCEPromptBuilder() + memory = self._make_memory() + messages = builder.build_messages(memory) + user_content = messages[1]["content"] + image_blocks = [b for b in user_content if b.get("type") == "image_url"] + assert len(image_blocks) == 1 + + def test_build_messages_has_instruction(self, mock_encode): + from easi.tasks.reverie_ce.prompts import ReverieCEPromptBuilder + builder = ReverieCEPromptBuilder() + memory = self._make_memory() + messages = builder.build_messages(memory) + text_blocks = [b["text"] for b in messages[1]["content"] if b.get("type") == "text"] + full_text = "\n".join(text_blocks) + assert "laundry room" in full_text + + def test_build_messages_has_distance(self, mock_encode): + from easi.tasks.reverie_ce.prompts import ReverieCEPromptBuilder + builder = ReverieCEPromptBuilder() + memory = self._make_memory() + messages = builder.build_messages(memory) + text_blocks = [b["text"] for b in messages[1]["content"] if b.get("type") == "text"] + full_text = "\n".join(text_blocks) + assert "5.3" in full_text + + def test_parse_response_valid(self): + from easi.tasks.reverie_ce.prompts import ReverieCEPromptBuilder + builder = ReverieCEPromptBuilder() + memory = MagicMock() + memory.action_space = ["move_forward", "turn_left", "turn_right", "stop"] + response = json.dumps({ + "visual_state_description": "I see a hallway", + "reasoning_and_reflection": "Need to find the laundry room", + "language_plan": "Move forward", + "executable_plan": [{"action": "move_forward"}], + }) + actions = builder.parse_response(response, memory) + assert len(actions) == 1 + assert actions[0].action_name == "move_forward" + + def test_parse_response_invalid_json(self): + from easi.tasks.reverie_ce.prompts import ReverieCEPromptBuilder + builder = ReverieCEPromptBuilder() + memory = MagicMock() + memory.action_space = ["move_forward", "turn_left", "turn_right", "stop"] + actions = builder.parse_response("not json", memory) + assert actions == [] +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestReverieCEPromptBuilder -v` +Expected: FAIL (module not found) + +- [ ] **Step 3: Implement the prompt builder** + +```python +# easi/tasks/reverie_ce/prompts.py +"""REVERIE-CE prompt builder. + +Adapted from VLN-CE R2R for REVERIE's high-level instruction style. +REVERIE instructions describe a target location/object rather than +step-by-step route directions. +""" +from __future__ import annotations + +from easi.tasks.vlnce_r2r.prompts import VLNCEPromptBuilder +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +SYSTEM_PROMPT = """\ +## Role and Environment +You are a robot navigating in a 3D indoor environment. You observe the \ +environment through a front-facing camera and must navigate to the location \ +described in a high-level natural language instruction. + +## Observation Description +- **Distance to goal**: Geodesic (shortest walkable path) distance in meters \ +to the described location. Decreases as you get closer. + +## Available Actions +- move_forward: Move forward by 0.25 meters +- turn_left: Turn left by 15 degrees +- turn_right: Turn right by 15 degrees +- stop: Stop and end navigation (use ONLY when you believe you have reached \ +the described location) + +## Strategy +1. Read the instruction carefully — it describes a target location or object \ +in the environment, not a step-by-step route +2. Observe your surroundings in the image +3. Reason about which direction the described location is likely in +4. Navigate room by room, using landmarks and room types to orient yourself +5. Use stop ONLY when you are confident you have reached the described location + +## Guidelines +1. Always output at least one action in executable_plan. +2. Only use actions from the Available Actions list. +3. If previous actions failed, reason about why and try a different approach. +4. Do not repeatedly execute the same action sequence. +5. Keep your plan efficient and concise. + +## Response Format +Output a JSON object with exactly these 4 fields: +{ + "visual_state_description": "Describe what you see in the current image", + "reasoning_and_reflection": "Reason about your situation, reflect on \ +history and feedback", + "language_plan": "Describe your next navigation plan in natural language", + "executable_plan": [{"action": ""}] +} + +You may include multiple actions in executable_plan. Actions execute \ +sequentially.""" + + +class ReverieCEPromptBuilder(VLNCEPromptBuilder): + """Prompt builder for REVERIE-CE benchmark. + + Inherits message construction and response parsing from VLNCEPromptBuilder. + Overrides only the system prompt to frame the task around high-level + instructions rather than step-by-step route following. + """ + + def build_messages(self, memory): + # Use parent's build_messages but swap the system prompt + messages = super().build_messages(memory) + messages[0]["content"] = SYSTEM_PROMPT + return messages +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py::TestReverieCEPromptBuilder -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add easi/tasks/reverie_ce/prompts.py tests/test_reverie_ce_task.py +git commit -m "feat(reverie-ce): add prompt builder for high-level instructions" +``` + +--- + +## Chunk 3: YAML Configs and Integration + +### Task 6: Create YAML Configs + +**Files:** +- Create: `easi/tasks/reverie_ce/_base.yaml` +- Create: `easi/tasks/reverie_ce/reverie_ce_val_unseen.yaml` +- Create: `easi/tasks/reverie_ce/reverie_ce_test.yaml` + +- [ ] **Step 1: Create _base.yaml** + +```yaml +# easi/tasks/reverie_ce/_base.yaml +display_name: "REVERIE-CE" +description: "REVERIE in Continuous Environments (navigation-only)" +simulator: "habitat_sim:v0_1_7" +task_class: "easi.tasks.reverie_ce.task.ReverieCETask" +max_steps: 500 + +dataset: + source: huggingface + repo_id: "oscarqjh/REVERIE-CE_easi" + subset: null + hf_data_dir: "data" + zip_files: + - "mp3d_scenes.zip" + +simulator_configs: + render_platform: auto + screen_height: 480 + screen_width: 480 + hfov: 90 + sensor_height: 1.25 + gpu_device_id: 0 + success_distance: 3.0 + forward_step_size: 0.25 + turn_angle: 15 + allow_sliding: true + additional_deps: + - "fastdtw>=0.3.4" + +agent: + prompt_builder: "easi.tasks.reverie_ce.prompts.ReverieCEPromptBuilder" + prompt_builder_kwargs: + use_feedback: true + use_geo_distance: true + action_history_len: 20 + chat_history: false + message_window_len: 5 + generation_kwargs: + temperature: 0 + max_tokens: 4096 + top_p: 0.95 +``` + +- [ ] **Step 2: Create split configs** + +```yaml +# easi/tasks/reverie_ce/reverie_ce_val_unseen.yaml +extends: _base.yaml +name: reverie_ce_val_unseen +display_name: "REVERIE-CE Val Unseen" +description: "REVERIE-CE validation split (unseen environments)" +dataset: + split: "val_unseen" +``` + +```yaml +# easi/tasks/reverie_ce/reverie_ce_test.yaml +extends: _base.yaml +name: reverie_ce_test +display_name: "REVERIE-CE Test" +description: "REVERIE-CE test split" +dataset: + split: "test" +``` + +- [ ] **Step 3: Verify task discovery** + +Run: `.venv/bin/easi task list 2>&1 | grep -i reverie` +Expected: Should show `reverie_ce_val_unseen` and `reverie_ce_test` + +- [ ] **Step 4: Commit** + +```bash +git add easi/tasks/reverie_ce/_base.yaml easi/tasks/reverie_ce/reverie_ce_val_unseen.yaml easi/tasks/reverie_ce/reverie_ce_test.yaml +git commit -m "feat(reverie-ce): add YAML task configs for val_unseen and test splits" +``` + +--- + +### Task 7: Run Full Test Suite + +- [ ] **Step 1: Run all existing tests to verify no regressions** + +Run: `.venv/bin/pytest tests/ -v --timeout=60` +Expected: All tests pass (946+) + +- [ ] **Step 2: Run REVERIE-CE specific tests** + +Run: `.venv/bin/pytest tests/test_reverie_ce_task.py -v` +Expected: All REVERIE-CE tests pass + +- [ ] **Step 3: Final commit if any test fixes needed** + +```bash +git add -A +git commit -m "fix(reverie-ce): address test failures" +``` + +--- + +## Post-Implementation + +After all tasks are complete: + +1. **Data preparation** (Task 1) must be done separately — download Dynam3D data, reformat to JSONL, upload to `oscarqjh/REVERIE-CE_easi` +2. **End-to-end smoke test** once the HuggingFace repo is ready: + ```bash + easi task download reverie_ce_val_unseen + easi sim test habitat_sim:v0_1_7 + easi start reverie_ce_val_unseen --agent dummy --max-episodes 1 + ``` diff --git a/docs/superpowers/specs/2026-03-12-reverie-ce-design.md b/docs/superpowers/specs/2026-03-12-reverie-ce-design.md new file mode 100644 index 0000000..1245c65 --- /dev/null +++ b/docs/superpowers/specs/2026-03-12-reverie-ce-design.md @@ -0,0 +1,231 @@ +# REVERIE-CE Integration Design + +**Date:** 2026-03-12 +**Status:** Approved + +## Overview + +Integrate REVERIE-CE (navigation-only) as a new task in EASI, reusing the existing `habitat_sim:v0_1_7` simulator infrastructure and VLN-CE R2R vendor code. REVERIE-CE uses the pre-converted Dynam3D dataset (HuggingFace) with Matterport3D scenes. + +## Scope + +- **Navigation-only**: No object grounding. The agent navigates to the described area and calls stop. +- **Reuse existing simulator**: Habitat-Sim 0.1.7 (Python 3.8), same as VLN-CE R2R. +- **Pre-converted data**: Use Dynam3D's pre-converted REVERIE-CE episodes from HuggingFace, repackaged into an EASI-compatible repo. + +## Architecture + +``` +easi/tasks/reverie_ce/ +├── task.py # ReverieCETask (inherits from VLNCETask) +├── _base.yaml # Config pointing to REVERIE-CE dataset repo +├── bridge.py # ReverieCEBridge (inherits from VLNCEBridge) +├── prompts.py # ReverieCEPromptBuilder (high-level instruction style) +├── actions.py # Same 4 discrete actions as VLN-CE R2R +├── vendor/ # Reuse vlnce_r2r vendor code (import, not symlink) +├── reverie_ce_val_unseen.yaml +└── reverie_ce_test.yaml +``` + +### Key decisions + +- **Bridge**: Inherit from `VLNCEBridge` (not symlink) to allow future customization. +- **Task**: Inherit from `VLNCETask` to reuse metric extraction and aggregation. +- **Vendor**: Import `SceneSimulator` and `scene_config` from `vlnce_r2r/vendor/` — no duplication. +- **Simulator**: No new simulator code. Uses `habitat_sim:v0_1_7` as-is. + +## Data Pipeline + +### Source + +Dynam3D pre-converted REVERIE-CE data from HuggingFace (`MrZihanWang/Dynam3D`): +- `reverie_training_data/` — per-scene JSON files (~60 files) +- `reverie_val_unseen_data.json` — single file +- `reverie_test_data.json` — single file +- `reverie_val_unseen_gt.json`, `reverie_test_gt.json` — ground truth + +### EASI HuggingFace Repo + +Reformat (not re-convert) the Dynam3D output into EASI's per-split JSONL convention and upload to `oscarqjh/REVERIE-CE_easi`. This is a one-time reshaping step — no discrete-to-CE conversion is involved. + +``` +oscarqjh/REVERIE-CE_easi/ +├── data/ +│ ├── train.jsonl +│ ├── val_unseen.jsonl +│ ├── test.jsonl +│ ├── val_unseen_gt.json +│ └── test_gt.json +├── mp3d_scenes.zip # Matterport3D .glb files (shared with R2R) +``` + +### Episode Format (JSONL) + +Each line matches the VLN-CE R2R format used by `vlnce_r2r`: + +```json +{ + "episode_id": "50001", + "scene_id": "mp3d/cV4RVeZvu5T/cV4RVeZvu5T.glb", + "instruction": "Go to the laundry room and bring me the blue cushion", + "start_position": [x, y, z], + "start_rotation": [qx, qy, qz, qw], + "goal_position": [x, y, z], + "geodesic_distance": 10.5, + "gt_locations": [[x1, y1, z1], [x2, y2, z2], ...] +} +``` + +The `mp3d_scenes.zip` contains the same Matterport3D scenes as R2R. Can be shared or symlinked to avoid duplication on disk. + +## Task Configuration + +### _base.yaml + +```yaml +name: reverie_ce +simulator: "habitat_sim:v0_1_7" +task_class: "easi.tasks.reverie_ce.task.ReverieCETask" + +dataset: + source: huggingface + repo_id: "oscarqjh/REVERIE-CE_easi" + hf_data_dir: "data" + zip_files: ["mp3d_scenes.zip"] + +simulator_configs: + render_platform: auto + screen_height: 480 + screen_width: 480 + hfov: 90 + sensor_height: 1.25 + forward_step_size: 0.25 + turn_angle: 15 + allow_sliding: true + gpu_device_id: 0 + success_distance: 3.0 + additional_deps: ["fastdtw>=0.3.4"] + +agent: + prompt_builder: "easi.tasks.reverie_ce.prompts.ReverieCEPromptBuilder" + prompt_builder_kwargs: + use_feedback: true + use_geo_distance: true + action_history_len: 20 + chat_history: false + message_window_len: 5 + generation_kwargs: + temperature: 0 + max_tokens: 4096 + top_p: 0.95 +``` + +### Split configs + +Each split YAML extends `_base.yaml`: + +```yaml +# reverie_ce_val_unseen.yaml +extends: _base.yaml +name: reverie_ce_val_unseen +dataset: + split: "val_unseen" +``` + +```yaml +# reverie_ce_test.yaml +extends: _base.yaml +name: reverie_ce_test +dataset: + split: "test" +``` + +## Action Space + +Same 4 discrete actions as VLN-CE R2R: + +| Action | Effect | +|---|---| +| `move_forward` | Move 0.25m forward | +| `turn_left` | Turn 15 degrees left | +| `turn_right` | Turn 15 degrees right | +| `stop` | End navigation, evaluate success | + +## Metrics + +Same as VLN-CE R2R (navigation-only): + +| Metric | Description | +|---|---| +| SR (Success Rate) | 1.0 if agent stops within 3m of goal | +| SPL | Success weighted by path efficiency | +| NE (Navigation Error) | Geodesic distance to goal at stop | +| Oracle SR | Best geodesic distance achieved during episode | +| NDTW | Normalized Dynamic Time Warping | +| SDTW | Success-weighted DTW | +| path_length | Total distance traveled | +| steps_taken | Number of actions executed | + +Implemented by inheriting `VLNCETask.evaluate_episode()` and `aggregate_results()`. + +## Prompt Builder + +### Differences from VLN-CE R2R + +REVERIE uses high-level instructions ("Go to the laundry room and bring me the blue cushion") vs R2R's turn-by-turn route descriptions ("Exit the bedroom and turn left, walk past the kitchen..."). + +The system prompt should: +- Frame the task as "navigate to the described area" rather than "follow route instructions" +- Emphasize spatial reasoning from the high-level description +- Otherwise keep the same structure: image, instruction, geodesic feedback, action history + +### Response format + +Same JSON format as VLN-CE R2R: + +```json +{ + "visual_state_description": "...", + "reasoning_and_reflection": "...", + "language_plan": "...", + "executable_plan": [{"action": "move_forward"}] +} +``` + +## Component Reuse Summary + +| Component | Source | Method | +|---|---|---| +| Simulator | `habitat_sim:v0_1_7` | As-is, no changes | +| Bridge | `vlnce_r2r.bridge.VLNCEBridge` | Inherit | +| SceneSimulator | `vlnce_r2r.vendor.scene_simulator` | Import | +| scene_config | `vlnce_r2r.vendor.scene_config` | Import | +| DTW metrics | `vlnce_r2r.vendor.dtw` | Import | +| Task class | `vlnce_r2r.task.VLNCETask` | Inherit | +| Actions | `vlnce_r2r.actions` | Import or duplicate (trivial) | +| Prompt builder | New | Adapted for high-level instructions | +| Dataset | New HuggingFace repo | Repackaged from Dynam3D | + +## CLI Usage + +```bash +# List available tasks +easi task list # Should show reverie_ce_val_unseen, reverie_ce_test + +# Download dataset +easi task download reverie_ce_val_unseen + +# Run evaluation +easi start reverie_ce_val_unseen --agent react --backend openai --model gpt-4o + +# Parallel evaluation +easi start reverie_ce_val_unseen --agent react --backend openai --model gpt-4o --num-parallel 4 +``` + +## References + +- [Dynam3D (GitHub)](https://github.com/MrZihan/Dynam3D) — conversion scripts and pre-converted data +- [Dynam3D (HuggingFace)](https://huggingface.co/datasets/MrZihanWang/Dynam3D) — pre-converted REVERIE-CE episodes +- [VLN-CE (GitHub)](https://github.com/jacobkrantz/VLN-CE) — original R2R-CE implementation +- [REVERIE (GitHub)](https://github.com/YuankaiQi/REVERIE) — original discrete REVERIE benchmark +- [REVE-CE (IEEE Xplore)](https://ieeexplore.ieee.org/document/9674225) — prior work porting REVERIE to CE (no public code) diff --git a/easi/__init__.py b/easi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/agents/__init__.py b/easi/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/agents/dummy_agent.py b/easi/agents/dummy_agent.py new file mode 100644 index 0000000..fcdb2f4 --- /dev/null +++ b/easi/agents/dummy_agent.py @@ -0,0 +1,24 @@ +"""Dummy agent for testing — returns random actions without calling an LLM.""" +from __future__ import annotations + +import random + +from easi.core.base_agent import BaseAgent +from easi.core.episode import Action, Observation + + +class DummyAgent(BaseAgent): + """Agent that picks random actions from the action space. + + Does not call the LLM client. Useful for testing the full pipeline + without needing a running LLM server. + """ + + def __init__(self, action_space: list[str], seed: int | None = None): + super().__init__(llm_client=None, action_space=action_space) + self._rng = random.Random(seed) + + def act(self, observation: Observation, task_description: str) -> Action: + """Pick a random action from the action space.""" + self._step_count += 1 + return Action(action_name=self._rng.choice(self.action_space)) diff --git a/easi/agents/prompt_builder.py b/easi/agents/prompt_builder.py new file mode 100644 index 0000000..f59349c --- /dev/null +++ b/easi/agents/prompt_builder.py @@ -0,0 +1,206 @@ +"""PromptBuilder protocol and default implementation. + +Both protocol methods receive AgentMemory as their state source. +Contributors adding a new task only need to implement build_messages +and parse_response. +""" +from __future__ import annotations + +import base64 +import io +import json +import time +from pathlib import Path +from typing import Protocol, runtime_checkable + +from easi.core.episode import Action, Observation +from easi.core.memory import AgentMemory +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +_IMAGE_READ_RETRIES = 3 +_IMAGE_READ_BASE_DELAY = 0.1 # seconds, doubles each retry + + +def _encode_image_base64(image_path: str) -> str | None: + """Read an image file and return base64-encoded data URL. + + Validates the image is a complete PNG/JPEG before encoding. + Retries with exponential backoff if the file appears truncated. + If still truncated after retries, returns the raw data anyway — + the caller (vLLM) will raise an error and the episode will restart. + + Returns None only if the file doesn't exist. + """ + from PIL import Image + + p = Path(image_path) + if not p.exists(): + logger.warning("Image file not found: %s", image_path) + return None + suffix = p.suffix.lower().lstrip(".") + mime = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg"}.get(suffix, "image/png") + + for attempt in range(_IMAGE_READ_RETRIES): + data = p.read_bytes() + try: + Image.open(io.BytesIO(data)).verify() + return f"data:{mime};base64,{base64.b64encode(data).decode('utf-8')}" + except Exception: + if attempt < _IMAGE_READ_RETRIES - 1: + delay = _IMAGE_READ_BASE_DELAY * (2 ** attempt) + logger.debug("Image truncated (attempt %d), retrying in %.1fs: %s", attempt + 1, delay, image_path) + time.sleep(delay) + + # Return the data as-is — let vLLM raise the error and trigger episode restart + logger.warning("Image still truncated after %d retries, sending anyway: %s", _IMAGE_READ_RETRIES, image_path) + return f"data:{mime};base64,{base64.b64encode(data).decode('utf-8')}" + + +def validate_action_name(action_name: str, action_space: list[str]) -> str | None: + """Validate action name against action_space. Returns canonical name or None.""" + if action_name in action_space: + return action_name + # Case-insensitive fallback + for valid in action_space: + if valid.lower() == action_name.lower(): + return valid + return None + + +@runtime_checkable +class PromptBuilderProtocol(Protocol): + """Interface for task-specific prompt construction. + + Implementations are referenced in task.yaml via: + agent: + prompt_builder: "easi.tasks.my_task.prompts.MyPromptBuilder" + + Required methods: + build_messages(memory) -> list[dict] + parse_response(llm_response, memory) -> list[Action] + + Optional methods: + get_response_format(memory) -> dict | None + Return a response_format dict for API-level JSON enforcement. + E.g. {"type": "json_schema", "json_schema": {"name": "...", "schema": {...}}} + When provided, the agent passes it to LLMClient.generate(). + Builders that don't implement this get no schema enforcement. + """ + + def build_messages(self, memory: AgentMemory) -> list[dict]: + """Build COMPLETE message list to send to LLM.""" + ... + + def parse_response(self, llm_response: str, memory: AgentMemory) -> list[Action]: + """Parse LLM response into validated actions.""" + ... + + +class DefaultPromptBuilder: + """Generic prompt builder that works with any task. + + Produces OpenAI-format messages with interleaved text+image. + """ + + SYSTEM_TEMPLATE = """You are an embodied agent operating in a simulated environment. Given a task, you must accomplish it by choosing actions from the available action space. + +## Task +{task_description} + +## Available Actions +{action_list} + +## Output Format +You MUST respond with valid JSON in this exact format: +{{ + "observation": "Describe what you see in the current image", + "reasoning": "Explain your step-by-step reasoning", + "plan": "Your high-level plan", + "executable_plan": [ + {{"action": ""}}, + {{"action": ""}} + ] +}} + +## Guidelines +1. Always output at least one action in executable_plan. +2. Only use actions from the available action list. +3. If previous actions failed, reason about why and try a different approach. +4. Output at most 10 actions per plan. +""" + + STEP_TEMPLATE = """Task: {task_description} + +{history_section} + +Based on the current observation image, decide your next action(s). Respond with valid JSON.""" + + def build_messages(self, memory: AgentMemory) -> list[dict]: + """Build complete message list from memory state.""" + messages: list[dict] = [] + + # System message + action_list = "\n".join( + f" {i}. {name}" for i, name in enumerate(memory.action_space) + ) + system_text = self.SYSTEM_TEMPLATE.format( + action_list=action_list, + task_description=memory.task_description, + ) + messages.append({"role": "system", "content": system_text}) + + # User message with observation + action_history = memory.action_history + if action_history: + history_lines = [] + for i, (action_name, feedback) in enumerate(action_history): + history_lines.append(f" Step {i+1}: {action_name} -> {feedback}") + history_section = "## Action History\n" + "\n".join(history_lines) + else: + history_section = "This is the first step." + + text = self.STEP_TEMPLATE.format( + task_description=memory.task_description, + history_section=history_section, + ) + + content_parts: list[dict] = [{"type": "text", "text": text}] + if memory.current_observation and memory.current_observation.rgb_path: + image_url = _encode_image_base64(memory.current_observation.rgb_path) + if image_url: + content_parts.append({ + "type": "image_url", + "image_url": {"url": image_url}, + }) + + messages.append({"role": "user", "content": content_parts}) + return messages + + def parse_response(self, llm_response: str, memory: AgentMemory) -> list[Action]: + """Parse JSON response into validated actions.""" + try: + data = json.loads(llm_response) + except json.JSONDecodeError as e: + logger.warning("Failed to parse LLM response as JSON: %s", e) + return [] + + plan = data.get("executable_plan", []) + if not isinstance(plan, list) or not plan: + logger.warning("No executable_plan in LLM response") + return [] + + actions = [] + for entry in plan: + if not isinstance(entry, dict): + continue + action_name = entry.get("action", "") + validated = validate_action_name(action_name, memory.action_space) + if validated: + actions.append(Action(action_name=validated)) + else: + logger.warning("Skipping invalid action: '%s'", action_name) + break + + return actions diff --git a/easi/agents/react_agent.py b/easi/agents/react_agent.py new file mode 100644 index 0000000..f6f3835 --- /dev/null +++ b/easi/agents/react_agent.py @@ -0,0 +1,323 @@ +"""ReAct agent with multi-action buffering and PromptBuilder delegation. + +The agent is a thin orchestrator: it populates AgentMemory, delegates +prompt construction and response parsing to the PromptBuilder, and +manages action buffering. +""" +from __future__ import annotations + +from easi.agents.prompt_builder import DefaultPromptBuilder, PromptBuilderProtocol +from easi.core.base_agent import BaseAgent +from easi.core.episode import Action, Observation +from easi.core.memory import AgentMemory +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +# Exception types that indicate response_format is unsupported by the backend. +# Lazy-resolved on first use to avoid importing litellm at module level. +_FORMAT_UNSUPPORTED_ERRORS: tuple[type[Exception], ...] | None = None + + +def _get_format_unsupported_errors() -> tuple[type[Exception], ...]: + """Return exception types for unsupported response_format.""" + global _FORMAT_UNSUPPORTED_ERRORS + if _FORMAT_UNSUPPORTED_ERRORS is None: + try: + from litellm.exceptions import BadRequestError + _FORMAT_UNSUPPORTED_ERRORS = (BadRequestError,) + except ImportError: + _FORMAT_UNSUPPORTED_ERRORS = () + return _FORMAT_UNSUPPORTED_ERRORS + + +def _format_messages_for_log(messages: list[dict]) -> str: + """Extract readable text from OpenAI-format messages for logging. + + Shows image positions inline as [img_N] markers so interleaved + placement is visible in the log output. + """ + parts = [] + for msg in messages: + role = msg.get("role", "?") + content = msg.get("content", "") + if isinstance(content, str): + text = content + elif isinstance(content, list): + img_idx = 0 + text_parts = [] + for block in content: + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "image_url": + img_idx += 1 + text_parts.append(f"[img_{img_idx}]") + text = "".join(text_parts) + else: + text = str(content) + parts.append(f"--- {role} ---\n{text}") + return "\n".join(parts) + + +class ReActAgent(BaseAgent): + """ReAct agent with action buffering and pluggable prompt building. + + Flow per LLM call: + 1. PromptBuilder constructs messages from AgentMemory + 2. LLM returns response text + 3. PromptBuilder parses response into validated Actions + 4. Agent buffers actions, returns first + 5. Subsequent act() calls pop from buffer without LLM call + 6. On failure feedback -> clear buffer -> next act() re-queries LLM + """ + + # Registry of fallback strategies. To add a new strategy: + # 1. Add a _fallback_ method that takes (messages, response_format, failed_response) + # and returns list[Action] (empty = give up, use default action) + # 2. Register it here + _FALLBACK_STRATEGIES = {"default_action", "reprompt"} + + def __init__( + self, + llm_client, + action_space: list[str] | None = None, + prompt_builder: PromptBuilderProtocol | None = None, + fallback_action: str | None = None, + fallback_strategy: str = "default_action", + max_fallback_retries: int = 1, + max_consecutive_fallbacks: int = 0, + ): + super().__init__(llm_client=llm_client, action_space=action_space or []) + self.prompt_builder: PromptBuilderProtocol = prompt_builder or DefaultPromptBuilder() + self.memory = AgentMemory(action_space=self.action_space) + self._action_buffer: list[Action] = [] + self._supports_response_format: bool | None = None # None = unknown + self._fallback_action_name = fallback_action + self._fallback_strategy = fallback_strategy + self._max_fallback_retries = max_fallback_retries + self._max_consecutive_fallbacks = max_consecutive_fallbacks # 0 = disabled + self._consecutive_fallbacks: int = 0 + self.triggered_fallback: bool = False + self.forced_early_stop: bool = False + if fallback_strategy not in self._FALLBACK_STRATEGIES: + raise ValueError( + f"Unknown fallback_strategy '{fallback_strategy}'. " + f"Available: {sorted(self._FALLBACK_STRATEGIES)}" + ) + + def reset(self) -> None: + super().reset() + self.memory.clear() + self._action_buffer.clear() + self._consecutive_fallbacks = 0 + self.forced_early_stop = False + + def update_action_space(self, action_space: list[str]) -> None: + """Update the action space (e.g., after dynamic expansion per episode).""" + self.action_space = action_space + self.memory.action_space = action_space + if hasattr(self.prompt_builder, 'set_action_space'): + self.prompt_builder.set_action_space(action_space) + + def act(self, observation: Observation, task_description: str) -> Action: + """Return the next action. + + If buffer has pending actions, pop and return (no LLM call). + Otherwise, call LLM, parse response via builder, buffer actions. + """ + # Buffered action path + if self._action_buffer: + action = self._action_buffer.pop(0) + self.memory.record_step(observation, action, llm_response=None) + self.triggered_fallback = False + return action + + # LLM call path + self.memory.current_observation = observation + self.memory.task_description = task_description + + messages = self.prompt_builder.build_messages(self.memory) + prompt_text = _format_messages_for_log(messages) + + logger.trace("Step %d prompt (%d messages):\n%s", + self._step_count + 1, len(messages), prompt_text) + + # Query builder for response_format (optional method) + get_rf = getattr(self.prompt_builder, 'get_response_format', None) + response_format = get_rf(self.memory) if get_rf else None + + response = self._generate_with_fallback(messages, response_format) + + logger.trace("Step %d LLM response:\n%s", + self._step_count + 1, response) + + actions = self.prompt_builder.parse_response(response, self.memory) + + # If parsing failed, run the configured fallback strategy + if not actions: + actions = self._run_fallback(messages, response_format, response) + + # If still no actions after fallback, use the default action + if not actions: + action = self._default_fallback_action() + self.memory.record_step( + observation, action, llm_response=response, prompt_text=prompt_text, + ) + self._step_count += 1 + self.triggered_fallback = True + self._consecutive_fallbacks += 1 + + # Force stop if too many consecutive fallbacks + if (self._max_consecutive_fallbacks > 0 + and self._consecutive_fallbacks >= self._max_consecutive_fallbacks): + logger.warning( + "Forcing stop: %d consecutive fallbacks reached limit (%d)", + self._consecutive_fallbacks, self._max_consecutive_fallbacks, + ) + self.forced_early_stop = True + return Action(action_name="stop") + + return action + + self.triggered_fallback = False + self._consecutive_fallbacks = 0 + self.memory.record_step( + observation, actions[0], llm_response=response, prompt_text=prompt_text, + ) + self._step_count += 1 + + if len(actions) > 1: + self._action_buffer = actions[1:] + + return actions[0] + + def add_feedback(self, action_name: str, feedback: str) -> None: + """Record action feedback. Clear buffer on failure.""" + self.memory.record_feedback(feedback) + if any(kw in feedback.lower() for kw in ("fail", "error", "invalid")): + if self._action_buffer: + logger.info( + "Action '%s' failed, clearing %d buffered actions", + action_name, len(self._action_buffer), + ) + self._action_buffer.clear() + + # ---- Fallback system ---- + + def _run_fallback( + self, + messages: list[dict], + response_format: dict | None, + failed_response: str, + ) -> list[Action]: + """Dispatch to the configured fallback strategy. + + Returns a list of actions if the strategy recovered, or [] to + fall through to _default_fallback_action(). + """ + handler = getattr(self, f"_fallback_{self._fallback_strategy}", None) + if handler is None: + return [] + return handler(messages, response_format, failed_response) + + def _default_fallback_action(self) -> Action: + """Last-resort action when all fallback strategies fail. + + Priority: + 1. Configured fallback_action (from YAML agent config) + 2. "stop"/"Stop" if in action space (case-insensitive) + 3. "<>" sentinel to end the episode + """ + if self._fallback_action_name: + logger.warning("Fallback: using configured action '%s'", self._fallback_action_name) + return Action(action_name=self._fallback_action_name) + stop_names = {a for a in self.action_space if a.lower() == "stop"} + if stop_names: + name = next(iter(stop_names)) + logger.warning("Fallback: using '%s' from action space", name) + return Action(action_name=name) + logger.warning("Fallback: no suitable action, signalling <>") + return Action(action_name="<>") + + def _fallback_default_action( + self, messages, response_format, failed_response, + ) -> list[Action]: + """Strategy 'default_action': skip reprompt, go straight to default.""" + return [] + + def _fallback_reprompt( + self, messages, response_format, failed_response, + ) -> list[Action]: + """Strategy 'reprompt': re-query the LLM with a warning about the failure. + + Appends the failed response + a correction prompt, then retries. + Falls through to default action after max_fallback_retries attempts. + """ + retry_messages = list(messages) + + # Get correction prompt from builder, or use default + get_correction = getattr(self.prompt_builder, 'get_reprompt_message', None) + if get_correction: + correction_text = get_correction() + else: + correction_text = ( + "Your previous response could not be executed. " + "Make sure you reply in proper JSON format and " + "do NOT leave the executable_plan field as an empty list. " + "You MUST include at least one valid action." + ) + + for attempt in range(1, self._max_fallback_retries + 1): + # Append the failed response as assistant + correction as user + retry_messages.append({ + "role": "assistant", + "content": [{"type": "text", "text": failed_response}], + }) + retry_messages.append({ + "role": "user", + "content": [{"type": "text", "text": correction_text}], + }) + + logger.info( + "Fallback reprompt attempt %d/%d", + attempt, self._max_fallback_retries, + ) + + response = self._generate_with_fallback(retry_messages, response_format) + logger.trace("Reprompt attempt %d response:\n%s", attempt, response) + + actions = self.prompt_builder.parse_response(response, self.memory) + if actions: + logger.info("Reprompt attempt %d succeeded with %d actions", attempt, len(actions)) + return actions + + # Update failed_response for next iteration + failed_response = response + + logger.warning( + "Reprompt failed after %d attempts, falling back to default action", + self._max_fallback_retries, + ) + return [] + + def _generate_with_fallback( + self, messages: list[dict], response_format: dict | None, + ) -> str: + """Call LLM with optional response_format, falling back on failure. + + If response_format is provided and the backend doesn't support it, + the failure is caught, cached, and retried without response_format. + The prompt template is already in messages, so fallback always works. + """ + if response_format is None or self._supports_response_format is False: + return self.llm_client.generate(messages) + + try: + return self.llm_client.generate(messages, response_format=response_format) + except _get_format_unsupported_errors() as e: + logger.warning( + "response_format not supported by backend, " + "falling back to prompt-only: %s", e, + ) + self._supports_response_format = False + return self.llm_client.generate(messages) diff --git a/easi/analysis/__init__.py b/easi/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/analysis/trajectory_video.py b/easi/analysis/trajectory_video.py new file mode 100644 index 0000000..9ea1b32 --- /dev/null +++ b/easi/analysis/trajectory_video.py @@ -0,0 +1,405 @@ +"""Trajectory video generator for post-evaluation analysis. + +Generates per-episode videos showing the robot's path on a top-down map +alongside the agent's camera view. No simulator dependencies — pure +post-processing from episode output directories. + +Requires: opencv-python-headless (optional dependency) +""" +from __future__ import annotations + +import json +import random +from pathlib import Path + +import numpy as np + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +def _require_cv2(): + """Import cv2 with helpful error if missing.""" + try: + import cv2 + return cv2 + except ImportError: + raise ImportError( + "opencv-python-headless is required for trajectory video generation.\n" + "Install it with: pip install opencv-python-headless" + ) + + +def discover_episodes( + run_dir: Path | str, + filter_by: str | None = None, + sample_n: int | None = None, + seed: int = 42, +) -> list[Path]: + """Discover and filter episode directories in a run. + + Args: + run_dir: Path to evaluation run directory. + filter_by: "success" or "failed" to filter by outcome. + sample_n: Randomly sample N episodes after filtering. + seed: Random seed for sampling. + + Returns: + Sorted list of episode directory paths. + """ + run_dir = Path(run_dir) + episodes_dir = run_dir / "episodes" + if not episodes_dir.is_dir(): + logger.warning("No episodes/ directory found in %s", run_dir) + return [] + + episode_dirs = sorted( + d for d in episodes_dir.iterdir() + if d.is_dir() and (d / "trajectory.jsonl").exists() + ) + + if filter_by: + filtered = [] + for ep_dir in episode_dirs: + result_path = ep_dir / "result.json" + if not result_path.exists(): + continue + try: + result = json.loads(result_path.read_text()) + except (json.JSONDecodeError, OSError): + continue + success = result.get("success") + if filter_by == "success" and success == 1.0: + filtered.append(ep_dir) + elif filter_by == "failed" and success != 1.0: + filtered.append(ep_dir) + episode_dirs = filtered + + if sample_n is not None and sample_n < len(episode_dirs): + episode_dirs = random.Random(seed).sample(episode_dirs, sample_n) + episode_dirs.sort() + + return episode_dirs + + +def world_to_pixel( + world_x: float, world_z: float, map_meta: dict +) -> tuple[int, int]: + """Project world [x, z] to pixel coords using map metadata. + + Habitat-Sim uses Y-up coordinates. The floor plane is [x, z]. + """ + bounds_lower = map_meta["bounds_lower"] + mpp = map_meta["meters_per_pixel"] + px = int((world_x - bounds_lower[0]) / mpp) + pz = int((world_z - bounds_lower[2]) / mpp) + return px, pz + + +def world_to_pixel_fallback( + world_x: float, + world_z: float, + all_positions: list[list[float]], + canvas_size: tuple[int, int], + padding: int = 20, +) -> tuple[int, int]: + """Project world coords to pixel coords on a blank canvas. + + Computes bounding box from all positions and maps linearly. + """ + xs = [p[0] for p in all_positions] + zs = [p[1] for p in all_positions] + x_min, x_max = min(xs), max(xs) + z_min, z_max = min(zs), max(zs) + + # Avoid division by zero for single-point paths + x_range = max(x_max - x_min, 0.01) + z_range = max(z_max - z_min, 0.01) + + draw_w = canvas_size[0] - 2 * padding + draw_h = canvas_size[1] - 2 * padding + + px = int(padding + (world_x - x_min) / x_range * draw_w) + pz = int(padding + (world_z - z_min) / z_range * draw_h) + return px, pz + + +def _load_trajectory(ep_dir: Path) -> list[dict]: + """Load trajectory.jsonl entries.""" + path = ep_dir / "trajectory.jsonl" + entries = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def _parse_positions(entries: list[dict]) -> list[list[float] | None]: + """Extract [x, z] positions from trajectory entries. + + Returns a list parallel to entries. Reset entry (step 0) has None + since its info is empty — start position comes from episode_meta.json. + """ + positions = [] + for entry in entries: + info = entry.get("info", {}) + raw = info.get("agent_position") + if raw is not None: + pos_3d = json.loads(raw) if isinstance(raw, str) else raw + positions.append([pos_3d[0], pos_3d[2]]) # [x, z] floor plane + else: + positions.append(None) + return positions + + +def render_episode_video( + ep_dir: Path, + output_path: Path, + fps: int = 4, +) -> None: + """Render a trajectory video for one episode. + + Args: + ep_dir: Path to episode directory containing trajectory.jsonl, step_*.png, etc. + output_path: Where to write the MP4 video. + fps: Frames per second. + """ + cv2 = _require_cv2() + from PIL import Image + + traj_path = ep_dir / "trajectory.jsonl" + if not traj_path.exists(): + logger.warning("No trajectory.jsonl in %s, skipping", ep_dir) + return + + entries = _load_trajectory(ep_dir) + if len(entries) < 2: + logger.warning("Trajectory too short in %s, skipping", ep_dir) + return + + positions = _parse_positions(entries) + + # Load episode metadata (start position, goal, gt_locations) + meta_path = ep_dir / "episode_meta.json" + ep_meta = json.loads(meta_path.read_text()) if meta_path.exists() else {} + start_pos = ep_meta.get("start_position") + goal_pos = ep_meta.get("goal_position") + gt_locations = ep_meta.get("gt_locations") + + # Start position as [x, z] + start_xz = [start_pos[0], start_pos[2]] if start_pos else None + + # Goal position as [x, z] + goal_xz = [goal_pos[0], goal_pos[2]] if goal_pos else None + + # GT path as [[x, z], ...] + gt_xz = [[p[0], p[2]] for p in gt_locations] if gt_locations else None + + # Load topdown map or create blank canvas + map_path = ep_dir / "topdown_map.png" + map_meta_path = ep_dir / "topdown_map_meta.json" + has_map = map_path.exists() and map_meta_path.exists() + + if has_map: + map_img = np.array(Image.open(map_path).convert("RGB")) + map_meta = json.loads(map_meta_path.read_text()) + else: + map_img = None + map_meta = None + + # Collect all valid [x, z] positions for bounding box fallback + all_xz = [p for p in positions if p is not None] + if start_xz: + all_xz.insert(0, start_xz) + if goal_xz: + all_xz.append(goal_xz) + + if not all_xz: + logger.warning("No positions found in %s, skipping", ep_dir) + return + + # Determine panel height from first step image + first_img_path = ep_dir / entries[0].get("rgb_path", "step_0000.png") + if first_img_path.exists(): + cam_h = np.array(Image.open(first_img_path)).shape[0] + else: + cam_h = 480 + panel_h = cam_h + + # Blank canvas fallback + if map_img is None: + map_img = np.full((panel_h, panel_h, 3), 40, dtype=np.uint8) + + # Resize map to match panel height + scale = panel_h / map_img.shape[0] + map_w = int(map_img.shape[1] * scale) + map_base = cv2.resize(map_img, (map_w, panel_h)) + + # Load result for final frame overlay + result_path = ep_dir / "result.json" + result = json.loads(result_path.read_text()) if result_path.exists() else {} + + # Helper: project world coord to map pixel + def to_pixel(x, z): + if map_meta: + px, pz = world_to_pixel(x, z, map_meta) + return int(px * scale), int(pz * scale) + else: + return world_to_pixel_fallback(x, z, all_xz, (map_w, panel_h)) + + # Set up video writer + output_path.parent.mkdir(parents=True, exist_ok=True) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + frame_w = map_w + cam_h # map_panel + camera_panel (camera is square) + writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_w, panel_h)) + + if not writer.isOpened(): + logger.error("Failed to open video writer for %s", output_path) + return + + try: + path_so_far = [] + if start_xz: + path_so_far.append(start_xz) + + for i, entry in enumerate(entries): + # Update path + if positions[i] is not None: + path_so_far.append(positions[i]) + + # Draw map panel + map_frame = map_base.copy() + + # Draw GT path (dashed, faint) + if gt_xz and len(gt_xz) >= 2: + for j in range(len(gt_xz) - 1): + p1 = to_pixel(*gt_xz[j]) + p2 = to_pixel(*gt_xz[j + 1]) + # Dashed line: draw every other segment + if j % 2 == 0: + cv2.line(map_frame, p1, p2, (180, 180, 180), 1) + + # Draw agent path (solid, growing) + if len(path_so_far) >= 2: + for j in range(len(path_so_far) - 1): + p1 = to_pixel(*path_so_far[j]) + p2 = to_pixel(*path_so_far[j + 1]) + cv2.line(map_frame, p1, p2, (0, 200, 0), 2) + + # Draw start (blue circle) + if start_xz: + sp = to_pixel(*start_xz) + cv2.circle(map_frame, sp, 6, (255, 100, 100), -1) + + # Draw goal (red circle) + if goal_xz: + gp = to_pixel(*goal_xz) + cv2.circle(map_frame, gp, 6, (100, 100, 255), -1) + + # Draw current position (green arrowhead) + if path_so_far: + cp = to_pixel(*path_so_far[-1]) + cv2.circle(map_frame, cp, 5, (0, 255, 0), -1) + + # Step/distance text on map panel + step_text = f"Step: {entry.get('step', i)}" + cv2.putText(map_frame, step_text, (10, panel_h - 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + geo_dist = entry.get("info", {}).get("geo_distance") + if geo_dist: + dist_text = f"Dist: {geo_dist}m" + cv2.putText(map_frame, dist_text, (10, panel_h - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + + # Load camera panel + rgb_name = entry.get("rgb_path", f"step_{i:04d}.png") + cam_path = ep_dir / rgb_name + if cam_path.exists(): + cam_img = np.array(Image.open(cam_path).convert("RGB")) + cam_img = cv2.resize(cam_img, (cam_h, panel_h)) + else: + logger.warning("Missing image %s, using placeholder", cam_path.name) + cam_img = np.full((panel_h, cam_h, 3), 30, dtype=np.uint8) + + # Action overlay on camera panel + action = entry.get("action", "") + if action: + cv2.putText(cam_img, action, (10, panel_h - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + + # Final frame: overlay outcome + if i == len(entries) - 1 and result: + success = result.get("success") + if success == 1.0: + label = "SUCCESS" + color = (0, 255, 0) + elif success is not None: + label = "FAILURE" + color = (0, 0, 255) + else: + label = "NO GOAL" + color = (200, 200, 200) + cv2.putText(cam_img, label, (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) + + # Concatenate panels + frame = np.concatenate([map_frame, cam_img], axis=1) + + # OpenCV uses BGR + writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + finally: + writer.release() + + logger.trace("Wrote %s (%d frames)", output_path, len(entries)) + + +def generate_trajectory_videos( + run_dir: str, + filter_by: str | None = None, + sample_n: int | None = None, + fps: int = 4, + seed: int = 42, +) -> None: + """Generate trajectory videos for all matching episodes in a run. + + Args: + run_dir: Path to evaluation run directory. + filter_by: "success" or "failed". + sample_n: Randomly sample N episodes. + fps: Video frame rate. + seed: Random seed for sampling. + """ + _require_cv2() + + run_path = Path(run_dir) + if not run_path.is_dir(): + logger.error("Run directory not found: %s", run_dir) + return + + episodes = discover_episodes(run_path, filter_by=filter_by, sample_n=sample_n, seed=seed) + if not episodes: + logger.info("No episodes found matching criteria in %s", run_dir) + return + + output_dir = run_path / "analysis" / "videos" + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("Generating %d trajectory videos in %s", len(episodes), output_dir) + + from easi.utils.progress import ProgressBar + + failed = 0 + with ProgressBar(total=len(episodes)) as bar: + for i, ep_dir in enumerate(episodes): + output_path = output_dir / f"{ep_dir.name}.mp4" + try: + render_episode_video(ep_dir, output_path, fps=fps) + except Exception: + logger.exception("Failed to render %s", ep_dir.name) + failed += 1 + bar.update(completed=i + 1, failed=failed) + + logger.info("Done. %d videos saved to %s", len(episodes) - failed, output_dir) diff --git a/easi/cli.py b/easi/cli.py new file mode 100644 index 0000000..0f3c945 --- /dev/null +++ b/easi/cli.py @@ -0,0 +1,1075 @@ +"""EASI CLI entry point. + +Usage: + easi env list|install|check + easi task list|info|download + easi sim test + easi start [ ...] [--tasks t1,t2] # Run evaluation + easi llm-server [--port PORT] [--mode MODE] +""" + +import argparse +import sys + +from easi.utils.logging import get_logger, setup_logging + +logger = get_logger(__name__) + + +def build_parser() -> argparse.ArgumentParser: + # Shared parent so --verbosity works at any position in the command + common = argparse.ArgumentParser(add_help=False) + common.add_argument( + "--verbosity", + type=str, + default="INFO", + choices=["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"], + help="Set logging verbosity (default: INFO)", + ) + + parser = argparse.ArgumentParser( + prog="easi", + description="EASI - Embodied Reasoning Evaluation for Spatial Intelligence", + parents=[common], + ) + + subparsers = parser.add_subparsers(dest="command") + + # --- env command group --- + env_parser = subparsers.add_parser( + "env", help="Manage simulator environments", parents=[common] + ) + env_sub = env_parser.add_subparsers(dest="env_action") + + env_sub.add_parser( + "list", help="List available simulators and versions", parents=[common] + ) + + env_install = env_sub.add_parser( + "install", help="Install a simulator environment", parents=[common] + ) + env_install.add_argument( + "simulator", type=str, help="e.g., 'dummy' or 'ai2thor:v2_1_0'" + ) + env_install.add_argument( + "--reinstall", + action="store_true", + help="Remove existing env and install from scratch", + ) + env_install.add_argument( + "--with-task-deps", + type=str, + default=None, + metavar="TASK", + help="Also install additional_deps from a task (e.g., 'ebalfred_base')", + ) + + env_check = env_sub.add_parser( + "check", help="Check if environment is ready", parents=[common] + ) + env_check.add_argument("simulator", type=str) + + # --- task command group --- + task_parser = subparsers.add_parser( + "task", help="Manage tasks (benchmarks)", parents=[common] + ) + task_sub = task_parser.add_subparsers(dest="task_action") + + task_sub.add_parser("list", help="List available tasks", parents=[common]) + + task_info = task_sub.add_parser("info", help="Show task details", parents=[common]) + task_info.add_argument("task", type=str, help="e.g., 'dummy_task'") + + task_download = task_sub.add_parser( + "download", help="Download task dataset", parents=[common] + ) + task_download.add_argument("task", type=str) + task_download.add_argument( + "--refresh-data", + action="store_true", + dest="refresh_data", + help="Delete cached dataset and re-download from source", + ) + + task_scaffold = task_sub.add_parser( + "scaffold", help="Generate boilerplate for a new benchmark", parents=[common] + ) + task_scaffold.add_argument( + "name", type=str, help="Task name in snake_case (e.g., 'my_benchmark')" + ) + task_scaffold.add_argument( + "--simulator", + type=str, + default="dummy:v1", + help="Simulator key (e.g., 'ai2thor:v2_1_0')", + ) + task_scaffold.add_argument("--max-steps", type=int, default=50) + + # --- sim command group --- + sim_parser = subparsers.add_parser( + "sim", help="Control simulators", parents=[common] + ) + sim_sub = sim_parser.add_subparsers(dest="sim_action") + + sim_test = sim_sub.add_parser( + "test", help="Run a smoke test (reset + N steps)", parents=[common] + ) + sim_test.add_argument( + "simulator", type=str, help="e.g., 'dummy' or 'ai2thor:v5_0_0'" + ) + sim_test.add_argument("--steps", type=int, default=5, help="Number of steps") + sim_test.add_argument( + "--timeout", + type=float, + default=200.0, + help="Bridge startup timeout in seconds (default: 200)", + ) + sim_test.add_argument( + "--render-platform", + type=str, + default=None, + dest="render_platform", + help="Rendering platform override (auto, native, xvfb, egl, headless, xorg)", + ) + sim_test.add_argument( + "--sim-gpus", + type=str, + default=None, + dest="sim_gpus", + help="Comma-separated GPU IDs for xorg render platform (e.g., '0' or '1,2'). Defaults to GPU 0.", + ) + + # --- start command --- + start_parser = subparsers.add_parser( + "start", help="Run a full evaluation", parents=[common] + ) + # All defaults are None so resume logic can distinguish "user provided" from "default". + # Real defaults live in EvaluationRunner.__init__. + start_parser.add_argument( + "task_names_positional", + type=str, + nargs="*", + default=None, + metavar="task", + help="Task name(s) (e.g., 'dummy_task', 'ebalfred_base'). " + "Optional when --resume is provided.", + ) + start_parser.add_argument( + "--tasks", + type=str, + default=None, + dest="tasks_csv", + help="Comma-separated task names (e.g., 'ebalfred_base,ebnavigation_base')", + ) + start_parser.add_argument( + "--agent", type=str, default=None, choices=["dummy", "react"], dest="agent_type" + ) + start_parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Base output directory (creates ///)", + ) + start_parser.add_argument( + "--data-dir", + type=str, + default=None, + help="Directory for downloading/caching datasets (default: ./datasets)", + ) + start_parser.add_argument( + "--episodes", type=str, default=None, + help="Episode filter: IDs (2,5,7), ranges (10:20), or mixed (2,10:20,40). " + "Use :N for first N episodes.", + ) + start_parser.add_argument( + "--llm-url", type=str, default=None, dest="llm_base_url", help="LLM server URL" + ) + start_parser.add_argument("--seed", type=int, default=None, dest="agent_seed") + start_parser.add_argument( + "--backend", + type=str, + default=None, + help="LLM backend: vllm, custom, openai, anthropic, gemini, dummy", + ) + start_parser.add_argument( + "--model", + type=str, + default=None, + help="Model name (HF path for vLLM, API name for proprietary)", + ) + start_parser.add_argument( + "--port", + type=int, + default=None, + help="Port for local inference server (default: 8080)", + ) + start_parser.add_argument( + "--llm-kwargs", + type=str, + default=None, + dest="llm_kwargs_raw", + help="JSON string of extra kwargs, e.g. '{\"tensor_parallel_size\": 4}'", + ) + start_parser.add_argument( + "--max-retries", + type=int, + default=None, + help="Max LLM retry attempts on transient errors (default: 3)", + ) + start_parser.add_argument( + "--num-parallel", + type=int, + default=None, + dest="num_parallel", + help="Number of parallel simulator instances (default: 1, sequential).", + ) + start_parser.add_argument( + "--llm-instances", + type=int, + default=None, + dest="llm_instances", + help="Number of local LLM server instances to start (default: 1). " + "Each instance runs on a subset of --llm-gpus.", + ) + start_parser.add_argument( + "--llm-gpus", + type=str, + default=None, + dest="llm_gpus", + help="Comma-separated GPU IDs for LLM inference (e.g., '0,1'). " + "GPUs are split evenly across --llm-instances.", + ) + start_parser.add_argument( + "--sim-gpus", + type=str, + default=None, + dest="sim_gpus", + help="Comma-separated GPU IDs for simulator rendering (e.g., '2,3'). " + "If not set, simulators use CPU rendering.", + ) + start_parser.add_argument( + "--resume", + type=str, + default=None, + dest="resume_dir", + help="Path to a previous run directory to resume from", + ) + start_parser.add_argument( + "--refresh-data", + action="store_true", + dest="refresh_data", + help="Delete cached dataset and re-download from source", + ) + start_parser.add_argument( + "--render-platform", + type=str, + default=None, + dest="render_platform", + help="Rendering platform: auto, native, xvfb, egl, headless, xorg (default: simulator's preference). " + "xorg starts a GPU X server (defaults to GPU 0, use --sim-gpus to specify).", + ) + + # --- model command --- + model_parser = subparsers.add_parser( + "model", help="Manage custom models", parents=[common] + ) + model_sub = model_parser.add_subparsers(dest="model_action") + model_sub.add_parser("list", help="List available custom models", parents=[common]) + model_info_parser = model_sub.add_parser( + "info", help="Show model details", parents=[common] + ) + model_info_parser.add_argument("model_name", help="Model name") + + # --- ps command --- + ps_parser = subparsers.add_parser( + "ps", + help="Show EASI-related processes (bridges, LLM servers)", + parents=[common], + ) + ps_parser.add_argument( + "--kill", action="store_true", help="Kill all found EASI processes" + ) + + # --- analyze command group --- + analyze_parser = subparsers.add_parser( + "analyze", help="Post-evaluation analysis tools", parents=[common] + ) + analyze_sub = analyze_parser.add_subparsers(dest="analyze_action") + + traj_parser = analyze_sub.add_parser( + "trajectory", help="Generate trajectory videos", parents=[common] + ) + traj_parser.add_argument( + "run_dir", type=str, help="Path to evaluation run directory" + ) + traj_parser.add_argument( + "--filter", choices=["success", "failed"], + help="Filter episodes by outcome", + ) + traj_parser.add_argument( + "--sample", type=int, help="Random sample N episodes" + ) + traj_parser.add_argument( + "--fps", type=int, default=4, help="Video frame rate (default: 4)" + ) + traj_parser.add_argument( + "--seed", type=int, default=42, help="Random seed for --sample (default: 42)" + ) + + # --- llm-server command --- + llm_parser = subparsers.add_parser( + "llm-server", help="Start dummy LLM server", parents=[common] + ) + llm_parser.add_argument("--port", type=int, default=8000) + llm_parser.add_argument("--host", type=str, default="127.0.0.1") + llm_parser.add_argument("--mode", choices=["fixed", "random"], default="random") + llm_parser.add_argument( + "--action-space", + type=str, + nargs="+", + default=["MoveAhead", "TurnLeft", "TurnRight", "Stop"], + ) + + return parser + + +# --- Command handlers --- + + +def cmd_env_list() -> None: + from easi.simulators.registry import get_simulator_entry, list_simulators + + sims = list_simulators() + if not sims: + logger.info("No simulators found.") + return + + # Deduplicate: show each name:version pair once + seen = set() + for key in sims: + entry = get_simulator_entry(key) + pair = f"{entry.name}:{entry.version}" + if pair in seen: + continue + seen.add(pair) + default_marker = " (default)" if key == entry.name else "" + runtime_tag = f" [{entry.runtime}]" if entry.runtime != "conda" else "" + logger.info( + " %s%s%s -- %s", pair, default_marker, runtime_tag, entry.description + ) + + +def cmd_env_install( + simulator: str, reinstall: bool = False, with_task_deps: str | None = None +) -> None: + from easi.simulators.registry import create_env_manager + + env_manager = create_env_manager(simulator) + + if reinstall: + logger.info("Removing existing environment: %s", env_manager.get_env_name()) + env_manager.remove() + + logger.info("Installing environment: %s", env_manager.get_env_name()) + env_manager.install() + + if with_task_deps: + from easi.core.docker_env_manager import DockerEnvironmentManager + + if isinstance(env_manager, DockerEnvironmentManager): + logger.warning( + "--with-task-deps is not supported for Docker simulators (deps baked into image)." + ) + else: + from easi.tasks.registry import get_task_entry, load_task_class + + entry = get_task_entry(with_task_deps) + TaskClass = load_task_class(with_task_deps) + task = TaskClass(split_yaml_path=entry.config_path) + if task.additional_deps: + env_manager.install_additional_deps(task.additional_deps) + else: + logger.info("Task %s has no additional_deps.", with_task_deps) + + logger.info("Done.") + + +def cmd_env_check(simulator: str) -> None: + from easi.core.docker_env_manager import DockerEnvironmentManager + from easi.simulators.registry import create_env_manager + + env_manager = create_env_manager(simulator) + + missing = env_manager.check_system_deps() + if missing: + logger.info("Missing system deps: %s", missing) + + if env_manager.env_is_ready(): + logger.info("Environment %s is ready.", env_manager.get_env_name()) + if isinstance(env_manager, DockerEnvironmentManager): + logger.info("Runtime: docker (image: %s)", env_manager.image_name) + else: + logger.info("Python: %s", env_manager.get_python_executable()) + else: + logger.info("Environment %s is NOT ready.", env_manager.get_env_name()) + logger.info("Run: easi env install %s", simulator) + + +def cmd_task_list() -> None: + from easi.tasks.registry import get_task_entry, list_tasks + + tasks = list_tasks() + if not tasks: + logger.info("No tasks found.") + return + + for name in tasks: + entry = get_task_entry(name) + logger.info( + " %s -- %s (simulator: %s)", name, entry.display_name, entry.simulator_key + ) + + +def cmd_task_info(task_name: str) -> None: + from easi.tasks.registry import get_task_entry + + entry = get_task_entry(task_name) + logger.info("Task: %s", entry.display_name) + logger.info(" Name: %s", entry.name) + logger.info(" Description: %s", entry.description) + logger.info(" Simulator: %s", entry.simulator_key) + logger.info(" Max steps: %s", entry.max_steps) + + +def cmd_task_scaffold(name: str, simulator: str, max_steps: int) -> None: + from pathlib import Path + + from easi.tasks.scaffold import scaffold_task + + tasks_dir = Path(__file__).parent / "tasks" + tests_dir = Path(__file__).parent.parent / "tests" + task_dir = scaffold_task( + name, simulator, output_dir=tasks_dir, max_steps=max_steps, tests_dir=tests_dir + ) + logger.info("Created task scaffold at: %s", task_dir) + logger.info("Next steps:") + logger.info( + " 1. Edit %s/bridge.py — implement _create_env() and _extract_image()", + task_dir.name, + ) + logger.info(" 2. Edit %s/task.py — implement format_reset_config()", task_dir.name) + logger.info(" 3. Edit %s/%s.yaml — configure dataset source", task_dir.name, name) + logger.info(" 4. Run tests: pytest tests/test_%s.py -v", name) + + +def cmd_task_download(task_name: str, refresh_data: bool = False) -> None: + from easi.tasks.registry import load_task_class + + TaskClass = load_task_class(task_name) + task = TaskClass() + path = task.download_dataset(force=refresh_data) + if path and str(path): + logger.info("Dataset ready at: %s", path) + else: + logger.info("Task uses built-in episodes (no download needed).") + + +def cmd_sim_test( + simulator: str, + steps: int, + timeout: float, + render_platform_name: str | None = None, + sim_gpus: list[int] | None = None, +) -> None: + from pathlib import Path + + from easi.core.docker_env_manager import DockerEnvironmentManager + from easi.core.episode import Action + from easi.core.render_platforms import get_render_platform + from easi.simulators.registry import ( + create_env_manager, + get_simulator_entry, + load_simulator_class, + resolve_render_adapter, + resolve_render_platform, + ) + from easi.simulators.subprocess_runner import SubprocessRunner + + entry = get_simulator_entry(simulator) + env_manager = create_env_manager(simulator) + SimClass = load_simulator_class(simulator) + sim = SimClass() + + if entry.runtime == "docker": + # --- Docker launch path --- + assert isinstance(env_manager, DockerEnvironmentManager), ( + f"runtime='docker' but env_manager is {type(env_manager).__name__}, " + "expected DockerEnvironmentManager subclass" + ) + logger.info("Testing %s (Docker)...", simulator) + logger.info(" Image: %s", env_manager.image_name) + logger.info(" GPU: %s", env_manager.gpu_required) + + render_platform = get_render_platform("headless") + bridge_path = sim._get_bridge_script_path() + + runner = SubprocessRunner( + python_executable=env_manager.container_python_path, + bridge_script_path=bridge_path, + render_platform=render_platform, + startup_timeout=timeout, + command_timeout=timeout, + ) + + data_dir_str = ( + entry.data_dir.replace("~", str(Path.home())) if entry.data_dir else None + ) + + try: + runner.launch_docker( + docker_env_manager=env_manager, + data_dir=data_dir_str, + ) + sim.set_runner(runner) + + logger.info(" Reset...") + obs = sim.reset("smoke_test_001") + logger.info(" Reset OK (rgb: %s)", obs.rgb_path) + + for i in range(steps): + action = Action(action_name="MoveAhead") + result = sim.step(action) + logger.info( + " Step %d: done=%s, reward=%s", i + 1, result.done, result.reward + ) + if result.done: + break + + logger.info(" Closing...") + sim.close() + logger.info(" Close OK") + logger.info("Smoke test passed!") + + except KeyboardInterrupt: + logger.info("Interrupted, shutting down bridge...") + sim.close() + logger.info("Bridge process terminated.") + sys.exit(130) + except Exception as e: + logger.error("Smoke test FAILED: %s", e) + sim.close() + sys.exit(1) + + else: + # --- Conda launch path --- + platform_name = render_platform_name or env_manager.default_render_platform + if platform_name not in env_manager.supported_render_platforms: + logger.error( + "Render platform '%s' not supported by %s. Supported: %s", + platform_name, + simulator, + env_manager.supported_render_platforms, + ) + sys.exit(1) + render_platform = resolve_render_platform( + simulator, platform_name, env_manager=env_manager + ) + try: + try: + render_platform.setup(gpu_ids=sim_gpus or [0]) + worker_binding = render_platform.for_worker(0) + except RuntimeError as e: + if platform_name == "xorg": + logger.warning("%s", str(e)) + sys.exit(0) + raise + + logger.info("Testing %s...", simulator) + logger.info(" Python: %s", env_manager.get_python_executable()) + logger.info(" Render platform: %s", render_platform.log_name) + + from easi.core.render_platforms import EnvVars + + env_vars = env_manager.get_env_vars(render_platform_name=platform_name) + render_adapter = resolve_render_adapter(simulator, env_manager=env_manager) + + adapter_env = ( + render_adapter.get_env_vars(worker_binding) + if render_adapter + else EnvVars() + ) + binding_env = EnvVars.merge(worker_binding.extra_env, adapter_env) + if worker_binding.display: + binding_env = EnvVars.merge( + binding_env, EnvVars(replace={"DISPLAY": worker_binding.display}) + ) + if worker_binding.cuda_visible_devices is not None: + binding_env = EnvVars.merge( + binding_env, + EnvVars( + replace={ + "CUDA_VISIBLE_DEVICES": worker_binding.cuda_visible_devices + } + ), + ) + env_vars = EnvVars.merge(env_vars, binding_env) if env_vars else binding_env + + runner = SubprocessRunner( + python_executable=env_manager.get_python_executable(), + bridge_script_path=sim._get_bridge_script_path(), + render_platform=render_platform, + screen_config=env_manager.screen_config, + startup_timeout=timeout, + command_timeout=timeout, + extra_env=env_vars if env_vars else None, + render_adapter=render_adapter, + worker_binding=worker_binding, + ) + + runner.launch() + sim.set_runner(runner) + + logger.info(" Reset...") + obs = sim.reset("smoke_test_001") + logger.info(" Reset OK (rgb: %s)", obs.rgb_path) + + for i in range(steps): + action = Action(action_name="MoveAhead") + result = sim.step(action) + logger.info( + " Step %d: done=%s, reward=%s", i + 1, result.done, result.reward + ) + if result.done: + break + + logger.info(" Closing...") + sim.close() + logger.info(" Close OK") + logger.info("Smoke test passed!") + + except KeyboardInterrupt: + logger.info("Interrupted, shutting down bridge...") + sim.close() + logger.info("Bridge process terminated.") + sys.exit(130) + except Exception as e: + logger.error("Smoke test FAILED: %s", e) + sim.close() + sys.exit(1) + finally: + render_platform.teardown() + + +def _resolve_task_list(args_ns) -> list[str]: + """Build task list from positional args and/or --tasks flag.""" + tasks: list[str] = [] + if args_ns.tasks_csv: + tasks = [t.strip() for t in args_ns.tasks_csv.split(",") if t.strip()] + elif args_ns.task_names_positional: + tasks = args_ns.task_names_positional + return tasks + + +def cmd_start(args): + import json as _json + from pathlib import Path + + from easi.evaluation.runner import EvaluationRunner + + task_list = _resolve_task_list(args) + + # Collect explicitly-provided CLI args (argparse defaults are None) + raw = {k: v for k, v in vars(args).items() if v is not None} + # Remove argparse internals that aren't runner params + for key in ("command", "verbosity", "task_names_positional", "tasks_csv"): + raw.pop(key, None) + + # Extract session-specific params (not saved in config.json) + resume_dir = raw.pop("resume_dir", None) + redownload = raw.pop("refresh_data", False) + + if resume_dir: + if len(task_list) > 1: + logger.error("--resume cannot be used with multiple tasks.") + sys.exit(1) + config_path = Path(resume_dir) / "config.json" + if not config_path.exists(): + logger.error("Resume directory has no config.json: %s", resume_dir) + sys.exit(1) + saved = _json.loads(config_path.read_text()).get("cli_options", {}) + # Migrate legacy max_episodes -> episodes + if "max_episodes" in saved and "episodes" not in saved: + old_val = saved.pop("max_episodes") + if isinstance(old_val, int) and old_val > 0: + saved["episodes"] = f":{old_val}" + saved.pop("max_episodes", None) + # Saved values fill gaps; explicit CLI args win + run_kwargs = {**saved, **raw} + # If no task was given on CLI, pull from saved config + if not task_list: + saved_task = saved.get("task_name") + if saved_task: + task_list = [saved_task] + else: + run_kwargs = raw + + if not task_list: + logger.error( + "Task name is required. Provide it as a positional arg, --tasks, or use --resume." + ) + sys.exit(1) + + # Remove task_name from run_kwargs; it's passed per-task below + run_kwargs.pop("task_name", None) + num_parallel = run_kwargs.pop("num_parallel", None) or 1 + + # Parse comma-separated GPU strings into lists of ints + # When resuming, values from config.json may already be list[int] + llm_gpus_val = run_kwargs.pop("llm_gpus", None) + sim_gpus_val = run_kwargs.pop("sim_gpus", None) + if llm_gpus_val: + if isinstance(llm_gpus_val, list): + run_kwargs["llm_gpus"] = [int(g) for g in llm_gpus_val] + else: + run_kwargs["llm_gpus"] = [int(g) for g in llm_gpus_val.split(",")] + if sim_gpus_val: + if isinstance(sim_gpus_val, list): + run_kwargs["sim_gpus"] = [int(g) for g in sim_gpus_val] + else: + run_kwargs["sim_gpus"] = [int(g) for g in sim_gpus_val.split(",")] + + all_summaries: list[tuple[str, dict]] = [] + + for task_name in task_list: + logger.info("=== Starting evaluation: %s ===", task_name) + + if num_parallel > 1: + from easi.evaluation.parallel_runner import ParallelRunner + + runner = ParallelRunner( + task_name=task_name, + num_parallel=num_parallel, + **run_kwargs, + resume_dir=resume_dir, + refresh_data=redownload, + ) + else: + runner = EvaluationRunner( + task_name=task_name, + **run_kwargs, + resume_dir=resume_dir, + refresh_data=redownload, + ) + + results = runner.run() + logger.info("Completed %d episodes for %s.", len(results), task_name) + + # Read the summary.json that the runner just saved + run_dir = runner.run_dir if hasattr(runner, "run_dir") else None + summary = {} + if run_dir: + summary_file = Path(run_dir) / "summary.json" + if summary_file.exists(): + import json as _json + summary = _json.loads(summary_file.read_text()) + + all_summaries.append((task_name, summary)) + + # Log generic metrics (top-level) + for key in ("num_episodes", "success_rate", "avg_steps", "median_steps"): + if key in summary: + logger.info(" %s: %s", key, summary[key]) + + # Log task-specific metrics + metrics = summary.get("metrics", {}) + if isinstance(metrics, dict): + # If metrics has sub-groups (e.g. base/spot/stretch), log the base group + base = metrics.get("base", metrics) + for key, value in base.items(): + if key not in ("num_episodes", "success_rate"): + logger.info(" %s: %s", key, value) + + # Combined summary when multiple tasks were evaluated + if len(all_summaries) > 1: + logger.info("") + logger.info("=== Combined Summary ===") + for task_name, summary in all_summaries: + logger.info("[%s]", task_name) + for key, value in summary.items(): + logger.info(" %s: %s", key, value) + + +def cmd_llm_server(host: str, port: int, mode: str, action_space: list[str]) -> None: + from easi.llm.dummy_server import run_server + + run_server(host=host, port=port, mode=mode, action_space=action_space) + + +def cmd_ps(kill: bool = False) -> None: + """Show (and optionally kill) EASI-related processes.""" + import os + import signal + import subprocess + + # Patterns that identify EASI-spawned processes + patterns = [ + "easi.llm.models.http_server", # custom model server + "vllm.entrypoints.openai.api_server", # vLLM server + "easi.llm.dummy_server", # dummy LLM server + ] + # Also match bridge scripts by looking for bridge.py in easi paths + bridge_pattern = "easi/simulators/.*/bridge.py|easi/tasks/.*/bridge.py" + + my_pid = os.getpid() + + # Use ps to find matching processes + try: + result = subprocess.run( + ["ps", "aux"], + capture_output=True, + text=True, + timeout=10, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + logger.error("Failed to run 'ps aux'") + return + + found: list[dict] = [] + for line in result.stdout.strip().splitlines()[1:]: # skip header + parts = line.split(None, 10) + if len(parts) < 11: + continue + pid = int(parts[1]) + if pid == my_pid: + continue + cmd_str = parts[10] + stat = parts[7] + + matched_pattern = None + for pattern in patterns: + if pattern in cmd_str: + matched_pattern = pattern + break + if matched_pattern is None: + import re + + if re.search(bridge_pattern, cmd_str): + matched_pattern = "bridge" + + if matched_pattern is None: + continue + + is_zombie = "Z" in stat + found.append( + { + "pid": pid, + "user": parts[0], + "stat": stat, + "cpu": parts[2], + "mem": parts[3], + "start": parts[8], + "command": cmd_str[:120], + "pattern": matched_pattern, + "zombie": is_zombie, + } + ) + + if not found: + logger.info("No EASI-related processes found.") + return + + # Display + logger.info("Found %d EASI-related process(es):\n", len(found)) + logger.info( + " %-7s %-6s %-5s %-5s %-8s %s", + "PID", + "STAT", + "CPU%", + "MEM%", + "TYPE", + "COMMAND", + ) + logger.info(" %s", "-" * 80) + for p in found: + zombie_tag = " [ZOMBIE]" if p["zombie"] else "" + ptype = p["pattern"].split(".")[-1] if "." in p["pattern"] else p["pattern"] + logger.info( + " %-7d %-6s %-5s %-5s %-8s %s%s", + p["pid"], + p["stat"], + p["cpu"], + p["mem"], + ptype, + p["command"][:60], + zombie_tag, + ) + + # GPU usage summary + try: + gpu_result = subprocess.run( + [ + "nvidia-smi", + "--query-compute-apps=pid,gpu_uuid,used_memory", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + if gpu_result.returncode == 0 and gpu_result.stdout.strip(): + easi_pids = {p["pid"] for p in found} + gpu_lines = [] + for line in gpu_result.stdout.strip().splitlines(): + parts = [x.strip() for x in line.split(",")] + if len(parts) >= 3: + gpu_pid = int(parts[0]) + if gpu_pid in easi_pids: + gpu_lines.append((gpu_pid, parts[1][:12], parts[2])) + if gpu_lines: + logger.info("\n GPU memory held by EASI processes:") + for gpu_pid, gpu_id, mem_mb in gpu_lines: + logger.info(" PID %-7d GPU %s %s MiB", gpu_pid, gpu_id, mem_mb) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass # no nvidia-smi + + # Kill if requested + if kill: + logger.info("") + for p in found: + try: + os.kill(p["pid"], signal.SIGTERM) + logger.info(" Sent SIGTERM to PID %d (%s)", p["pid"], p["pattern"]) + except ProcessLookupError: + logger.info(" PID %d already exited", p["pid"]) + except PermissionError: + logger.warning(" Cannot kill PID %d (permission denied)", p["pid"]) + # Wait briefly then SIGKILL any survivors + import time + + time.sleep(2) + for p in found: + try: + os.kill(p["pid"], 0) # check if still alive + os.kill(p["pid"], signal.SIGKILL) + logger.info(" Sent SIGKILL to PID %d", p["pid"]) + except (ProcessLookupError, PermissionError): + pass + logger.info(" Done.") + + +def cmd_model(args) -> None: + from easi.llm.models.registry import get_model_entry, list_models + + if args.model_action == "list": + names = list_models() + if not names: + logger.info("No custom models found.") + return + for name in names: + entry = get_model_entry(name) + logger.info(" %s -- %s", name, entry.display_name) + + elif args.model_action == "info": + entry = get_model_entry(args.model_name) + logger.info("Model: %s", entry.display_name) + logger.info(" Name: %s", entry.name) + logger.info(" Description: %s", entry.description) + logger.info(" Model class: %s", entry.model_class) + logger.info(" Default kwargs: %s", entry.default_kwargs) + + else: + build_parser().parse_args(["model", "--help"]) + + +# --- Main --- + + +def main() -> None: + try: + _main() + except KeyboardInterrupt: + logger.info("Interrupted by user.") + sys.exit(130) + + +def _main() -> None: + parser = build_parser() + args = parser.parse_args() + + setup_logging(args.verbosity) + + if args.command is None: + parser.print_help() + sys.exit(0) + + # Dispatch commands + if args.command == "env": + if args.env_action == "list": + cmd_env_list() + elif args.env_action == "install": + cmd_env_install( + args.simulator, + reinstall=args.reinstall, + with_task_deps=args.with_task_deps, + ) + elif args.env_action == "check": + cmd_env_check(args.simulator) + else: + parser.parse_args(["env", "--help"]) + + elif args.command == "task": + if args.task_action == "list": + cmd_task_list() + elif args.task_action == "info": + cmd_task_info(args.task) + elif args.task_action == "download": + cmd_task_download(args.task, refresh_data=args.refresh_data) + elif args.task_action == "scaffold": + cmd_task_scaffold(args.name, args.simulator, args.max_steps) + else: + parser.parse_args(["task", "--help"]) + + elif args.command == "sim": + if args.sim_action == "test": + raw_sim_gpus = getattr(args, "sim_gpus", None) + sim_gpus_parsed = ( + [int(g) for g in raw_sim_gpus.split(",")] if raw_sim_gpus else None + ) + cmd_sim_test( + args.simulator, + args.steps, + args.timeout, + getattr(args, "render_platform", None), + sim_gpus=sim_gpus_parsed, + ) + else: + parser.parse_args(["sim", "--help"]) + + elif args.command == "start": + cmd_start(args) + + elif args.command == "ps": + cmd_ps(kill=args.kill) + + elif args.command == "model": + cmd_model(args) + + elif args.command == "analyze": + if args.analyze_action == "trajectory": + from easi.analysis.trajectory_video import generate_trajectory_videos + generate_trajectory_videos( + run_dir=args.run_dir, + filter_by=getattr(args, "filter", None), + sample_n=args.sample, + fps=args.fps, + seed=args.seed, + ) + else: + parser.parse_args(["analyze", "--help"]) + + elif args.command == "llm-server": + cmd_llm_server(args.host, args.port, args.mode, args.action_space) + + +if __name__ == "__main__": + main() diff --git a/easi/communication/__init__.py b/easi/communication/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/communication/filesystem.py b/easi/communication/filesystem.py new file mode 100644 index 0000000..c4404c4 --- /dev/null +++ b/easi/communication/filesystem.py @@ -0,0 +1,221 @@ +"""Filesystem-based IPC for command/response exchange between parent and bridge subprocess. + +All writes use the atomic write pattern: write to a .tmp file, then os.rename() to the +final path. This guarantees readers never see partial files (rename is atomic on Linux +when source and dest are on the same filesystem). + +The parent deletes response.json before writing a new command.json to avoid stale reads. +""" + +from __future__ import annotations + +import json +import os +import tempfile +import time +from pathlib import Path + +from easi.core.exceptions import SimulatorTimeoutError +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +# Default filenames used in the IPC workspace +COMMAND_FILE = "command.json" +RESPONSE_FILE = "response.json" +STATUS_FILE = "status.json" + + +def create_workspace(prefix: str = "easi_") -> Path: + """Create a unique temporary workspace directory for IPC. + + Each EASI process/episode gets its own workspace, so concurrent processes + never collide on file paths. + """ + workspace = Path(tempfile.mkdtemp(prefix=prefix)) + logger.trace("Created IPC workspace: %s", workspace) + return workspace + + +def cleanup_workspace(workspace: Path) -> None: + """Remove the IPC workspace directory and all contents.""" + import shutil + + if workspace.exists(): + shutil.rmtree(workspace, ignore_errors=True) + logger.trace("Cleaned up IPC workspace: %s", workspace) + + +def atomic_write_json(path: Path, data: dict) -> None: + """Write JSON data atomically using write-to-tmp + rename pattern. + + This ensures readers never see a partially written file. + """ + tmp_path = path.with_suffix(".tmp") + tmp_path.write_text(json.dumps(data, indent=2)) + os.rename(str(tmp_path), str(path)) + + +def read_json(path: Path) -> dict | None: + """Read a JSON file, returning None if it doesn't exist or is invalid. + + Uses try/except instead of exists-then-read to avoid TOCTOU races. + """ + try: + return json.loads(path.read_text()) + except (FileNotFoundError, json.JSONDecodeError): + return None + + +def delete_file(path: Path) -> None: + """Delete a file if it exists, ignoring FileNotFoundError.""" + try: + path.unlink() + except FileNotFoundError: + pass + + +def write_command(workspace: Path, command: dict) -> None: + """Write a command for the bridge subprocess to read. + + Deletes any existing response first to prevent stale reads. + """ + response_path = workspace / RESPONSE_FILE + command_path = workspace / COMMAND_FILE + + delete_file(response_path) + atomic_write_json(command_path, command) + logger.trace("Wrote command: %s", command.get("type", "unknown")) + + +def poll_for_response( + workspace: Path, + poll_interval: float = 0.1, + timeout: float = 60.0, + process: object | None = None, +) -> dict: + """Poll the workspace for a response.json file from the bridge subprocess. + + Args: + workspace: IPC workspace directory. + poll_interval: Seconds between poll attempts. + timeout: Maximum seconds to wait before raising SimulatorTimeoutError. + process: Optional subprocess.Popen instance. If provided, checks whether + the subprocess has exited (crashed) during polling. + + Returns: + Parsed response dict. + + Raises: + SimulatorTimeoutError: If timeout is exceeded. + SimulatorError: If the subprocess has exited unexpectedly. + """ + from easi.core.exceptions import SimulatorError + + response_path = workspace / RESPONSE_FILE + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + # Check if subprocess has crashed + if process is not None and hasattr(process, "poll"): + if process.poll() is not None: + raise SimulatorError( + f"Bridge subprocess exited with code {process.returncode} " + f"while waiting for response" + ) + + data = read_json(response_path) + if data is not None: + logger.trace("Received response: status=%s", data.get("status", "unknown")) + return data + + time.sleep(poll_interval) + + raise SimulatorTimeoutError( + f"Timed out waiting for response after {timeout}s", + timeout=timeout, + ) + + +def poll_for_status( + workspace: Path, + poll_interval: float = 0.1, + timeout: float = 30.0, + process: object | None = None, +) -> dict: + """Poll the workspace for a status.json file (bridge startup health check). + + Same semantics as poll_for_response but reads status.json instead. + """ + from easi.core.exceptions import SimulatorError + + status_path = workspace / STATUS_FILE + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + if process is not None and hasattr(process, "poll"): + if process.poll() is not None: + raise SimulatorError( + f"Bridge subprocess exited with code {process.returncode} " + f"during startup" + ) + + data = read_json(status_path) + if data is not None: + logger.trace("Received status: ready=%s", data.get("ready", False)) + return data + + time.sleep(poll_interval) + + raise SimulatorTimeoutError( + f"Bridge subprocess did not report ready within {timeout}s", + timeout=timeout, + ) + + +def poll_for_command( + workspace: Path, + poll_interval: float = 0.1, + timeout: float = 60.0, +) -> dict: + """Poll the workspace for a command.json file (bridge-side). + + Used by the bridge subprocess to wait for commands from the parent. + + Returns: + Parsed command dict. + + Raises: + SimulatorTimeoutError: If timeout is exceeded. + """ + command_path = workspace / COMMAND_FILE + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + data = read_json(command_path) + if data is not None: + # Delete the command file after reading to signal we've consumed it + delete_file(command_path) + logger.trace("Bridge received command: %s", data.get("type", "unknown")) + return data + + time.sleep(poll_interval) + + raise SimulatorTimeoutError( + f"No command received within {timeout}s", + timeout=timeout, + ) + + +def write_response(workspace: Path, response: dict) -> None: + """Write a response for the parent process to read (bridge-side).""" + response_path = workspace / RESPONSE_FILE + atomic_write_json(response_path, response) + logger.trace("Bridge wrote response: status=%s", response.get("status", "unknown")) + + +def write_status(workspace: Path, ready: bool) -> None: + """Write a status file to signal bridge readiness (bridge-side).""" + status_path = workspace / STATUS_FILE + atomic_write_json(status_path, {"ready": ready}) + logger.trace("Bridge wrote status: ready=%s", ready) diff --git a/easi/communication/schemas.py b/easi/communication/schemas.py new file mode 100644 index 0000000..16e77cc --- /dev/null +++ b/easi/communication/schemas.py @@ -0,0 +1,106 @@ +"""JSON schemas for command/response exchange between parent and bridge subprocess.""" + +from __future__ import annotations + +import json +from dataclasses import asdict +from pathlib import Path +from typing import Any + +from easi.core.episode import Action, Observation, StepResult + + +# --- Command schemas (parent → child) --- + +def make_reset_command( + episode_id: str, + reset_config: dict | None = None, + episode_output_dir: str | None = None, +) -> dict: + cmd = { + "type": "reset", + "episode_id": episode_id, + "reset_config": reset_config or {}, + } + if episode_output_dir is not None: + cmd["episode_output_dir"] = episode_output_dir + return cmd + + +def make_step_command(action: Action) -> dict: + return { + "type": "step", + "action": { + "action_name": action.action_name, + "params": action.params, + }, + } + + +def make_close_command() -> dict: + return {"type": "close"} + + +# --- Response schemas (child → parent) --- + +def make_observation_response( + rgb_path: str, + depth_path: str | None = None, + agent_pose: list[float] | None = None, + metadata: dict[str, str] | None = None, + reward: float = 0.0, + done: bool = False, + info: dict[str, float] | None = None, +) -> dict: + return { + "status": "ok", + "observation": { + "rgb_path": rgb_path, + "depth_path": depth_path, + "agent_pose": agent_pose or [], + "metadata": metadata or {}, + }, + "reward": reward, + "done": done, + "info": info or {}, + } + + +def make_error_response(error: str) -> dict: + return {"status": "error", "error": error} + + +def make_status_response(ready: bool) -> dict: + return {"ready": ready} + + +# --- Parsing helpers --- + +def parse_observation(data: dict) -> Observation: + obs = data["observation"] + # Merge top-level info into metadata so prompt builders can read it + metadata = dict(obs.get("metadata", {})) + metadata.update(data.get("info", {})) + return Observation( + rgb_path=obs["rgb_path"], + depth_path=obs.get("depth_path"), + agent_pose=obs.get("agent_pose", []), + metadata=metadata, + ) + + +def parse_step_result(data: dict) -> StepResult: + return StepResult( + observation=parse_observation(data), + reward=data.get("reward", 0.0), + done=data.get("done", False), + info=data.get("info", {}), + ) + + +def parse_action_from_command(data: dict) -> Action: + action_data = data["action"] + return Action( + action_name=action_data["action_name"], + params=action_data.get("params", {}), + ) diff --git a/easi/core/__init__.py b/easi/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/core/base_agent.py b/easi/core/base_agent.py new file mode 100644 index 0000000..610db6e --- /dev/null +++ b/easi/core/base_agent.py @@ -0,0 +1,35 @@ +"""Abstract base class for agents.""" +from __future__ import annotations + +from abc import ABC, abstractmethod + +from easi.core.episode import Action, Observation +from easi.core.protocols import LLMClientProtocol +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +class BaseAgent(ABC): + """Abstract base for agents that bridge LLM inference and simulator actions.""" + + def __init__(self, llm_client: LLMClientProtocol | None, action_space: list[str]): + self.llm_client = llm_client + self.action_space = action_space + self._step_count: int = 0 + + @abstractmethod + def act(self, observation: Observation, task_description: str) -> Action: + """Return the next action given current observation and task.""" + ... + + def add_feedback(self, action_name: str, feedback: str) -> None: + """Record action feedback from the environment. + + Default: no-op. Subclasses (e.g., ReActAgent) override to track + action history and clear action buffer on failure. + """ + + def reset(self) -> None: + """Reset agent state for a new episode.""" + self._step_count = 0 diff --git a/easi/core/base_env_manager.py b/easi/core/base_env_manager.py new file mode 100644 index 0000000..e220f3a --- /dev/null +++ b/easi/core/base_env_manager.py @@ -0,0 +1,441 @@ +"""Abstract base class for per-simulator-version environment management. + +Each simulator version provides a concrete subclass that declares: +- conda_env.yaml path (conda-channel packages only) +- requirements.txt path (pip-installable Python deps, installed via uv) +- system dependencies (xvfb, EGL, etc.) +- validation import to confirm the env works + +The shared install() logic handles the full sequence: + check_system_deps -> conda create -> pip install uv -> uv pip install -> validate +""" + +from __future__ import annotations + +import subprocess +import tarfile +import urllib.request +from abc import ABC, abstractmethod +from pathlib import Path + +from easi.core.exceptions import EnvironmentSetupError +from easi.utils.locking import file_lock +from easi.utils.logging import get_logger +from easi.utils.paths import get_locks_dir +from easi.utils.spinner import spinner +from easi.utils.system_deps import SystemDependencyChecker + +logger = get_logger(__name__) + + +class BaseEnvironmentManager(ABC): + """Abstract base for per-simulator-version environment management.""" + + def __init__(self, conda_prefix: Path | None = None, installation_kwargs: dict | None = None): + self.conda_prefix = conda_prefix or self._default_conda_prefix() + self.installation_kwargs = installation_kwargs or {} + self._dep_checker = SystemDependencyChecker() + + @property + @abstractmethod + def simulator_name(self) -> str: + """Name of the simulator (e.g., 'ai2thor').""" + ... + + @property + @abstractmethod + def version(self) -> str: + """Version identifier (e.g., 'v2_1_0').""" + ... + + @abstractmethod + def get_conda_env_yaml_path(self) -> Path: + """Path to the conda environment YAML (conda-only deps).""" + ... + + @abstractmethod + def get_requirements_txt_path(self) -> Path: + """Path to requirements.txt (uv-installed Python deps).""" + ... + + @abstractmethod + def get_system_deps(self) -> list[str]: + """List of required system packages (e.g., ['xvfb', 'conda']).""" + ... + + @abstractmethod + def get_validation_import(self) -> str: + """Python import statement to validate env works. + + Example: "import ai2thor; assert ai2thor.__version__.startswith('2.1')" + """ + ... + + @property + def default_render_platform(self) -> str: + """Default rendering platform for this simulator. + + Override in subclasses. Common values: + "auto" -- native display if available, xvfb fallback + "headless" -- no display (simulator handles internally) + "egl" -- GPU-accelerated headless via EGL + + See ``easi.core.render_platform`` for all options. + """ + return "headless" + + @property + def supported_render_platforms(self) -> list[str]: + """Render platforms this simulator can use. + + Override in subclasses to advertise which platforms are compatible. + Validated when user passes ``--render-platform``. + """ + return [self.default_render_platform] + + @property + def screen_config(self) -> str: + """Screen resolution config (e.g. ``"1024x768x24"``). + + Used by platforms that create a virtual display (xvfb). + Override for custom resolution/depth. + """ + return "1024x768x24" + + def get_env_vars(self, render_platform_name: str | None = None) -> "EnvVars": + """Return environment variables to inject into the bridge subprocess. + + Override in subclasses to provide simulator-specific env vars. + + Args: + render_platform_name: Active render platform name (e.g. "egl"). + Subclasses can use this to conditionally set env vars. + + Returns: + EnvVars instance. Empty by default. + """ + from easi.core.render_platforms import EnvVars + + return EnvVars() + + def get_env_name(self) -> str: + """Conda environment name for this simulator version.""" + return f"easi_{self.simulator_name}_{self.version}" + + def get_python_executable(self) -> str: + """Return the full path to the Python executable in this conda env.""" + env_path = self.conda_prefix / "envs" / self.get_env_name() + return str(env_path / "bin" / "python") + + def check_system_deps(self) -> list[str]: + """Check system dependencies, returning list of missing ones.""" + return self._dep_checker.check_all(self.get_system_deps()) + + def env_is_ready(self) -> bool: + """Check if the conda environment exists and passes validation.""" + python_exec = self.get_python_executable() + if not Path(python_exec).exists(): + return False + + # Include simulator env vars (e.g. LD_LIBRARY_PATH for CoppeliaSim) + env_vars = self.get_env_vars() + run_env = None + if env_vars: + import os + run_env = env_vars.apply_to_env(os.environ.copy()) + + try: + result = subprocess.run( + [python_exec, "-c", self.get_validation_import()], + capture_output=True, + text=True, + timeout=30, + env=run_env, + ) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError): + return False + + def remove(self) -> None: + """Remove the conda environment entirely (for --reinstall).""" + env_name = self.get_env_name() + env_path = self.conda_prefix / "envs" / env_name + if not env_path.exists(): + logger.info("Environment %s does not exist, nothing to remove", env_name) + return + with spinner(f"Removing environment {env_name}"): + self._run_command( + ["conda", "env", "remove", "-n", env_name, "-y"], + f"conda env remove {env_name}", + ) + logger.info("Environment %s removed", env_name) + + def install(self) -> None: + """Install the conda+uv environment with file-based locking. + + Serializes concurrent installs of the same env across processes. + """ + lock_path = get_locks_dir() / f"{self.get_env_name()}.lock" + with file_lock(lock_path): + if self.env_is_ready(): + logger.info("Environment %s already ready, skipping install", self.get_env_name()) + return + self._do_install() + + def install_additional_deps(self, packages: list[str]) -> None: + """Install extra pip packages into this conda env via uv (idempotent). + + Called by EvaluationRunner when a task declares additional_deps + in its YAML config. + """ + if not packages: + return + python_exec = self.get_python_executable() + with spinner(f"Installing task dependencies: {', '.join(packages)}"): + self._run_command( + [python_exec, "-m", "uv", "pip", "install"] + packages, + "uv pip install (task deps)", + ) + + def _do_install(self) -> None: + """Execute the full install sequence (called under lock).""" + env_name = self.get_env_name() + logger.info("Installing environment %s for %s %s", env_name, self.simulator_name, self.version) + + # Check system deps + self._dep_checker.assert_all(self.get_system_deps()) + + # Create/update conda env + conda_yaml = self.get_conda_env_yaml_path() + if conda_yaml.exists(): + with spinner(f"Creating conda environment {env_name}"): + self._run_conda_create(env_name, conda_yaml) + else: + logger.warning("No conda_env.yaml found at %s, skipping conda setup", conda_yaml) + + # Install uv in the conda env + python_exec = self.get_python_executable() + with spinner("Installing uv"): + self._run_command([python_exec, "-m", "pip", "install", "uv"], "pip install uv") + + # Install Python deps via uv + requirements = self.get_requirements_txt_path() + if requirements.exists(): + with spinner("Installing Python dependencies"): + self._run_command( + [python_exec, "-m", "uv", "pip", "install", "-r", str(requirements)], + "uv pip install", + ) + else: + logger.warning("No requirements.txt found at %s, skipping uv install", requirements) + + # Run post-install hook (binary downloads, file copies, etc.) + self._run_post_install() + + # Validate (with env vars so e.g. LD_LIBRARY_PATH is set) + env_vars = self.get_env_vars() + validation_env = None + if env_vars: + import os + validation_env = env_vars.apply_to_env(os.environ.copy()) + with spinner("Validating environment"): + self._run_command( + [python_exec, "-c", self.get_validation_import()], + "environment validation", + env=validation_env, + ) + + logger.info("Environment %s installed and validated successfully", env_name) + + # ── Post-install hook and helpers ────────────────────────────────── + + def get_extras_dir(self) -> Path: + """Directory for downloaded binaries and other extras (inside conda env dir).""" + env_path = self.conda_prefix / "envs" / self.get_env_name() + return env_path / "extras" + + @staticmethod + def _resolve_template(template: str, variables: dict[str, str]) -> str: + """Resolve {var} placeholders in a string.""" + result = template + for key, value in variables.items(): + result = result.replace(f"{{{key}}}", value) + return result + + def _get_template_variables(self) -> dict[str, str]: + """Return template variables for env_vars and post_install use. + + Available variables: + {env_dir} — conda env directory + {extras_dir} — extras directory for binaries/downloads + """ + env_dir = str(self.conda_prefix / "envs" / self.get_env_name()) + return { + "env_dir": env_dir, + "extras_dir": str(self.get_extras_dir()), + } + + def post_install(self, context: dict) -> None: + """Override for custom post-install steps (binary downloads, file copies, etc.). + + Called after conda + pip installs, before validation. Use helper methods + like _download_and_extract() for common operations. Use _run_command() + with an env dict to run pip install with custom env vars. + + Args: + context: Dict with keys: + env_dir — conda env directory path + extras_dir — directory for downloaded extras + env_vars — resolved env vars from get_env_vars() + + Does nothing by default. + """ + + def _run_post_install(self) -> None: + """Build context and call post_install() hook.""" + ctx = self._get_template_variables() + ctx["env_vars"] = self.get_env_vars().to_flat_dict() + self.post_install(ctx) + + def _download_and_extract( + self, + url: str, + filename: str, + dest_dir: Path, + extract: bool = True, + strip_components: int = 0, + ) -> None: + """Download a file and optionally extract it. Idempotent (skips if done). + + Helper for use inside post_install() overrides. + + Args: + url: Download URL. + filename: Local filename to save as. + dest_dir: Directory to download/extract into. + extract: Whether to extract archives (tar.xz, tar.gz, zip). + strip_components: Remove N leading path components when extracting. + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + dest = dest_dir / filename + + # Idempotency: check marker for extracted archives, or file existence + marker = dest_dir / f".{filename}.done" + if marker.exists(): + logger.info("Already installed: %s, skipping", filename) + return + if not extract and dest.exists(): + logger.info("Already downloaded: %s, skipping", filename) + return + + logger.trace("Download target: %s -> %s", url, dest) + with spinner(f"Downloading {filename}"): + logger.info("Downloading %s", url) + req = urllib.request.Request(url, headers={"User-Agent": "easi/1.0"}) + with urllib.request.urlopen(req) as response, open(str(dest), "wb") as out: + total = 0 + while True: + chunk = response.read(1024 * 1024) # 1MB chunks + if not chunk: + break + out.write(chunk) + total += len(chunk) + logger.trace("Download complete: %s (%.1f MB)", filename, total / 1024 / 1024) + + if extract: + logger.trace("Extracting %s to %s (strip_components=%d)", filename, dest_dir, strip_components) + with spinner(f"Extracting {filename}"): + self._extract_archive(dest, dest_dir, strip_components) + logger.trace("Extraction complete, removing archive %s", dest) + dest.unlink(missing_ok=True) # Remove archive to save space + # Log extracted contents (top-level only) + top_items = sorted(p.name for p in dest_dir.iterdir() if not p.name.startswith(".")) + logger.trace("Contents of %s after extraction: %s", dest_dir, top_items) + + marker.touch() + logger.trace("Wrote done marker: %s", marker) + + def _extract_archive(self, archive: Path, dest_dir: Path, strip_components: int = 0) -> None: + """Extract a tar.xz, tar.gz, tar.bz2, or zip archive.""" + name = archive.name + if name.endswith((".tar.xz", ".tar.gz", ".tgz", ".tar.bz2")): + with tarfile.open(str(archive)) as tf: + if strip_components > 0: + for member in tf.getmembers(): + parts = Path(member.name).parts + if len(parts) > strip_components: + member.name = str(Path(*parts[strip_components:])) + tf.extract(member, dest_dir) + else: + tf.extractall(dest_dir) + elif name.endswith(".zip"): + import zipfile + with zipfile.ZipFile(str(archive)) as zf: + zf.extractall(dest_dir) + else: + logger.warning("Unknown archive format: %s, skipping extraction", name) + + # ── Conda / command helpers ───────────────────────────────────── + + def _run_conda_create(self, env_name: str, yaml_path: Path) -> None: + """Create or update a conda environment from a YAML file.""" + env_path = self.conda_prefix / "envs" / env_name + + if env_path.exists(): + cmd = ["conda", "env", "update", "-f", str(yaml_path), "-n", env_name] + desc = "conda env update" + else: + cmd = ["conda", "env", "create", "-f", str(yaml_path), "-n", env_name] + desc = "conda env create" + + self._run_command(cmd, desc) + + def _run_command(self, cmd: list[str], description: str, env: dict[str, str] | None = None) -> None: + """Run a subprocess command, streaming output through the logger. + + Args: + cmd: Command and arguments. + description: Human-readable description for error messages. + env: Optional environment dict. If None, inherits parent env. + """ + logger.trace("%s", " ".join(cmd)) + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=env, + ) + output_lines = [] + try: + for line in process.stdout: + line = line.rstrip() + output_lines.append(line) + logger.trace(" %s", line) + finally: + process.stdout.close() + process.wait() + if process.returncode != 0: + raise EnvironmentSetupError( + f"{description} failed (exit {process.returncode}):\n" + + "\n".join(output_lines[-20:]) + ) + + @staticmethod + def _default_conda_prefix() -> Path: + """Determine the conda prefix from the conda executable.""" + try: + result = subprocess.run( + ["conda", "info", "--base"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + return Path(result.stdout.strip()) + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + # Fallback to common default + return Path.home() / "miniconda3" diff --git a/easi/core/base_simulator.py b/easi/core/base_simulator.py new file mode 100644 index 0000000..81ced52 --- /dev/null +++ b/easi/core/base_simulator.py @@ -0,0 +1,122 @@ +"""Abstract base class for simulators. + +Concrete simulators subclass this and implement: +- name/version properties +- _get_bridge_script_path() — path to the bridge.py for subprocess execution +- _parse_observation() — convert bridge output into Observation +- _format_action() — convert Action into dict the bridge understands + +The shared template methods (reset, step, close) handle subprocess communication +via the filesystem IPC layer. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path + +from easi.communication import filesystem, schemas +from easi.core.episode import Action, Observation, StepResult +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +class BaseSimulator(ABC): + """Abstract base for all simulators. Manages subprocess lifecycle via IPC.""" + + def __init__(self, workspace_dir: Path | None = None): + self._workspace_dir = workspace_dir + self._runner = None # Set when start() is called + + @property + @abstractmethod + def name(self) -> str: + """Simulator name (e.g., 'ai2thor').""" + ... + + @property + @abstractmethod + def version(self) -> str: + """Version identifier (e.g., 'v2_1_0').""" + ... + + @abstractmethod + def _get_bridge_script_path(self) -> Path: + """Return the absolute path to the bridge.py script for subprocess execution.""" + ... + + def _parse_observation(self, data: dict) -> Observation: + """Parse observation from response data. Override for custom parsing.""" + return schemas.parse_observation(data) + + def _format_action(self, action: Action) -> dict: + """Format an Action into the command dict. Override for custom formatting.""" + return schemas.make_step_command(action) + + def reset( + self, + episode_id: str, + reset_config: dict | None = None, + episode_output_dir: str | None = None, + ) -> Observation: + """Reset the simulator for a new episode. + + Sends a reset command to the bridge subprocess and waits for the + observation response. + + Args: + episode_id: Unique identifier for this episode. + reset_config: Task-specific configuration for the episode. + episode_output_dir: Directory where the bridge should save + observation images. If None, the bridge uses its IPC workspace. + """ + if self._runner is None: + raise RuntimeError("Simulator not started. Call start() first.") + + command = schemas.make_reset_command(episode_id, reset_config, episode_output_dir) + response = self._runner.send_command(command) + + if response.get("status") == "error": + from easi.core.exceptions import SimulatorError + raise SimulatorError(f"Reset failed: {response.get('error', 'unknown')}") + + return self._parse_observation(response) + + def step(self, action: Action) -> StepResult: + """Execute one action in the simulator. + + Sends a step command and returns the StepResult. + """ + if self._runner is None: + raise RuntimeError("Simulator not started. Call start() first.") + + command = self._format_action(action) + response = self._runner.send_command(command) + + if response.get("status") == "error": + from easi.core.exceptions import SimulatorError + raise SimulatorError(f"Step failed: {response.get('error', 'unknown')}") + + return schemas.parse_step_result(response) + + def close(self) -> None: + """Shut down the simulator subprocess.""" + if self._runner is not None: + command = schemas.make_close_command() + try: + self._runner.send_command(command, timeout=10.0) + except Exception: + logger.warning("Close command failed, force-killing subprocess") + self._runner.shutdown() + self._runner = None + + def is_running(self) -> bool: + """Check if the bridge subprocess is alive.""" + if self._runner is None: + return False + return self._runner.is_alive() + + def set_runner(self, runner: object) -> None: + """Attach a SubprocessRunner instance (called by orchestration code).""" + self._runner = runner diff --git a/easi/core/base_task.py b/easi/core/base_task.py new file mode 100644 index 0000000..66a0bbf --- /dev/null +++ b/easi/core/base_task.py @@ -0,0 +1,464 @@ +"""Abstract base class for tasks (benchmarks). + +A task owns: +- Which simulator+version to use (pinned via task.yaml) +- The action space available to the agent +- The dataset (episodes to evaluate on) +- Episode-to-simulator mapping (format_reset_config) +- Success criteria and metrics (evaluate_episode) + +Concrete tasks subclass this and implement: +- get_task_yaml_path() — where the task.yaml lives +- format_reset_config() — adapter from dataset episodes to simulator configs +- evaluate_episode() — compute metrics for a completed episode +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from pathlib import Path + +from easi.core.episode import EpisodeRecord, StepResult +from easi.core.exceptions import DatasetError +from easi.tasks.yaml_utils import resolve_task_yaml +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +def hf_row_to_episode(row: dict) -> dict: + """Convert a HuggingFace dataset row to an episode dict. + + HF dataset rows contain all information for a single episode. + For EB-Alfred_easi: {id, task, repeat_idx, instruction, task_type, trial_id} + This is a passthrough — the row IS the episode. + """ + return dict(row) + + +class BaseTask(ABC): + """Abstract base for all tasks (benchmarks).""" + + def __init__( + self, + data_dir: Path | None = None, + split_yaml_path: Path | None = None, + ): + self._split_yaml_path = split_yaml_path + self._config = self._load_config() + self._episodes: list[dict] | None = None + self._action_space_cache: list[str] | None = None + self._data_dir = data_dir + self._simulator_kwarg_overrides: dict = {} + + @abstractmethod + def get_task_yaml_path(self) -> Path: + """Return path to this task's task.yaml.""" + ... + + @abstractmethod + def format_reset_config(self, episode: dict) -> dict: + """Translate a dataset episode into simulator reset kwargs. + + This is the core adapter method between dataset format and simulator API. + """ + ... + + @abstractmethod + def evaluate_episode( + self, episode: dict, trajectory: list[StepResult] + ) -> dict[str, float]: + """Evaluate a completed episode, returning a metric dict. + + Example return: {"success": 1.0, "spl": 0.73, "distance_to_goal": 0.15} + """ + ... + + def aggregate_results( + self, records: list["EpisodeRecord"] + ) -> dict[str, float]: + """Aggregate metrics across all completed episodes. + + Default implementation: averages all numeric keys from each + record's episode_results. Override in subclasses for custom + aggregation logic (e.g., weighted scores, category breakdowns, + trajectory-level analysis). + + Args: + records: List of EpisodeRecord objects. Each contains: + - episode: raw dataset row dict + - trajectory: list of StepResult objects + - episode_results: dict from evaluate_episode() + + Returns: + Summary metrics dict (placed under "metrics" in summary.json). + """ + from easi.evaluation.metrics import default_aggregate + logger.info( + "Using default_aggregate for task '%s'. " + "Override aggregate_results() in your task class for custom aggregation.", + self.name, + ) + return default_aggregate(records) + + # --- Hooks --- + + def on_episode_reset(self, observation, agent) -> None: + """Called after simulator reset, before the agent-simulator loop. + + Override in subclasses to perform task-specific setup, e.g. updating + the agent's action space from bridge metadata. + + Args: + observation: The initial observation from sim.reset(). + agent: The agent instance (may have update_action_space, etc.). + """ + + # --- Shared implementation --- + + def get_bridge_script_path(self) -> Path | None: + """Return task-specific bridge script path, or None for simulator default. + + Override in subclasses to provide a task-specific bridge that extends + the generic simulator bridge (e.g., EBAlfredBridge extends AI2ThorBridge). + """ + return None + + @property + def simulator_configs(self) -> dict: + """Full simulator configuration from task YAML (includes additional_deps).""" + return self._config.get("simulator_configs", {}) + + @property + def additional_deps(self) -> list[str]: + """Extra pip packages to install in the simulator conda env.""" + return self.simulator_configs.get("additional_deps", []) + + @property + def simulator_kwargs(self) -> dict: + """Bridge-facing kwargs (simulator_configs minus runner/infra keys + max_steps).""" + cfg = dict(self.simulator_configs) + cfg.pop("additional_deps", None) + cfg.pop("env_vars", None) + cfg.pop("command_timeout", None) + cfg.pop("startup_timeout", None) + cfg.pop("render_platform", None) + cfg["max_steps"] = self.max_steps + # Merge per-worker overrides injected by the runner + cfg.update(self._simulator_kwarg_overrides) + return cfg + + def inject_simulator_kwarg(self, key: str, value) -> None: + """Inject a per-worker override into simulator_kwargs. + + Used by the runner to pass worker-specific values (e.g. assigned GPU ID) + without mutating the shared task config. + """ + self._simulator_kwarg_overrides[key] = value + + @property + def extra_env_vars(self) -> dict[str, str]: + """Task-level environment variables from simulator_configs.env_vars.""" + return self.simulator_configs.get("env_vars", {}) + + def get_instruction(self, episode: dict) -> str: + """Return human-readable task instruction for this episode. + + Default tries common field names. Override in subclasses + for benchmarks that use different keys. + """ + return episode.get("instruction", episode.get("task_description", self.name)) + + @property + def name(self) -> str: + return self._config["name"] + + @property + def simulator_key(self) -> str: + """Returns e.g. 'dummy:v1' — used to look up from simulator registry.""" + return self._config["simulator"] + + @property + def action_space(self) -> list[str]: + if self._action_space_cache is None: + self._action_space_cache = self._build_action_space() + return self._action_space_cache + + def _build_action_space(self) -> list[str]: + """Return the action space for this task. + + Override in subclasses to define the action space programmatically. + """ + return [] + + @property + def max_steps(self) -> int: + return self._config.get("max_steps", 500) + + def download_dataset(self, force: bool = False) -> Path: + """Download dataset if needed. Returns path to local data directory. + + Args: + force: If True, delete cached dataset and re-download. + + - source=local: validate path exists, return it + - source=huggingface: download via huggingface_hub, cache locally + """ + dataset_config = self._config.get("dataset", {}) + source = dataset_config.get("source", "local") + + if source == "local": + path = dataset_config.get("path") + if path is None: + # Use built-in episodes (no download needed) + return Path() + local_path = Path(path) + if not local_path.exists(): + raise DatasetError(f"Local dataset path does not exist: {local_path}") + return local_path + + elif source == "huggingface": + return self._download_huggingface(dataset_config, force=force) + + else: + raise DatasetError(f"Unknown dataset source: {source}") + + def load_episodes(self) -> list[dict]: + """Load and return all episodes from the dataset.""" + if self._episodes is not None: + return self._episodes + + self._episodes = self._load_episodes_from_config() + logger.info("Loaded %d episodes for task %s", len(self._episodes), self.name) + return self._episodes + + def get_episode(self, index: int) -> dict: + """Get a single episode by index.""" + episodes = self.load_episodes() + if index < 0 or index >= len(episodes): + raise IndexError(f"Episode index {index} out of range [0, {len(episodes)})") + return episodes[index] + + def __len__(self) -> int: + """Number of episodes.""" + return len(self.load_episodes()) + + def _load_config(self) -> dict: + """Load task config from split yaml (if provided) or default task.yaml.""" + yaml_path = self._split_yaml_path or self.get_task_yaml_path() + if not yaml_path.exists(): + raise DatasetError(f"Task config not found: {yaml_path}") + return resolve_task_yaml(yaml_path) + + def _load_episodes_from_config(self) -> list[dict]: + """Load episodes from the dataset. + + For HuggingFace datasets: downloads the repo, then loads the split + using the datasets library. Each row = one episode. + For local datasets: looks for episodes.json. + """ + dataset_config = self._config.get("dataset", {}) + source = dataset_config.get("source", "local") + + if source == "huggingface": + return self._load_episodes_from_hf(dataset_config) + + # Local source — existing behavior + data_dir = self.download_dataset() + if not data_dir or data_dir == Path(): + return self._get_builtin_episodes() + + episodes_file = data_dir / "episodes.json" + if episodes_file.exists(): + return json.loads(episodes_file.read_text()) + + raise DatasetError( + f"No episodes.json found in {data_dir}. " + f"Override _load_episodes_from_config() for custom loading." + ) + + def _load_episodes_from_hf(self, dataset_config: dict) -> list[dict]: + """Load episodes from a HuggingFace dataset (subset + split). + + Each row in the dataset = one episode dict. + Downloads all files via snapshot_download, then loads locally. + """ + data_dir = self.download_dataset() + + subset = dataset_config.get("subset") + split_name = dataset_config.get("split") + # hf_data_dir restricts which subdirectory the datasets library + # scans for data files. Useful when the repo also contains large + # non-episode files (e.g. scene meshes) that would confuse auto- + # detection. + hf_data_dir = dataset_config.get("hf_data_dir") + + try: + from datasets import ( + get_dataset_config_names, + get_dataset_split_names, + load_dataset, + ) + except ImportError: + raise DatasetError( + "The 'datasets' library is required for HF episode loading. " + "Install with: pip install datasets" + ) + + local_path = str(data_dir) + + # Auto-detect subset if not specified + if subset is None: + configs = get_dataset_config_names(local_path, data_dir=hf_data_dir) + if len(configs) == 1: + subset = configs[0] + logger.info("Auto-detected single subset: %s", subset) + elif "default" in configs: + subset = "default" + else: + raise DatasetError( + f"Dataset at {local_path} has multiple subsets {configs} — " + f"please specify 'subset' in task yaml." + ) + + # Auto-detect split if not specified + if split_name is None: + splits = get_dataset_split_names( + local_path, subset, data_dir=hf_data_dir, + ) + if len(splits) == 1: + split_name = splits[0] + logger.info("Auto-detected single split: %s", split_name) + else: + raise DatasetError( + f"Dataset at {local_path} subset={subset} has " + f"multiple splits {splits} — " + f"please specify 'split' in task yaml." + ) + + logger.info( + "Loading episodes from local HF dataset %s subset=%s split=%s", + local_path, subset, split_name, + ) + + import tempfile + hf_cache = Path(tempfile.gettempdir()) / "easi_hf_cache" + ds = load_dataset(local_path, subset, split=split_name, + data_dir=hf_data_dir, cache_dir=str(hf_cache)) + episodes = [hf_row_to_episode(row) for row in ds] + + for ep in episodes: + ep["_data_dir"] = str(data_dir) + + logger.info("Loaded %d episodes from %s/%s/%s", + len(episodes), local_path, subset, split_name) + return episodes + + def _get_builtin_episodes(self) -> list[dict]: + """Return built-in episodes when no dataset download is needed. + + Override in subclasses that provide built-in test episodes (e.g., DummyTask). + """ + return [] + + def _download_huggingface(self, config: dict, force: bool = False) -> Path: + """Download a dataset from HuggingFace Hub with file-based locking. + + Uses snapshot_download to get the full repo (including .zip files), + then extracts any listed zip_files. + """ + from easi.utils.locking import file_lock + from easi.utils.paths import get_locks_dir + + repo_id = config["repo_id"] + lock_path = get_locks_dir() / f"dataset_{repo_id.replace('/', '_')}.lock" + + # Use data_dir if set, otherwise default (./datasets) + if self._data_dir: + base_dir = self._data_dir + else: + from easi.utils.paths import get_datasets_dir + base_dir = get_datasets_dir() # returns ./datasets + + with file_lock(lock_path): + target = base_dir / repo_id.replace("/", "_") + + if force and target.exists(): + import shutil + logger.info("Force re-download: removing cached %s", target) + shutil.rmtree(target, ignore_errors=True) + + if not target.exists(): + try: + from huggingface_hub import snapshot_download + except ImportError: + raise DatasetError( + "huggingface_hub is required for HuggingFace downloads. " + "Install with: pip install huggingface_hub" + ) + + import time as _time + + max_attempts = 5 + for attempt in range(1, max_attempts + 1): + logger.info( + "Downloading dataset %s from HuggingFace (attempt %d/%d)...", + repo_id, attempt, max_attempts, + ) + try: + snapshot_download( + repo_id=repo_id, + local_dir=str(target), + repo_type="dataset", + ) + break # success + except Exception as e: + if attempt < max_attempts: + wait = 2 ** attempt # 2, 4, 8, 16s + logger.warning( + "Download attempt %d failed: %s. Retrying in %ds...", + attempt, e, wait, + ) + _time.sleep(wait) + else: + if target.exists(): + import shutil + shutil.rmtree(target, ignore_errors=True) + raise DatasetError( + f"Failed to download {repo_id} after {max_attempts} attempts: {e}" + ) + + logger.info("Downloaded dataset %s to %s", repo_id, target) + else: + logger.info("Dataset %s already cached at %s", repo_id, target) + + # Extract any .zip files listed in config + zip_files = config.get("zip_files", []) + if zip_files: + self._extract_zip_files(target, zip_files) + + return target + + @staticmethod + def _extract_zip_files(dataset_dir: Path, zip_filenames: list[str]) -> None: + """Extract listed .zip files within a downloaded dataset directory.""" + import zipfile as zf + + for zip_name in zip_filenames: + zip_path = dataset_dir / zip_name + if not zip_path.exists(): + logger.warning("Zip file not found: %s", zip_path) + continue + + marker = dataset_dir / f".{zip_name}.extracted" + if marker.exists(): + logger.trace("Already extracted: %s", zip_name) + continue + + logger.info("Extracting %s...", zip_path) + with zf.ZipFile(zip_path, "r") as z: + z.extractall(dataset_dir) + + marker.write_text("extracted") + logger.info("Extracted %s to %s", zip_name, dataset_dir) diff --git a/easi/core/docker_env_manager.py b/easi/core/docker_env_manager.py new file mode 100644 index 0000000..0f84a46 --- /dev/null +++ b/easi/core/docker_env_manager.py @@ -0,0 +1,203 @@ +"""Docker-based environment manager for simulators that require containerization. + +Parallel to BaseEnvironmentManager (conda-based). Docker simulators subclass this +instead. The bridge code inside the container is identical — same BaseBridge, +same filesystem IPC. +""" + +from __future__ import annotations + +import subprocess +from abc import ABC, abstractmethod +from pathlib import Path + +from easi.utils.logging import get_logger +from easi.utils.system_deps import SystemDependencyChecker + +logger = get_logger(__name__) + + +class DockerEnvironmentManager(ABC): + """Abstract base for Docker-isolated simulator environments. + + Subclasses must define: simulator_name, version, image_name, + dockerfile_path, gpu_required, container_python_path, + container_data_mount, easi_mount, get_system_deps(). + """ + + def __init__(self, installation_kwargs: dict | None = None): + self.installation_kwargs = installation_kwargs or {} + self._dep_checker = SystemDependencyChecker() + + # --- Abstract properties (subclass must implement) --- + + @property + @abstractmethod + def simulator_name(self) -> str: + """Name of the simulator (e.g., 'matterport3d').""" + ... + + @property + @abstractmethod + def version(self) -> str: + """Version identifier (e.g., 'v0_1').""" + ... + + @property + @abstractmethod + def image_name(self) -> str: + """Docker image name (e.g., 'easi_matterport3d_v0_1').""" + ... + + @property + @abstractmethod + def dockerfile_path(self) -> Path: + """Path to Dockerfile for building the image.""" + ... + + @property + @abstractmethod + def gpu_required(self) -> bool: + """Whether the container needs GPU access (--gpus all).""" + ... + + @property + @abstractmethod + def container_python_path(self) -> str: + """Python executable path inside the container.""" + ... + + @property + @abstractmethod + def container_data_mount(self) -> str: + """Mount point for simulator scene data inside the container.""" + ... + + @property + @abstractmethod + def easi_mount(self) -> str: + """Mount point for EASI repo inside the container (read-only).""" + ... + + # --- Concrete methods --- + + @abstractmethod + def get_system_deps(self) -> list[str]: + """System dependencies (e.g., ['docker'] or ['docker', 'nvidia-docker']).""" + ... + + def check_system_deps(self) -> list[str]: + """Check system dependencies, returning list of missing ones.""" + return self._dep_checker.check_all(self.get_system_deps()) + + def get_env_vars(self) -> dict[str, str]: + """Environment variables to set inside the container. Override if needed.""" + return {} + + def get_env_name(self) -> str: + """Return a name for this environment (used for display/logging).""" + return f"docker:{self.image_name}" + + def env_is_ready(self) -> bool: + """Check if the Docker image exists.""" + try: + result = subprocess.run( + ["docker", "image", "inspect", self.image_name], + capture_output=True, + timeout=10, + ) + return result.returncode == 0 + except Exception: + return False + + def install(self) -> None: + """Build the Docker image and run post_install (e.g., dataset download).""" + if not self.env_is_ready(): + dockerfile = self.dockerfile_path + if not dockerfile.exists(): + raise FileNotFoundError(f"Dockerfile not found: {dockerfile}") + + build_context = dockerfile.parent + logger.info( + "Building Docker image %s from %s ...", + self.image_name, + dockerfile, + ) + subprocess.run( + [ + "docker", "build", + "-t", self.image_name, + "-f", str(dockerfile), + str(build_context), + ], + check=True, + ) + logger.info("Docker image %s built successfully.", self.image_name) + else: + logger.info("Docker image %s already exists.", self.image_name) + + # Run post-install hook (e.g., dataset download) + self.post_install() + + def post_install(self) -> None: + """Hook for subclasses to run after image build (e.g., download datasets). + + Called by install() after the Docker image is ready. + Default is a no-op. + """ + pass + + def remove(self) -> None: + """Remove the Docker image.""" + logger.info("Removing Docker image %s ...", self.image_name) + subprocess.run( + ["docker", "rmi", self.image_name], + capture_output=True, + ) + + def build_docker_run_command( + self, + bridge_command: list[str], + workspace_dir: str | None = None, + episode_output_dir: str | None = None, + data_dir: str | None = None, + ) -> list[str]: + """Build a `docker run` command for launching the bridge. + + Mounts IPC workspace, episode output dir, EASI repo, and scene data + at the same host paths (so rgb_path in response.json works on both sides). + """ + easi_repo_root = str(Path(__file__).resolve().parents[1]) + + cmd = ["docker", "run", "--rm"] + + # GPU + if self.gpu_required: + cmd.extend(["--gpus", "all"]) + + # Volume mounts (same path on host and container for IPC compatibility) + if workspace_dir: + cmd.extend(["-v", f"{workspace_dir}:{workspace_dir}"]) + if episode_output_dir: + cmd.extend(["-v", f"{episode_output_dir}:{episode_output_dir}"]) + + # Data mount (host path -> container mount point) + if data_dir: + cmd.extend(["-v", f"{data_dir}:{self.container_data_mount}:ro"]) + + # EASI repo (read-only) + cmd.extend(["-v", f"{easi_repo_root}:{self.easi_mount}:ro"]) + + # Environment variables + cmd.extend(["-e", "PYTHONUNBUFFERED=1"]) # real-time log output + env_vars = self.get_env_vars() + for key, value in env_vars.items(): + cmd.extend(["-e", f"{key}={value}"]) + + # Image name + cmd.append(self.image_name) + + # Bridge command + cmd.extend(bridge_command) + + return cmd diff --git a/easi/core/episode.py b/easi/core/episode.py new file mode 100644 index 0000000..a334ab2 --- /dev/null +++ b/easi/core/episode.py @@ -0,0 +1,46 @@ +"""Core data structures for observations, actions, and step results.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class Observation: + """Observation produced by a simulator after reset or step.""" + + rgb_path: str + depth_path: str | None = None + agent_pose: list[float] = field(default_factory=list) + metadata: dict[str, str] = field(default_factory=dict) + + +@dataclass +class Action: + """An action to be executed in the simulator.""" + + action_name: str + params: dict[str, float] = field(default_factory=dict) + + +@dataclass +class StepResult: + """Result of executing one step in the simulator.""" + + observation: Observation + reward: float = 0.0 + done: bool = False + info: dict[str, float] = field(default_factory=dict) + + +@dataclass +class EpisodeRecord: + """Bundles all data for one completed episode. + + Used by aggregate_results() to give the aggregation function + access to both the raw episode data and the full trajectory. + """ + + episode: dict + trajectory: list[StepResult] + episode_results: dict[str, float] diff --git a/easi/core/exceptions.py b/easi/core/exceptions.py new file mode 100644 index 0000000..a093341 --- /dev/null +++ b/easi/core/exceptions.py @@ -0,0 +1,39 @@ +"""Custom exception hierarchy for EASI.""" + +from __future__ import annotations + + +class EASIError(Exception): + """Base exception for all EASI errors.""" + + +class EnvironmentSetupError(EASIError): + """Raised when a simulator environment fails to install or validate.""" + + def __init__(self, message: str, missing_deps: list[str] | None = None): + super().__init__(message) + self.missing_deps = missing_deps or [] + + +class SimulatorError(EASIError): + """Raised when a simulator encounters an error during operation.""" + + +class SimulatorTimeoutError(SimulatorError): + """Raised when a simulator operation exceeds the configured timeout.""" + + def __init__(self, message: str, timeout: float): + super().__init__(message) + self.timeout = timeout + + +class ActionParseError(EASIError): + """Raised when an LLM response cannot be parsed into a valid Action.""" + + def __init__(self, message: str, raw_response: str): + super().__init__(message) + self.raw_response = raw_response + + +class DatasetError(EASIError): + """Raised when a dataset cannot be loaded or downloaded.""" diff --git a/easi/core/memory.py b/easi/core/memory.py new file mode 100644 index 0000000..55fb1dd --- /dev/null +++ b/easi/core/memory.py @@ -0,0 +1,74 @@ +"""Agent memory: shared state between agent and prompt builder.""" +from __future__ import annotations + +from dataclasses import dataclass, field + +from easi.core.episode import Action, Observation + + +@dataclass +class StepRecord: + """Record of a single agent step.""" + + observation: Observation + action: Action | None = None + feedback: str | None = None + llm_response: str | None = None # None for buffered actions + # Flattened text preview of the prompt sent to the LLM on this step. + # ``None`` for buffered / fallback steps that didn't re-query. Image + # blocks are replaced with ``[img_N]`` markers to keep the string small. + prompt_text: str | None = None + step_number: int = 0 + + +@dataclass +class AgentMemory: + """Shared state that the agent populates and the prompt builder reads.""" + + task_description: str = "" + action_space: list[str] = field(default_factory=list) + steps: list[StepRecord] = field(default_factory=list) + current_observation: Observation | None = None + + @property + def is_first_turn(self) -> bool: + """True when no completed steps exist yet.""" + return len(self.steps) == 0 + + @property + def action_history(self) -> list[tuple[str, str]]: + """(action_name, feedback) for completed steps with feedback.""" + return [ + (s.action.action_name, s.feedback) + for s in self.steps + if s.action and s.feedback is not None + ] + + def record_step( + self, + observation: Observation, + action: Action, + llm_response: str | None, + prompt_text: str | None = None, + ) -> None: + """Record a completed step (action taken, awaiting feedback).""" + self.steps.append( + StepRecord( + observation=observation, + action=action, + llm_response=llm_response, + prompt_text=prompt_text, + step_number=len(self.steps), + ) + ) + + def record_feedback(self, feedback: str) -> None: + """Attach feedback to the most recent step.""" + if self.steps: + self.steps[-1].feedback = feedback + + def clear(self) -> None: + """Reset memory for a new episode.""" + self.steps.clear() + self.current_observation = None + self.task_description = "" diff --git a/easi/core/protocols.py b/easi/core/protocols.py new file mode 100644 index 0000000..21a00f5 --- /dev/null +++ b/easi/core/protocols.py @@ -0,0 +1,77 @@ +"""Protocol classes defining interfaces for EASI components.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Protocol, runtime_checkable + +from easi.core.episode import Action, Observation, StepResult + + +@runtime_checkable +class SimulatorProtocol(Protocol): + """Any simulator must satisfy this interface.""" + + name: str + version: str + + def reset(self, episode_id: str, reset_config: dict | None = None) -> Observation: ... + def step(self, action: Action) -> StepResult: ... + def close(self) -> None: ... + def is_running(self) -> bool: ... + + +@runtime_checkable +class EnvironmentManagerProtocol(Protocol): + """Manages conda+uv environment for a specific simulator version.""" + + simulator_name: str + version: str + + def check_system_deps(self) -> list[str]: ... + def install(self) -> None: ... + def env_is_ready(self) -> bool: ... + def get_python_executable(self) -> str: ... + def get_env_name(self) -> str: ... + + +@runtime_checkable +class LLMClientProtocol(Protocol): + """Calls an LLM inference server.""" + + def generate(self, messages: list[dict]) -> str: ... + + +@runtime_checkable +class AgentProtocol(Protocol): + """An agent that decides actions given observations.""" + + def reset(self) -> None: ... + def act(self, observation: Observation, task_description: str) -> Action: ... + def add_feedback(self, action_name: str, feedback: str) -> None: ... + + +@runtime_checkable +class TaskProtocol(Protocol): + """A benchmark task that maps dataset episodes to simulator configs.""" + + name: str + simulator_key: str + action_space: list[str] + max_steps: int + + def download_dataset(self) -> Path: ... + def load_episodes(self) -> list[dict]: ... + def get_episode(self, index: int) -> dict: ... + def format_reset_config(self, episode: dict) -> dict: ... + def evaluate_episode( + self, episode: dict, trajectory: list[StepResult] + ) -> dict[str, float]: ... + def get_bridge_script_path(self) -> Path | None: ... + @property + def simulator_configs(self) -> dict: ... + @property + def additional_deps(self) -> list[str]: ... + @property + def simulator_kwargs(self) -> dict: ... + def __len__(self) -> int: ... diff --git a/easi/core/render_platforms/__init__.py b/easi/core/render_platforms/__init__.py new file mode 100644 index 0000000..01e2b84 --- /dev/null +++ b/easi/core/render_platforms/__init__.py @@ -0,0 +1,47 @@ +"""Render platform package — pluggable display/rendering backends. + +Each render platform encapsulates how to launch a bridge subprocess with +the correct display/rendering environment. Simulators declare a default +platform; users can override via CLI (--render-platform) or task YAML. +Simulator-specific render quirks belong in a ``SimulatorRenderAdapter`` +registered from the simulator manifest, not in backend-specific subclasses. + +Built-in platforms: + auto — use native DISPLAY if available, fall back to xvfb + native — require existing DISPLAY + xvfb — always wrap with xvfb-run + egl — GPU-accelerated headless via EGL + headless — no display (simulator has native headless support) + xorg — GPU-accelerated X11 via auto-managed Xorg servers +""" + +from .auto import AutoPlatform +from .base import EnvVars, RenderPlatform, SimulatorRenderAdapter, WorkerBinding +from .egl import EGLPlatform +from .headless import HeadlessPlatform +from .native import NativePlatform +from .registry import available_platforms, get_render_platform +from .xorg import XorgPlatform +from .xorg_manager import XorgInstance, XorgManager +from .xvfb import XvfbPlatform + +__all__ = [ + # Base types + "RenderPlatform", + "EnvVars", + "WorkerBinding", + "SimulatorRenderAdapter", + # Built-in platforms + "AutoPlatform", + "NativePlatform", + "XvfbPlatform", + "EGLPlatform", + "HeadlessPlatform", + "XorgPlatform", + # Xorg internals + "XorgManager", + "XorgInstance", + # Registry + "get_render_platform", + "available_platforms", +] diff --git a/easi/core/render_platforms/auto.py b/easi/core/render_platforms/auto.py new file mode 100644 index 0000000..7b4ad92 --- /dev/null +++ b/easi/core/render_platforms/auto.py @@ -0,0 +1,25 @@ +"""Auto render platform — native display if available, falls back to xvfb.""" + +from __future__ import annotations + +import os + +from .base import RenderPlatform +from .xvfb import XvfbPlatform + + +class AutoPlatform(RenderPlatform): + """Detect native display; fall back to xvfb if unavailable.""" + + @property + def name(self) -> str: + return "auto" + + @property + def resolved_name(self) -> str: + return "native" if os.environ.get("DISPLAY", "") else "xvfb" + + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + if os.environ.get("DISPLAY", ""): + return cmd + return XvfbPlatform().wrap_command(cmd, screen_config) diff --git a/easi/core/render_platforms/base.py b/easi/core/render_platforms/base.py new file mode 100644 index 0000000..04dc5cd --- /dev/null +++ b/easi/core/render_platforms/base.py @@ -0,0 +1,169 @@ +"""Base classes for render platforms. + +Defines the RenderPlatform ABC and EnvVars dataclass used by all built-in +render backends plus simulator render adapters. + +Also defines WorkerBinding (resolved per-worker render facts) and +SimulatorRenderAdapter (optional simulator-specific launch adjustments). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class EnvVars: + """Structured environment variables with replace/prepend semantics. + + ``replace`` vars overwrite any existing value. + ``prepend`` vars are prepended with ':' to any existing value (for PATH-like vars). + """ + + replace: dict[str, str] = field(default_factory=dict) + prepend: dict[str, str] = field(default_factory=dict) + + def to_flat_dict(self) -> dict[str, str]: + """Combine into single dict (for internal use like post_install).""" + return {**self.replace, **self.prepend} + + def apply_to_env(self, base: dict[str, str]) -> dict[str, str]: + """Merge into a base env dict (e.g. os.environ.copy()).""" + env = dict(base) + for k, v in self.replace.items(): + env[k] = v + for k, v in self.prepend.items(): + env[k] = f"{v}:{env[k]}" if k in env else v + return env + + def __bool__(self) -> bool: + return bool(self.replace) or bool(self.prepend) + + @classmethod + def merge(cls, *env_vars: EnvVars) -> EnvVars: + """Merge multiple EnvVars. Later values win for replace; prepend values concatenate.""" + replace: dict[str, str] = {} + prepend: dict[str, str] = {} + for ev in env_vars: + if ev is None: + continue + replace.update(ev.replace) + for k, v in ev.prepend.items(): + prepend[k] = f"{v}:{prepend[k]}" if k in prepend else v + return cls(replace=replace, prepend=prepend) + + +@dataclass +class WorkerBinding: + """Resolved per-worker render facts produced by a render backend. + + Carries the concrete display and GPU assignment for one worker subprocess, + plus any extra env vars and arbitrary metadata the backend wants to pass + downstream (e.g. to a SimulatorRenderAdapter). + + Fields: + display: X display string (e.g. ":10"), or None for headless/EGL. + cuda_visible_devices: GPU id(s) string (e.g. "0" or "0,1"), or None. + extra_env: Additional env vars contributed by the render backend. + metadata: Arbitrary backend-specific data for adapter consumption. + """ + + display: str | None = None + cuda_visible_devices: str | None = None + extra_env: EnvVars = field(default_factory=EnvVars) + metadata: dict[str, Any] = field(default_factory=dict) + + +class SimulatorRenderAdapter(ABC): + """Extension point for simulator-specific render launch adjustments. + + Simulators that need to inject render-related env vars or wrap the launch + command beyond what the core render backend provides should subclass this + and register it via the simulator manifest. + + Default implementations are no-ops so simulators that need no adjustments + do not have to implement anything. + """ + + def get_env_vars(self, binding: WorkerBinding) -> EnvVars: + return EnvVars() + + def wrap_command(self, cmd: list[str], binding: WorkerBinding) -> list[str]: + return cmd + + +class RenderPlatform(ABC): + """Strategy interface for display/rendering backends. + + Lifecycle hooks (``setup`` / ``teardown``) allow platforms that manage + external services (e.g. Xorg) to start and stop them without + if/else logic in the callers. ``for_worker`` always returns a + ``WorkerBinding`` so callers have a uniform interface; backends that + need per-worker GPU/display assignment (e.g. ``XorgPlatform``) override it. + """ + + def __init__(self, env_manager=None): + self._env_manager = env_manager + + @property + @abstractmethod + def name(self) -> str: + """Short identifier (e.g. 'xvfb', 'egl').""" + ... + + @property + def resolved_name(self) -> str: + """Actual backend after auto-detection. Defaults to :attr:`name`.""" + return self.name + + @property + def log_name(self) -> str: + """Human-readable name for log messages.""" + resolved = self.resolved_name + if resolved != self.name: + return f"{resolved} (via auto-detection)" + return resolved + + @abstractmethod + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + """Optionally wrap the bridge launch command. + + Args: + cmd: The original command ``[python, bridge.py, ...]``. + screen_config: Screen resolution string, e.g. ``"1024x768x24"``. + + Returns: + The (possibly wrapped) command. + """ + ... + + def get_env_vars(self) -> EnvVars: + """Extra env vars needed by this platform (merged into subprocess).""" + return EnvVars() + + def get_system_deps(self) -> list[str]: + """System dependency names required by this platform.""" + return [] + + def is_available(self) -> bool: + """Whether this platform can run in the current environment.""" + return True + + # -- Lifecycle hooks (override in platforms that manage services) ---------- + + def setup(self, gpu_ids: list[int] | None = None) -> None: + """Called once before any simulator is created. Start external services.""" + + def teardown(self) -> None: + """Called once after all simulators are done. Stop external services.""" + + def for_worker(self, worker_id: int) -> WorkerBinding: + """Return the per-worker render binding for this platform. + + The default returns a ``WorkerBinding`` carrying the backend name in + ``metadata``. Platforms that assign per-worker displays or GPUs + (e.g. ``XorgPlatform``) override this to return a populated binding. + """ + return WorkerBinding(metadata={"backend": self.name}) diff --git a/easi/core/render_platforms/egl.py b/easi/core/render_platforms/egl.py new file mode 100644 index 0000000..56835db --- /dev/null +++ b/easi/core/render_platforms/egl.py @@ -0,0 +1,29 @@ +"""EGL render platform — GPU-accelerated headless rendering via EGL (no X11 needed).""" + +from __future__ import annotations + +from pathlib import Path + +from .base import EnvVars, RenderPlatform + + +class EGLPlatform(RenderPlatform): + """GPU-accelerated headless rendering via EGL (no X11 needed).""" + + @property + def name(self) -> str: + return "egl" + + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + return cmd + + def get_env_vars(self) -> EnvVars: + replace: dict[str, str] = {"PYOPENGL_PLATFORM": "egl"} + # Don't set __EGL_VENDOR_LIBRARY_FILENAMES — let glvnd discover + # vendors from its default search path (/usr/share/glvnd/egl_vendor.d). + # Setting it explicitly restricts discovery to a single vendor file, + # which breaks GPU rendering when NVIDIA is available but Mesa is forced. + return EnvVars(replace=replace) + + def get_system_deps(self) -> list[str]: + return ["egl"] diff --git a/easi/core/render_platforms/headless.py b/easi/core/render_platforms/headless.py new file mode 100644 index 0000000..5f4567e --- /dev/null +++ b/easi/core/render_platforms/headless.py @@ -0,0 +1,16 @@ +"""Headless render platform — no display, for simulators with native headless support.""" + +from __future__ import annotations + +from .base import RenderPlatform + + +class HeadlessPlatform(RenderPlatform): + """No display at all -- for simulators with native headless support.""" + + @property + def name(self) -> str: + return "headless" + + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + return cmd diff --git a/easi/core/render_platforms/native.py b/easi/core/render_platforms/native.py new file mode 100644 index 0000000..114d107 --- /dev/null +++ b/easi/core/render_platforms/native.py @@ -0,0 +1,21 @@ +"""Native render platform — use existing DISPLAY environment variable.""" + +from __future__ import annotations + +import os + +from .base import RenderPlatform + + +class NativePlatform(RenderPlatform): + """Use the existing DISPLAY. Fails at validation if none is set.""" + + @property + def name(self) -> str: + return "native" + + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + return cmd + + def is_available(self) -> bool: + return bool(os.environ.get("DISPLAY", "")) diff --git a/easi/core/render_platforms/registry.py b/easi/core/render_platforms/registry.py new file mode 100644 index 0000000..2318a44 --- /dev/null +++ b/easi/core/render_platforms/registry.py @@ -0,0 +1,40 @@ +"""Render platform registry — maps names to platform classes.""" + +from __future__ import annotations + +from .auto import AutoPlatform +from .base import RenderPlatform +from .egl import EGLPlatform +from .headless import HeadlessPlatform +from .native import NativePlatform +from .xorg import XorgPlatform +from .xvfb import XvfbPlatform + +_BUILTIN: dict[str, type[RenderPlatform]] = { + "auto": AutoPlatform, + "native": NativePlatform, + "xvfb": XvfbPlatform, + "egl": EGLPlatform, + "headless": HeadlessPlatform, + "xorg": XorgPlatform, +} + + +def get_render_platform(name: str) -> RenderPlatform: + """Instantiate a render platform by name. + + Raises: + ValueError: If name is not recognised. + """ + cls = _BUILTIN.get(name) + if cls is None: + raise ValueError( + f"Unknown render platform '{name}'. " + f"Available: {', '.join(sorted(_BUILTIN))}" + ) + return cls() + + +def available_platforms() -> list[str]: + """Return sorted list of registered platform names.""" + return sorted(_BUILTIN) diff --git a/easi/core/render_platforms/xorg.py b/easi/core/render_platforms/xorg.py new file mode 100644 index 0000000..03c1247 --- /dev/null +++ b/easi/core/render_platforms/xorg.py @@ -0,0 +1,76 @@ +"""Xorg render platform — GPU-accelerated X11 display managed by EASI. + +``XorgPlatform`` owns the ``XorgManager`` lifecycle and resolves per-worker +bindings for adapter-driven simulator launch wiring. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from easi.utils.logging import get_logger + +from .base import EnvVars, RenderPlatform, WorkerBinding +from .xorg_manager import XorgManager + +if TYPE_CHECKING: + from .xorg_manager import XorgInstance + +logger = get_logger(__name__) + + +class XorgPlatform(RenderPlatform): + """Render platform backed by auto-managed Xorg servers. + + Call ``setup(gpu_ids=...)`` to start Xorg servers, then + ``for_worker(worker_id)`` to get per-worker bindings. + ``teardown()`` stops all servers. + """ + + def __init__(self, env_manager=None): + super().__init__(env_manager=env_manager) + self._xorg_mgr: XorgManager | None = None + self._instances: list[XorgInstance] = [] + + @property + def name(self) -> str: + return "xorg" + + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + return cmd + + def get_env_vars(self) -> EnvVars: + return EnvVars() + + def is_available(self) -> bool: + return True + + def setup(self, gpu_ids: list[int] | None = None) -> None: + """Start one Xorg server per GPU.""" + self._xorg_mgr = XorgManager(gpu_ids=gpu_ids or [0]) + self._instances = self._xorg_mgr.start() + + def teardown(self) -> None: + """Stop all Xorg servers.""" + if self._xorg_mgr is not None: + self._xorg_mgr.stop() + self._xorg_mgr = None + self._instances = [] + + def for_worker(self, worker_id: int) -> WorkerBinding: + """Resolve a per-worker binding for a specific Xorg instance.""" + if not self._instances: + raise RuntimeError( + "XorgPlatform.setup() must be called before for_worker()" + ) + inst = self._instances[worker_id % len(self._instances)] + return _build_worker_binding(inst.display, inst.gpu_id) + + +def _build_worker_binding(display_num: int, gpu_id: int) -> WorkerBinding: + return WorkerBinding( + display=f":{display_num}", + cuda_visible_devices=str(gpu_id), + extra_env=EnvVars(replace={"EASI_GPU_DISPLAY": "1"}), + metadata={"backend": "xorg", "display_num": display_num, "gpu_id": gpu_id}, + ) diff --git a/easi/core/render_platforms/xorg_manager.py b/easi/core/render_platforms/xorg_manager.py new file mode 100644 index 0000000..0a71e55 --- /dev/null +++ b/easi/core/render_platforms/xorg_manager.py @@ -0,0 +1,345 @@ +"""Manages lifecycle of Xorg servers for GPU-accelerated rendering. + +Starts one Xorg per GPU, waits for health, stops on exit. +Follows the same lifecycle pattern as ``easi.llm.server_manager.ServerManager``. +""" + +from __future__ import annotations + +import os +import shutil +import signal +import subprocess +import time +from typing import NamedTuple + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +_STARTUP_TIMEOUT = 10.0 +_HEALTH_POLL_INTERVAL = 0.5 + + +class XorgInstance(NamedTuple): + """A running Xorg server bound to a GPU.""" + + display: int + gpu_id: int + pid: int + + +def _find_available_display(start: int, max_probe: int = 50) -> int: + """Find the first available X display number starting from *start*.""" + for num in range(start, start + max_probe): + lock_file = f"/tmp/.X{num}-lock" + if not os.path.exists(lock_file): + return num + raise RuntimeError( + f"No available X display in range :{start}-:{start + max_probe - 1}" + ) + + +def _get_pci_bus_id(gpu_index: int) -> str: + """Query PCI BusID for a GPU via nvidia-smi, return Xorg format (PCI:B:D:F).""" + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=pci.bus_id", + "--format=csv,noheader", + "-i", + str(gpu_index), + ], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0 or not result.stdout.strip(): + raise RuntimeError( + f"nvidia-smi failed for GPU {gpu_index}: {result.stderr.strip()}" + ) + + raw = result.stdout.strip() + # Format: 00000000:3F:00.0 → strip domain, parse hex bus:dev.func + no_domain = raw.split(":", 1)[1] # "3F:00.0" + bus_hex, rest = no_domain.split(":", 1) # "3F", "00.0" + dev_hex, func = rest.split(".", 1) # "00", "0" + bus_dec = int(bus_hex, 16) + dev_dec = int(dev_hex, 16) + return f"PCI:{bus_dec}:{dev_dec}:{func}" + + +def _write_xorg_conf(gpu_index: int, pci_bus_id: str) -> str: + """Write a minimal xorg.conf and return its path.""" + conf_path = f"/tmp/easi-xorg-gpu{gpu_index}.conf" + conf = f"""\ +Section "Device" + Identifier "Device{gpu_index}" + Driver "nvidia" + BusID "{pci_bus_id}" + Option "AllowEmptyInitialConfiguration" "True" +EndSection + +Section "Screen" + Identifier "Screen{gpu_index}" + Device "Device{gpu_index}" + DefaultDepth 24 + SubSection "Display" + Depth 24 + Virtual 1920 1080 + EndSubSection +EndSection + +Section "ServerLayout" + Identifier "Layout{gpu_index}" + Screen "Screen{gpu_index}" +EndSection +""" + with open(conf_path, "w") as f: + f.write(conf) + return conf_path + + +class XorgManager: + """Manages Xorg server processes for GPU-accelerated rendering. + + Starts one Xorg server per GPU ID, waits for each to become ready, + and cleans up all on stop. + """ + + def __init__(self, gpu_ids: list[int], base_display: int = 10): + self.gpu_ids = gpu_ids + self.base_display = base_display + self._processes: list[subprocess.Popen] = [] + self._used_sudo: list[bool] = [] + self._instances: list[XorgInstance] = [] + self._conf_files: list[str] = [] + + def start(self) -> list[XorgInstance]: + """Start Xorg on each GPU. Returns list of XorgInstance.""" + xorg_path = shutil.which("Xorg") + if xorg_path is None: + raise RuntimeError( + "Xorg is not installed. Install with: apt install xserver-xorg" + ) + + try: + next_display = self.base_display + for gpu_id in self.gpu_ids: + display_num = _find_available_display(next_display) + next_display = display_num + 1 + instance = self._start_one(xorg_path, gpu_id, display_num) + self._instances.append(instance) + except Exception: + logger.warning( + "Xorg startup failed, stopping %d already-started servers", + len(self._processes), + ) + self.stop() + raise + + logger.info( + "All %d Xorg servers ready: %s", + len(self._instances), + [(f":{i.display}", f"GPU {i.gpu_id}") for i in self._instances], + ) + return list(self._instances) + + def _start_one( + self, + xorg_path: str, + gpu_id: int, + display_num: int, + ) -> XorgInstance: + """Start a single Xorg server on the given GPU and display.""" + pci_bus_id = _get_pci_bus_id(gpu_id) + conf_path = _write_xorg_conf(gpu_id, pci_bus_id) + self._conf_files.append(conf_path) + + display_str = f":{display_num}" + cmd = [ + xorg_path, + display_str, + "-config", + conf_path, + "-noreset", + "-nolisten", + "tcp", + ] + + logger.info( + "Starting Xorg on display %s using GPU %d (%s)", + display_str, + gpu_id, + pci_bus_id, + ) + + proc, used_sudo = self._launch_xorg(cmd, xorg_path) + self._processes.append(proc) + self._used_sudo.append(used_sudo) + + try: + self._wait_for_ready(display_num, proc) + except RuntimeError as exc: + if not used_sudo and proc.poll() is not None: + if self._passwordless_sudo_available(xorg_path): + logger.info( + "Direct Xorg launch exited early on display %s; retrying with sudo", + display_str, + ) + self._processes.pop() + self._used_sudo.pop() + proc, used_sudo = self._launch_xorg_with_sudo(cmd, xorg_path) + self._processes.append(proc) + self._used_sudo.append(used_sudo) + self._wait_for_ready(display_num, proc) + else: + raise RuntimeError( + self._sudo_required_message(xorg_path, display_str) + ) from exc + else: + raise + + logger.info( + "Xorg ready on display %s (PID %d, GPU %d)", + display_str, + proc.pid, + gpu_id, + ) + return XorgInstance(display=display_num, gpu_id=gpu_id, pid=proc.pid) + + def _launch_xorg( + self, cmd: list[str], xorg_path: str + ) -> tuple[subprocess.Popen, bool]: + """Try launching Xorg directly, fall back to sudo on PermissionError. + + Returns (process, used_sudo) so stop() knows whether to use sudo kill. + """ + try: + proc = subprocess.Popen( + cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + ) + return proc, False + except PermissionError: + logger.info( + "Direct Xorg launch failed (permission denied), retrying with sudo" + ) + + return self._launch_xorg_with_sudo(cmd, xorg_path) + + def _launch_xorg_with_sudo( + self, + cmd: list[str], + xorg_path: str, + ) -> tuple[subprocess.Popen, bool]: + sudo_cmd = ["sudo", "-n"] + cmd + try: + proc = subprocess.Popen( + sudo_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + ) + return proc, True + except (PermissionError, FileNotFoundError) as exc: + raise RuntimeError(self._sudo_required_message(xorg_path)) from exc + + def _wait_for_ready(self, display_num: int, proc: subprocess.Popen) -> None: + """Poll until the X server responds or timeout.""" + deadline = time.monotonic() + _STARTUP_TIMEOUT + display_str = f":{display_num}" + + while time.monotonic() < deadline: + if proc.poll() is not None: + raise RuntimeError( + f"Xorg exited with code {proc.returncode} on display {display_str}. " + f"Check /var/log/Xorg.{display_num}.log for details." + ) + try: + result = subprocess.run( + ["xset", "-display", display_str, "q"], + capture_output=True, + timeout=2, + ) + if result.returncode == 0: + return + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + time.sleep(_HEALTH_POLL_INTERVAL) + + raise RuntimeError( + f"Xorg on display {display_str} did not become ready within " + f"{_STARTUP_TIMEOUT}s" + ) + + def stop(self) -> None: + """Stop all Xorg servers and clean up.""" + for proc, sudo in zip(self._processes, self._used_sudo): + self._kill_proc(proc, signal.SIGTERM, sudo) + + for proc, sudo in zip(self._processes, self._used_sudo): + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + self._kill_proc(proc, signal.SIGKILL, sudo) + try: + proc.wait(timeout=2) + except subprocess.TimeoutExpired: + pass + + self._processes.clear() + self._used_sudo.clear() + self._instances.clear() + + for conf in self._conf_files: + try: + os.unlink(conf) + except OSError: + pass + self._conf_files.clear() + + @staticmethod + def _kill_proc(proc: subprocess.Popen, sig: int, used_sudo: bool) -> None: + """Send a signal to a process group, using sudo if the process was sudo-launched.""" + try: + pgid = os.getpgid(proc.pid) + if used_sudo: + subprocess.run( + ["sudo", "-n", "kill", f"-{sig}", f"-{pgid}"], + capture_output=True, + timeout=5, + ) + else: + os.killpg(pgid, sig) + except (ProcessLookupError, PermissionError, subprocess.TimeoutExpired): + pass + + @staticmethod + def _passwordless_sudo_available(xorg_path: str) -> bool: + """Probe ``sudo -n -l `` — sudoers may only whitelist specific commands.""" + try: + result = subprocess.run( + ["sudo", "-n", "-l", xorg_path], + capture_output=True, + timeout=5, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return False + return result.returncode == 0 + + @staticmethod + def _sudo_required_message(xorg_path: str, display_str: str | None = None) -> str: + display_hint = "" + if display_str is not None: + display_hint = ( + f"Xorg exited before becoming ready on display {display_str}. " + ) + user = os.environ.get("USER", "$USER") + return ( + f"{display_hint}This usually requires root or a console session.\n" + f"To fix this, run EASI as root or authorize passwordless sudo:\n\n" + f" sudo bash -c 'echo \"{user} ALL=(ALL) NOPASSWD: {xorg_path}, /usr/bin/kill\" > /etc/sudoers.d/easi-xorg'" + ) diff --git a/easi/core/render_platforms/xvfb.py b/easi/core/render_platforms/xvfb.py new file mode 100644 index 0000000..8afcf0e --- /dev/null +++ b/easi/core/render_platforms/xvfb.py @@ -0,0 +1,22 @@ +"""Xvfb render platform — virtual framebuffer via xvfb-run.""" + +from __future__ import annotations + +from .base import RenderPlatform + + +class XvfbPlatform(RenderPlatform): + """Always wrap with ``xvfb-run``.""" + + @property + def name(self) -> str: + return "xvfb" + + def wrap_command(self, cmd: list[str], screen_config: str) -> list[str]: + return [ + "xvfb-run", "-a", + "-s", f"-screen 0 {screen_config}", + ] + cmd + + def get_system_deps(self) -> list[str]: + return ["xvfb"] diff --git a/easi/evaluation/__init__.py b/easi/evaluation/__init__.py new file mode 100644 index 0000000..a88296e --- /dev/null +++ b/easi/evaluation/__init__.py @@ -0,0 +1 @@ +"""Evaluation orchestration.""" diff --git a/easi/evaluation/episode_filter.py b/easi/evaluation/episode_filter.py new file mode 100644 index 0000000..691fb5d --- /dev/null +++ b/easi/evaluation/episode_filter.py @@ -0,0 +1,119 @@ +"""Parse and apply ``--episodes`` filter expressions. + +Syntax (comma-separated tokens, freely mixed):: + + :N index slice — first N episodes + M:N index slice — episodes at indices M..N-1 + M: index slice — from index M onwards + 42 episode ID "42" + +Examples:: + + --episodes :10 first 10 (replaces old --max-episodes) + --episodes 30:40 index range 30-39 + --episodes 2,5,7 episode IDs 2, 5, 7 + --episodes 2,10:20,40 episode ID 2 + range 10-19 + episode ID 40 + +Semantics: +- All selections are unioned and deduplicated. +- Original dataset order is preserved. +- If a requested episode ID is not found, raises ValueError. +""" + +from __future__ import annotations + + +def parse_episodes_flag(value: str) -> tuple[list[tuple[int | None, int | None]], list[str]]: + """Parse ``--episodes`` value into index slices and episode IDs. + + Returns: + (slices, ids) where slices is a list of (start, stop) tuples + and ids is a list of episode ID strings. + """ + slices: list[tuple[int | None, int | None]] = [] + ids: list[str] = [] + + for token in value.split(","): + token = token.strip() + if not token: + continue + + if ":" in token: + parts = token.split(":", 1) + start_s, stop_s = parts[0].strip(), parts[1].strip() + start = int(start_s) if start_s else None + stop = int(stop_s) if stop_s else None + + # Validate + if start is not None and stop is not None and start >= stop: + raise ValueError( + f"Invalid range '{token}': start ({start}) must be less than stop ({stop})" + ) + if start is not None and start < 0: + raise ValueError(f"Invalid range '{token}': negative index") + if stop is not None and stop < 0: + raise ValueError(f"Invalid range '{token}': negative index") + + slices.append((start, stop)) + else: + ids.append(token) + + return slices, ids + + +def filter_episodes( + episodes: list[dict], + episodes_flag: str, +) -> list[dict]: + """Filter an episode list according to an ``--episodes`` expression. + + Args: + episodes: Full episode list from task.load_episodes(). + episodes_flag: Raw ``--episodes`` CLI value. + + Returns: + Filtered episode list in original order. + + Raises: + ValueError: If a requested episode ID is not found, or the + expression is invalid. + """ + slices, ids = parse_episodes_flag(episodes_flag) + + if not slices and not ids: + raise ValueError("Empty --episodes value") + + # Collect selected indices (as a set for deduplication) + selected_indices: set[int] = set() + + # Apply index slices + for start, stop in slices: + s = start if start is not None else 0 + e = stop if stop is not None else len(episodes) + # Clamp to valid range + s = max(0, min(s, len(episodes))) + e = max(0, min(e, len(episodes))) + for i in range(s, e): + selected_indices.add(i) + + # Apply episode ID selections + if ids: + # Build ID -> index map + id_to_indices: dict[str, int] = {} + for i, ep in enumerate(episodes): + ep_id = str(ep.get("episode_id", "")) + # First occurrence wins (IDs should be unique, but be safe) + if ep_id not in id_to_indices: + id_to_indices[ep_id] = i + + missing = [eid for eid in ids if eid not in id_to_indices] + if missing: + raise ValueError( + f"Episode IDs not found in dataset: {', '.join(missing)}" + ) + + for eid in ids: + selected_indices.add(id_to_indices[eid]) + + # Return in original order + return [episodes[i] for i in sorted(selected_indices)] diff --git a/easi/evaluation/metrics.py b/easi/evaluation/metrics.py new file mode 100644 index 0000000..5af56a5 --- /dev/null +++ b/easi/evaluation/metrics.py @@ -0,0 +1,94 @@ +"""Metric aggregation utilities.""" +from __future__ import annotations + +from easi.core.episode import EpisodeRecord + + +def generic_aggregate(records: list[EpisodeRecord]) -> dict: + """Compute generic metrics shared across all benchmarks. + + These are computed from per-episode results and included at the + top level of summary.json alongside task-specific "metrics". + + Keys produced: + - success_rate: mean of task_success (or success) across episodes + - avg_steps: mean of num_steps across episodes + - num_episodes: total count + """ + if not records: + return {} + + n = len(records) + + # Success: try task_success first, fall back to success + successes = [] + for r in records: + er = r.episode_results + s = er.get("task_success", er.get("success")) + if isinstance(s, (int, float)): + successes.append(float(s)) + + # Steps + steps = [] + for r in records: + er = r.episode_results + s = er.get("num_steps", er.get("steps_taken")) + if isinstance(s, (int, float)): + steps.append(float(s)) + + # Early stops (forced by consecutive fallback limit) + early_stops = sum( + 1 for r in records + if r.episode_results.get("forced_early_stop", False) + ) + + result = {"num_episodes": n} + if successes: + result["success_rate"] = round(sum(successes) / n, 4) + if steps: + result["avg_steps"] = round(sum(steps) / n, 1) + sorted_steps = sorted(steps) + result["median_steps"] = round(sorted_steps[len(sorted_steps) // 2], 1) + result["early_stop_rate"] = round(early_stops / n, 4) if n else 0 + + return result + + +def default_aggregate(records: list[EpisodeRecord]) -> dict: + """Default aggregation: average all numeric keys from episode_results. + + Used by tasks that don't override aggregate_results(). + + Args: + records: List of EpisodeRecord objects (one per episode). + + Returns: + Summary metrics dict with avg_ for each numeric key. + """ + if not records: + return {} + + summary: dict = {} + + # Collect all numeric keys from episode_results + numeric_keys: dict[str, list[float]] = {} + for r in records: + for key, value in r.episode_results.items(): + if isinstance(value, (int, float)): + numeric_keys.setdefault(key, []).append(float(value)) + + # Average each numeric metric over ALL episodes (not just those that emitted the key). + # Failed episodes may not emit task-specific keys — they should contribute 0, not be excluded. + total = len(records) + for key, values in numeric_keys.items(): + summary[f"avg_{key}"] = round(sum(values) / total, 4) + + # Convenience aliases + if "avg_success" in summary: + summary["success_rate"] = summary["avg_success"] + if "avg_task_success" in summary: + summary["success_rate"] = summary["avg_task_success"] + if "avg_num_steps" in summary: + summary["avg_steps"] = summary["avg_num_steps"] + + return summary diff --git a/easi/evaluation/parallel_runner.py b/easi/evaluation/parallel_runner.py new file mode 100644 index 0000000..4cb01c2 --- /dev/null +++ b/easi/evaluation/parallel_runner.py @@ -0,0 +1,495 @@ +"""Thread-pool based parallel evaluation runner. + +Extends EvaluationRunner with concurrent episode execution: +1. Load task once (shared read-only across workers) +2. Fill a queue with (index, episode) tuples +3. Launch N worker threads, each with its own simulator + agent +4. Workers pull episodes from the queue and run them via inherited _run_episode() +5. Collect results thread-safely, aggregate metrics, save summary +""" + +from __future__ import annotations + +import json +import queue +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from easi.core.episode import EpisodeRecord +from easi.evaluation.runner import EvaluationRunner, _sanitize_dirname +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +def _get_gpu_count() -> int | None: + """Detect the number of GPUs via nvidia-smi. + + Returns the GPU count, or None if detection fails (e.g., no GPUs, + nvidia-smi not installed). + """ + import subprocess + + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], + capture_output=True, text=True, timeout=10, + ) + if result.returncode == 0: + lines = [l.strip() for l in result.stdout.strip().splitlines() if l.strip()] + return len(lines) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + return None + + +class ParallelRunner(EvaluationRunner): + """Thread-pool based parallel evaluation runner. + + Each worker thread owns its own simulator and agent instance. + Episodes are distributed via a shared queue. + """ + + def __init__(self, *, num_parallel: int = 2, **kwargs): + super().__init__(**kwargs) + self.num_parallel = num_parallel + self._validate_gpu_args() + + def _serialize_cli_options(self) -> dict: + """Add num_parallel to the serialized config.""" + base = super()._serialize_cli_options() + base["num_parallel"] = self.num_parallel + return base + + def _validate_gpu_args(self): + """Validate GPU allocation arguments.""" + if self.llm_instances and self.llm_instances > 1 and not self.llm_gpus: + raise ValueError( + "--llm-gpus is required when --llm-instances > 1. " + "Specify which GPUs to use for LLM inference." + ) + if self.llm_gpus and self.sim_gpus: + overlap = set(self.llm_gpus) & set(self.sim_gpus) + if overlap: + raise ValueError( + f"--llm-gpus and --sim-gpus must not overlap. " + f"Overlapping GPU IDs: {overlap}" + ) + # Warn if local-server args are set but backend is not a local backend + if self.backend and self.backend not in ("vllm", "custom"): + ignored = [] + if self.llm_instances: + ignored.append("--llm-instances") + if self.llm_gpus: + ignored.append("--llm-gpus") + if ignored: + logger.warning( + "%s will be ignored because --backend is '%s' (not a local LLM backend).", + ", ".join(ignored), self.backend, + ) + # Validate GPU IDs against hardware + all_gpu_ids = set() + if self.llm_gpus: + all_gpu_ids.update(self.llm_gpus) + if self.sim_gpus: + all_gpu_ids.update(self.sim_gpus) + if all_gpu_ids: + gpu_count = _get_gpu_count() + if gpu_count is not None: + invalid = {g for g in all_gpu_ids if g < 0 or g >= gpu_count} + if invalid: + raise ValueError( + f"GPU IDs {sorted(invalid)} do not exist. " + f"This machine has {gpu_count} GPU(s) " + f"(valid IDs: 0-{gpu_count - 1})." + ) + + def _parse_base_urls(self) -> list[str | None]: + """Parse base URL(s) into list for round-robin assignment.""" + if self.llm_base_url: + return [u.strip() for u in self.llm_base_url.split(",") if u.strip()] + return [None] + + def run(self) -> list[dict]: + """Run evaluation with thread-pool parallelism.""" + logger.trace( + "ParallelRunner.run() called: task=%s, num_parallel=%d", + self.task_name, self.num_parallel, + ) + + # --- Resolve LLM backend and vLLM URLs --- + backend, base_url = self._resolve_llm_backend() + server_mgr = None + self._render_platform = None + + try: + if backend in ("vllm", "custom") and base_url is None: + # Auto-manage vLLM instances + from easi.llm.server_manager import MultiServerManager + from easi.llm.utils import parse_llm_kwargs, split_kwargs as _split + + all_kw = parse_llm_kwargs(self.llm_kwargs_raw) + srv_kw, _ = _split(all_kw) + + num_instances = self.llm_instances or 1 + gpu_ids = self.llm_gpus + + startup_timeout = float(srv_kw.pop("startup_timeout", 300.0)) + + server_mgr = MultiServerManager( + model=self.model, + num_instances=num_instances, + gpu_ids=gpu_ids, + base_port=self.port, + server_kwargs=srv_kw, + startup_timeout=startup_timeout, + backend=backend, + ) + base_urls = server_mgr.start() + elif base_url: + base_urls = self._parse_base_urls() + else: + base_urls = [None] + + # Setup render platform (starts external services like Xorg if needed) + self._render_platform = self._setup_render_platform(backend) + + # --- Load task --- + logger.trace("Loading task") + task = self._create_task() + if self.refresh_data: + task.download_dataset(force=True) + episodes = task.load_episodes() + if self.episodes_filter is not None: + from easi.evaluation.episode_filter import filter_episodes + episodes = filter_episodes(episodes, self.episodes_filter) + logger.trace( + "Task loaded. %d episodes, simulator_key=%s", + len(episodes), task.simulator_key, + ) + + # --- Resolve LLM backend + handle resume --- + logger.trace("Resolved LLM backend=%s, base_url=%s", backend, base_url) + + # Compute resolved generation kwargs (YAML defaults + CLI overrides) + from easi.llm.utils import parse_llm_kwargs, split_kwargs + + agent_config = task._config.get("agent", {}) + yaml_gen_kwargs = agent_config.get("generation_kwargs", {}) + all_llm_kwargs = parse_llm_kwargs(self.llm_kwargs_raw) + _, cli_gen_kwargs = split_kwargs(all_llm_kwargs) + resolved_gen_kwargs = {**yaml_gen_kwargs, **cli_gen_kwargs} + + # Handle resume + if self.resume_dir: + run_dir = self.resume_dir + self.run_dir = run_dir + completed_results, start_index = self._load_completed_results( + run_dir, len(episodes), + ) + self._reattach_resume_data(completed_results, episodes, run_dir) + logger.info( + "Resuming from %s — %d completed, starting from index %d", + run_dir, len(completed_results), start_index, + ) + else: + run_dir = self.output_dir / self.task_name / self.run_id + self.run_dir = run_dir + completed_results = [] + start_index = 0 + + # --- Create output directory and save config --- + logger.trace("Creating output directory and saving config") + episodes_dir = run_dir / "episodes" + episodes_dir.mkdir(parents=True, exist_ok=True) + + config = { + "run_id": self.run_id, + "total_episodes": len(episodes), + "num_parallel": self.num_parallel, + "cli_options": self._serialize_cli_options(), + "resolved_backend": backend, + "resolved_base_url": base_url, + "resolved_generation_kwargs": resolved_gen_kwargs, + "task_config": task._config, + } + (run_dir / "config.json").write_text(json.dumps(config, indent=2)) + logger.trace("Run config:\n%s", json.dumps(config, indent=2, default=str)) + + # Check if all episodes already complete (resume edge case) + if start_index >= len(episodes): + logger.info("All %d episodes already complete, re-aggregating summary.", len(episodes)) + all_results = completed_results + # Skip to aggregation + wall_seconds = 0.0 + results_list = [(i, r) for i, r in enumerate(all_results)] + else: + # --- Fill episode queue (from start_index) --- + episode_queue: queue.Queue[tuple[int, dict]] = queue.Queue() + for i, episode in enumerate(episodes): + if i >= start_index: + episode_queue.put((i, episode)) + remaining = episode_queue.qsize() + logger.trace("Queued %d episodes (skipped %d completed)", remaining, start_index) + + # --- Prepare thread-safe collection --- + results_lock = threading.Lock() + new_results: list[tuple[int, dict]] = [] + progress = {"completed": 0, "failed": 0} + progress_lock = threading.Lock() + total_episodes = len(episodes) + + num_workers = min(self.num_parallel, remaining) + + # --- Progress bar --- + from easi.utils.progress import ProgressBar + + progress_bar = ProgressBar( + total=total_episodes, + num_workers=num_workers, + start_index=start_index, + ) + def _worker(worker_id: int) -> None: + """Worker thread: owns a simulator + agent, pulls episodes from queue.""" + logger.trace("[Worker %d] Starting up", worker_id) + episodes_done = 0 + sim = None + + try: + # Create simulator + logger.trace( + "[Worker %d] Creating simulator (key=%s)", + worker_id, task.simulator_key, + ) + sim, sim_runner = self._create_simulator( + task.simulator_key, task=task, label=f"bridge-{worker_id}", + worker_id=worker_id, + ) + logger.trace( + "[Worker %d] Simulator ready (PID=%s)", + worker_id, + getattr(sim_runner, 'pid', 'unknown'), + ) + + # Create agent + logger.trace("[Worker %d] Creating agent", worker_id) + # Round-robin URL assignment + worker_url = base_urls[worker_id % len(base_urls)] + agent = self._create_agent( + task.action_space, task._config, + backend=backend, base_url=worker_url, + ) + logger.trace("[Worker %d] Agent ready", worker_id) + while True: + # Pull next episode + try: + idx, episode = episode_queue.get_nowait() + except queue.Empty: + break + + logger.trace( + "[Worker %d] Queue remaining: ~%d", + worker_id, episode_queue.qsize(), + ) + + episode_id = episode.get("episode_id", f"ep_{idx}") + episode_dir = episodes_dir / f"{idx:03d}_{_sanitize_dirname(episode_id)}" + episode_dir.mkdir(exist_ok=True) + + result = None + for attempt in range(1, self.max_retries + 1): + logger.trace( + "[Worker %d] Running episode %s (attempt %d/%d)", + worker_id, episode_id, attempt, self.max_retries, + ) + try: + result = self._run_episode( + sim, agent, task, episode, idx, episode_dir, + ) + logger.trace( + "[Worker %d] Episode %s completed in %.1fs: %s", + worker_id, episode_id, + result.get("elapsed_seconds", 0), + {k: v for k, v in result.items() + if k in ("success", "num_steps", "elapsed_seconds")}, + ) + break + except Exception as exc: + logger.warning( + "[Worker %d] Episode %s attempt %d/%d failed: %s", + worker_id, episode_id, attempt, self.max_retries, exc, + ) + logger.trace( + "[Worker %d] Exception details:", + worker_id, exc_info=True, + ) + self._clear_episode_dir(episode_dir) + if attempt < self.max_retries: + logger.info( + "[Worker %d] Re-launching simulator for retry...", + worker_id, + ) + try: + sim.close() + except Exception: + pass + try: + sim, sim_runner = self._create_simulator( + task.simulator_key, task=task, + label=f"bridge-{worker_id}", + worker_id=worker_id, + ) + except Exception as restart_exc: + logger.error( + "[Worker %d] Simulator restart failed: %s", + worker_id, restart_exc, + ) + result = { + "episode_id": episode_id, + "instruction": task.get_instruction(episode), + "success": 0.0, + "num_steps": 0, + "elapsed_seconds": 0.0, + "error": f"simulator restart failed: {restart_exc}", + } + sim = None + break + else: + logger.error( + "[Worker %d] Episode %s failed after %d attempts, skipping", + worker_id, episode_id, self.max_retries, + ) + result = { + "episode_id": episode_id, + "instruction": task.get_instruction(episode), + "success": 0.0, + "num_steps": 0, + "elapsed_seconds": 0.0, + "error": str(exc), + } + + # Save per-episode result (strip internal keys) + result_to_save = { + k: v for k, v in result.items() + if not k.startswith("_") + } + (episode_dir / "result.json").write_text( + json.dumps(result_to_save, indent=2) + ) + + # Thread-safe results collection + failed = "error" in result + with results_lock: + new_results.append((idx, result)) + + with progress_lock: + progress["completed"] += 1 + if failed: + progress["failed"] += 1 + current_completed = progress["completed"] + start_index + current_failed = progress["failed"] + + progress_bar.update( + completed=current_completed, + failed=current_failed, + active_workers=num_workers, + ) + + episodes_done += 1 + + # If simulator restart failed, stop this worker + if sim is None: + logger.warning( + "[Worker %d] No simulator, stopping worker", worker_id, + ) + break + + finally: + logger.trace("[Worker %d] Shutting down simulator", worker_id) + if sim is not None: + try: + sim.close() + except Exception: + pass + logger.trace( + "[Worker %d] Shutdown complete (%d episodes done)", + worker_id, episodes_done, + ) + + # --- Launch worker threads --- + logger.trace("Launching %d worker threads", num_workers) + wall_start = time.monotonic() + + with progress_bar, ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for wid in range(num_workers): + futures.append(executor.submit(_worker, wid)) + logger.trace("All %d worker threads submitted", num_workers) + + # Wait for all workers to complete and propagate exceptions + for future in futures: + future.result() + + wall_seconds = round(time.monotonic() - wall_start, 2) + + # Merge completed results from resume with new results + new_results.sort(key=lambda x: x[0]) + results_list = [(i, r) for i, r in enumerate(completed_results)] + results_list.extend(new_results) + + # --- Sort results and aggregate --- + results_list.sort(key=lambda x: x[0]) + all_results = [r for _, r in results_list] + + num_successful = sum(1 for r in all_results if "error" not in r) + num_failed = len(all_results) - num_successful + logger.trace( + "Results sorted. %d successful, %d failed", + num_successful, num_failed, + ) + + # Build EpisodeRecords for aggregate_results + effective = sum(1 for r in all_results if "error" not in r) + records = [] + for r in all_results: + trajectory = r.pop("_trajectory", []) + episode = r.pop("_episode", {}) + episode_results = dict(r) + records.append(EpisodeRecord( + episode=episode, + trajectory=trajectory, + episode_results=episode_results, + )) + + # Aggregate and save summary + from easi.evaluation.metrics import generic_aggregate + + try: + metric_results = task.aggregate_results(records) + except Exception as exc: + logger.error("aggregate_results() failed: %s", exc, exc_info=True) + metric_results = {"aggregation_error": str(exc)} + + generic = generic_aggregate(records) + summary = { + **generic, + "effective_episodes": effective, + "metrics": metric_results, + } + summary["num_parallel"] = self.num_parallel + summary["wall_clock_seconds"] = wall_seconds + if backend and backend != "legacy": + summary["llm_usage"] = self._aggregate_llm_usage(all_results) + summary["model"] = self.model + summary["backend"] = backend + (run_dir / "summary.json").write_text(json.dumps(summary, indent=2)) + logger.info("Results saved to: %s", run_dir) + + return all_results + finally: + if server_mgr is not None: + server_mgr.stop() + if self._render_platform is not None: + self._render_platform.teardown() diff --git a/easi/evaluation/runner.py b/easi/evaluation/runner.py new file mode 100644 index 0000000..47fc98e --- /dev/null +++ b/easi/evaluation/runner.py @@ -0,0 +1,954 @@ +"""Sequential evaluation runner. + +Ties together Task + Simulator + Agent into an episode loop: +1. Load task -> get episodes, simulator key, action space +2. Start simulator subprocess +3. For each episode: + a. Reset simulator with format_reset_config(episode) + b. Loop: agent.act(observation) -> simulator.step(action) until done or max_steps + c. Evaluate: task.evaluate_episode(episode, trajectory) + d. Save per-episode metrics + trajectory.jsonl + images +4. Aggregate metrics into summary.json + +Output directory structure: + /// + config.json + summary.json + episodes/ + 000_/ + result.json + trajectory.jsonl + step_0000.png, step_0001.png, ... +""" + +from __future__ import annotations + +import inspect +import json +import re +import time +from datetime import datetime +from pathlib import Path + +from easi.core.episode import EpisodeRecord, StepResult +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +def _sanitize_dirname(name: str) -> str: + """Replace characters unsafe for directory names.""" + return re.sub(r"[^\w\-.]", "_", name) + + +class EvaluationRunner: + """Sequential evaluation runner.""" + + # Session-specific params excluded from config.json + _EXCLUDE_FROM_CONFIG = frozenset({"resume_dir", "refresh_data"}) + + def __init__( + self, + task_name: str, + agent_type: str = "react", + output_dir: Path | str = "./logs", + data_dir: Path | str | None = None, + episodes: str | None = None, + llm_base_url: str | None = None, + agent_seed: int | None = None, + backend: str | None = None, + model: str = "default", + port: int = 8080, + llm_kwargs_raw: str | None = None, + max_retries: int = 3, + resume_dir: Path | str | None = None, + refresh_data: bool = False, + render_platform: str | None = None, + llm_instances: int | None = None, + llm_gpus: list[int] | None = None, + sim_gpus: list[int] | None = None, + ): + # Auto-capture all init args for config.json (before any mutation) + frame = inspect.currentframe() + self._cli_options = { + k: v + for k, v in inspect.getargvalues(frame).locals.items() + if k not in ("self", "frame") and k not in self._EXCLUDE_FROM_CONFIG + } + + self.task_name = task_name + self.agent_type = agent_type + self.output_dir = Path(output_dir) + if data_dir is not None: + self.data_dir = Path(data_dir) + else: + from easi.utils.paths import get_datasets_dir + self.data_dir = get_datasets_dir() + # Update cli_options with resolved value so config.json is accurate + self._cli_options["data_dir"] = "./datasets" if data_dir is None else str(self.data_dir) + self.episodes_filter = episodes + self.llm_base_url = llm_base_url + self.agent_seed = agent_seed + self.backend = backend + self.model = model + self.port = port + self.llm_kwargs_raw = llm_kwargs_raw + self.max_retries = max_retries + self.resume_dir = Path(resume_dir) if resume_dir else None + self.refresh_data = refresh_data + self.render_platform_name = render_platform + self.llm_instances = llm_instances + self.llm_gpus = llm_gpus + self.sim_gpus = sim_gpus + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + if self.model: + safe_model = self.model.replace("/", "_") + # For custom backend, append model_path to distinguish variants + if self.backend == "custom" and self.llm_kwargs_raw: + from easi.llm.utils import parse_llm_kwargs + + model_path = parse_llm_kwargs(self.llm_kwargs_raw).get("model_path", "") + if model_path: + # Use last 2 path components (e.g. Qwen_Qwen3-VL-8B-Instruct) + path_suffix = "_".join(model_path.rstrip("/").split("/")[-2:]) + path_suffix = path_suffix.replace("/", "_") + safe_model = f"{safe_model}_{path_suffix}" + self.run_id = f"{timestamp}_{safe_model}" + else: + self.run_id = timestamp + + def _resolve_llm_backend(self) -> tuple[str | None, str | None]: + """Resolve which LLM backend to use. + + Returns (backend, base_url): + - (None, None) for dummy agent + - ("legacy", url) for --llm-url without --backend + - (backend_name, url_or_none) for --backend + """ + if self.agent_type == "dummy": + return None, None + + if self.backend: + return self.backend, self.llm_base_url + + if self.llm_base_url: + return "legacy", self.llm_base_url + + raise ValueError( + f"Agent '{self.agent_type}' requires --backend or --llm-url. " + f"Use --backend vllm|openai|anthropic|gemini or --llm-url ." + ) + + def _serialize_cli_options(self) -> dict: + """Serialize _cli_options for JSON output (convert Paths to strings).""" + return { + k: str(v) if isinstance(v, Path) else v + for k, v in self._cli_options.items() + } + + def _setup_render_platform(self, backend: str | None = None): + """Resolve, setup, and return the global render platform (if any). + + Calls ``platform.setup(gpu_ids=...)`` so lifecycle platforms (xorg) + can start external services. For built-in platforms, ``setup()`` is a + no-op. Returns ``None`` when no ``--render-platform`` was specified + (the platform is then resolved per-simulator in ``_create_simulator``). + """ + from easi.core.render_platforms import get_render_platform + + if not self.render_platform_name: + return None + + platform = get_render_platform(self.render_platform_name) + + # Warn about GPU contention before setup + if ( + platform.name == "xorg" + and not self.sim_gpus + and backend in ("vllm", "custom") + and not self.llm_gpus + ): + logger.warning( + "Xorg and LLM server will both use GPU 0. " + "Use --llm-gpus and --sim-gpus to separate them." + ) + + logger.info("Render platform: %s", platform.log_name) + platform.setup(gpu_ids=self.sim_gpus) + return platform + + def run(self) -> list[dict]: + """Run evaluation and return per-episode metric dicts.""" + if self.resume_dir: + run_dir = self.resume_dir + else: + run_dir = self.output_dir / self.task_name / self.run_id + self.run_dir = run_dir + + episodes_dir = run_dir / "episodes" + episodes_dir.mkdir(parents=True, exist_ok=True) + + # 1. Load task (before resume so we know total_episodes) + task = self._create_task() + if self.refresh_data: + task.download_dataset(force=True) + episodes = task.load_episodes() + if self.episodes_filter is not None: + from easi.evaluation.episode_filter import filter_episodes + episodes = filter_episodes(episodes, self.episodes_filter) + + # Handle resume: load completed results and find start point + if self.resume_dir: + all_results, start_index = self._load_completed_results( + run_dir, len(episodes) + ) + self._reattach_resume_data(all_results, episodes, run_dir) + logger.info( + "Resuming from %s — %d completed episodes, starting from index %d", + run_dir, + len(all_results), + start_index, + ) + else: + all_results = [] + start_index = 0 + + # 2. Resolve LLM backend and optionally start server + backend, base_url = self._resolve_llm_backend() + server = None + self._render_platform = None + + try: + if backend in ("vllm", "custom") and base_url is None: + from easi.llm.server_manager import ServerManager + from easi.llm.utils import parse_llm_kwargs, split_kwargs + + all_kwargs = parse_llm_kwargs(self.llm_kwargs_raw) + server_kwargs, _ = split_kwargs(all_kwargs) + startup_timeout = float(server_kwargs.pop("startup_timeout", 300.0)) + server = ServerManager( + backend, + self.model, + port=self.port, + server_kwargs=server_kwargs, + startup_timeout=startup_timeout, + ) + base_url = server.start() + + # Resolve and setup render platform (starts external services if needed) + self._render_platform = self._setup_render_platform(backend) + + # Compute resolved generation kwargs (YAML defaults + CLI overrides) + from easi.llm.utils import parse_llm_kwargs, split_kwargs + + agent_config = task._config.get("agent", {}) + yaml_gen_kwargs = agent_config.get("generation_kwargs", {}) + all_llm_kwargs = parse_llm_kwargs(self.llm_kwargs_raw) + _, cli_gen_kwargs = split_kwargs(all_llm_kwargs) + resolved_gen_kwargs = {**yaml_gen_kwargs, **cli_gen_kwargs} + + # Save run config + config = { + "run_id": self.run_id, + "total_episodes": len(episodes), + "cli_options": self._serialize_cli_options(), + "resolved_backend": backend, + "resolved_base_url": base_url, + "resolved_generation_kwargs": resolved_gen_kwargs, + "task_config": task._config, + } + (run_dir / "config.json").write_text(json.dumps(config, indent=2)) + logger.trace("Run config:\n%s", json.dumps(config, indent=2, default=str)) + + # Skip simulator/agent if all episodes already complete (resume) + if start_index >= len(episodes): + logger.info( + "All %d episodes already complete, re-aggregating summary.", + len(episodes), + ) + else: + # 3. Create agent + agent = self._create_agent( + task.action_space, task._config, backend=backend, base_url=base_url + ) + + # 4. Start simulator + sim, sim_runner = self._create_simulator(task.simulator_key, task=task) + + # 5. Progress bar + from easi.utils.progress import ProgressBar + + progress_bar = ProgressBar( + total=len(episodes), + num_workers=1, + start_index=start_index, + ) + progress_bar.start() + + try: + for i, episode in enumerate(episodes): + if i < start_index: + continue + episode_id = episode.get("episode_id", f"ep_{i}") + logger.info( + "Episode %d/%d: %s", + i + 1, + len(episodes), + episode_id, + ) + + episode_dir = ( + episodes_dir / f"{i:03d}_{_sanitize_dirname(episode_id)}" + ) + episode_dir.mkdir(exist_ok=True) + + result = None + for attempt in range(1, self.max_retries + 1): + try: + result = self._run_episode( + sim, + agent, + task, + episode, + i, + episode_dir, + ) + break + except Exception as exc: + logger.warning( + "Episode %s attempt %d/%d failed: %s", + episode_id, + attempt, + self.max_retries, + exc, + ) + self._clear_episode_dir(episode_dir) + if attempt < self.max_retries: + logger.info("Re-launching simulator for retry...") + try: + sim.close() + except Exception: + pass + try: + sim, sim_runner = self._create_simulator( + task.simulator_key, + task=task, + ) + except Exception as restart_exc: + logger.error( + "Simulator restart failed: %s", + restart_exc, + ) + result = { + "episode_id": episode_id, + "instruction": task.get_instruction( + episode + ), + "success": 0.0, + "num_steps": 0, + "elapsed_seconds": 0.0, + "error": f"simulator restart failed: {restart_exc}", + } + sim = None + break + else: + logger.error( + "Episode %s failed after %d attempts, skipping", + episode_id, + self.max_retries, + ) + result = { + "episode_id": episode_id, + "instruction": task.get_instruction(episode), + "success": 0.0, + "num_steps": 0, + "elapsed_seconds": 0.0, + "error": str(exc), + } + + all_results.append(result) + + # Update progress bar + failed_count = sum(1 for r in all_results if "error" in r) + progress_bar.update( + completed=len(all_results) + start_index, + failed=failed_count, + ) + + # Save per-episode result (strip internal keys) + result_to_save = { + k: v for k, v in result.items() if not k.startswith("_") + } + (episode_dir / "result.json").write_text( + json.dumps(result_to_save, indent=2) + ) + + # If simulator restart failed, stop evaluation + if sim is None: + logger.error( + "No simulator available, stopping evaluation early." + ) + break + + finally: + progress_bar.stop() + if sim is not None: + sim.close() + finally: + if server: + server.stop() + if self._render_platform is not None: + self._render_platform.teardown() + + # 5. Build EpisodeRecords for aggregate_results + effective = sum(1 for r in all_results if "error" not in r) + records = [] + for r in all_results: + trajectory = r.pop("_trajectory", []) + episode = r.pop("_episode", {}) + episode_results = dict(r) + records.append( + EpisodeRecord( + episode=episode, + trajectory=trajectory, + episode_results=episode_results, + ) + ) + + # 6. Aggregate and save summary + from easi.evaluation.metrics import generic_aggregate + + try: + metric_results = task.aggregate_results(records) + except Exception as exc: + logger.error("aggregate_results() failed: %s", exc, exc_info=True) + metric_results = {"aggregation_error": str(exc)} + + generic = generic_aggregate(records) + summary = { + **generic, + "effective_episodes": effective, + "metrics": metric_results, + } + if backend and backend != "legacy": + summary["llm_usage"] = self._aggregate_llm_usage(all_results) + summary["model"] = self.model + summary["backend"] = backend + (run_dir / "summary.json").write_text(json.dumps(summary, indent=2)) + logger.info("Results saved to: %s", run_dir) + + return all_results + + def _load_completed_results( + self, run_dir: Path, total_episodes: int + ) -> tuple[list[dict], int]: + """Scan episode dirs to find the first incomplete episode. + + Walks episode directories in ascending order (by index prefix). + An episode is "complete" if its directory has a valid result.json. + Returns results for all consecutive complete episodes from the start, + and clears all directories from the first incomplete episode onward. + + Args: + run_dir: The run directory containing episodes/. + total_episodes: Total number of episodes in the evaluation. + + Returns: + (completed_results, start_index) tuple. + """ + import shutil + + episodes_dir = run_dir / "episodes" + if not episodes_dir.exists(): + return [], 0 + + # Collect all episode dirs, sorted by name (which starts with {i:03d}_) + episode_dirs = sorted( + [d for d in episodes_dir.iterdir() if d.is_dir()], + key=lambda d: d.name, + ) + if not episode_dirs: + return [], 0 + + # Walk in order, loading results until we hit an incomplete episode + completed_results = [] + start_index = 0 + + for ep_dir in episode_dirs: + result_file = ep_dir / "result.json" + if result_file.exists(): + try: + result = json.loads(result_file.read_text()) + completed_results.append(result) + start_index += 1 + continue + except (json.JSONDecodeError, OSError): + logger.warning( + "Corrupt result.json in %s, treating as incomplete", ep_dir + ) + # First incomplete episode found — stop here + break + + # Clear all episode dirs from start_index onward + dirs_to_clear = episode_dirs[start_index:] + if dirs_to_clear: + logger.info( + "Resume: clearing %d episode dirs from index %d onward", + len(dirs_to_clear), + start_index, + ) + for d in dirs_to_clear: + shutil.rmtree(d) + + return completed_results, start_index + + @staticmethod + def _reattach_resume_data( + completed_results: list[dict], + episodes: list[dict], + run_dir: Path, + ) -> None: + """Re-attach trajectory and episode data to resumed results. + + On resume, result.json lacks ``_trajectory`` and ``_episode`` (they are + stripped on save). This method reads ``trajectory.jsonl`` from each + episode dir and pairs results with the original episode dicts so that + ``aggregate_results()`` has access to the full data. + """ + episodes_dir = run_dir / "episodes" + episode_dirs = sorted( + [d for d in episodes_dir.iterdir() if d.is_dir()], + key=lambda d: d.name, + ) + for idx, result in enumerate(completed_results): + # Attach episode from the loaded episode list + if idx < len(episodes): + result["_episode"] = episodes[idx] + + # Read trajectory from trajectory.jsonl + if idx < len(episode_dirs): + traj_file = episode_dirs[idx] / "trajectory.jsonl" + if traj_file.exists(): + try: + lines = traj_file.read_text().strip().splitlines() + result["_trajectory"] = [json.loads(l) for l in lines] + except (json.JSONDecodeError, OSError): + result["_trajectory"] = [] + + def _run_episode( + self, + sim, + agent, + task, + episode: dict, + index: int, + episode_dir: Path, + ) -> dict: + """Run a single episode and return metrics.""" + agent.reset() + + episode_id = episode.get("episode_id", f"ep_{index}") + + # Reset simulator (bridge saves images to episode_dir) + reset_config = task.format_reset_config(episode) + observation = sim.reset( + episode_id, + reset_config, + episode_output_dir=str(episode_dir), + ) + + # Task-specific post-reset setup (e.g., per-episode action space) + task.on_episode_reset(observation, agent) + + # Write reset entry to trajectory + trajectory_path = episode_dir / "trajectory.jsonl" + self._write_trajectory_entry( + trajectory_path, + { + "step": 0, + "type": "reset", + "rgb_path": Path(observation.rgb_path).name, + "agent_pose": observation.agent_pose, + "reward": 0.0, + "done": False, + "info": {}, + }, + ) + + # Agent-simulator loop + trajectory: list[StepResult] = [] + task_description = task.get_instruction(episode) + start_time = time.monotonic() + + for step in range(task.max_steps): + action = agent.act(observation, task_description) + + # Handle stop signal (e.g., empty plan from LLM) + if action.action_name == "<>": + logger.info("Agent signalled stop (empty plan), ending episode") + break + + step_result = sim.step(action) + trajectory.append(step_result) + + # Get LLM response + prompt text from agent memory. Both are + # ``None`` for buffered actions (the agent didn't re-query the + # LLM this step). + llm_response = None + prompt_text = None + if hasattr(agent, "memory") and agent.memory.steps: + last_step = agent.memory.steps[-1] + llm_response = last_step.llm_response + prompt_text = last_step.prompt_text + + # Write step entry to trajectory + triggered_fallback = getattr(agent, "triggered_fallback", False) + self._write_trajectory_entry( + trajectory_path, + { + "step": step + 1, + "type": "step", + "action": action.action_name, + "llm_response": llm_response, + "prompt": prompt_text, + "triggered_fallback": triggered_fallback, + "rgb_path": Path(step_result.observation.rgb_path).name, + "agent_pose": step_result.observation.agent_pose, + "reward": step_result.reward, + "done": step_result.done, + "info": step_result.info, + }, + ) + + # Feed action outcome back to agent for ReAct reasoning + last_success = step_result.info.get("last_action_success", 1.0) + feedback = step_result.info.get( + "feedback", + "success" if last_success else "failed", + ) + agent.add_feedback(action.action_name, feedback) + + observation = step_result.observation + + if step_result.done: + break + + elapsed = time.monotonic() - start_time + + # Evaluate + metrics = task.evaluate_episode(episode, trajectory) + metrics["episode_id"] = episode_id + metrics["instruction"] = task_description + metrics["elapsed_seconds"] = round(elapsed, 2) + metrics["forced_early_stop"] = getattr(agent, "forced_early_stop", False) + + # Attach trajectory and episode for aggregate_results() + metrics["_trajectory"] = trajectory + metrics["_episode"] = episode + + # Snapshot LLM usage for this episode + if hasattr(agent, "llm_client") and hasattr(agent.llm_client, "get_usage"): + metrics["llm_usage"] = agent.llm_client.get_usage() + agent.llm_client.reset_usage() + + return metrics + + @staticmethod + def _clear_episode_dir(episode_dir: Path) -> None: + """Remove all files in an episode directory for a clean retry.""" + for f in episode_dir.iterdir(): + if f.is_file(): + f.unlink() + + @staticmethod + def _write_trajectory_entry(path: Path, entry: dict) -> None: + """Append a single JSON line to the trajectory file.""" + with path.open("a") as f: + f.write(json.dumps(entry) + "\n") + + @staticmethod + def _aggregate_llm_usage(results: list[dict]) -> dict: + """Sum up llm_usage from per-episode results.""" + total = { + "total_calls": 0, + "total_prompt_tokens": 0, + "total_completion_tokens": 0, + "total_tokens": 0, + "total_cost_usd": 0.0, + } + for r in results: + usage = r.get("llm_usage", {}) + total["total_calls"] += usage.get("num_calls", 0) + total["total_prompt_tokens"] += usage.get("prompt_tokens", 0) + total["total_completion_tokens"] += usage.get("completion_tokens", 0) + total["total_cost_usd"] += usage.get("cost_usd", 0.0) + total["total_tokens"] = ( + total["total_prompt_tokens"] + total["total_completion_tokens"] + ) + n = len(results) or 1 + total["avg_prompt_tokens_per_episode"] = round(total["total_prompt_tokens"] / n) + total["avg_cost_per_episode_usd"] = round(total["total_cost_usd"] / n, 6) + return total + + def _create_task(self): + from easi.tasks.registry import get_task_entry, load_task_class + + entry = get_task_entry(self.task_name) + TaskClass = load_task_class(self.task_name) + return TaskClass( + split_yaml_path=entry.config_path, + data_dir=self.data_dir, + ) + + def _create_agent( + self, + action_space: list[str], + task_config: dict, + backend: str | None = None, + base_url: str | None = None, + ): + from easi.utils.import_utils import import_class + + if self.agent_type == "dummy": + from easi.agents.dummy_agent import DummyAgent + + return DummyAgent(action_space=action_space, seed=self.agent_seed) + + elif self.agent_type == "react": + from easi.agents.react_agent import ReActAgent + + agent_config = task_config.get("agent", {}) + + # Create LLM client based on backend + if backend and backend != "legacy": + from easi.llm.client import LLMClient + from easi.llm.utils import ( + build_litellm_model, + parse_llm_kwargs, + split_kwargs, + validate_backend, + ) + + validate_backend(backend) + litellm_model = build_litellm_model(backend, self.model) + all_kwargs = parse_llm_kwargs(self.llm_kwargs_raw) + _, client_kwargs = split_kwargs(all_kwargs) + + # Merge YAML generation_kwargs with CLI kwargs (CLI overrides) + yaml_gen_kwargs = agent_config.get("generation_kwargs", {}) + merged_kwargs = {**yaml_gen_kwargs, **client_kwargs} + + # Local backends need longer timeout (generation is slower than API) + if base_url and "timeout" not in merged_kwargs: + merged_kwargs["timeout"] = 600.0 + + llm = LLMClient( + model=litellm_model, + base_url=base_url, + num_retries=self.max_retries, + **merged_kwargs, + ) + else: + # Legacy path: existing LLMApiClient + from easi.llm.api_client import LLMApiClient + + llm = LLMApiClient(base_url=base_url or "http://127.0.0.1:8000") + + # Load task-specific prompt builder + prompt_builder = None + builder_class_name = agent_config.get("prompt_builder") + if builder_class_name: + BuilderClass = import_class(builder_class_name) + builder_kwargs = agent_config.get("prompt_builder_kwargs", {}) + prompt_builder = BuilderClass(**builder_kwargs) + + return ReActAgent( + llm_client=llm, + action_space=action_space, + prompt_builder=prompt_builder, + fallback_action=agent_config.get("fallback_action"), + fallback_strategy=agent_config.get("fallback_strategy", "default_action"), + max_fallback_retries=agent_config.get("max_fallback_retries", 1), + max_consecutive_fallbacks=agent_config.get("max_consecutive_fallbacks", 0), + ) + else: + raise ValueError(f"Unknown agent type: {self.agent_type}") + + def _create_simulator( + self, simulator_key: str, task=None, label: str = "bridge", worker_id: int = 0 + ): + import json as _json + + from easi.simulators.registry import ( + create_env_manager, + get_simulator_entry, + load_simulator_class, + ) + from easi.simulators.subprocess_runner import SubprocessRunner + + entry = get_simulator_entry(simulator_key) + env_manager = create_env_manager(simulator_key) + SimClass = load_simulator_class(simulator_key) + sim = SimClass() + + # Auto-install simulator env if not ready + if not env_manager.env_is_ready(): + logger.info("Simulator environment not ready, auto-installing...") + env_manager.install() + + # Task-specific bridge overrides simulator default + bridge_path = ( + task.get_bridge_script_path() if task else None + ) or sim._get_bridge_script_path() + + extra_args = ["--data-dir", str(self.data_dir)] + if task and task.simulator_kwargs: + extra_args.extend( + ["--simulator-kwargs", _json.dumps(task.simulator_kwargs)] + ) + + # Extract runner-level timeouts from simulator_configs + sim_configs = task.simulator_configs if task else {} + runner_kwargs = {} + if sim_configs.get("command_timeout"): + runner_kwargs["command_timeout"] = float(sim_configs["command_timeout"]) + if sim_configs.get("startup_timeout"): + runner_kwargs["startup_timeout"] = float(sim_configs["startup_timeout"]) + + # --- Docker runtime path --- + if entry.runtime == "docker": + from easi.core.docker_env_manager import DockerEnvironmentManager + from easi.core.render_platforms import get_render_platform + + assert isinstance(env_manager, DockerEnvironmentManager), ( + f"Simulator {simulator_key} declares runtime=docker but env_manager " + f"is not a DockerEnvironmentManager" + ) + + runner = SubprocessRunner( + python_executable=env_manager.container_python_path, + bridge_script_path=bridge_path, + render_platform=get_render_platform("headless"), + extra_args=extra_args, + label=label, + **runner_kwargs, + ) + data_dir_str = ( + str(self.data_dir) + if self.data_dir + else ( + entry.data_dir.replace("~", str(Path.home())) + if entry.data_dir + else None + ) + ) + runner.launch_docker( + docker_env_manager=env_manager, + data_dir=data_dir_str, + ) + sim.set_runner(runner) + return sim, runner + + # --- Conda runtime path (existing) --- + + # Install task-level additional deps + if task and task.additional_deps: + env_manager.install_additional_deps(task.additional_deps) + + # Resolve render platform: CLI > task YAML > env_manager default + from easi.simulators.registry import resolve_render_platform + + yaml_platform = sim_configs.get("render_platform") if task else None + platform_name = ( + self.render_platform_name + or yaml_platform + or env_manager.default_render_platform + ) + + if platform_name not in env_manager.supported_render_platforms: + raise ValueError( + f"Render platform '{platform_name}' is not supported by " + f"{env_manager.simulator_name}:{env_manager.version}. " + f"Supported: {env_manager.supported_render_platforms}" + ) + + # Use pre-setup global platform if available, else resolve per-simulator + if getattr(self, "_render_platform", None) is not None: + render_platform = self._render_platform + else: + render_platform = resolve_render_platform( + simulator_key, platform_name, env_manager=env_manager + ) + render_platform.setup(gpu_ids=self.sim_gpus) + # Register so teardown() is called in the finally block + self._render_platform = render_platform + logger.info("Render platform: %s", render_platform.log_name) + + from easi.core.render_platforms import EnvVars + + env_vars = env_manager.get_env_vars(render_platform_name=platform_name) + + if task and task.extra_env_vars: + env_vars = EnvVars.merge(env_vars, EnvVars(replace=task.extra_env_vars)) + + from easi.simulators.registry import ( + resolve_render_adapter as _resolve_render_adapter, + ) + + adapter = _resolve_render_adapter(simulator_key, env_manager=env_manager) + + base_render_platform = render_platform + binding = render_platform.for_worker(worker_id) + + adapter_env = adapter.get_env_vars(binding) if adapter else EnvVars() + binding_env = EnvVars.merge(binding.extra_env, adapter_env) + + if binding.display: + binding_env = EnvVars.merge( + binding_env, EnvVars(replace={"DISPLAY": binding.display}) + ) + if binding.cuda_visible_devices is not None: + binding_env = EnvVars.merge( + binding_env, + EnvVars(replace={"CUDA_VISIBLE_DEVICES": binding.cuda_visible_devices}), + ) + + if self.sim_gpus is not None and binding.cuda_visible_devices is None: + gpu_id = self.sim_gpus[worker_id % len(self.sim_gpus)] + # Inject assigned GPU ID for simulators that manage GPU + # selection natively (e.g. Habitat-Sim via gpu_device_id). + # These declare "device: gpu" in their YAML. + if task is not None: + task.inject_simulator_kwarg("_assigned_gpu_id", gpu_id) + # Only set CUDA_VISIBLE_DEVICES for simulators that don't + # handle GPU selection themselves (i.e. no "device" config). + uses_native_gpu = ( + task is not None + and task.simulator_configs.get("device") in ("gpu", "cpu") + ) + if not uses_native_gpu: + env_vars = EnvVars.merge( + env_vars, EnvVars(replace={"CUDA_VISIBLE_DEVICES": str(gpu_id)}) + ) + + env_vars = EnvVars.merge(env_vars, binding_env) + render_platform = base_render_platform + active_binding = binding + active_adapter = adapter + + runner = SubprocessRunner( + python_executable=env_manager.get_python_executable(), + bridge_script_path=bridge_path, + render_platform=render_platform, + screen_config=env_manager.screen_config, + extra_args=extra_args, + extra_env=env_vars if env_vars else None, + render_adapter=active_adapter, + worker_binding=active_binding, + label=label, + **runner_kwargs, + ) + runner.launch() + sim.set_runner(runner) + + return sim, runner diff --git a/easi/llm/__init__.py b/easi/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/llm/api_client.py b/easi/llm/api_client.py new file mode 100644 index 0000000..4aab026 --- /dev/null +++ b/easi/llm/api_client.py @@ -0,0 +1,80 @@ +"""HTTP client for OpenAI-compatible LLM inference servers. + +Works with vLLM, SGLang, Ollama, and the built-in dummy server — +any server that implements the /v1/chat/completions endpoint. +""" + +from __future__ import annotations + +from typing import Any + +import requests + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +class LLMApiClient: + """Stateless HTTP client for LLM inference endpoints.""" + + def __init__( + self, + base_url: str = "http://127.0.0.1:8000", + model: str = "default", + timeout: float = 120.0, + max_tokens: int = 512, + temperature: float = 0.0, + ): + self.base_url = base_url.rstrip("/") + self.model = model + self.timeout = timeout + self.max_tokens = max_tokens + self.temperature = temperature + + def generate( + self, + messages: list[dict], + ) -> str: + """Send a chat completion request and return the assistant's response text. + + Args: + messages: Chat history in OpenAI format. Images are embedded as + content parts with type "image_url" (Decision #10). + + Returns: + The assistant's response text. + """ + payload: dict[str, Any] = { + "model": self.model, + "messages": messages, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + } + + url = f"{self.base_url}/v1/chat/completions" + logger.trace("POST %s (messages: %d)", url, len(messages)) + + try: + response = requests.post(url, json=payload, timeout=self.timeout) + response.raise_for_status() + except requests.ConnectionError: + raise ConnectionError( + f"Cannot connect to LLM server at {self.base_url}. " + f"Start one with: easi llm-server" + ) + except requests.Timeout: + raise TimeoutError( + f"LLM server request timed out after {self.timeout}s" + ) + except requests.HTTPError as e: + raise RuntimeError(f"LLM server returned error: {e}") + + data = response.json() + choices = data.get("choices", []) + if not choices: + raise RuntimeError(f"LLM server returned no choices: {data}") + + content = choices[0].get("message", {}).get("content", "") + logger.trace("Response: %s", content[:100]) + return content diff --git a/easi/llm/client.py b/easi/llm/client.py new file mode 100644 index 0000000..4eda94d --- /dev/null +++ b/easi/llm/client.py @@ -0,0 +1,139 @@ +"""Unified LLM client wrapping LiteLLM. + +Provides text generation with optional response_format pass-through +for API-level JSON schema enforcement. + +Usage tracking is cumulative — call get_usage() to snapshot, reset_usage() between episodes. +""" +from __future__ import annotations + +from typing import Any + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +# Lazy imports to avoid requiring litellm when not needed. +# Parameters accepted by litellm.completion() / OpenAI chat completions API. +# Anything not in this set is silently dropped to avoid provider rejections. +_LITELLM_PARAMS = frozenset({ + "temperature", "max_tokens", "top_p", "n", "stop", "seed", + "frequency_penalty", "presence_penalty", "logit_bias", + "logprobs", "top_logprobs", + "response_format", "tools", "tool_choice", + "stream", "stream_options", + "user", "metadata", +}) + +# Parameters passed to vLLM via extra_body (not standard OpenAI params). +_VLLM_EXTRA_BODY_PARAMS = frozenset({ + "skip_special_tokens", +}) + +litellm = None + + +def _ensure_imports() -> None: + """Import litellm on first use.""" + global litellm + if litellm is None: + try: + import litellm as _litellm + except ImportError as e: + raise ImportError( + "LLMClient requires litellm. " + "Install with: pip install easi[llm]" + ) from e + litellm = _litellm + # Suppress litellm's verbose logging + litellm.suppress_debug_info = True + + +class LLMClient: + """Unified LLM client for all backends.""" + + def __init__( + self, + model: str, + base_url: str | None = None, + num_retries: int = 3, + timeout: float = 120.0, + **kwargs: Any, + ): + self.model = model + self.base_url = base_url + self.num_retries = num_retries + self.timeout = timeout + # Split kwargs into litellm params, vLLM extra_body, and unknown (dropped). + self.default_kwargs = {k: v for k, v in kwargs.items() if k in _LITELLM_PARAMS} + self._extra_body = {k: v for k, v in kwargs.items() if k in _VLLM_EXTRA_BODY_PARAMS} + dropped = {k: v for k, v in kwargs.items() + if k not in _LITELLM_PARAMS and k not in _VLLM_EXTRA_BODY_PARAMS} + if dropped: + logger.debug("Dropping unsupported generation kwargs: %s", dropped) + self._usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "num_calls": 0, + "cost_usd": 0.0, + } + + def generate(self, messages: list[dict], response_format: dict | None = None) -> str: + """Generate text completion. Drop-in for LLMApiClient.generate().""" + _ensure_imports() + + call_kwargs: dict[str, Any] = { + "model": self.model, + "messages": messages, + "num_retries": self.num_retries, + "timeout": self.timeout, + "drop_params": True, + **self.default_kwargs, + } + if self._extra_body: + call_kwargs["extra_body"] = self._extra_body + if self.base_url: + call_kwargs["api_base"] = self.base_url + # Local servers (vLLM, custom) don't need a real API key, + # but LiteLLM requires one for the openai/ prefix. + call_kwargs.setdefault("api_key", "dummy") + if response_format is not None: + call_kwargs["response_format"] = response_format + + logger.trace("LLM call: model=%s, messages=%d", self.model, len(messages)) + try: + response = litellm.completion(**call_kwargs) + except Exception as e: + logger.trace("LLM API error: %s: %s", type(e).__name__, e) + raise + self._track_usage(response) + + content = response.choices[0].message.content or "" + logger.trace("LLM response: %s", content[:200] if content else "") + return content + + def get_usage(self) -> dict: + """Return cumulative usage stats (copy).""" + return dict(self._usage) + + def reset_usage(self) -> None: + """Reset usage counters.""" + self._usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "num_calls": 0, + "cost_usd": 0.0, + } + + def _track_usage(self, response: Any) -> None: + """Accumulate token usage and cost from a LiteLLM response.""" + usage = getattr(response, "usage", None) + if usage: + self._usage["prompt_tokens"] += getattr(usage, "prompt_tokens", 0) + self._usage["completion_tokens"] += getattr(usage, "completion_tokens", 0) + self._usage["num_calls"] += 1 + try: + cost = litellm.completion_cost(completion_response=response) + self._usage["cost_usd"] += float(cost) + except Exception: + pass # Cost unavailable for local/unknown models diff --git a/easi/llm/dummy_server.py b/easi/llm/dummy_server.py new file mode 100644 index 0000000..67df20a --- /dev/null +++ b/easi/llm/dummy_server.py @@ -0,0 +1,155 @@ +"""Minimal dummy LLM server for testing. + +Implements OpenAI-compatible /v1/chat/completions endpoint using stdlib +http.server. Returns fixed or random actions. + +Usage: + python -m easi.llm.dummy_server --port 8000 --mode random + # or via CLI: + easi llm-server --port 8000 --mode random +""" + +from __future__ import annotations + +import argparse +import json +import random +from http.server import BaseHTTPRequestHandler, HTTPServer + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +DEFAULT_ACTION_SPACE = ["MoveAhead", "TurnLeft", "TurnRight", "Stop"] + + +class DummyLLMHandler(BaseHTTPRequestHandler): + """HTTP handler for the dummy LLM server.""" + + # Set by the server factory + mode: str = "random" + action_space: list[str] = DEFAULT_ACTION_SPACE + + def do_POST(self) -> None: + if self.path == "/v1/chat/completions": + self._handle_chat_completions() + else: + self.send_error(404, f"Not found: {self.path}") + + def do_GET(self) -> None: + if self.path == "/health": + self._send_json({"status": "ok"}) + else: + self.send_error(404, f"Not found: {self.path}") + + def _handle_chat_completions(self) -> None: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + + try: + request = json.loads(body) + except json.JSONDecodeError: + self.send_error(400, "Invalid JSON") + return + + # Generate response based on mode + if self.mode == "fixed": + action = self.action_space[0] if self.action_space else "MoveAhead" + else: # random + action = random.choice(self.action_space) + + response_text = f"I will take the following action.\nAction: {action}" + + response = { + "id": "dummy-001", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response_text, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + self._send_json(response) + + def _send_json(self, data: dict) -> None: + body = json.dumps(data).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format: str, *args) -> None: + logger.trace(format, *args) + + +def create_handler(mode: str, action_space: list[str]) -> type: + """Create a handler class with the given configuration.""" + + class ConfiguredHandler(DummyLLMHandler): + pass + + ConfiguredHandler.mode = mode + ConfiguredHandler.action_space = action_space + return ConfiguredHandler + + +def run_server( + host: str = "127.0.0.1", + port: int = 8000, + mode: str = "random", + action_space: list[str] | None = None, +) -> None: + """Start the dummy LLM server.""" + action_space = action_space or DEFAULT_ACTION_SPACE + handler_class = create_handler(mode, action_space) + + server = HTTPServer((host, port), handler_class) + logger.info("Dummy LLM server running on http://%s:%d", host, port) + logger.info("Mode: %s, Actions: %s", mode, action_space) + logger.info("Press Ctrl+C to stop") + + try: + server.serve_forever() + except KeyboardInterrupt: + logger.info("Shutting down...") + finally: + server.server_close() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Dummy LLM server") + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--mode", choices=["fixed", "random"], default="random") + parser.add_argument( + "--action-space", + type=str, + nargs="+", + default=DEFAULT_ACTION_SPACE, + ) + args = parser.parse_args() + + from easi.utils.logging import setup_logging + setup_logging("INFO") + run_server( + host=args.host, + port=args.port, + mode=args.mode, + action_space=args.action_space, + ) + + +if __name__ == "__main__": + main() diff --git a/easi/llm/models/__init__.py b/easi/llm/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/llm/models/base_model_server.py b/easi/llm/models/base_model_server.py new file mode 100644 index 0000000..8f927cf --- /dev/null +++ b/easi/llm/models/base_model_server.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +class BaseModelServer(ABC): + """Abstract base class for custom model servers. + + Subclasses must implement ``load`` and ``generate``. The ``unload`` + method is optional and defaults to a no-op. + """ + + @abstractmethod + def load(self, model_path: str, device: str, **kwargs) -> None: + """Load model weights, tokenizer, processors.""" + + @abstractmethod + def generate(self, messages: list[dict], **kwargs) -> str: + """Generate response from OpenAI-format messages.""" + + def unload(self) -> None: + """Release GPU memory. Optional override.""" + pass diff --git a/easi/llm/models/echo/__init__.py b/easi/llm/models/echo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/llm/models/echo/manifest.yaml b/easi/llm/models/echo/manifest.yaml new file mode 100644 index 0000000..3be5b52 --- /dev/null +++ b/easi/llm/models/echo/manifest.yaml @@ -0,0 +1,5 @@ +name: echo +display_name: "Echo Model" +description: "Testing model that echoes back user input. No GPU required." +model_class: "easi.llm.models.echo.model.EchoModel" +default_kwargs: {} diff --git a/easi/llm/models/echo/model.py b/easi/llm/models/echo/model.py new file mode 100644 index 0000000..812b6c4 --- /dev/null +++ b/easi/llm/models/echo/model.py @@ -0,0 +1,16 @@ +"""Echo model for testing the custom model server pipeline.""" +from __future__ import annotations + +from easi.llm.models.base_model_server import BaseModelServer +from easi.llm.models.helpers import extract_text_only + + +class EchoModel(BaseModelServer): + """Returns the user's message back. Useful for testing.""" + + def load(self, model_path: str, device: str, **kwargs) -> None: + pass + + def generate(self, messages: list[dict], **kwargs) -> str: + text = extract_text_only(messages) + return f"Echo: {text}" diff --git a/easi/llm/models/helpers.py b/easi/llm/models/helpers.py new file mode 100644 index 0000000..98bb1c2 --- /dev/null +++ b/easi/llm/models/helpers.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import base64 +import io +from typing import Any + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +def _extract_text_from_content(content: str | list[dict[str, Any]]) -> str: + """Extract concatenated text from an OpenAI-format content field. + + Handles both plain string content and the list-of-parts format. + Image parts are silently skipped. + """ + if isinstance(content, str): + return content + parts: list[str] = [] + for part in content: + if part.get("type") == "text": + parts.append(part.get("text", "")) + return "\n".join(parts) + + +def extract_images(messages: list[dict[str, Any]]) -> list[Any]: + """Extract PIL Images from base64-encoded image_url parts in messages. + + Returns a list of ``PIL.Image.Image`` objects. PIL is imported lazily + so the function can be defined even when Pillow is not installed. + """ + from PIL import Image # lazy import + + images: list[Any] = [] + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if part.get("type") != "image_url": + continue + image_url = part.get("image_url", {}) + url = image_url.get("url", "") + if url.startswith("data:") and ";base64," in url: + # Format: data:;base64, + _, encoded = url.split(",", 1) + raw = base64.b64decode(encoded) + images.append(Image.open(io.BytesIO(raw))) + elif url.startswith("data:"): + # Non-base64 data URI (e.g. data:text/plain,...) — skip + logger.debug("Skipping non-base64 data URI") + else: + # HTTP/HTTPS URLs — not yet supported for extraction + logger.debug("Skipping non-data image URL: %s", url[:80]) + return images + + +def extract_text_only(messages: list[dict[str, Any]]) -> str: + """Concatenate all text content from messages, ignoring roles and images.""" + parts: list[str] = [] + for msg in messages: + content = msg.get("content", "") + text = _extract_text_from_content(content) + if text: + parts.append(text) + return "\n".join(parts) + + +def extract_by_role(messages: list[dict[str, Any]]) -> dict[str, str]: + """Group text content by role. + + Returns a mapping from role name to the concatenated text for that role. + If a role appears multiple times its texts are joined with newlines. + """ + grouped: dict[str, list[str]] = {} + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + text = _extract_text_from_content(content) + if text: + grouped.setdefault(role, []).append(text) + return {role: "\n".join(texts) for role, texts in grouped.items()} diff --git a/easi/llm/models/http_server.py b/easi/llm/models/http_server.py new file mode 100644 index 0000000..aaff480 --- /dev/null +++ b/easi/llm/models/http_server.py @@ -0,0 +1,158 @@ +"""FastAPI HTTP server wrapping a BaseModelServer in OpenAI-compatible endpoints. + +Provides ``create_app`` to build a FastAPI application and ``main`` for +subprocess launch via ``python -m easi.llm.models.http_server``. +""" + +from __future__ import annotations + +import asyncio +import json +import time +import uuid +from functools import partial +from typing import Any + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +# Generation kwargs recognised from the request body. +_GENERATION_KWARGS = frozenset( + { + "temperature", + "max_tokens", + "top_p", + "top_k", + "frequency_penalty", + "presence_penalty", + "stop", + "seed", + } +) + + +def create_app(model: Any) -> Any: + """Create a FastAPI application that serves *model* over HTTP. + + Parameters + ---------- + model: + A loaded :class:`BaseModelServer` instance. + + Returns + ------- + FastAPI + The application, ready to be passed to ``uvicorn.run``. + """ + from fastapi import FastAPI + from fastapi.responses import JSONResponse + + app = FastAPI(title="EASI Model Server") + + @app.get("/health") + async def health() -> dict: + return {"status": "ok"} + + @app.post("/v1/chat/completions") + async def chat_completions(request: dict) -> JSONResponse: # type: ignore[arg-type] + messages = request.get("messages", []) + req_model = request.get("model", "custom") + + # Extract recognised generation kwargs. + gen_kwargs: dict[str, Any] = {} + for key in _GENERATION_KWARGS: + if key in request: + gen_kwargs[key] = request[key] + + try: + loop = asyncio.get_running_loop() + content = await loop.run_in_executor( + None, partial(model.generate, messages, **gen_kwargs) + ) + except Exception as e: + logger.error("Generation failed: %s", e, exc_info=True) + return JSONResponse( + status_code=500, + content={"error": {"message": str(e), "type": "server_error"}}, + ) + + response = { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": req_model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + return JSONResponse(content=response) + + return app + + +def main() -> None: + """Entry point for subprocess launch. + + Usage:: + + python -m easi.llm.models.http_server \ + --model-name my_model \ + --model-path /path/to/weights \ + --device cuda:0 \ + --port 8000 \ + --kwargs '{"key": "value"}' + """ + import argparse + + import uvicorn + + from easi.llm.models.registry import get_model_entry, load_model_class + + parser = argparse.ArgumentParser(description="EASI custom model HTTP server") + parser.add_argument("--model-name", required=True, help="Registered model name") + parser.add_argument("--model-path", required=True, help="Path to model weights") + parser.add_argument("--device", default="cuda:0", help="Device (default: cuda:0)") + parser.add_argument("--port", type=int, default=8000, help="Port (default: 8000)") + parser.add_argument( + "--kwargs", + default="{}", + help="Extra kwargs as JSON string (default: '{}')", + ) + args = parser.parse_args() + + extra_kwargs: dict[str, Any] = json.loads(args.kwargs) + + # Merge manifest default_kwargs with CLI overrides (CLI wins) + entry = get_model_entry(args.model_name) + merged_kwargs = {**entry.default_kwargs, **extra_kwargs} + + logger.info( + "Loading model '%s' from %s on %s (kwargs=%s)", + args.model_name, + args.model_path, + args.device, + merged_kwargs, + ) + + cls = load_model_class(args.model_name) + model_instance = cls() + model_instance.load(args.model_path, args.device, **merged_kwargs) + + app = create_app(model_instance) + + logger.info("Starting HTTP server on port %d", args.port) + uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/easi/llm/models/internvl3/__init__.py b/easi/llm/models/internvl3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/llm/models/internvl3/manifest.yaml b/easi/llm/models/internvl3/manifest.yaml new file mode 100644 index 0000000..29a03b5 --- /dev/null +++ b/easi/llm/models/internvl3/manifest.yaml @@ -0,0 +1,11 @@ +name: internvl3 +display_name: "InternVL3" +description: > + InternVL3 vision-language model family (1B, 2B, 8B, 78B, etc.). + Supports image understanding with OpenAI-format messages. + Requires: pip install transformers torch torchvision pillow + Recommended: pip install flash-attn (for flash_attention_2) +model_class: "easi.llm.models.internvl3.model.InternVL3Model" +default_kwargs: + dtype: "bfloat16" + attn_implementation: "flash_attention_2" diff --git a/easi/llm/models/internvl3/model.py b/easi/llm/models/internvl3/model.py new file mode 100644 index 0000000..d0f1e7f --- /dev/null +++ b/easi/llm/models/internvl3/model.py @@ -0,0 +1,380 @@ +"""InternVL3 custom model server for EASI. + +Loads InternVL3 models via transformers and serves them through the +EASI custom backend HTTP server. Uses the model's built-in ``.chat()`` +method which handles chat template formatting and image preprocessing +internally. + +Usage:: + + easi start --backend custom --model internvl3 \\ + --llm-kwargs '{"model_path": "OpenGVLab/InternVL3-8B"}' +""" +from __future__ import annotations + +from typing import Any + +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode + +from easi.llm.models.base_model_server import BaseModelServer +from easi.llm.models.helpers import extract_images +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +_ALLOWED_DTYPES = {"bfloat16", "float16", "float32", "auto"} + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +# --------------------------------------------------------------------------- +# Image preprocessing (InternVL3 dynamic resolution tiling) +# --------------------------------------------------------------------------- + +def _build_transform(input_size: int) -> T.Compose: + return T.Compose([ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ]) + + +def _find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best = (1, 1) + best_diff = float("inf") + area = width * height + for ratio in target_ratios: + target_ar = ratio[0] / ratio[1] + diff = abs(aspect_ratio - target_ar) + if diff < best_diff or (diff == best_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1]): + best_diff = diff + best = ratio + return best + + +def _dynamic_preprocess( + image: Image.Image, + min_num: int = 1, + max_num: int = 12, + image_size: int = 448, + use_thumbnail: bool = True, +) -> list[Image.Image]: + """Split image into tiles using InternVL3's dynamic resolution strategy.""" + width, height = image.size + aspect_ratio = width / height + + target_ratios = set() + for n in range(min_num, max_num + 1): + for i in range(1, n + 1): + for j in range(1, n + 1): + if i * j <= max_num and i * j >= min_num: + target_ratios.add((i, j)) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + best = _find_closest_aspect_ratio( + aspect_ratio, target_ratios, width, height, image_size + ) + + target_w = best[0] * image_size + target_h = best[1] * image_size + blocks = best[0] * best[1] + + resized = image.resize((target_w, target_h)) + processed = [] + for i in range(blocks): + box = ( + (i % best[0]) * image_size, + (i // best[0]) * image_size, + ((i % best[0]) + 1) * image_size, + ((i // best[0]) + 1) * image_size, + ) + processed.append(resized.crop(box)) + + if use_thumbnail and blocks > 1: + thumbnail = image.resize((image_size, image_size)) + processed.append(thumbnail) + + return processed + + +def _load_image(image: Image.Image, max_num: int = 12) -> Any: + """Preprocess a PIL image into a pixel_values tensor for InternVL3. + + Returns float32 tensors; the caller should cast to the model's dtype + via ``.to(dtype=model.dtype, device=model.device)``. + """ + import torch + + transform = _build_transform(448) + tiles = _dynamic_preprocess(image, image_size=448, max_num=max_num) + pixel_values = torch.stack([transform(tile) for tile in tiles]) + return pixel_values + + +# --------------------------------------------------------------------------- +# Message conversion +# --------------------------------------------------------------------------- + +def _openai_to_internvl_messages( + messages: list[dict], +) -> list[dict]: + """Convert OpenAI-format messages to InternVL3 format. + + OpenAI format uses ``image_url`` content parts with base64 data URIs. + InternVL3 expects ```` placeholder tokens in the text content, + with actual pixel tensors passed separately. + + Returns a new message list with image_url parts replaced by ``\\n`` + text tokens. The caller is responsible for extracting and preprocessing + the actual PIL images via ``extract_images()`` + ``_load_image()``. + """ + converted = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + converted.append({"role": role, "content": content}) + continue + + # Multimodal content list + text_parts: list[str] = [] + for part in content: + ptype = part.get("type", "") + if ptype == "image_url": + text_parts.append("\n") + elif ptype == "text": + text_parts.append(part.get("text", "")) + + converted.append({"role": role, "content": "".join(text_parts)}) + + return converted + + +# --------------------------------------------------------------------------- +# Model server +# --------------------------------------------------------------------------- + +class InternVL3Model(BaseModelServer): + """InternVL3 vision-language model server.""" + + def load(self, model_path: str, device: str, **kwargs: Any) -> None: + """Load InternVL3 model and tokenizer. + + Args: + model_path: HuggingFace model ID or local path. + device: Device string (e.g. ``"cuda:0"``). + **kwargs: ``torch_dtype``, ``attn_implementation``. + """ + import torch + from transformers import AutoModel, AutoTokenizer + + # Resolve dtype + dtype_str = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto") + if dtype_str not in _ALLOWED_DTYPES: + logger.warning("Unrecognised dtype '%s', falling back to 'auto'", dtype_str) + dtype_str = "auto" + torch_dtype = getattr(torch, dtype_str, "auto") if dtype_str != "auto" else "auto" + + attn_impl = kwargs.pop("attn_implementation", None) + + # Device mapping + try: + import accelerate # noqa: F401 + load_kwargs: dict[str, Any] = {"torch_dtype": torch_dtype, "device_map": "auto"} + except ImportError: + logger.info("accelerate not installed, loading on %s", device) + load_kwargs = {"torch_dtype": torch_dtype} + + # Attention implementation with fallback + if attn_impl: + if attn_impl == "flash_attention_2": + try: + import flash_attn # noqa: F401 + load_kwargs["attn_implementation"] = attn_impl + except ImportError: + logger.warning("flash_attn not installed, falling back to sdpa") + load_kwargs["attn_implementation"] = "sdpa" + else: + load_kwargs["attn_implementation"] = attn_impl + + load_kwargs["trust_remote_code"] = True + + logger.info( + "Loading InternVL3 from %s (dtype=%s, attn=%s)", + model_path, dtype_str, attn_impl, + ) + self.model = AutoModel.from_pretrained(model_path, **load_kwargs).eval() + + if "device_map" not in load_kwargs: + self.model = self.model.to(device) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + self.device = next(self.model.parameters()).device + logger.info("InternVL3 loaded on %s", self.device) + + def generate(self, messages: list[dict], **kwargs: Any) -> str: + """Generate response from OpenAI-format messages.""" + import torch + + # Extract and preprocess images, cast to model dtype + pil_images = extract_images(messages) + pixel_values = None + if pil_images: + model_dtype = next(self.model.parameters()).dtype + tensors = [ + _load_image(img).to(dtype=model_dtype, device=self.device) + for img in pil_images + ] + pixel_values = torch.cat(tensors, dim=0) + + # Convert messages + internvl_messages = _openai_to_internvl_messages(messages) + + # Extract system message (prepend to first user question if present) + system_prefix = "" + start = 0 + if internvl_messages and internvl_messages[0]["role"] == "system": + system_prefix = internvl_messages[0]["content"] + "\n" + start = 1 + + # Build question from last user message (model.chat expects this) + question = internvl_messages[-1]["content"] if internvl_messages else "" + + # Build history from prior messages (pairs of user/assistant) + history: list[tuple[str, str]] = [] + i = start + while i < len(internvl_messages) - 1: + if ( + internvl_messages[i]["role"] == "user" + and i + 1 < len(internvl_messages) + and internvl_messages[i + 1]["role"] == "assistant" + ): + user_content = internvl_messages[i]["content"] + # Prepend system message to the first user turn + if i == start and system_prefix: + user_content = system_prefix + user_content + history.append(( + user_content, + internvl_messages[i + 1]["content"], + )) + i += 2 + else: + i += 1 + + # If no history, prepend system to the question directly + if not history and system_prefix: + question = system_prefix + question + + # Generation config + max_new_tokens = kwargs.get("max_tokens", 4096) + temperature = kwargs.get("temperature", 0.0) + top_p = kwargs.get("top_p", 0.95) + + generation_config = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + } + if temperature > 0: + generation_config["temperature"] = temperature + generation_config["top_p"] = top_p + + skip_special = kwargs.get("skip_special_tokens", True) + + with torch.no_grad(): + if skip_special: + # Default path: use model.chat() which hardcodes + # skip_special_tokens=True + response = self.model.chat( + self.tokenizer, + pixel_values, + question, + generation_config, + history=history, + ) + else: + # SFT path: replicate model.chat() logic but decode + # with skip_special_tokens=False to preserve action tokens + response = self._chat_keep_special( + pixel_values, question, generation_config, history, + ) + + return response + + def _chat_keep_special( + self, + pixel_values, + question: str, + generation_config: dict, + history: list[tuple[str, str]], + ) -> str: + """Like model.chat() but with skip_special_tokens=False.""" + import torch + from internvl.conversation import get_conv_template + + IMG_START_TOKEN = '' + IMG_END_TOKEN = '' + IMG_CONTEXT_TOKEN = '' + + template = get_conv_template(self.model.template) + template.system_message = self.model.system_message + eos_token_id = self.tokenizer.convert_tokens_to_ids(template.sep.strip()) + + for old_question, old_answer in history: + template.append_message(template.roles[0], old_question) + template.append_message(template.roles[1], old_answer) + template.append_message(template.roles[0], question) + template.append_message(template.roles[1], None) + query = template.get_prompt() + + if pixel_values is not None: + img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.model.img_context_token_id = img_context_token_id + num_patches = pixel_values.shape[0] + image_tokens = ( + IMG_START_TOKEN + + IMG_CONTEXT_TOKEN * self.model.num_image_token * num_patches + + IMG_END_TOKEN + ) + query = query.replace('', image_tokens, 1) + + model_inputs = self.tokenizer(query, return_tensors='pt') + input_ids = model_inputs['input_ids'].to(self.model.device) + attention_mask = model_inputs['attention_mask'].to(self.model.device) + generation_config['eos_token_id'] = eos_token_id + + generation_output = self.model.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config, + ) + response = self.tokenizer.batch_decode( + generation_output, skip_special_tokens=False, + )[0] + response = response.split(template.sep.strip())[0].strip() + return response + + def unload(self) -> None: + """Release GPU memory.""" + if hasattr(self, "model"): + del self.model + if hasattr(self, "tokenizer"): + del self.tokenizer + + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("InternVL3 model unloaded") diff --git a/easi/llm/models/qwen3_vl/__init__.py b/easi/llm/models/qwen3_vl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easi/llm/models/qwen3_vl/manifest.yaml b/easi/llm/models/qwen3_vl/manifest.yaml new file mode 100644 index 0000000..22da378 --- /dev/null +++ b/easi/llm/models/qwen3_vl/manifest.yaml @@ -0,0 +1,10 @@ +name: qwen3_vl +display_name: "Qwen3-VL" +description: > + Qwen3-VL vision-language model family (8B, 72B, etc.). + Supports image understanding with OpenAI-format messages. + Requires: uv pip install transformers torch torchvision pillow +model_class: "easi.llm.models.qwen3_vl.model.Qwen3VLModel" +default_kwargs: + dtype: "bfloat16" + attn_implementation: "flash_attention_2" diff --git a/easi/llm/models/qwen3_vl/model.py b/easi/llm/models/qwen3_vl/model.py new file mode 100644 index 0000000..2a44d51 --- /dev/null +++ b/easi/llm/models/qwen3_vl/model.py @@ -0,0 +1,184 @@ +"""Qwen3-VL custom model server for EASI. + +Loads Qwen3-VL models (8B, 72B, etc.) via HuggingFace Transformers and +serves them through the custom model server pipeline. + +Requires: + pip install transformers torch torchvision pillow + +Usage: + easi start --backend custom --model qwen3_vl \ + --llm-kwargs '{"model_path": "Qwen/Qwen3-VL-8B-Instruct"}' +""" +from __future__ import annotations + +from easi.llm.models.base_model_server import BaseModelServer +from easi.llm.models.helpers import extract_images +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +# Recognised dtype strings for from_pretrained +_ALLOWED_DTYPES = {"bfloat16", "float16", "float32", "auto"} + + +def _openai_to_qwen_messages(messages: list[dict], images: list) -> list[dict]: + """Convert OpenAI-format messages to Qwen3-VL format. + + OpenAI format: + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} + Qwen format: + {"type": "image", "image": } + """ + image_idx = 0 + converted = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + converted.append({"role": role, "content": content}) + continue + + new_content = [] + for part in content: + ptype = part.get("type", "") + if ptype == "image_url" and image_idx < len(images): + new_content.append({"type": "image", "image": images[image_idx]}) + image_idx += 1 + elif ptype == "text": + new_content.append({"type": "text", "text": part.get("text", "")}) + + converted.append({"role": role, "content": new_content}) + + return converted + + +class Qwen3VLModel(BaseModelServer): + """Qwen3-VL vision-language model server.""" + + def load(self, model_path: str, device: str, **kwargs) -> None: + """Load Qwen3-VL model and processor. + + Args: + model_path: HuggingFace model ID (e.g., "Qwen/Qwen3-VL-8B-Instruct") + or local path to model weights. + device: Device string (e.g., "cuda:0"). When using device_map="auto", + this is used as fallback. + **kwargs: Extra kwargs passed to from_pretrained. + Supported: torch_dtype, attn_implementation. + """ + import torch + from transformers import AutoProcessor, Qwen3VLForConditionalGeneration + + # Resolve torch dtype — newer transformers uses "dtype" instead of "torch_dtype" + dtype_str = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto") + if dtype_str not in _ALLOWED_DTYPES: + logger.warning("Unrecognised dtype '%s', falling back to 'auto'", dtype_str) + dtype_str = "auto" + torch_dtype = getattr(torch, dtype_str, "auto") if dtype_str != "auto" else "auto" + + attn_impl = kwargs.pop("attn_implementation", None) + + # Use device_map="auto" only if accelerate is available; otherwise + # fall back to loading on the specified device directly. + try: + import accelerate # noqa: F401 + load_kwargs = {"dtype": torch_dtype, "device_map": "auto"} + except ImportError: + logger.info("accelerate not installed, loading model on %s without device_map", device) + load_kwargs = {"dtype": torch_dtype} + + if attn_impl: + # Validate flash_attention_2 availability before requesting it + if attn_impl == "flash_attention_2": + try: + import flash_attn # noqa: F401 + load_kwargs["attn_implementation"] = attn_impl + except ImportError: + logger.warning( + "flash_attn not installed, falling back to sdpa attention. " + "Install with: pip install flash-attn --no-build-isolation" + ) + load_kwargs["attn_implementation"] = "sdpa" + else: + load_kwargs["attn_implementation"] = attn_impl + + logger.info("Loading Qwen3-VL from %s (dtype=%s, attn=%s)", model_path, dtype_str, attn_impl) + self.model = Qwen3VLForConditionalGeneration.from_pretrained( + model_path, **load_kwargs + ) + # Move to device if device_map was not used + if "device_map" not in load_kwargs: + self.model = self.model.to(device) + + self.processor = AutoProcessor.from_pretrained(model_path) + # device_map="auto" shards across GPUs; .device would raise RuntimeError. + # Use the device of the first parameter instead. + self.device = next(self.model.parameters()).device + logger.info("Qwen3-VL loaded on %s", self.device) + + def generate(self, messages: list[dict], **kwargs) -> str: + """Generate response from OpenAI-format messages. + + Converts OpenAI message format to Qwen3-VL format, processes + images via the Qwen processor, and runs generation. + """ + import torch + + # Extract images from OpenAI-format base64 entries + images = extract_images(messages) + + # Convert message format + qwen_messages = _openai_to_qwen_messages(messages, images) + + # Process with Qwen processor (handles tokenization + image processing) + inputs = self.processor.apply_chat_template( + qwen_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + inputs = {k: v.to(self.device) if hasattr(v, "to") else v for k, v in inputs.items()} + + # Generation kwargs + max_new_tokens = kwargs.get("max_tokens", 4096) + temperature = kwargs.get("temperature", 0.0) + top_p = kwargs.get("top_p", 0.95) + + gen_kwargs = {"max_new_tokens": max_new_tokens} + if temperature > 0: + gen_kwargs["temperature"] = temperature + gen_kwargs["top_p"] = top_p + gen_kwargs["do_sample"] = True + else: + gen_kwargs["do_sample"] = False + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, **gen_kwargs) + + # Trim input tokens from output + generated_ids_trimmed = [ + out_ids[len(in_ids):] + for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) + ] + output_text = self.processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=kwargs.get("skip_special_tokens", True), + clean_up_tokenization_spaces=False, + ) + + return output_text[0] if output_text else "" + + def unload(self) -> None: + """Release GPU memory.""" + if hasattr(self, "model"): + del self.model + if hasattr(self, "processor"): + del self.processor + + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Qwen3-VL model unloaded") diff --git a/easi/llm/models/registry.py b/easi/llm/models/registry.py new file mode 100644 index 0000000..71e2758 --- /dev/null +++ b/easi/llm/models/registry.py @@ -0,0 +1,114 @@ +"""Model registry with manifest-based auto-discovery. + +Scans easi/llm/models/*/manifest.yaml to discover available custom model +server configurations. Follows the same pattern as the simulator registry. + +Lookup semantics: +- list_models() → all registered model names +- get_model_entry("my_model") → ModelEntry dataclass +- load_model_class("my_model") → imported class +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +import yaml + +from easi.utils.import_utils import import_class as _import_class +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class ModelEntry: + """Registry entry for a custom model server.""" + + name: str + display_name: str + description: str + model_class: str # fully qualified class name + default_kwargs: dict = field(default_factory=dict) + + +# Module-level registry populated on first access +_registry: dict[str, ModelEntry] | None = None + + +def _get_models_dir() -> Path: + """Return the directory containing model subdirectories.""" + return Path(__file__).parent + + +def _discover_models() -> dict[str, ModelEntry]: + """Scan model directories for manifest.yaml files.""" + models_dir = _get_models_dir() + entries: dict[str, ModelEntry] = {} + + for manifest_path in sorted(models_dir.glob("*/manifest.yaml")): + try: + manifest = yaml.safe_load(manifest_path.read_text()) + except Exception as e: + logger.warning("Failed to load %s: %s", manifest_path, e) + continue + + try: + entry = ModelEntry( + name=manifest["name"], + display_name=manifest.get("display_name", manifest["name"]), + description=manifest.get("description", ""), + model_class=manifest["model_class"], + default_kwargs=manifest.get("default_kwargs", {}), + ) + entries[entry.name] = entry + logger.trace("Discovered model: %s (%s)", entry.name, entry.display_name) + except KeyError as e: + logger.warning( + "Invalid manifest %s: missing required field %s", manifest_path, e + ) + continue + + return entries + + +def _get_registry() -> dict[str, ModelEntry]: + """Get the model registry, discovering on first access.""" + global _registry + if _registry is None: + _registry = _discover_models() + return _registry + + +def list_models() -> list[str]: + """List all registered model names.""" + return sorted(_get_registry().keys()) + + +def get_model_entry(name: str) -> ModelEntry: + """Look up a model entry by name. + + Args: + name: The model name as defined in its manifest.yaml. + + Raises: + KeyError: If the model is not found. + """ + registry = _get_registry() + if name not in registry: + available = list_models() + raise KeyError(f"Model '{name}' not found. Available: {available}") + return registry[name] + + +def load_model_class(name: str): + """Import and return the model class for the given name.""" + entry = get_model_entry(name) + return _import_class(entry.model_class) + + +def refresh() -> None: + """Force re-discovery of models (useful after adding new ones at runtime).""" + global _registry + _registry = None diff --git a/easi/llm/server_manager.py b/easi/llm/server_manager.py new file mode 100644 index 0000000..c656342 --- /dev/null +++ b/easi/llm/server_manager.py @@ -0,0 +1,369 @@ +"""Manages lifecycle of local LLM inference servers (vLLM, etc.). + +Starts the server as a subprocess, waits for health check, and stops on exit. +""" +from __future__ import annotations + +import json +import os +import signal +import socket +import subprocess +import sys +import threading +import time +from pathlib import Path + +import requests + +from easi.utils.logging import get_logger + +logger = get_logger(__name__) + +_HEALTH_POLL_INTERVAL = 5.0 +_DEFAULT_STARTUP_TIMEOUT = 600.0 +_DEFAULT_VLLM_SERVER_FLAGS = { + "enable_prefix_caching": True, + "enable_log_requests": False, +} + + +def _port_is_available(port: int) -> bool: + """Check if a TCP port is available on localhost.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.bind(("127.0.0.1", port)) + return True + except OSError: + return False + finally: + sock.close() + + +class ServerManager: + """Manages a local inference server subprocess.""" + + def __init__( + self, + backend: str, + model: str, + port: int = 8080, + server_kwargs: dict | None = None, + startup_timeout: float = _DEFAULT_STARTUP_TIMEOUT, + cuda_visible_devices: str | None = None, + label: str = "server", + ): + self.backend = backend + self.model = model + self.port = port + self.server_kwargs = server_kwargs or {} + self.startup_timeout = startup_timeout + self.cuda_visible_devices = cuda_visible_devices + self.label = label + self._process: subprocess.Popen | None = None + self._log_thread: threading.Thread | None = None + logger.trace( + "[%s] ServerManager init: backend=%s, model=%s, port=%d, " + "server_kwargs=%s, cuda_visible_devices=%s", + label, backend, model, port, self.server_kwargs, cuda_visible_devices, + ) + + def start(self) -> str: + """Start the server, wait for health, return base_url.""" + self.launch() + return self.wait_until_ready() + + def launch(self) -> None: + """Spawn the server process without waiting for health.""" + self._check_port() + + cmd, extra_env = self._build_command() + logger.info("[%s] Starting %s server: %s", self.label, self.backend, " ".join(cmd)) + + spawn_env = os.environ.copy() + spawn_env.update(extra_env) + + self._process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=spawn_env, + preexec_fn=os.setsid, + ) + self._log_thread = threading.Thread( + target=self._stream_output, + args=(self._process, self.label), + daemon=True, + ) + self._log_thread.start() + + def wait_until_ready(self) -> str: + """Poll health endpoint until the server is ready. Returns base_url.""" + base_url = f"http://localhost:{self.port}/v1" + self._wait_for_health(base_url) + logger.info("[%s] Server ready at %s", self.label, base_url) + return base_url + + def stop(self) -> None: + """Terminate the server process and all its children. + + Uses process-group kill (SIGTERM → SIGKILL) to ensure child + processes (e.g., vLLM tensor-parallel workers) are cleaned up. + """ + if self._process is not None: + pid = self._process.pid + logger.info("[%s] Stopping %s server (pid=%d)", self.label, self.backend, pid) + try: + pgid = os.getpgid(pid) + os.killpg(pgid, signal.SIGTERM) + except (ProcessLookupError, PermissionError): + pass # already dead + try: + self._process.wait(timeout=30) + except subprocess.TimeoutExpired: + logger.warning("[%s] Server did not terminate, killing process group...", self.label) + try: + os.killpg(os.getpgid(pid), signal.SIGKILL) + except (ProcessLookupError, PermissionError): + self._process.kill() + self._process.wait(timeout=10) + self._process = None + if self._log_thread is not None: + self._log_thread.join(timeout=5) + self._log_thread = None + + def is_running(self) -> bool: + """Check if server process is alive.""" + if self._process is None: + return False + return self._process.poll() is None + + def _check_port(self, retries: int = 6, delay: float = 5.0) -> None: + """Raise if port is already in use. + + Retries a few times to handle TIME_WAIT from a recently stopped + server (common when running tasks back-to-back). + """ + logger.trace("[%s] Checking if port %d is available...", self.label, self.port) + for attempt in range(retries): + if _port_is_available(self.port): + return + if attempt < retries - 1: + logger.trace( + "[%s] Port %d in use, waiting %.0fs (%d/%d)...", + self.label, self.port, delay, attempt + 1, retries, + ) + time.sleep(delay) + raise RuntimeError( + f"Port {self.port} is still in use after {retries * delay:.0f}s. " + f"Use --port to specify a different port, " + f"or --llm-url to connect to an existing server." + ) + + def _build_command(self) -> tuple[list[str], dict]: + """Build the server launch command and environment overrides. + + Returns: + Tuple of (command list, env dict). The env dict contains + ``CUDA_VISIBLE_DEVICES`` when *cuda_visible_devices* is set. + """ + if self.backend == "vllm": + cmd = [ + sys.executable, "-m", "vllm.entrypoints.openai.api_server", + "--model", self.model, + "--port", str(self.port), + ] + # Merge defaults with user overrides (user wins) + merged_kwargs = {**_DEFAULT_VLLM_SERVER_FLAGS, **self.server_kwargs} + overridden = { + k: v for k, v in self.server_kwargs.items() + if k in _DEFAULT_VLLM_SERVER_FLAGS and v != _DEFAULT_VLLM_SERVER_FLAGS[k] + } + if overridden: + logger.trace("[%s] User overrides for default vLLM flags: %s", self.label, overridden) + logger.trace("[%s] Merged vLLM kwargs: %s", self.label, merged_kwargs) + for key, value in merged_kwargs.items(): + flag = "--" + key.replace("_", "-") + if isinstance(value, bool): + if value: + cmd.append(flag) + else: + no_flag = "--no-" + key.replace("_", "-") + cmd.append(no_flag) + else: + cmd.extend([flag, str(value)]) + elif self.backend == "custom": + model_path = self.server_kwargs.get("model_path", self.model) + extra_kwargs = {k: v for k, v in self.server_kwargs.items() if k != "model_path"} + device = "cuda:0" # CUDA_VISIBLE_DEVICES handles GPU remapping + cmd = [ + sys.executable, "-m", "easi.llm.models.http_server", + "--model-name", self.model, + "--model-path", str(model_path), + "--device", device, + "--port", str(self.port), + ] + if extra_kwargs: + cmd.extend(["--kwargs", json.dumps(extra_kwargs)]) + else: + raise ValueError(f"Unsupported server backend: {self.backend}") + + env: dict[str, str] = {} + if self.cuda_visible_devices is not None: + env["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices + + logger.trace("[%s] Built command: %s", self.label, cmd) + logger.trace("[%s] Extra env: %s", self.label, env) + return cmd, env + + @staticmethod + def _stream_output(proc: subprocess.Popen, label: str = "server") -> None: + """Read server stdout/stderr line by line and log at TRACE level.""" + for raw_line in proc.stdout: + line = raw_line.decode("utf-8", errors="replace").rstrip() + if line: + logger.trace("[%s] %s", label, line) + proc.stdout.close() + + def _wait_for_health(self, base_url: str) -> None: + """Poll /health until the server responds or timeout.""" + health_url = base_url.replace("/v1", "") + "/health" + deadline = time.monotonic() + self.startup_timeout + logger.trace( + "[%s] Waiting for health at %s (timeout=%.0fs)", + self.label, health_url, self.startup_timeout, + ) + + while time.monotonic() < deadline: + if self._process and self._process.poll() is not None: + raise RuntimeError( + f"[{self.label}] {self.backend} server exited with code " + f"{self._process.returncode}. " + f"Run with --verbosity TRACE to see server output." + ) + try: + resp = requests.get(health_url, timeout=5) + if resp.status_code == 200: + logger.trace("[%s] Health check passed (status=%d)", self.label, resp.status_code) + return + logger.trace("[%s] Health check returned status %d, retrying...", self.label, resp.status_code) + except (requests.ConnectionError, requests.Timeout): + logger.trace("[%s] Health check connection refused/timed out, retrying...", self.label) + + time.sleep(_HEALTH_POLL_INTERVAL) + + self.stop() + raise RuntimeError( + f"[{self.label}] {self.backend} server failed to start within " + f"{self.startup_timeout}s. Run with --verbosity TRACE to see server output." + ) + + def __enter__(self) -> str: + return self.start() + + def __exit__(self, *exc) -> None: + self.stop() + + +class MultiServerManager: + """Manages multiple local LLM server instances across GPUs.""" + + def __init__( + self, + model: str, + num_instances: int, + gpu_ids: list[int] | None = None, + base_port: int = 8000, + server_kwargs: dict | None = None, + startup_timeout: float = 300.0, + backend: str = "vllm", + ): + if gpu_ids is not None and len(gpu_ids) % num_instances != 0: + raise ValueError( + f"Cannot divide {len(gpu_ids)} GPUs evenly across " + f"{num_instances} instances" + ) + self.model = model + self.num_instances = num_instances + self.gpu_ids = gpu_ids + self.base_port = base_port + self.server_kwargs = server_kwargs or {} + self.startup_timeout = startup_timeout + self.backend = backend + self._managers: list[ServerManager] = [] + + def start(self) -> list[str]: + """Start all instances in parallel, return list of base_urls. + + All server processes are spawned first, then health checks run + concurrently via threads. Ports are assigned by probing from + *base_port* upward, skipping any that are already in use. If + any instance fails, all are stopped before re-raising. + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + + gpus_per = len(self.gpu_ids) // self.num_instances if self.gpu_ids else None + next_port = self.base_port + + try: + # Phase 1: Spawn all processes (fast, no blocking) + for i in range(self.num_instances): + if gpus_per is not None: + instance_gpus = self.gpu_ids[i * gpus_per : (i + 1) * gpus_per] + cuda_devices = ",".join(str(g) for g in instance_gpus) + else: + cuda_devices = None + port = self._find_available_port(next_port) + next_port = port + 1 + mgr = ServerManager( + backend=self.backend, + model=self.model, + port=port, + server_kwargs=self.server_kwargs, + startup_timeout=self.startup_timeout, + cuda_visible_devices=cuda_devices, + label=f"{self.backend}-{i}", + ) + mgr.launch() + self._managers.append(mgr) + + # Phase 2: Wait for all health checks concurrently + logger.info( + "All %d %s processes spawned, waiting for health checks...", + self.num_instances, self.backend, + ) + urls = [None] * len(self._managers) + with ThreadPoolExecutor(max_workers=len(self._managers)) as pool: + future_to_idx = { + pool.submit(mgr.wait_until_ready): idx + for idx, mgr in enumerate(self._managers) + } + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + urls[idx] = future.result() + + except Exception: + logger.warning( + "%s startup failed, stopping %d spawned instances", + self.backend, len(self._managers), + ) + self.stop() + raise + + return urls + + @staticmethod + def _find_available_port(start: int, max_probe: int = 100) -> int: + """Find the first available port starting from *start*.""" + for port in range(start, start + max_probe): + if _port_is_available(port): + return port + raise RuntimeError( + f"No available port found in range {start}-{start + max_probe - 1}" + ) + + def stop(self): + """Stop all managed instances.""" + for mgr in self._managers: + mgr.stop() + self._managers.clear() diff --git a/easi/llm/templates/internvl3.jinja b/easi/llm/templates/internvl3.jinja new file mode 100644 index 0000000..9867dff --- /dev/null +++ b/easi/llm/templates/internvl3.jinja @@ -0,0 +1,4 @@ +{#- InternVL3 chat template patched for OpenAI image_url compatibility. + Original source: OpenGVLab/InternVL3-8B tokenizer_config.json + Patch: accepts both type=="image" and type=="image_url" content parts. -#} +{%- if messages[0]['role'] == 'system' %}{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}{%- else %}{{- '<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n' }}{%- endif %}{% for message in messages %}{%- if messages[0]['role'] != 'system' or not loop.first %}{{'<|im_start|>' + message['role'] + '\n'}}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or content['type'] == 'image_url' %}{{ '\n' }}{% elif content['type'] == 'video' %}{{ '