Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import hashlib
import mimetypes
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -609,23 +610,25 @@ def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None":
return out


_FILE_URL_PREFIX = "file://"
def _image_file_suffix_from_data_url(url: str) -> str:
media_type = url.split(",", 1)[0].removeprefix("data:").split(";", 1)[0]
return mimetypes.guess_extension(media_type) or ".png"


def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int:
"""Replace base64 image data in rollout trajectories with file paths on disk.

Scans all trajectory step prompts for data:image URLs, writes the decoded
image bytes to ``{output_dir}/assets/images/{hash}.png``, and replaces the
URL in-place with ``file://{path}``. Deduplicates by content hash so each
unique image is written only once.
image bytes to ``{output_dir}/assets/images/{hash}.{ext}``, and replaces the
URL in-place with an absolute ``file://`` URI. Deduplicates by content hash
and media type so each unique image file is written only once.

Returns the number of unique images written to disk.
"""
images_dir = output_dir / "assets" / "images"
images_dir.mkdir(parents=True, exist_ok=True)

written: set[str] = set()
written: set[Path] = set()

for output in rollouts:
for step in output.get("trajectory", []):
Expand All @@ -644,11 +647,11 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -
continue
b64_data = url.split(",", 1)[1]
content_hash = hashlib.sha256(b64_data.encode()).hexdigest()[:16]
path = images_dir / f"{content_hash}.png"
if content_hash not in written:
path = images_dir / f"{content_hash}{_image_file_suffix_from_data_url(url)}"
if path not in written:
if not path.exists():
path.write_bytes(base64.b64decode(b64_data))
written.add(content_hash)
item["image_url"]["url"] = f"{_FILE_URL_PREFIX}{path}"
written.add(path)
item["image_url"]["url"] = path.resolve().as_uri()

return len(written)
69 changes: 66 additions & 3 deletions src/prime_rl/utils/monitor/prime.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import asyncio
import base64
import io
import json
import math
import mimetypes
import os
import time
from datetime import datetime, timezone
from pathlib import Path
from threading import Thread
from typing import Any
from urllib.parse import unquote, urlparse

import httpx
import pyarrow as pa
Expand Down Expand Up @@ -57,6 +60,8 @@ def _json(val: Any) -> str:


_DROPPED_JSON_VALUE = object()
_FILE_URL_SCHEME = "file"
_MAX_INLINE_SAMPLE_IMAGE_BYTES = 2 * 1024 * 1024


def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str = "") -> Any:
Expand Down Expand Up @@ -89,6 +94,63 @@ def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str
return value


def _local_image_file_to_data_url(
url: str,
cache: dict[str, str | None],
max_bytes: int = _MAX_INLINE_SAMPLE_IMAGE_BYTES,
) -> str | None:
if url in cache:
return cache[url]

parsed = urlparse(url)
if parsed.scheme != _FILE_URL_SCHEME or parsed.netloc not in ("", "localhost"):
cache[url] = None
return None

path = Path(unquote(parsed.path))
Comment thread
d42me marked this conversation as resolved.
media_type = mimetypes.guess_type(path.name)[0]
if media_type is None or not media_type.startswith("image/"):
cache[url] = None
return None
Comment thread
d42me marked this conversation as resolved.

try:
if path.stat().st_size > max_bytes:
cache[url] = None
return None
encoded = base64.b64encode(path.read_bytes()).decode("ascii")
except OSError:
cache[url] = None
return None

data_url = f"data:{media_type};base64,{encoded}"
cache[url] = data_url
return data_url


def _inline_local_image_urls(value: Any, cache: dict[str, str | None]) -> Any:
if isinstance(value, list):
return [_inline_local_image_urls(item, cache) for item in value]

if not isinstance(value, dict):
return value

inlined = {key: _inline_local_image_urls(item, cache) for key, item in value.items()}
image_url = inlined.get("image_url")
if not isinstance(image_url, dict):
return inlined

url = image_url.get("url")
if not isinstance(url, str):
return inlined

data_url = _local_image_file_to_data_url(url, cache)
if data_url is None:
return inlined

inlined["image_url"] = {**image_url, "url": data_url}
return inlined


class PrimeMonitor(Monitor):
"""Logs to Prime Intellect API."""

Expand Down Expand Up @@ -332,6 +394,7 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int
"""Convert rollouts directly to Parquet bytes for upload."""
now = datetime.now(timezone.utc)
rows = []
image_data_url_cache: dict[str, str | None] = {}

for sample_id, rollout in enumerate(rollouts):
prompt = rollout.get("prompt")
Expand Down Expand Up @@ -366,9 +429,9 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int
"tag": "",
"problem_id": problem_id,
"sample_id": sample_id,
"prompt": json.dumps(prompt),
"completion": json.dumps(completion),
"trajectory": json.dumps(trajectory_data),
"prompt": json.dumps(_inline_local_image_urls(prompt, image_data_url_cache)),
"completion": json.dumps(_inline_local_image_urls(completion, image_data_url_cache)),
"trajectory": json.dumps(_inline_local_image_urls(trajectory_data, image_data_url_cache)),
"answer": rollout.get("answer") or "",
"env_name": rollout.get("env_name") or "",
"task": rollout.get("task") or "",
Expand Down
Loading