diff --git a/nemo_gym/rollout_collection.py b/nemo_gym/rollout_collection.py index 38d709d7b..0c5c8f15e 100644 --- a/nemo_gym/rollout_collection.py +++ b/nemo_gym/rollout_collection.py @@ -172,6 +172,32 @@ def _preprocess_rows_from_config(self, config: RolloutCollectionConfig) -> List[ rows_iterator: Iterator[tuple[int, str]] = zip(range_iterator, rows_iterator) raw_rows = [(row_idx, row_str, orjson.loads(row_str)) for row_idx, row_str in rows_iterator] + # Fallback: if prompt_config wasn't explicitly set, try to infer it from the + # agent's dataset catalog. Each BenchmarkDatasetConfig entry has a + # (jsonl_fpath, prompt_config) pair; when the user specifies input_jsonl_fpath + # matching one of those entries, use its declared prompt_config. This avoids + # forcing the user to re-specify the prompt that's already in the agent YAML. + if prompt_cfg is None and config.agent_name: + global_cfg = get_global_config_dict() + agent_cfg = global_cfg.get(config.agent_name) if global_cfg is not None else None + if agent_cfg is not None: + agents_section = agent_cfg.get("responses_api_agents") or {} + # agents_section has exactly one entry (enforced by schema: Dict[...] min=1 max=1) + inner = next(iter(agents_section.values()), None) or {} + for ds in inner.get("datasets") or []: + ds_path = ds.get("jsonl_fpath") + ds_prompt = ds.get("prompt_config") + if not ds_path or not ds_prompt: + continue + try: + same = Path(str(ds_path)).resolve() == _input_path.resolve() + except OSError: + same = str(ds_path) == str(_input_path) + if same: + prompt_cfg = load_prompt_config(str(ds_prompt)) + print(f"Using prompt config from agent '{config.agent_name}' dataset catalog: {ds_prompt}") + break + # Validate and apply prompt config before per-row processing if prompt_cfg is not None: validate_prompt_compatibility([row for _, _, row in raw_rows], prompt_cfg) diff --git a/tests/unit_tests/test_rollout_collection.py b/tests/unit_tests/test_rollout_collection.py index 7af09006e..2ff0e3ab6 100644 --- a/tests/unit_tests/test_rollout_collection.py +++ b/tests/unit_tests/test_rollout_collection.py @@ -98,7 +98,11 @@ def test_preprocess_rows_prompt_config_preserves_rcp_fields(self, tmp_path: Path assert result[0]["responses_create_params"]["tools"] == [{"type": "function", "name": "calc"}] assert result[0]["responses_create_params"]["input"] == [{"role": "user", "content": "test"}] - def test_preprocess_rows_from_config(self, tmp_path: Path) -> None: + def test_preprocess_rows_from_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_gym.rollout_collection as rc_module + + monkeypatch.setattr(rc_module, "get_global_config_dict", lambda: {}) + fpath = tmp_path / "input.jsonl" samples = [json.dumps({"responses_create_params": {"input": []}, "x": i}) for i in range(10)] fpath.write_text("\n".join(samples) + "\n") @@ -542,3 +546,203 @@ async def test_call_aggregate_metrics_empty(self, tmp_path: Path) -> None: output_fpath = tmp_path / "output.jsonl" result = await helper._call_aggregate_metrics([], [], output_fpath) assert result is None + + +class TestPromptConfigFromCatalog: + """Tests for defaulting prompt_config from the agent's dataset catalog. + + When +prompt_config is not explicitly specified, _preprocess_rows_from_config + looks up the agent_name in the global config dict and finds a matching + datasets[*] entry by jsonl_fpath, then uses its prompt_config. + """ + + @staticmethod + def _write_prompt_yaml(path: Path, user_template: str, system: str | None = None) -> None: + data: dict = {"user": user_template} + if system is not None: + data["system"] = system + path.write_text(yaml.dump(data)) + + @staticmethod + def _write_input_jsonl(path: Path, rows: list[dict]) -> None: + path.write_text("\n".join(json.dumps(r) for r in rows) + "\n") + + @staticmethod + def _patch_global_config(monkeypatch: pytest.MonkeyPatch, agent_name: str, datasets: list[dict]) -> None: + """Patch get_global_config_dict to return an agent config with the given datasets.""" + import nemo_gym.rollout_collection as rc_module + + global_cfg = { + agent_name: { + "responses_api_agents": { + "inner_agent": {"datasets": datasets}, + }, + }, + } + monkeypatch.setattr(rc_module, "get_global_config_dict", lambda: global_cfg) + + def test_explicit_prompt_config_wins_over_catalog(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Explicit +prompt_config takes priority over the agent's catalog prompt_config.""" + explicit_prompt = tmp_path / "explicit.yaml" + self._write_prompt_yaml(explicit_prompt, "EXPLICIT: {question}") + + catalog_prompt = tmp_path / "catalog.yaml" + self._write_prompt_yaml(catalog_prompt, "CATALOG: {question}") + + input_fpath = tmp_path / "input.jsonl" + self._write_input_jsonl(input_fpath, [{"question": "What is 2+2?"}]) + + self._patch_global_config( + monkeypatch, + agent_name="my_agent", + datasets=[ + {"jsonl_fpath": str(input_fpath), "prompt_config": str(catalog_prompt)}, + ], + ) + + config = RolloutCollectionConfig( + agent_name="my_agent", + input_jsonl_fpath=str(input_fpath), + output_jsonl_fpath=str(tmp_path / "out.jsonl"), + prompt_config=str(explicit_prompt), + ) + + result = RolloutCollectionHelper._preprocess_rows_from_config(None, config) + assert len(result) == 1 + assert result[0]["responses_create_params"]["input"] == [ + {"role": "user", "content": "EXPLICIT: What is 2+2?"}, + ] + + def test_catalog_fallback_used_when_input_matches(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Without explicit prompt_config, the agent's matching dataset entry is used.""" + catalog_prompt = tmp_path / "catalog.yaml" + self._write_prompt_yaml(catalog_prompt, "CATALOG: {question}") + + input_fpath = tmp_path / "input.jsonl" + self._write_input_jsonl(input_fpath, [{"question": "What is 2+2?"}]) + + self._patch_global_config( + monkeypatch, + agent_name="my_agent", + datasets=[ + {"jsonl_fpath": str(input_fpath), "prompt_config": str(catalog_prompt)}, + ], + ) + + config = RolloutCollectionConfig( + agent_name="my_agent", + input_jsonl_fpath=str(input_fpath), + output_jsonl_fpath=str(tmp_path / "out.jsonl"), + ) + + result = RolloutCollectionHelper._preprocess_rows_from_config(None, config) + assert len(result) == 1 + assert result[0]["responses_create_params"]["input"] == [ + {"role": "user", "content": "CATALOG: What is 2+2?"}, + ] + + def test_no_match_leaves_prompt_cfg_none(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """If no catalog entry matches the input, behavior is unchanged (rows pass through). + + Rows must already have responses_create_params.input since no prompt is applied. + """ + catalog_prompt = tmp_path / "catalog.yaml" + self._write_prompt_yaml(catalog_prompt, "CATALOG: {question}") + + # The input file the user passes: + input_fpath = tmp_path / "input.jsonl" + # These rows already have pre-rendered input (the legacy path). + self._write_input_jsonl( + input_fpath, + [{"responses_create_params": {"input": [{"role": "user", "content": "pre-rendered"}]}}], + ) + + # The agent's catalog has a different jsonl path, so it should NOT match. + other_jsonl = tmp_path / "other.jsonl" + other_jsonl.write_text("{}\n") + + self._patch_global_config( + monkeypatch, + agent_name="my_agent", + datasets=[ + {"jsonl_fpath": str(other_jsonl), "prompt_config": str(catalog_prompt)}, + ], + ) + + config = RolloutCollectionConfig( + agent_name="my_agent", + input_jsonl_fpath=str(input_fpath), + output_jsonl_fpath=str(tmp_path / "out.jsonl"), + ) + + result = RolloutCollectionHelper._preprocess_rows_from_config(None, config) + # prompt_cfg stayed None, so the pre-rendered input is preserved. + assert result[0]["responses_create_params"]["input"] == [ + {"role": "user", "content": "pre-rendered"}, + ] + + def test_no_agent_name_leaves_prompt_cfg_none(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """When agent_name is None (row-level agent_ref), fallback is skipped entirely.""" + import nemo_gym.rollout_collection as rc_module + + # Even if the global config has data, it should be ignored without agent_name. + monkeypatch.setattr( + rc_module, + "get_global_config_dict", + lambda: {"some_agent": {"responses_api_agents": {"inner": {"datasets": []}}}}, + ) + + input_fpath = tmp_path / "input.jsonl" + self._write_input_jsonl( + input_fpath, + [ + { + "responses_create_params": {"input": [{"role": "user", "content": "pre-rendered"}]}, + "agent_ref": {"name": "inline_agent"}, + } + ], + ) + + config = RolloutCollectionConfig( + input_jsonl_fpath=str(input_fpath), + output_jsonl_fpath=str(tmp_path / "out.jsonl"), + ) + + result = RolloutCollectionHelper._preprocess_rows_from_config(None, config) + # prompt_cfg stayed None; pre-rendered input preserved. + assert result[0]["responses_create_params"]["input"] == [ + {"role": "user", "content": "pre-rendered"}, + ] + + def test_multiple_catalog_entries_picks_matching(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """With multiple dataset entries, the one with matching jsonl_fpath is selected.""" + prompt_a = tmp_path / "prompt_a.yaml" + self._write_prompt_yaml(prompt_a, "A: {question}") + prompt_b = tmp_path / "prompt_b.yaml" + self._write_prompt_yaml(prompt_b, "B: {question}") + + input_a = tmp_path / "dataset_a.jsonl" + self._write_input_jsonl(input_a, [{"question": "from_a"}]) + input_b = tmp_path / "dataset_b.jsonl" + self._write_input_jsonl(input_b, [{"question": "from_b"}]) + + self._patch_global_config( + monkeypatch, + agent_name="my_agent", + datasets=[ + {"jsonl_fpath": str(input_a), "prompt_config": str(prompt_a)}, + {"jsonl_fpath": str(input_b), "prompt_config": str(prompt_b)}, + ], + ) + + # User points at input_b, so prompt_b should be selected. + config = RolloutCollectionConfig( + agent_name="my_agent", + input_jsonl_fpath=str(input_b), + output_jsonl_fpath=str(tmp_path / "out.jsonl"), + ) + + result = RolloutCollectionHelper._preprocess_rows_from_config(None, config) + assert result[0]["responses_create_params"]["input"] == [ + {"role": "user", "content": "B: from_b"}, + ]