From 4f983edd03db2438d5ea5e9cf586922bcea8d7d9 Mon Sep 17 00:00:00 2001 From: Oscar Moxon Date: Mon, 1 Jun 2026 16:02:41 +0100 Subject: [PATCH 1/6] Inline local rollout images in PrimeMonitor samples --- src/prime_rl/utils/monitor/prime.py | 69 ++++++++++++++++++++++++-- tests/unit/utils/test_prime_monitor.py | 29 +++++++++++ 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 657037c6f7..868dfbe0ac 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -1,13 +1,16 @@ import asyncio +import base64 import io import json import math +import mimetypes import os import time from datetime import datetime, timezone from pathlib import Path from threading import Thread from typing import Any +from urllib.parse import unquote, urlparse import httpx import pyarrow as pa @@ -57,6 +60,8 @@ def _json(val: Any) -> str: _DROPPED_JSON_VALUE = object() +_FILE_URL_SCHEME = "file" +_MAX_INLINE_SAMPLE_IMAGE_BYTES = 2 * 1024 * 1024 def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str = "") -> Any: @@ -89,6 +94,63 @@ def _drop_non_finite_json_values(value: Any, dropped_paths: list[str], path: str return value +def _local_image_file_to_data_url( + url: str, + cache: dict[str, str | None], + max_bytes: int = _MAX_INLINE_SAMPLE_IMAGE_BYTES, +) -> str | None: + if url in cache: + return cache[url] + + parsed = urlparse(url) + if parsed.scheme != _FILE_URL_SCHEME or parsed.netloc not in ("", "localhost"): + cache[url] = None + return None + + path = Path(unquote(parsed.path)) + media_type = mimetypes.guess_type(path.name)[0] or "image/png" + if not media_type.startswith("image/"): + cache[url] = None + return None + + try: + if path.stat().st_size > max_bytes: + cache[url] = None + return None + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + except OSError: + cache[url] = None + return None + + data_url = f"data:{media_type};base64,{encoded}" + cache[url] = data_url + return data_url + + +def _inline_local_image_urls(value: Any, cache: dict[str, str | None]) -> Any: + if isinstance(value, list): + return [_inline_local_image_urls(item, cache) for item in value] + + if not isinstance(value, dict): + return value + + inlined = {key: _inline_local_image_urls(item, cache) for key, item in value.items()} + image_url = inlined.get("image_url") + if not isinstance(image_url, dict): + return inlined + + url = image_url.get("url") + if not isinstance(url, str): + return inlined + + data_url = _local_image_file_to_data_url(url, cache) + if data_url is None: + return inlined + + inlined["image_url"] = {**image_url, "url": data_url} + return inlined + + class PrimeMonitor(Monitor): """Logs to Prime Intellect API.""" @@ -332,6 +394,7 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int """Convert rollouts directly to Parquet bytes for upload.""" now = datetime.now(timezone.utc) rows = [] + image_data_url_cache: dict[str, str | None] = {} for sample_id, rollout in enumerate(rollouts): prompt = rollout.get("prompt") @@ -366,9 +429,9 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int "tag": "", "problem_id": problem_id, "sample_id": sample_id, - "prompt": json.dumps(prompt), - "completion": json.dumps(completion), - "trajectory": json.dumps(trajectory_data), + "prompt": json.dumps(_inline_local_image_urls(prompt, image_data_url_cache)), + "completion": json.dumps(_inline_local_image_urls(completion, image_data_url_cache)), + "trajectory": json.dumps(_inline_local_image_urls(trajectory_data, image_data_url_cache)), "answer": rollout.get("answer") or "", "env_name": rollout.get("env_name") or "", "task": rollout.get("task") or "", diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index f44065a7c6..a07a1cbca3 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -1,3 +1,4 @@ +import base64 import io import json from unittest.mock import Mock @@ -94,6 +95,34 @@ def test_rollouts_to_parquet_bytes_skips_rollouts_without_trajectory(): assert rows[0]["sample_id"] == 0 +def test_rollouts_to_parquet_bytes_inlines_local_image_urls(tmp_path): + image_bytes = b"small image payload" + image_path = tmp_path / "sample.jpg" + image_path.write_bytes(image_bytes) + file_url = f"file://{image_path}" + expected_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode('ascii')}" + + monitor = _new_monitor() + monitor.run_id = "run-images" + rollout = _build_rollout(example_id=1, reward=1.0, task="vision-task") + image_content = [{"type": "image_url", "image_url": {"url": file_url}}] + rollout["prompt"] = [{"role": "user", "content": image_content}] + rollout["trajectory"][0]["prompt"] = [{"role": "user", "content": image_content}] + + parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=3) + + assert parquet_bytes is not None + + table = pq.read_table(io.BytesIO(parquet_bytes)) + row = table.to_pylist()[0] + prompt = json.loads(row["prompt"]) + trajectory = json.loads(row["trajectory"]) + + assert prompt[0]["content"][0]["image_url"]["url"] == expected_data_url + assert trajectory[0]["prompt"][0]["content"][0]["image_url"]["url"] == expected_data_url + assert rollout["prompt"][0]["content"][0]["image_url"]["url"] == file_url + + def test_sanitize_json_payload_drops_non_finite_values_and_logs_paths(): monitor = _new_monitor() monitor.logger = Mock() From 62cc00be3d6fe2e44e8906b2ee7b6764d379e1ff Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 3 Jun 2026 13:28:03 -0700 Subject: [PATCH 2/6] fix: emit absolute rollout image file uris --- src/prime_rl/orchestrator/trajectories.py | 9 ++---- tests/unit/orchestrator/test_trajectories.py | 33 ++++++++++++++++++++ tests/unit/utils/test_prime_monitor.py | 2 +- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 5679253bf5..6ad3d7d5c6 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -609,16 +609,13 @@ def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": return out -_FILE_URL_PREFIX = "file://" - - def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int: """Replace base64 image data in rollout trajectories with file paths on disk. Scans all trajectory step prompts for data:image URLs, writes the decoded image bytes to ``{output_dir}/assets/images/{hash}.png``, and replaces the - URL in-place with ``file://{path}``. Deduplicates by content hash so each - unique image is written only once. + URL in-place with an absolute ``file://`` URI. Deduplicates by content hash + so each unique image is written only once. Returns the number of unique images written to disk. """ @@ -649,6 +646,6 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - if not path.exists(): path.write_bytes(base64.b64decode(b64_data)) written.add(content_hash) - item["image_url"]["url"] = f"{_FILE_URL_PREFIX}{path}" + item["image_url"]["url"] = path.resolve().as_uri() return len(written) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 87df2646c6..1ee34327c2 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -1,4 +1,7 @@ +import base64 +from pathlib import Path from unittest.mock import MagicMock +from urllib.parse import unquote, urlparse import numpy as np import pybase64 @@ -9,6 +12,7 @@ _deserialize_tool_calls, align_routed_experts, interleave_rollout, + offload_images_to_disk, ) _interleave_rollout = interleave_rollout @@ -47,6 +51,35 @@ def _sample_routed_experts(sample) -> np.ndarray: ) +def test_offload_images_to_disk_uses_absolute_file_uri_for_relative_output_dir(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + image_bytes = b"small image payload" + data_url = f"data:image/png;base64,{base64.b64encode(image_bytes).decode('ascii')}" + rollout = { + "trajectory": [ + { + "prompt": [ + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": data_url}}], + } + ] + } + ] + } + + assert offload_images_to_disk([rollout], Path("outputs/run")) == 1 + + file_url = rollout["trajectory"][0]["prompt"][0]["content"][0]["image_url"]["url"] + parsed = urlparse(file_url) + image_path = Path(unquote(parsed.path)) + + assert parsed.scheme == "file" + assert parsed.netloc == "" + assert image_path.is_absolute() + assert image_path.read_bytes() == image_bytes + + def test_deserialize_tool_calls_does_not_inject_missing_key(): messages = [{"role": "assistant", "content": "hello"}] diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index a07a1cbca3..9f54dec43f 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -99,7 +99,7 @@ def test_rollouts_to_parquet_bytes_inlines_local_image_urls(tmp_path): image_bytes = b"small image payload" image_path = tmp_path / "sample.jpg" image_path.write_bytes(image_bytes) - file_url = f"file://{image_path}" + file_url = image_path.as_uri() expected_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode('ascii')}" monitor = _new_monitor() From 3ccf1a6d5b2c0c2e789e4aaca4cef193b404dcb6 Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 3 Jun 2026 14:40:17 -0700 Subject: [PATCH 3/6] fix: preserve rollout image media types --- src/prime_rl/orchestrator/trajectories.py | 18 +++++++++----- tests/unit/orchestrator/test_trajectories.py | 26 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 6ad3d7d5c6..9f39846344 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -1,5 +1,6 @@ import base64 import hashlib +import mimetypes from pathlib import Path from typing import Any @@ -609,20 +610,25 @@ def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": return out +def _image_file_suffix_from_data_url(url: str) -> str: + media_type = url.split(",", 1)[0].removeprefix("data:").split(";", 1)[0] + return mimetypes.guess_extension(media_type) or ".png" + + def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) -> int: """Replace base64 image data in rollout trajectories with file paths on disk. Scans all trajectory step prompts for data:image URLs, writes the decoded - image bytes to ``{output_dir}/assets/images/{hash}.png``, and replaces the + image bytes to ``{output_dir}/assets/images/{hash}.{ext}``, and replaces the URL in-place with an absolute ``file://`` URI. Deduplicates by content hash - so each unique image is written only once. + and media type so each unique image file is written only once. Returns the number of unique images written to disk. """ images_dir = output_dir / "assets" / "images" images_dir.mkdir(parents=True, exist_ok=True) - written: set[str] = set() + written: set[Path] = set() for output in rollouts: for step in output.get("trajectory", []): @@ -641,11 +647,11 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - continue b64_data = url.split(",", 1)[1] content_hash = hashlib.sha256(b64_data.encode()).hexdigest()[:16] - path = images_dir / f"{content_hash}.png" - if content_hash not in written: + path = images_dir / f"{content_hash}{_image_file_suffix_from_data_url(url)}" + if path not in written: if not path.exists(): path.write_bytes(base64.b64decode(b64_data)) - written.add(content_hash) + written.add(path) item["image_url"]["url"] = path.resolve().as_uri() return len(written) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 1ee34327c2..bf30c2fd4c 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -80,6 +80,32 @@ def test_offload_images_to_disk_uses_absolute_file_uri_for_relative_output_dir(t assert image_path.read_bytes() == image_bytes +def test_offload_images_to_disk_preserves_image_media_type_suffix(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + image_bytes = b"jpeg image payload" + data_url = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode('ascii')}" + rollout = { + "trajectory": [ + { + "prompt": [ + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": data_url}}], + } + ] + } + ] + } + + assert offload_images_to_disk([rollout], Path("outputs/run")) == 1 + + file_url = rollout["trajectory"][0]["prompt"][0]["content"][0]["image_url"]["url"] + image_path = Path(unquote(urlparse(file_url).path)) + + assert image_path.suffix == ".jpg" + assert image_path.read_bytes() == image_bytes + + def test_deserialize_tool_calls_does_not_inject_missing_key(): messages = [{"role": "assistant", "content": "hello"}] From 1fb8432fa814b8e2979fbd1df9b2a509e1533c1e Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 3 Jun 2026 14:56:30 -0700 Subject: [PATCH 4/6] fix: skip extensionless rollout image files --- src/prime_rl/utils/monitor/prime.py | 4 ++-- tests/unit/utils/test_prime_monitor.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 868dfbe0ac..96933925b5 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -108,8 +108,8 @@ def _local_image_file_to_data_url( return None path = Path(unquote(parsed.path)) - media_type = mimetypes.guess_type(path.name)[0] or "image/png" - if not media_type.startswith("image/"): + media_type = mimetypes.guess_type(path.name)[0] + if media_type is None or not media_type.startswith("image/"): cache[url] = None return None diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index 9f54dec43f..c2925db51d 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -123,6 +123,28 @@ def test_rollouts_to_parquet_bytes_inlines_local_image_urls(tmp_path): assert rollout["prompt"][0]["content"][0]["image_url"]["url"] == file_url +def test_rollouts_to_parquet_bytes_does_not_inline_extensionless_file_urls(tmp_path): + image_path = tmp_path / "sample" + image_path.write_bytes(b"extensionless payload") + file_url = image_path.as_uri() + + monitor = _new_monitor() + monitor.run_id = "run-extensionless" + rollout = _build_rollout(example_id=1, reward=1.0, task="vision-task") + image_content = [{"type": "image_url", "image_url": {"url": file_url}}] + rollout["prompt"] = [{"role": "user", "content": image_content}] + + parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=3) + + assert parquet_bytes is not None + + table = pq.read_table(io.BytesIO(parquet_bytes)) + row = table.to_pylist()[0] + prompt = json.loads(row["prompt"]) + + assert prompt[0]["content"][0]["image_url"]["url"] == file_url + + def test_sanitize_json_payload_drops_non_finite_values_and_logs_paths(): monitor = _new_monitor() monitor.logger = Mock() From 8d2648012a9c2815ad29b5e3a38a8a64c4300e46 Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 3 Jun 2026 17:16:46 -0700 Subject: [PATCH 5/6] chore: remove rollout image tests --- tests/unit/orchestrator/test_trajectories.py | 83 +++----------------- tests/unit/utils/test_prime_monitor.py | 51 ------------ 2 files changed, 12 insertions(+), 122 deletions(-) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index bf30c2fd4c..bda129cd43 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -1,7 +1,4 @@ -import base64 -from pathlib import Path from unittest.mock import MagicMock -from urllib.parse import unquote, urlparse import numpy as np import pybase64 @@ -12,14 +9,13 @@ _deserialize_tool_calls, align_routed_experts, interleave_rollout, - offload_images_to_disk, ) _interleave_rollout = interleave_rollout def interleave_rollout(output, *args, **kwargs): - output.setdefault("env_name", "test-env") + kwargs.setdefault("env_name", output.get("env_name", "test-env")) return _interleave_rollout(output, *args, **kwargs) @@ -51,61 +47,6 @@ def _sample_routed_experts(sample) -> np.ndarray: ) -def test_offload_images_to_disk_uses_absolute_file_uri_for_relative_output_dir(tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - image_bytes = b"small image payload" - data_url = f"data:image/png;base64,{base64.b64encode(image_bytes).decode('ascii')}" - rollout = { - "trajectory": [ - { - "prompt": [ - { - "role": "user", - "content": [{"type": "image_url", "image_url": {"url": data_url}}], - } - ] - } - ] - } - - assert offload_images_to_disk([rollout], Path("outputs/run")) == 1 - - file_url = rollout["trajectory"][0]["prompt"][0]["content"][0]["image_url"]["url"] - parsed = urlparse(file_url) - image_path = Path(unquote(parsed.path)) - - assert parsed.scheme == "file" - assert parsed.netloc == "" - assert image_path.is_absolute() - assert image_path.read_bytes() == image_bytes - - -def test_offload_images_to_disk_preserves_image_media_type_suffix(tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - image_bytes = b"jpeg image payload" - data_url = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode('ascii')}" - rollout = { - "trajectory": [ - { - "prompt": [ - { - "role": "user", - "content": [{"type": "image_url", "image_url": {"url": data_url}}], - } - ] - } - ] - } - - assert offload_images_to_disk([rollout], Path("outputs/run")) == 1 - - file_url = rollout["trajectory"][0]["prompt"][0]["content"][0]["image_url"]["url"] - image_path = Path(unquote(urlparse(file_url).path)) - - assert image_path.suffix == ".jpg" - assert image_path.read_bytes() == image_bytes - - def test_deserialize_tool_calls_does_not_inject_missing_key(): messages = [{"role": "assistant", "content": "hello"}] @@ -408,7 +349,7 @@ def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extens assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.completion_temperatures == [] # second step rollout = rollouts[1] @@ -417,7 +358,7 @@ def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extens assert rollout.completion_ids == [7, 8] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.completion_temperatures == [] def test_branching_equivalent_multi_step_trajectory_with_tool_calls( @@ -435,7 +376,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.completion_temperatures == [] # second step rollout = rollouts[1] @@ -444,7 +385,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( assert rollout.completion_ids == [7, 8] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.completion_temperatures == [] def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output): @@ -459,7 +400,7 @@ def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.completion_temperatures == [] assert rollout.env_name == "test-env" @@ -474,8 +415,8 @@ def test_interleave_rollout_multi_step_trajectory(multi_step_trajectory_output): assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] assert rollout.completion_mask == [True, True, False, False, True, True] assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] - # Temperatures: 2 completion tokens at temp 1.0, then 2 prompt tokens at temp 1.0, then 2 completion tokens at temp 1.0 - assert rollout.completion_temperatures == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + # ``completion_temperatures`` is filled by the orchestrator post-interleave; empty here. + assert rollout.completion_temperatures == [] def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_trajectory_with_tool_calls_output): @@ -489,8 +430,8 @@ def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_tra assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] assert rollout.completion_mask == [True, True, False, False, True, True] assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] - # Temperatures: 2 completion tokens at temp 1.0, then 2 prompt tokens at temp 1.0, then 2 completion tokens at temp 1.0 - assert rollout.completion_temperatures == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + # ``completion_temperatures`` is filled by the orchestrator post-interleave; empty here. + assert rollout.completion_temperatures == [] @pytest.fixture @@ -1011,9 +952,9 @@ def test_interleave_rollout_error_masks_all_false(): # Extension holds so tokens merge, but ALL completion_mask should be False assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] assert rollout.completion_mask == [False, False, False, False, False, False] - # Logprobs and temperatures still present + # Logprobs preserved; ``completion_temperatures`` is filled by the orchestrator post-interleave. assert rollout.completion_logprobs == [-0.1, -0.2, 0.0, 0.0, -0.3, -0.4] - assert rollout.completion_temperatures == [0.8] * 6 + assert rollout.completion_temperatures == [] def test_align_routed_experts_none(): diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index c2925db51d..f44065a7c6 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -1,4 +1,3 @@ -import base64 import io import json from unittest.mock import Mock @@ -95,56 +94,6 @@ def test_rollouts_to_parquet_bytes_skips_rollouts_without_trajectory(): assert rows[0]["sample_id"] == 0 -def test_rollouts_to_parquet_bytes_inlines_local_image_urls(tmp_path): - image_bytes = b"small image payload" - image_path = tmp_path / "sample.jpg" - image_path.write_bytes(image_bytes) - file_url = image_path.as_uri() - expected_data_url = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode('ascii')}" - - monitor = _new_monitor() - monitor.run_id = "run-images" - rollout = _build_rollout(example_id=1, reward=1.0, task="vision-task") - image_content = [{"type": "image_url", "image_url": {"url": file_url}}] - rollout["prompt"] = [{"role": "user", "content": image_content}] - rollout["trajectory"][0]["prompt"] = [{"role": "user", "content": image_content}] - - parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=3) - - assert parquet_bytes is not None - - table = pq.read_table(io.BytesIO(parquet_bytes)) - row = table.to_pylist()[0] - prompt = json.loads(row["prompt"]) - trajectory = json.loads(row["trajectory"]) - - assert prompt[0]["content"][0]["image_url"]["url"] == expected_data_url - assert trajectory[0]["prompt"][0]["content"][0]["image_url"]["url"] == expected_data_url - assert rollout["prompt"][0]["content"][0]["image_url"]["url"] == file_url - - -def test_rollouts_to_parquet_bytes_does_not_inline_extensionless_file_urls(tmp_path): - image_path = tmp_path / "sample" - image_path.write_bytes(b"extensionless payload") - file_url = image_path.as_uri() - - monitor = _new_monitor() - monitor.run_id = "run-extensionless" - rollout = _build_rollout(example_id=1, reward=1.0, task="vision-task") - image_content = [{"type": "image_url", "image_url": {"url": file_url}}] - rollout["prompt"] = [{"role": "user", "content": image_content}] - - parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=3) - - assert parquet_bytes is not None - - table = pq.read_table(io.BytesIO(parquet_bytes)) - row = table.to_pylist()[0] - prompt = json.loads(row["prompt"]) - - assert prompt[0]["content"][0]["image_url"]["url"] == file_url - - def test_sanitize_json_payload_drops_non_finite_values_and_logs_paths(): monitor = _new_monitor() monitor.logger = Mock() From 27ec93dae9c4fd0c727999012e8527c10536f587 Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 3 Jun 2026 17:17:16 -0700 Subject: [PATCH 6/6] chore: fully drop test changes --- tests/unit/orchestrator/test_trajectories.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index bda129cd43..87df2646c6 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -15,7 +15,7 @@ def interleave_rollout(output, *args, **kwargs): - kwargs.setdefault("env_name", output.get("env_name", "test-env")) + output.setdefault("env_name", "test-env") return _interleave_rollout(output, *args, **kwargs) @@ -349,7 +349,7 @@ def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extens assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [] + assert rollout.completion_temperatures == [1.0, 1.0] # second step rollout = rollouts[1] @@ -358,7 +358,7 @@ def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extens assert rollout.completion_ids == [7, 8] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [] + assert rollout.completion_temperatures == [1.0, 1.0] def test_branching_equivalent_multi_step_trajectory_with_tool_calls( @@ -376,7 +376,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [] + assert rollout.completion_temperatures == [1.0, 1.0] # second step rollout = rollouts[1] @@ -385,7 +385,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( assert rollout.completion_ids == [7, 8] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [] + assert rollout.completion_temperatures == [1.0, 1.0] def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output): @@ -400,7 +400,7 @@ def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [] + assert rollout.completion_temperatures == [1.0, 1.0] assert rollout.env_name == "test-env" @@ -415,8 +415,8 @@ def test_interleave_rollout_multi_step_trajectory(multi_step_trajectory_output): assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] assert rollout.completion_mask == [True, True, False, False, True, True] assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] - # ``completion_temperatures`` is filled by the orchestrator post-interleave; empty here. - assert rollout.completion_temperatures == [] + # Temperatures: 2 completion tokens at temp 1.0, then 2 prompt tokens at temp 1.0, then 2 completion tokens at temp 1.0 + assert rollout.completion_temperatures == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_trajectory_with_tool_calls_output): @@ -430,8 +430,8 @@ def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_tra assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] assert rollout.completion_mask == [True, True, False, False, True, True] assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] - # ``completion_temperatures`` is filled by the orchestrator post-interleave; empty here. - assert rollout.completion_temperatures == [] + # Temperatures: 2 completion tokens at temp 1.0, then 2 prompt tokens at temp 1.0, then 2 completion tokens at temp 1.0 + assert rollout.completion_temperatures == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] @pytest.fixture @@ -952,9 +952,9 @@ def test_interleave_rollout_error_masks_all_false(): # Extension holds so tokens merge, but ALL completion_mask should be False assert rollout.completion_ids == [3, 4, 5, 6, 7, 8] assert rollout.completion_mask == [False, False, False, False, False, False] - # Logprobs preserved; ``completion_temperatures`` is filled by the orchestrator post-interleave. + # Logprobs and temperatures still present assert rollout.completion_logprobs == [-0.1, -0.2, 0.0, 0.0, -0.3, -0.4] - assert rollout.completion_temperatures == [] + assert rollout.completion_temperatures == [0.8] * 6 def test_align_routed_experts_none():