diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 5679253bf5..9f39846344 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -1,5 +1,6 @@ import base64 import hashlib +import mimetypes from pathlib import Path from typing import Any @@ -609,23 +610,25 @@ def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": return out -_FILE_URL_PREFIX = "file://" +def _image_file_suffix_from_data_url(url: str) -> str: + media_type = url.split(",", 1)[0].removeprefix("data:").split(";", 1)[0] + return mimetypes.guess_extension(media_type) or ".png" def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int: """Replace base64 image data in rollout trajectories with file paths on disk. Scans all trajectory step prompts for data:image URLs, writes the decoded - image bytes to ``{output_dir}/assets/images/{hash}.png``, and replaces the - URL in-place with ``file://{path}``. Deduplicates by content hash so each - unique image is written only once. + image bytes to ``{output_dir}/assets/images/{hash}.{ext}``, and replaces the + URL in-place with an absolute ``file://`` URI. Deduplicates by content hash + and media type so each unique image file is written only once. Returns the number of unique images written to disk. """ images_dir = output_dir / "assets" / "images" images_dir.mkdir(parents=True, exist_ok=True) - written: set[str] = set() + written: set[Path] = set() for output in rollouts: for step in output.get("trajectory", []): @@ -644,11 +647,11 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - continue b64_data = url.split(",", 1)[1] content_hash = hashlib.sha256(b64_data.encode()).hexdigest()[:16] - path = images_dir / f"{content_hash}.png" - if content_hash not in written: + path = images_dir / f"{content_hash}{_image_file_suffix_from_data_url(url)}" + if path not in written: if not path.exists(): path.write_bytes(base64.b64decode(b64_data)) - written.add(content_hash) - item["image_url"]["url"] = f"{_FILE_URL_PREFIX}{path}" + written.add(path) + item["image_url"]["url"] = path.resolve().as_uri() return len(written) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 657037c6f7..96933925b5 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -1,13 +1,16 @@ import asyncio +import base64 import io import json import math +import mimetypes import os import time from datetime import datetime, timezone from pathlib import Path from threading import Thread from typing import Any +from urllib.parse import unquote, urlparse import httpx import pyarrow as pa @@ -57,6 +60,8 @@ def _json(val: Any) -> str: _DROPPED_JSON_VALUE = object() +_FILE_URL_SCHEME = "file" +_MAX_INLINE_SAMPLE_IMAGE_BYTES = 2 * 1024 * 1024 def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str = "") -> Any: @@ -89,6 +94,63 @@ def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str return value +def _local_image_file_to_data_url( + url: str, + cache: dict[str, str | None], + max_bytes: int = _MAX_INLINE_SAMPLE_IMAGE_BYTES, +) -> str | None: + if url in cache: + return cache[url] + + parsed = urlparse(url) + if parsed.scheme != _FILE_URL_SCHEME or parsed.netloc not in ("", "localhost"): + cache[url] = None + return None + + path = Path(unquote(parsed.path)) + media_type = mimetypes.guess_type(path.name)[0] + if media_type is None or not media_type.startswith("image/"): + cache[url] = None + return None + + try: + if path.stat().st_size > max_bytes: + cache[url] = None + return None + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + except OSError: + cache[url] = None + return None + + data_url = f"data:{media_type};base64,{encoded}" + cache[url] = data_url + return data_url + + +def _inline_local_image_urls(value: Any, cache: dict[str, str | None]) -> Any: + if isinstance(value, list): + return [_inline_local_image_urls(item, cache) for item in value] + + if not isinstance(value, dict): + return value + + inlined = {key: _inline_local_image_urls(item, cache) for key, item in value.items()} + image_url = inlined.get("image_url") + if not isinstance(image_url, dict): + return inlined + + url = image_url.get("url") + if not isinstance(url, str): + return inlined + + data_url = _local_image_file_to_data_url(url, cache) + if data_url is None: + return inlined + + inlined["image_url"] = {**image_url, "url": data_url} + return inlined + + class PrimeMonitor(Monitor): """Logs to Prime Intellect API.""" @@ -332,6 +394,7 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int """Convert rollouts directly to Parquet bytes for upload.""" now = datetime.now(timezone.utc) rows = [] + image_data_url_cache: dict[str, str | None] = {} for sample_id, rollout in enumerate(rollouts): prompt = rollout.get("prompt") @@ -366,9 +429,9 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int "tag": "", "problem_id": problem_id, "sample_id": sample_id, - "prompt": json.dumps(prompt), - "completion": json.dumps(completion), - "trajectory": json.dumps(trajectory_data), + "prompt": json.dumps(_inline_local_image_urls(prompt, image_data_url_cache)), + "completion": json.dumps(_inline_local_image_urls(completion, image_data_url_cache)), + "trajectory": json.dumps(_inline_local_image_urls(trajectory_data, image_data_url_cache)), "answer": rollout.get("answer") or "", "env_name": rollout.get("env_name") or "", "task": rollout.get("task") or "",