Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,32 @@ def _preprocess_rows_from_config(self, config: RolloutCollectionConfig) -> List[
rows_iterator: Iterator[tuple[int, str]] = zip(range_iterator, rows_iterator)
raw_rows = [(row_idx, row_str, orjson.loads(row_str)) for row_idx, row_str in rows_iterator]

# Fallback: if prompt_config wasn't explicitly set, try to infer it from the
# agent's dataset catalog. Each BenchmarkDatasetConfig entry has a
# (jsonl_fpath, prompt_config) pair; when the user specifies input_jsonl_fpath
# matching one of those entries, use its declared prompt_config. This avoids
# forcing the user to re-specify the prompt that's already in the agent YAML.
if prompt_cfg is None and config.agent_name:
global_cfg = get_global_config_dict()
agent_cfg = global_cfg.get(config.agent_name) if global_cfg is not None else None
if agent_cfg is not None:
agents_section = agent_cfg.get("responses_api_agents") or {}
# agents_section has exactly one entry (enforced by schema: Dict[...] min=1 max=1)
inner = next(iter(agents_section.values()), None) or {}
for ds in inner.get("datasets") or []:
ds_path = ds.get("jsonl_fpath")
ds_prompt = ds.get("prompt_config")
if not ds_path or not ds_prompt:
continue
try:
same = Path(str(ds_path)).resolve() == _input_path.resolve()
except OSError:
same = str(ds_path) == str(_input_path)
if same:
prompt_cfg = load_prompt_config(str(ds_prompt))
print(f"Using prompt config from agent '{config.agent_name}' dataset catalog: {ds_prompt}")
break

# Validate and apply prompt config before per-row processing
if prompt_cfg is not None:
validate_prompt_compatibility([row for _, _, row in raw_rows], prompt_cfg)
Expand Down
206 changes: 205 additions & 1 deletion tests/unit_tests/test_rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def test_preprocess_rows_prompt_config_preserves_rcp_fields(self, tmp_path: Path
assert result[0]["responses_create_params"]["tools"] == [{"type": "function", "name": "calc"}]
assert result[0]["responses_create_params"]["input"] == [{"role": "user", "content": "test"}]

def test_preprocess_rows_from_config(self, tmp_path: Path) -> None:
def test_preprocess_rows_from_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
import nemo_gym.rollout_collection as rc_module

monkeypatch.setattr(rc_module, "get_global_config_dict", lambda: {})

fpath = tmp_path / "input.jsonl"
samples = [json.dumps({"responses_create_params": {"input": []}, "x": i}) for i in range(10)]
fpath.write_text("\n".join(samples) + "\n")
Expand Down Expand Up @@ -542,3 +546,203 @@ async def test_call_aggregate_metrics_empty(self, tmp_path: Path) -> None:
output_fpath = tmp_path / "output.jsonl"
result = await helper._call_aggregate_metrics([], [], output_fpath)
assert result is None


class TestPromptConfigFromCatalog:
"""Tests for defaulting prompt_config from the agent's dataset catalog.

When +prompt_config is not explicitly specified, _preprocess_rows_from_config
looks up the agent_name in the global config dict and finds a matching
datasets[*] entry by jsonl_fpath, then uses its prompt_config.
"""

@staticmethod
def _write_prompt_yaml(path: Path, user_template: str, system: str | None = None) -> None:
data: dict = {"user": user_template}
if system is not None:
data["system"] = system
path.write_text(yaml.dump(data))

@staticmethod
def _write_input_jsonl(path: Path, rows: list[dict]) -> None:
path.write_text("\n".join(json.dumps(r) for r in rows) + "\n")

@staticmethod
def _patch_global_config(monkeypatch: pytest.MonkeyPatch, agent_name: str, datasets: list[dict]) -> None:
"""Patch get_global_config_dict to return an agent config with the given datasets."""
import nemo_gym.rollout_collection as rc_module

global_cfg = {
agent_name: {
"responses_api_agents": {
"inner_agent": {"datasets": datasets},
},
},
}
monkeypatch.setattr(rc_module, "get_global_config_dict", lambda: global_cfg)

def test_explicit_prompt_config_wins_over_catalog(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Explicit +prompt_config takes priority over the agent's catalog prompt_config."""
explicit_prompt = tmp_path / "explicit.yaml"
self._write_prompt_yaml(explicit_prompt, "EXPLICIT: {question}")

catalog_prompt = tmp_path / "catalog.yaml"
self._write_prompt_yaml(catalog_prompt, "CATALOG: {question}")

input_fpath = tmp_path / "input.jsonl"
self._write_input_jsonl(input_fpath, [{"question": "What is 2+2?"}])

self._patch_global_config(
monkeypatch,
agent_name="my_agent",
datasets=[
{"jsonl_fpath": str(input_fpath), "prompt_config": str(catalog_prompt)},
],
)

config = RolloutCollectionConfig(
agent_name="my_agent",
input_jsonl_fpath=str(input_fpath),
output_jsonl_fpath=str(tmp_path / "out.jsonl"),
prompt_config=str(explicit_prompt),
)

result = RolloutCollectionHelper._preprocess_rows_from_config(None, config)
assert len(result) == 1
assert result[0]["responses_create_params"]["input"] == [
{"role": "user", "content": "EXPLICIT: What is 2+2?"},
]

def test_catalog_fallback_used_when_input_matches(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Without explicit prompt_config, the agent's matching dataset entry is used."""
catalog_prompt = tmp_path / "catalog.yaml"
self._write_prompt_yaml(catalog_prompt, "CATALOG: {question}")

input_fpath = tmp_path / "input.jsonl"
self._write_input_jsonl(input_fpath, [{"question": "What is 2+2?"}])

self._patch_global_config(
monkeypatch,
agent_name="my_agent",
datasets=[
{"jsonl_fpath": str(input_fpath), "prompt_config": str(catalog_prompt)},
],
)

config = RolloutCollectionConfig(
agent_name="my_agent",
input_jsonl_fpath=str(input_fpath),
output_jsonl_fpath=str(tmp_path / "out.jsonl"),
)

result = RolloutCollectionHelper._preprocess_rows_from_config(None, config)
assert len(result) == 1
assert result[0]["responses_create_params"]["input"] == [
{"role": "user", "content": "CATALOG: What is 2+2?"},
]

def test_no_match_leaves_prompt_cfg_none(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""If no catalog entry matches the input, behavior is unchanged (rows pass through).

Rows must already have responses_create_params.input since no prompt is applied.
"""
catalog_prompt = tmp_path / "catalog.yaml"
self._write_prompt_yaml(catalog_prompt, "CATALOG: {question}")

# The input file the user passes:
input_fpath = tmp_path / "input.jsonl"
# These rows already have pre-rendered input (the legacy path).
self._write_input_jsonl(
input_fpath,
[{"responses_create_params": {"input": [{"role": "user", "content": "pre-rendered"}]}}],
)

# The agent's catalog has a different jsonl path, so it should NOT match.
other_jsonl = tmp_path / "other.jsonl"
other_jsonl.write_text("{}\n")

self._patch_global_config(
monkeypatch,
agent_name="my_agent",
datasets=[
{"jsonl_fpath": str(other_jsonl), "prompt_config": str(catalog_prompt)},
],
)

config = RolloutCollectionConfig(
agent_name="my_agent",
input_jsonl_fpath=str(input_fpath),
output_jsonl_fpath=str(tmp_path / "out.jsonl"),
)

result = RolloutCollectionHelper._preprocess_rows_from_config(None, config)
# prompt_cfg stayed None, so the pre-rendered input is preserved.
assert result[0]["responses_create_params"]["input"] == [
{"role": "user", "content": "pre-rendered"},
]

def test_no_agent_name_leaves_prompt_cfg_none(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""When agent_name is None (row-level agent_ref), fallback is skipped entirely."""
import nemo_gym.rollout_collection as rc_module

# Even if the global config has data, it should be ignored without agent_name.
monkeypatch.setattr(
rc_module,
"get_global_config_dict",
lambda: {"some_agent": {"responses_api_agents": {"inner": {"datasets": []}}}},
)

input_fpath = tmp_path / "input.jsonl"
self._write_input_jsonl(
input_fpath,
[
{
"responses_create_params": {"input": [{"role": "user", "content": "pre-rendered"}]},
"agent_ref": {"name": "inline_agent"},
}
],
)

config = RolloutCollectionConfig(
input_jsonl_fpath=str(input_fpath),
output_jsonl_fpath=str(tmp_path / "out.jsonl"),
)

result = RolloutCollectionHelper._preprocess_rows_from_config(None, config)
# prompt_cfg stayed None; pre-rendered input preserved.
assert result[0]["responses_create_params"]["input"] == [
{"role": "user", "content": "pre-rendered"},
]

def test_multiple_catalog_entries_picks_matching(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""With multiple dataset entries, the one with matching jsonl_fpath is selected."""
prompt_a = tmp_path / "prompt_a.yaml"
self._write_prompt_yaml(prompt_a, "A: {question}")
prompt_b = tmp_path / "prompt_b.yaml"
self._write_prompt_yaml(prompt_b, "B: {question}")

input_a = tmp_path / "dataset_a.jsonl"
self._write_input_jsonl(input_a, [{"question": "from_a"}])
input_b = tmp_path / "dataset_b.jsonl"
self._write_input_jsonl(input_b, [{"question": "from_b"}])

self._patch_global_config(
monkeypatch,
agent_name="my_agent",
datasets=[
{"jsonl_fpath": str(input_a), "prompt_config": str(prompt_a)},
{"jsonl_fpath": str(input_b), "prompt_config": str(prompt_b)},
],
)

# User points at input_b, so prompt_b should be selected.
config = RolloutCollectionConfig(
agent_name="my_agent",
input_jsonl_fpath=str(input_b),
output_jsonl_fpath=str(tmp_path / "out.jsonl"),
)

result = RolloutCollectionHelper._preprocess_rows_from_config(None, config)
assert result[0]["responses_create_params"]["input"] == [
{"role": "user", "content": "B: from_b"},
]
Loading