diff --git a/README.md b/README.md index 1e887deef..d9913d7b5 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,5 @@ # NeMo Gym -[![PyPI](https://img.shields.io/pypi/v/nemo-gym)](https://pypi.org/project/nemo-gym/) -[![Python](https://img.shields.io/pypi/pyversions/nemo-gym)](https://pypi.org/project/nemo-gym/) -[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -[![CI](https://github.com/NVIDIA-NeMo/Gym/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/NVIDIA-NeMo/Gym/actions/workflows/unit-tests.yml) -[![Docs](https://img.shields.io/badge/docs-NVIDIA-brightgreen)](https://docs.nvidia.com/nemo/gym/latest/) - **[Requirements](#-requirements)** • **[Quick Start](#-quick-start)** • **[Available Environments](#-available-environments)** • **[Documentation & Resources](#-documentation--resources)** • **[Community & Support](#-community--support)** • **[Citations](#-citations)** NeMo Gym is a library for building reinforcement learning (RL) training environments for large language models (LLMs). It provides infrastructure to develop environments, scale rollout collection, and integrate seamlessly with your preferred training framework. @@ -164,7 +158,9 @@ The Dataset column links to publicly available datasets (e.g., on HuggingFace). | Arc Agi | knowledge | Solve puzzles designed to test intelligence. See https://arcprize.org/arc-agi. | Improve puzzle-solving capabilities. | - | ✓ | - | arc_agi.yaml | - | | Aviary | agent | Multi-hop question answering on the HotPotQA dataset with Wikipedia search | Improve knowledge and agentic capability | ✓ | ✓ | Apache 2.0 | hotpotqa_aviary.yaml | - | | Aviary | math | GSM8k benchmark with calculator tool | Test math and agentic capability | ✓ | ✓ | Apache 2.0 | gsm8k_aviary.yaml | - | +| Base Gymnasium | other | Base class for Gymnasium-style servers. Not a standalone server. | Reusable base class for step/reset style environments | - | - | - | base_gymnasium.yaml | - | | Bird Sql | coding | Text-to-SQL with execution-based evaluation on BIRD dev (1534 SQLite tasks). Binary reward from unordered result-set equality. | Improve text-to-SQL capabilities on BIRD's realistic dev split using execution-based binary reward without an LLM judge. | - | - | - | bird_sql.yaml | - | +| Blackjack | games | Blackjack. Model hits or stands. Reward +1 win, 0 draw, -1 loss/bust. | Example gymnasium-style multi-step environment | - | - | - | blackjack.yaml | - | | Browsecomp Advanced Harness | agent | Model uses search tools to satisfy a user query. | Measure agentic search capability | - | - | - | browsecomp_advanced_harness.yaml | - | | Calendar | agent | Multi-turn calendar scheduling dataset. User states events and constraints in natural language; model schedules events to satisfy all constraints. | Improve multi-turn instruction following capabilities | ✓ | ✓ | Apache 2.0 | calendar.yaml | Nemotron-RL-agent-calendar_scheduling | | Calendar | agent | Multi-turn calendar scheduling dataset. User states events and constraints in natural language; model schedules events to satisfy all constraints. | Improve multi-turn instruction following capabilities | ✓ | ✓ | Creative Commons Attribution 4.0 International | calendar_v2.yaml | Nemotron-RL-Instruction-Following-Calendar-v2 | diff --git a/resources_servers/base_gymnasium/README.md b/resources_servers/base_gymnasium/README.md new file mode 100644 index 000000000..0cef836b5 --- /dev/null +++ b/resources_servers/base_gymnasium/README.md @@ -0,0 +1,198 @@ +# Gymnasium + +`GymnasiumServer` is a [Gymnasium](https://gymnasium.farama.org/)-style base class for resources servers. Implement `step()`, optionally `reset()`, and use with `gymnasium_agent`. Not a standalone server. + +```python +from resources_servers.base_gymnasium import GymnasiumServer +``` + +## Interface + +```python +from resources_servers.base_gymnasium import GymnasiumServer +from nemo_gym.openai_utils import NeMoGymResponse + +class MyEnv(GymnasiumServer): + async def reset(self, metadata: dict, session_id=None) -> tuple[str | None, dict]: + return None, {} # (observation, info); observation (if set) is appended to input + + async def step(self, action: NeMoGymResponse, metadata: dict, session_id=None) -> tuple[str | None, float, bool, bool, dict]: + ... # (observation, reward, terminated, truncated, info) +``` + +`reset()` runs once per episode. `step()` runs after each model response and returns the 5-tuple: + +- **observation**: next message to the model, or `None` to end the episode. +- **reward**: per-step reward; the agent sums across steps by default. +- **terminated**: episode ended naturally (task solved, game over). +- **truncated**: episode cut short (step limit, timeout). +- **info**: extra metadata the env wants to return (debug info, scores, stats). Also how the env sends `tool_outputs` to the agent (see Tool Use). + +Only `step()` is required. The default `reset()` returns `(None, {})`, meaning the prompt from the dataset is used as-is. A non-`None` observation from `reset()` is appended to the input as a user message before the first model call. + +The three arguments shown in the signatures above: + +- **`metadata`**: any extra fields from the JSONL row (e.g. `ground_truth`, `category`). Use for task config or scoring references. Access via `metadata.get("field")`. +- **`session_id`**: unique string per rollout. Use as a key into `self.session_state` for per-episode state (game boards, conversation history, etc.). Stateless envs can ignore it. +- **`action`**: the model's `NeMoGymResponse` for the current turn. Use `extract_text(action)` for text or iterate `action.output` for structured items like `function_call`. + +## Single-Step + +Single-step environments are the common non-agentic case: one model call, then grade the output. Implement `step()` so it always returns `terminated=True`. + +```python +class MySingleStepEnv(GymnasiumServer): + async def step(self, action, metadata, session_id=None): + response_text = extract_text(action) + reward = 1.0 if metadata.get("answer") in response_text else 0.0 + return None, reward, True, False, {} +``` + +## Multi-Step with Action Tags + +Multiple model calls per episode without native tool calling. The model uses `` tags in its output; `step()` parses them and returns the next observation or terminates. + +```python +import re + +class BlackjackEnv(GymnasiumServer): + async def reset(self, metadata, session_id=None): + hand = deal_hand() + self.session_state[session_id] = hand + return f"Your hand: {hand}. hit or stand?", {} + + async def step(self, action, metadata, session_id=None): + text = extract_text(action) + m = re.search(r"\s*(hit|stand)\s*", text, re.IGNORECASE) + decision = m.group(1).lower() if m else "stand" + + hand = self.session_state[session_id] + if decision == "hit": + hand = hit(hand) + if bust(hand): + return None, -1.0, True, False, {} + return f"Your hand: {hand}. hit or stand?", 0.0, False, False, {} + + reward = score_against_dealer(hand) + return None, reward, True, False, {} +``` + +## Tool Use + +For tool-calling environments, `step()` inspects `action.output` for items with `type == "function_call"`, executes them, and returns per-call outputs in `info["tool_outputs"]`. The agent synthesizes proper `function_call_output` items tied to each `call_id`, so the model sees the tool_call/response structure it was trained on. Tool schemas go in `responses_create_params.tools` in your JSONL so the model knows what tools are available. + +```python +import json + +class MyToolEnv(GymnasiumServer): + async def reset(self, metadata, session_id=None): + self.session_state[session_id] = initialize(metadata) + return None, {} + + async def step(self, action, metadata, session_id=None): + tool_calls = [o for o in action.output if o.type == "function_call"] + + if tool_calls: + fns = self.session_state[session_id]["functions"] + outputs = [self.tool_output(c, fns[c.name](**json.loads(c.arguments))) for c in tool_calls] + return None, 0.0, False, False, {"tool_outputs": outputs} + + reward = self._grade(action, metadata) + return None, reward, True, False, {} +``` + +`self.tool_output(call, result)` is a helper on `GymnasiumServer` that builds the `{"call_id", "output"}` dict the agent expects (JSON-serializes the result for you). + +## Multi-Turn + +`step()` returns the next user message as the observation. The `gymnasium_agent` appends it to the conversation and calls the model again. Return `None` as the observation to end. + +```python +class MyMultiTurnEnv(GymnasiumServer): + async def reset(self, metadata, session_id=None): + self.session_state[session_id] = {"turn": 0} + return None, {} + + async def step(self, action, metadata, session_id=None): + follow_ups = metadata.get("follow_ups", []) + state = self.session_state[session_id] + + if state["turn"] < len(follow_ups): + msg = follow_ups[state["turn"]] + state["turn"] += 1 + return msg, 0.0, False, False, {} + + reward = self._grade(action, metadata) + return None, reward, True, False, {} +``` + +## LLM-as-Judge + +Use `step()` to call a judge model through `self.server_client` and score the output. The judge model must be configured as a separate model server. + +```python +class MyJudgeEnv(GymnasiumServer): + judge_server: str = "judge_model" # name of the model server in YAML + + async def step(self, action, metadata, session_id=None): + response_text = extract_text(action) + judge_input = f"Question: {metadata.get('question')}\nAnswer: {response_text}\nIs this correct? Say YES or NO." + judge_resp = await self.server_client.post( + server_name=self.judge_server, + url_path="/v1/responses", + json={"input": [{"role": "user", "content": judge_input}]}, + ) + judgment = await judge_resp.json() + reward = 1.0 if "YES" in str(judgment.get("output_text", "")).upper() else 0.0 + return None, reward, True, False, {} +``` + +## Reward Model + +Same pattern. Call a reward model endpoint and use its score directly. + +```python +class MyRewardModelEnv(GymnasiumServer): + rm_server: str = "reward_model" + + async def step(self, action, metadata, session_id=None): + resp = await self.server_client.post( + server_name=self.rm_server, + url_path="/v1/score", + json={"input": metadata.get("prompt"), "response": extract_text(action)}, + ) + score = (await resp.json()).get("score", 0.0) + return None, score, True, False, {} +``` + +## YAML Configuration + +`GymnasiumServer` pairs with `gymnasium_agent` instead of `simple_agent`. Same shape as the `simple_agent` config, with the agent referencing the environment through `resources_server`. + +```yaml +my_env_instance: + resources_servers: + my_env: + entrypoint: app.py + domain: knowledge + +my_gymnasium_agent_instance: + responses_api_agents: + gymnasium_agent: + entrypoint: app.py + resources_server: + type: resources_servers + name: my_env + model_server: + type: responses_api_models + name: policy_model + max_steps: 10 + datasets: + - name: example + type: example + jsonl_fpath: resources_servers/my_env/data/example.jsonl +``` + +## Examples + +- [`blackjack`](https://github.com/NVIDIA-NeMo/Gym/tree/main/resources_servers/blackjack): multi-step game with action tags. diff --git a/resources_servers/base_gymnasium/__init__.py b/resources_servers/base_gymnasium/__init__.py new file mode 100644 index 000000000..25779b972 --- /dev/null +++ b/resources_servers/base_gymnasium/__init__.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ( + EnvResetRequest, + EnvResetResponse, + EnvStepRequest, + EnvStepResponse, + GymnasiumServer, + extract_text, +) + + +__all__ = [ + "EnvResetRequest", + "EnvResetResponse", + "EnvStepRequest", + "EnvStepResponse", + "GymnasiumServer", + "extract_text", +] diff --git a/resources_servers/base_gymnasium/base.py b/resources_servers/base_gymnasium/base.py new file mode 100644 index 000000000..211e30d03 --- /dev/null +++ b/resources_servers/base_gymnasium/base.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from abc import abstractmethod +from typing import Any, Dict, Optional + +from fastapi import FastAPI, Request +from pydantic import BaseModel, ConfigDict, Field + +from nemo_gym.base_resources_server import BaseVerifyRequest, SimpleResourcesServer +from nemo_gym.openai_utils import ( + NeMoGymResponse, + NeMoGymResponseCreateParamsNonStreaming, + NeMoGymResponseFunctionToolCall, +) +from nemo_gym.server_utils import SESSION_ID_KEY + + +class EnvResetRequest(BaseModel): + model_config = ConfigDict(extra="allow") + responses_create_params: NeMoGymResponseCreateParamsNonStreaming + + +class EnvResetResponse(BaseModel): + observation: Optional[str] = None + info: dict = {} + + +class EnvStepRequest(BaseModel): + model_config = ConfigDict(extra="allow") + responses_create_params: NeMoGymResponseCreateParamsNonStreaming + response: NeMoGymResponse + + +class EnvStepResponse(BaseModel): + observation: Optional[str] = None + reward: float = 0.0 + terminated: bool = False + truncated: bool = False + info: dict = {} + + +def extract_text(response: NeMoGymResponse) -> str: + """Extract all text content from a NeMoGymResponse.""" + parts = [] + for item in response.output: + if item.type == "message": + content = item.content + if isinstance(content, str): + parts.append(content) + else: + for c in content: + if c.type == "output_text": + parts.append(c.text) + return "".join(parts) + + +class GymnasiumServer(SimpleResourcesServer): + """Gymnasium-style base class. Used with gymnasium_agent. + + step() returns (observation, reward, terminated, truncated, info). + """ + + session_state: Dict[str, Any] = Field(default_factory=dict) + + def setup_webserver(self) -> FastAPI: + app = FastAPI() + self.setup_session_middleware(app) + app.post("/reset")(self._reset_endpoint) + app.post("/step")(self._step_endpoint) + app.post("/aggregate_metrics")(self.aggregate_metrics) + return app + + async def _reset_endpoint(self, body: EnvResetRequest, request: Request) -> EnvResetResponse: + session_id = request.session.get(SESSION_ID_KEY) + obs, info = await self.reset(body.model_extra or {}, session_id) + return EnvResetResponse(observation=obs, info=info) + + async def _step_endpoint(self, body: EnvStepRequest, request: Request) -> EnvStepResponse: + session_id = request.session.get(SESSION_ID_KEY) + obs, reward, terminated, truncated, info = await self.step(body.response, body.model_extra or {}, session_id) + if terminated or truncated: + await self.close_session(session_id) + return EnvStepResponse(observation=obs, reward=reward, terminated=terminated, truncated=truncated, info=info) + + async def reset(self, metadata: dict, session_id: Optional[str] = None) -> tuple[Optional[str], dict]: + return None, {} + + @abstractmethod + async def step( + self, action: NeMoGymResponse, metadata: dict, session_id: Optional[str] = None + ) -> tuple[Optional[str], float, bool, bool, dict]: ... + + async def close_session(self, session_id: Optional[str]) -> None: + self.session_state.pop(session_id, None) + + @staticmethod + def tool_output(call: NeMoGymResponseFunctionToolCall, result: Any) -> dict: + return {"call_id": call.call_id, "output": json.dumps(result, default=str)} + + async def verify(self, body: BaseVerifyRequest) -> None: # type: ignore[override] + raise NotImplementedError("GymnasiumServer uses /step instead of /verify. Use with gymnasium_agent.") diff --git a/resources_servers/base_gymnasium/configs/base_gymnasium.yaml b/resources_servers/base_gymnasium/configs/base_gymnasium.yaml new file mode 100644 index 000000000..b24e8f386 --- /dev/null +++ b/resources_servers/base_gymnasium/configs/base_gymnasium.yaml @@ -0,0 +1,23 @@ +base_gymnasium: + resources_servers: + base_gymnasium: + entrypoint: base.py + domain: other + verified: false + description: Base class for Gymnasium-style servers. Not a standalone server. + value: Reusable base class for step/reset style environments +base_gymnasium_agent: + responses_api_agents: + gymnasium_agent: + entrypoint: app.py + resources_server: + type: resources_servers + name: base_gymnasium + model_server: + type: responses_api_models + name: policy_model + max_steps: 1 + datasets: + - name: example + type: example + jsonl_fpath: resources_servers/base_gymnasium/data/example.jsonl diff --git a/resources_servers/base_gymnasium/data/.gitignore b/resources_servers/base_gymnasium/data/.gitignore new file mode 100644 index 000000000..4424b6fde --- /dev/null +++ b/resources_servers/base_gymnasium/data/.gitignore @@ -0,0 +1,5 @@ +*train.jsonl +*validation.jsonl +*train_prepare.jsonl +*validation_prepare.jsonl +*example_prepare.jsonl diff --git a/resources_servers/base_gymnasium/data/example.jsonl b/resources_servers/base_gymnasium/data/example.jsonl new file mode 100644 index 000000000..0e2c6d230 --- /dev/null +++ b/resources_servers/base_gymnasium/data/example.jsonl @@ -0,0 +1,5 @@ +{"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "agent_ref": {"type": "responses_api_agents", "name": "example_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "agent_ref": {"type": "responses_api_agents", "name": "example_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "agent_ref": {"type": "responses_api_agents", "name": "example_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "agent_ref": {"type": "responses_api_agents", "name": "example_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "agent_ref": {"type": "responses_api_agents", "name": "example_gymnasium_agent"}} diff --git a/resources_servers/base_gymnasium/data/example_metrics.json b/resources_servers/base_gymnasium/data/example_metrics.json new file mode 100644 index 000000000..a6b7d4e3e --- /dev/null +++ b/resources_servers/base_gymnasium/data/example_metrics.json @@ -0,0 +1,38 @@ +{ + "name": "example", + "type": "example", + "jsonl_fpath": "resources_servers/example_gymnasium/data/example.jsonl", + "num_repeats": 1, + "gitlab_identifier": null, + "huggingface_identifier": null, + "license": null, + "Number of examples": 5, + "Number of tools": { + "Total # non-null values": 0, + "Average": 0.0, + "Min": 0.0, + "Max": 0.0, + "Standard deviation": 0.0 + }, + "Json-dumped number of words (proxy for token count)": { + "Total # non-null values": 5, + "Average": 5.0, + "Min": 5.0, + "Max": 5.0, + "Standard deviation": 0.0 + }, + "Number of turns": { + "Total # non-null values": 5, + "Average": 1.0, + "Min": 1.0, + "Max": 1.0, + "Standard deviation": 0.0 + }, + "Temperature": { + "Total # non-null values": 0, + "Average": 0.0, + "Min": 0.0, + "Max": 0.0, + "Standard deviation": 0.0 + } +} \ No newline at end of file diff --git a/resources_servers/base_gymnasium/data/example_rollouts.jsonl b/resources_servers/base_gymnasium/data/example_rollouts.jsonl new file mode 100644 index 000000000..c8d8aafcc --- /dev/null +++ b/resources_servers/base_gymnasium/data/example_rollouts.jsonl @@ -0,0 +1,5 @@ +{"reward": 0.0, "terminated": true, "truncated": false, "responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "response": {"id": "x", "created_at": 0, "model": "x", "object": "response", "output": [], "parallel_tool_calls": true, "tool_choice": "auto", "tools": []}} +{"reward": 0.0, "terminated": true, "truncated": false, "responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "response": {"id": "x", "created_at": 0, "model": "x", "object": "response", "output": [], "parallel_tool_calls": true, "tool_choice": "auto", "tools": []}} +{"reward": 0.0, "terminated": true, "truncated": false, "responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "response": {"id": "x", "created_at": 0, "model": "x", "object": "response", "output": [], "parallel_tool_calls": true, "tool_choice": "auto", "tools": []}} +{"reward": 0.0, "terminated": true, "truncated": false, "responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "response": {"id": "x", "created_at": 0, "model": "x", "object": "response", "output": [], "parallel_tool_calls": true, "tool_choice": "auto", "tools": []}} +{"reward": 0.0, "terminated": true, "truncated": false, "responses_create_params": {"input": [{"role": "user", "content": "Hello"}]}, "response": {"id": "x", "created_at": 0, "model": "x", "object": "response", "output": [], "parallel_tool_calls": true, "tool_choice": "auto", "tools": []}} diff --git a/resources_servers/base_gymnasium/requirements.txt b/resources_servers/base_gymnasium/requirements.txt new file mode 100644 index 000000000..00ed83213 --- /dev/null +++ b/resources_servers/base_gymnasium/requirements.txt @@ -0,0 +1 @@ +-e nemo-gym[dev] @ ../../ diff --git a/resources_servers/base_gymnasium/tests/__init__.py b/resources_servers/base_gymnasium/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/resources_servers/base_gymnasium/tests/test_app.py b/resources_servers/base_gymnasium/tests/test_app.py new file mode 100644 index 000000000..31b9bde10 --- /dev/null +++ b/resources_servers/base_gymnasium/tests/test_app.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from nemo_gym.base_resources_server import BaseResourcesServerConfig +from nemo_gym.openai_utils import NeMoGymResponse, NeMoGymResponseOutputMessage, NeMoGymResponseOutputText +from nemo_gym.server_utils import SESSION_ID_KEY, ServerClient +from resources_servers.base_gymnasium import EnvResetRequest, EnvStepRequest, GymnasiumServer, extract_text + + +def _make_response(*parts: str) -> NeMoGymResponse: + return NeMoGymResponse( + id="r", + created_at=0.0, + model="m", + object="response", + output=[ + NeMoGymResponseOutputMessage( + id=f"msg_{i}", + content=[NeMoGymResponseOutputText(annotations=[], text=p, type="output_text")], + role="assistant", + status="completed", + type="message", + ) + for i, p in enumerate(parts) + ], + parallel_tool_calls=True, + tool_choice="auto", + tools=[], + ) + + +class _FakeRequest: + def __init__(self, session_id="sid-1"): + self.session = {SESSION_ID_KEY: session_id} + + +class _TerminatingEnv(GymnasiumServer): + async def step(self, action, metadata, session_id=None): + return None, 1.0, True, False, {} + + +class _OngoingEnv(GymnasiumServer): + async def step(self, action, metadata, session_id=None): + return "keep going", 0.0, False, False, {} + + +class _TruncatingEnv(GymnasiumServer): + async def step(self, action, metadata, session_id=None): + return None, 0.0, False, True, {} + + +_close_log: list = [] + + +class _CustomCloseEnv(GymnasiumServer): + async def step(self, action, metadata, session_id=None): + return None, 0.0, True, False, {} + + async def close_session(self, session_id): + _close_log.append(session_id) + await super().close_session(session_id) + + +def _make_env(cls): + config = BaseResourcesServerConfig(host="", port=0, entrypoint="", name="") + return cls(config=config, server_client=MagicMock(spec=ServerClient)) + + +class TestGymnasiumServer: + def test_routes_registered(self): + env = _make_env(_TerminatingEnv) + routes = {r.path for r in env.setup_webserver().routes} + assert {"/reset", "/step", "/aggregate_metrics"}.issubset(routes) + + def test_verify_raises(self): + env = _make_env(_TerminatingEnv) + with pytest.raises(NotImplementedError): + import asyncio + + asyncio.run(env.verify(SimpleNamespace())) + + @pytest.mark.asyncio + async def test_reset_default_returns_empty(self): + env = _make_env(_TerminatingEnv) + env.session_state["sid-1"] = {"x": 1} + body = EnvResetRequest(responses_create_params={"input": []}) + resp = await env._reset_endpoint(body, _FakeRequest()) + assert resp.observation is None + assert resp.info == {} + + @pytest.mark.asyncio + async def test_step_pops_on_terminated(self): + env = _make_env(_TerminatingEnv) + env.session_state["sid-1"] = {"x": 1} + body = EnvStepRequest(responses_create_params={"input": []}, response=_make_response("x")) + resp = await env._step_endpoint(body, _FakeRequest("sid-1")) + assert resp.terminated is True + assert "sid-1" not in env.session_state + + @pytest.mark.asyncio + async def test_step_pops_on_truncated(self): + env = _make_env(_TruncatingEnv) + env.session_state["sid-1"] = {"x": 1} + body = EnvStepRequest(responses_create_params={"input": []}, response=_make_response("x")) + resp = await env._step_endpoint(body, _FakeRequest("sid-1")) + assert resp.truncated is True + assert "sid-1" not in env.session_state + + @pytest.mark.asyncio + async def test_step_keeps_state_when_ongoing(self): + env = _make_env(_OngoingEnv) + env.session_state["sid-1"] = {"x": 1} + body = EnvStepRequest(responses_create_params={"input": []}, response=_make_response("x")) + resp = await env._step_endpoint(body, _FakeRequest("sid-1")) + assert resp.terminated is False + assert resp.truncated is False + assert "sid-1" in env.session_state + + @pytest.mark.asyncio + async def test_close_session_override_invoked(self): + _close_log.clear() + env = _make_env(_CustomCloseEnv) + env.session_state["sid-1"] = {"x": 1} + body = EnvStepRequest(responses_create_params={"input": []}, response=_make_response("x")) + await env._step_endpoint(body, _FakeRequest("sid-1")) + assert _close_log == ["sid-1"] + assert "sid-1" not in env.session_state + + +class TestExtractText: + def test_concats_output_text(self): + r = _make_response("hello ", "world") + assert extract_text(r) == "hello world" + + def test_empty_output(self): + r = NeMoGymResponse( + id="r", + created_at=0.0, + model="m", + object="response", + output=[], + parallel_tool_calls=True, + tool_choice="auto", + tools=[], + ) + assert extract_text(r) == "" diff --git a/resources_servers/blackjack/README.md b/resources_servers/blackjack/README.md new file mode 100644 index 000000000..54fbff84a --- /dev/null +++ b/resources_servers/blackjack/README.md @@ -0,0 +1,26 @@ +# Blackjack Env + +Multi-step gymnasium-style environment. + +Model hits or stands using `` tags until the hand ends. Game state managed per session. + +Example data provided in `data/example.jsonl` (system prompt only, no verifier_metadata needed). No train/validation data. + +## Run + +```bash +ng_run "+config_paths=[resources_servers/blackjack/configs/blackjack.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]" +``` + +## Data + +Each game is generated on the fly during `reset()`, so every row in `example.jsonl` is identical. To create more data, duplicate the row. Each rollout gets a fresh random deal. Use `num_repeats` in the YAML config or the `+num_repeats` CLI flag to control how many games per row. + +## Collect rollouts + +```bash +ng_collect_rollouts \ + +agent_name=blackjack_gymnasium_agent \ + +input_jsonl_fpath=resources_servers/blackjack/data/example.jsonl \ + +output_jsonl_fpath=results/blackjack_rollouts.jsonl +``` diff --git a/resources_servers/blackjack/app.py b/resources_servers/blackjack/app.py new file mode 100644 index 000000000..d2f33af3f --- /dev/null +++ b/resources_servers/blackjack/app.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Blackjack environment in Gymnasium API style. + +Multi-step: the model hits or stands until the hand ends. +Reward: +1 win, 0 draw, -1 loss. +""" + +import random +import re +from typing import Optional + +from nemo_gym.openai_utils import NeMoGymResponse +from resources_servers.base_gymnasium import GymnasiumServer, extract_text + + +_RANKS = ["2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K", "A"] + + +def _deal(rng: random.Random) -> str: + return rng.choice(_RANKS) + + +def _hand_value(hand: list[str]) -> int: + total = sum(10 if r in ("J", "Q", "K") else 11 if r == "A" else int(r) for r in hand) + aces = hand.count("A") + while total > 21 and aces: + total -= 10 + aces -= 1 + return total + + +def _fmt(hand: list[str]) -> str: + return "[" + ", ".join(hand) + "]" + + +class BlackjackEnv(GymnasiumServer): + async def reset(self, metadata: dict, session_id: Optional[str] = None) -> tuple[Optional[str], dict]: + rng = random.Random() + player = [_deal(rng), _deal(rng)] + dealer = [_deal(rng), _deal(rng)] + self.session_state[session_id] = {"player": player, "dealer": dealer, "rng": rng} + obs = ( + f"Your hand: {_fmt(player)} = {_hand_value(player)}\n" + f"Dealer shows: {dealer[0]}\n" + f"Respond with hit or stand." + ) + return obs, {} + + async def step( + self, action: NeMoGymResponse, metadata: dict, session_id: Optional[str] = None + ) -> tuple[Optional[str], float, bool, bool, dict]: + state = self.session_state.get(session_id, {}) + player = state.get("player", []) + dealer = state.get("dealer", []) + rng = state.get("rng") or random.Random() + text = extract_text(action) + m = re.search(r"\s*(hit|stand)\s*", text, re.IGNORECASE) + decision = m.group(1).lower() if m else "stand" + + if decision == "hit": + player.append(_deal(rng)) + val = _hand_value(player) + if val > 21: + return None, -1.0, True, False, {"result": "bust", "player": _fmt(player), "value": val} + obs = ( + f"Your hand: {_fmt(player)} = {val}\n" + f"Dealer shows: {dealer[0]}\n" + f"Respond with hit or stand." + ) + return obs, 0.0, False, False, {} + + while _hand_value(dealer) < 17: + dealer.append(_deal(rng)) + + pv, dv = _hand_value(player), _hand_value(dealer) + if dv > 21 or pv > dv: + reward, result = 1.0, "win" + elif pv == dv: + reward, result = 0.0, "draw" + else: + reward, result = -1.0, "loss" + + return ( + None, + reward, + True, + False, + { + "result": result, + "player": _fmt(player), + "player_value": pv, + "dealer": _fmt(dealer), + "dealer_value": dv, + }, + ) + + +if __name__ == "__main__": + BlackjackEnv.run_webserver() diff --git a/resources_servers/blackjack/configs/blackjack.yaml b/resources_servers/blackjack/configs/blackjack.yaml new file mode 100644 index 000000000..0dae1e2d8 --- /dev/null +++ b/resources_servers/blackjack/configs/blackjack.yaml @@ -0,0 +1,23 @@ +blackjack: + resources_servers: + blackjack: + entrypoint: app.py + domain: games + verified: false + description: Blackjack. Model hits or stands. Reward +1 win, 0 draw, -1 loss/bust. + value: Example gymnasium-style multi-step environment +blackjack_gymnasium_agent: + responses_api_agents: + gymnasium_agent: + entrypoint: app.py + resources_server: + type: resources_servers + name: blackjack + model_server: + type: responses_api_models + name: policy_model + max_steps: 5 + datasets: + - name: example + type: example + jsonl_fpath: resources_servers/blackjack/data/example.jsonl diff --git a/resources_servers/blackjack/data/.gitignore b/resources_servers/blackjack/data/.gitignore new file mode 100644 index 000000000..4424b6fde --- /dev/null +++ b/resources_servers/blackjack/data/.gitignore @@ -0,0 +1,5 @@ +*train.jsonl +*validation.jsonl +*train_prepare.jsonl +*validation_prepare.jsonl +*example_prepare.jsonl diff --git a/resources_servers/blackjack/data/example.jsonl b/resources_servers/blackjack/data/example.jsonl new file mode 100644 index 000000000..f48a084e1 --- /dev/null +++ b/resources_servers/blackjack/data/example.jsonl @@ -0,0 +1,5 @@ +{"responses_create_params": {"input": [{"role": "system", "content": "You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."}, {"role": "user", "content": "Deal me in."}]}, "agent_ref": {"type": "responses_api_agents", "name": "blackjack_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "system", "content": "You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."}, {"role": "user", "content": "Deal me in."}]}, "agent_ref": {"type": "responses_api_agents", "name": "blackjack_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "system", "content": "You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."}, {"role": "user", "content": "Deal me in."}]}, "agent_ref": {"type": "responses_api_agents", "name": "blackjack_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "system", "content": "You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."}, {"role": "user", "content": "Deal me in."}]}, "agent_ref": {"type": "responses_api_agents", "name": "blackjack_gymnasium_agent"}} +{"responses_create_params": {"input": [{"role": "system", "content": "You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."}, {"role": "user", "content": "Deal me in."}]}, "agent_ref": {"type": "responses_api_agents", "name": "blackjack_gymnasium_agent"}} diff --git a/resources_servers/blackjack/data/example_metrics.json b/resources_servers/blackjack/data/example_metrics.json new file mode 100644 index 000000000..ca1f0810a --- /dev/null +++ b/resources_servers/blackjack/data/example_metrics.json @@ -0,0 +1,38 @@ +{ + "name": "example", + "type": "example", + "jsonl_fpath": "resources_servers/blackjack/data/example.jsonl", + "num_repeats": 1, + "gitlab_identifier": null, + "huggingface_identifier": null, + "license": null, + "Number of examples": 5, + "Number of tools": { + "Total # non-null values": 0, + "Average": 0.0, + "Min": 0.0, + "Max": 0.0, + "Standard deviation": 0.0 + }, + "Json-dumped number of words (proxy for token count)": { + "Total # non-null values": 5, + "Average": 30.0, + "Min": 30.0, + "Max": 30.0, + "Standard deviation": 0.0 + }, + "Number of turns": { + "Total # non-null values": 5, + "Average": 1.0, + "Min": 1.0, + "Max": 1.0, + "Standard deviation": 0.0 + }, + "Temperature": { + "Total # non-null values": 0, + "Average": 0.0, + "Min": 0.0, + "Max": 0.0, + "Standard deviation": 0.0 + } +} \ No newline at end of file diff --git a/resources_servers/blackjack/data/example_rollouts.jsonl b/resources_servers/blackjack/data/example_rollouts.jsonl new file mode 100644 index 000000000..6175f946f --- /dev/null +++ b/resources_servers/blackjack/data/example_rollouts.jsonl @@ -0,0 +1,5 @@ +{"responses_create_params":{"background":null,"include":null,"input":[{"content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag.","role":"system","type":"message"},{"content":"Deal me in.","role":"user","type":"message"}],"instructions":null,"max_output_tokens":1024,"max_tool_calls":null,"metadata":null,"model":null,"parallel_tool_calls":true,"previous_response_id":null,"prompt":null,"reasoning":null,"service_tier":null,"store":null,"temperature":0.0,"text":null,"tool_choice":"auto","tools":[],"top_logprobs":null,"top_p":null,"truncation":null,"user":null,"stream":null},"response":{"id":"resp_0816d85808c14ee38cb42be5b1ae4863","created_at":1775592436.0,"error":null,"incomplete_details":null,"instructions":null,"metadata":null,"model":"Qwen/Qwen3-4B-Instruct-2507","object":"response","output":[{"prompt_token_ids":[151644,872,198,7771,1424,25,508,19,11,220,17,60,284,220,21,198,93909,4933,25,220,19,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198],"generation_token_ids":[27,1311,29,2685,522,1311,29,151645],"generation_log_probs":[-0.0001250427303602919,-1.5497195136049413e-6,-1.1920928244535389e-7,-0.693234920501709,0.0,0.0,-3.576278118089249e-7,-3.2186455882765586e-6],"id":"msg_b266aeb4d0be4afc80b0931cb97cd083","content":[{"annotations":[],"text":"stand","type":"output_text","logprobs":null}],"role":"assistant","status":"completed","type":"message"}],"parallel_tool_calls":true,"temperature":0.0,"tool_choice":"auto","tools":[],"top_p":null,"background":null,"conversation":null,"max_output_tokens":1024,"max_tool_calls":null,"previous_response_id":null,"prompt":null,"prompt_cache_key":null,"reasoning":null,"safety_identifier":null,"service_tier":null,"status":null,"text":null,"top_logprobs":null,"truncation":null,"usage":{"input_tokens":44,"input_tokens_details":{"cached_tokens":0},"output_tokens":8,"output_tokens_details":{"reasoning_tokens":0},"total_tokens":52},"user":null},"reward":-1.0,"terminated":true,"truncated":false,"info":{"result":"loss","player":"[4, 2]","player_value":6,"dealer":"[4, 2, 2, K]","dealer_value":18},"_ng_task_index":0,"_ng_rollout_index":1,"agent_ref":{"name":"blackjack_agent"}} +{"responses_create_params":{"background":null,"include":null,"input":[{"content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag.","role":"system","type":"message"},{"content":"Deal me in.","role":"user","type":"message"}],"instructions":null,"max_output_tokens":1024,"max_tool_calls":null,"metadata":null,"model":null,"parallel_tool_calls":true,"previous_response_id":null,"prompt":null,"reasoning":null,"service_tier":null,"store":null,"temperature":0.0,"text":null,"tool_choice":"auto","tools":[],"top_logprobs":null,"top_p":null,"truncation":null,"user":null,"stream":null},"response":{"id":"resp_c018550a3d174fe987939919762706f8","created_at":1775592436.0,"error":null,"incomplete_details":null,"instructions":null,"metadata":null,"model":"Qwen/Qwen3-4B-Instruct-2507","object":"response","output":[{"prompt_token_ids":[151644,872,198,7771,1424,25,508,24,11,220,24,60,284,220,16,23,198,93909,4933,25,220,18,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198],"generation_token_ids":[27,1311,29,2685,522,1311,29,151645],"generation_log_probs":[-1.5497195136049413e-6,-1.6689286894688848e-6,0.0,-0.00012838016846217215,0.0,0.0,-2.3841855067985307e-7,-3.576272320060525e-6],"id":"msg_02f9a7cb0eac41a8b13c2c690534a43b","content":[{"annotations":[],"text":"stand","type":"output_text","logprobs":null}],"role":"assistant","status":"completed","type":"message"}],"parallel_tool_calls":true,"temperature":0.0,"tool_choice":"auto","tools":[],"top_p":null,"background":null,"conversation":null,"max_output_tokens":1024,"max_tool_calls":null,"previous_response_id":null,"prompt":null,"prompt_cache_key":null,"reasoning":null,"safety_identifier":null,"service_tier":null,"status":null,"text":null,"top_logprobs":null,"truncation":null,"usage":{"input_tokens":45,"input_tokens_details":{"cached_tokens":0},"output_tokens":8,"output_tokens_details":{"reasoning_tokens":0},"total_tokens":53},"user":null},"reward":1.0,"terminated":true,"truncated":false,"info":{"result":"win","player":"[9, 9]","player_value":18,"dealer":"[3, 9, 3, Q]","dealer_value":25},"_ng_task_index":0,"_ng_rollout_index":2,"agent_ref":{"name":"blackjack_agent"}} +{"responses_create_params":{"background":null,"include":null,"input":[{"content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag.","role":"system","type":"message"},{"content":"Deal me in.","role":"user","type":"message"}],"instructions":null,"max_output_tokens":1024,"max_tool_calls":null,"metadata":null,"model":null,"parallel_tool_calls":true,"previous_response_id":null,"prompt":null,"reasoning":null,"service_tier":null,"store":null,"temperature":0.0,"text":null,"tool_choice":"auto","tools":[],"top_logprobs":null,"top_p":null,"truncation":null,"user":null,"stream":null},"response":{"id":"resp_1f8b38af629d4252818b94897f4e599a","created_at":1775592436.0,"error":null,"incomplete_details":null,"instructions":null,"metadata":null,"model":"Qwen/Qwen3-4B-Instruct-2507","object":"response","output":[{"prompt_token_ids":[151644,872,198,7771,1424,25,508,19,11,619,60,284,220,16,19,198,93909,4933,25,362,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198],"generation_token_ids":[27,1311,29,2685,522,1311,29,151645],"generation_log_probs":[-0.000011920858014491387,-2.0265558760002023e-6,0.0,-0.038049791008234024,0.0,0.0,-3.576278118089249e-7,-3.3378546504536644e-6],"id":"msg_9046c12edb9d48d992ce6cd8dbeed1eb","content":[{"annotations":[],"text":"stand","type":"output_text","logprobs":null}],"role":"assistant","status":"completed","type":"message"}],"parallel_tool_calls":true,"temperature":0.0,"tool_choice":"auto","tools":[],"top_p":null,"background":null,"conversation":null,"max_output_tokens":1024,"max_tool_calls":null,"previous_response_id":null,"prompt":null,"prompt_cache_key":null,"reasoning":null,"safety_identifier":null,"service_tier":null,"status":null,"text":null,"top_logprobs":null,"truncation":null,"usage":{"input_tokens":43,"input_tokens_details":{"cached_tokens":0},"output_tokens":8,"output_tokens_details":{"reasoning_tokens":0},"total_tokens":51},"user":null},"reward":-1.0,"terminated":true,"truncated":false,"info":{"result":"loss","player":"[4, J]","player_value":14,"dealer":"[A, J]","dealer_value":21},"_ng_task_index":0,"_ng_rollout_index":4,"agent_ref":{"name":"blackjack_agent"}} +{"responses_create_params":{"background":null,"include":null,"input":[{"content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag.","role":"system","type":"message"},{"content":"Deal me in.","role":"user","type":"message"}],"instructions":null,"max_output_tokens":1024,"max_tool_calls":null,"metadata":null,"model":null,"parallel_tool_calls":true,"previous_response_id":null,"prompt":null,"reasoning":null,"service_tier":null,"store":null,"temperature":0.0,"text":null,"tool_choice":"auto","tools":[],"top_logprobs":null,"top_p":null,"truncation":null,"user":null,"stream":null},"response":{"id":"resp_1d3acaf11fe142ffaaddc9925de94d78","created_at":1775592436.0,"error":null,"incomplete_details":null,"instructions":null,"metadata":null,"model":"Qwen/Qwen3-4B-Instruct-2507","object":"response","output":[{"prompt_token_ids":[151644,872,198,7771,1424,25,508,22,11,220,16,15,60,284,220,16,22,198,93909,4933,25,619,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198],"generation_token_ids":[27,1311,29,2685,522,1311,29,151645],"generation_log_probs":[-0.00004053033626405522,-1.6689286894688848e-6,0.0,-0.0011742371134459972,0.0,0.0,-2.3841855067985307e-7,-2.50339189733495e-6],"id":"msg_b2b71c2b4b5d4e059b77886de80db892","content":[{"annotations":[],"text":"stand","type":"output_text","logprobs":null}],"role":"assistant","status":"completed","type":"message"}],"parallel_tool_calls":true,"temperature":0.0,"tool_choice":"auto","tools":[],"top_p":null,"background":null,"conversation":null,"max_output_tokens":1024,"max_tool_calls":null,"previous_response_id":null,"prompt":null,"prompt_cache_key":null,"reasoning":null,"safety_identifier":null,"service_tier":null,"status":null,"text":null,"top_logprobs":null,"truncation":null,"usage":{"input_tokens":45,"input_tokens_details":{"cached_tokens":0},"output_tokens":8,"output_tokens_details":{"reasoning_tokens":0},"total_tokens":53},"user":null},"reward":-1.0,"terminated":true,"truncated":false,"info":{"result":"loss","player":"[7, 10]","player_value":17,"dealer":"[J, 5, 3]","dealer_value":18},"_ng_task_index":0,"_ng_rollout_index":3,"agent_ref":{"name":"blackjack_agent"}} +{"responses_create_params":{"background":null,"include":null,"input":[{"content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag.","role":"system","type":"message"},{"content":"Deal me in.","role":"user","type":"message"}],"instructions":null,"max_output_tokens":1024,"max_tool_calls":null,"metadata":null,"model":null,"parallel_tool_calls":true,"previous_response_id":null,"prompt":null,"reasoning":null,"service_tier":null,"store":null,"temperature":0.0,"text":null,"tool_choice":"auto","tools":[],"top_logprobs":null,"top_p":null,"truncation":null,"user":null,"stream":null},"response":{"id":"resp_a97b2558ac874f2ba5684a67e9b941c2","created_at":1775592436.0,"error":null,"incomplete_details":null,"instructions":null,"metadata":null,"model":"Qwen/Qwen3-4B-Instruct-2507","object":"response","output":[{"prompt_token_ids":[151644,872,198,7771,1424,25,508,22,11,1207,60,284,220,16,22,198,93909,4933,25,220,17,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198],"generation_token_ids":[27,1311,29,22492,522,1311,29,151645],"generation_log_probs":[-2.0265558760002023e-6,-1.6689286894688848e-6,-1.1920928244535389e-7,-0.10021035373210907,0.0,0.0,0.0,-5.722029527532868e-6],"id":"msg_601a2716de0a4eb38494d2131674c825","content":[{"annotations":[],"text":"hit","type":"output_text","logprobs":null}],"role":"assistant","status":"completed","type":"message"},{"prompt_token_ids":[151644,872,198,7771,1424,25,508,22,11,1207,60,284,220,16,22,198,93909,4933,25,220,17,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198,27,1311,29,22492,522,1311,29,151645,198,151644,872,198,7771,1424,25,508,22,11,1207,11,220,18,60,284,220,17,15,198,93909,4933,25,220,17,198,65354,448,366,1311,29,22492,522,1311,29,476,366,1311,29,2685,522,1311,14276,151645,198,151644,77091,198],"generation_token_ids":[27,1311,29,2685,522,1311,29,151645],"generation_log_probs":[-1.1920928244535389e-7,-2.3841855067985307e-7,-7.152555099310121e-7,0.0,0.0,-1.1920928244535389e-7,0.0,-1.6689286894688848e-6],"id":"msg_7de40025e0a7414594284d76461fe177","content":[{"annotations":[],"text":"stand","type":"output_text","logprobs":null}],"role":"assistant","status":"completed","type":"message"}],"parallel_tool_calls":true,"temperature":0.0,"tool_choice":"auto","tools":[],"top_p":null,"background":null,"conversation":null,"max_output_tokens":1024,"max_tool_calls":null,"previous_response_id":null,"prompt":null,"prompt_cache_key":null,"reasoning":null,"safety_identifier":null,"service_tier":null,"status":null,"text":null,"top_logprobs":null,"truncation":null,"usage":{"input_tokens":144,"input_tokens_details":{"cached_tokens":0},"output_tokens":16,"output_tokens_details":{"reasoning_tokens":0},"total_tokens":160},"user":null},"reward":-1.0,"terminated":true,"truncated":false,"info":{"result":"loss","player":"[]","player_value":0,"dealer":"[8, A]","dealer_value":19},"_ng_task_index":0,"_ng_rollout_index":0,"agent_ref":{"name":"blackjack_agent"}} diff --git a/resources_servers/blackjack/data/example_rollouts_aggregate_metrics.json b/resources_servers/blackjack/data/example_rollouts_aggregate_metrics.json new file mode 100644 index 000000000..b48dc04c5 --- /dev/null +++ b/resources_servers/blackjack/data/example_rollouts_aggregate_metrics.json @@ -0,0 +1,87 @@ +[ + { + "agent_ref": { + "name": "blackjack_agent" + }, + "agent_metrics": { + "mean/reward": -0.6, + "mean/terminated": 1.0, + "mean/truncated": 0.0, + "mean/input_tokens": 64.2, + "mean/output_tokens": 9.6, + "mean/total_tokens": 73.8, + "max/reward": 1.0, + "max/terminated": 1.0, + "max/truncated": 0.0, + "max/input_tokens": 144.0, + "max/output_tokens": 16.0, + "max/total_tokens": 160.0, + "min/reward": -1.0, + "min/terminated": 1.0, + "min/truncated": 0.0, + "min/input_tokens": 43.0, + "min/output_tokens": 8.0, + "min/total_tokens": 51.0, + "median/reward": -1.0, + "median/terminated": 1.0, + "median/truncated": 0.0, + "median/input_tokens": 45.0, + "median/output_tokens": 8.0, + "median/total_tokens": 53.0, + "std/reward": 0.8944271909999161, + "std/terminated": 0.0, + "std/truncated": 0.0, + "std/input_tokens": 44.61726123374226, + "std/output_tokens": 3.5777087639996643, + "std/total_tokens": 48.19439801470706 + }, + "key_metrics": { + "mean/reward": -0.6, + "mean/terminated": 1.0, + "mean/truncated": 0.0, + "mean/input_tokens": 64.2, + "mean/output_tokens": 9.6, + "mean/total_tokens": 73.8 + }, + "group_level_metrics": [ + { + "mean/reward": -0.6, + "mean/terminated": 1.0, + "mean/truncated": 0.0, + "mean/input_tokens": 64.2, + "mean/output_tokens": 9.6, + "mean/total_tokens": 73.8, + "max/reward": 1.0, + "max/terminated": 1.0, + "max/truncated": 0.0, + "max/input_tokens": 144.0, + "max/output_tokens": 16.0, + "max/total_tokens": 160.0, + "min/reward": -1.0, + "min/terminated": 1.0, + "min/truncated": 0.0, + "min/input_tokens": 43.0, + "min/output_tokens": 8.0, + "min/total_tokens": 51.0, + "median/reward": -1.0, + "median/terminated": 1.0, + "median/truncated": 0.0, + "median/input_tokens": 45.0, + "median/output_tokens": 8.0, + "median/total_tokens": 53.0, + "std/reward": 0.8944271909999161, + "std/terminated": 0.0, + "std/truncated": 0.0, + "std/input_tokens": 44.61726123374226, + "std/output_tokens": 3.5777087639996643, + "std/total_tokens": 48.19439801470706, + "sample": { + "agent_ref": { + "name": "agent" + } + }, + "_ng_task_index": 0 + } + ] + } +] \ No newline at end of file diff --git a/resources_servers/blackjack/data/example_rollouts_materialized_inputs.jsonl b/resources_servers/blackjack/data/example_rollouts_materialized_inputs.jsonl new file mode 100644 index 000000000..d717903c4 --- /dev/null +++ b/resources_servers/blackjack/data/example_rollouts_materialized_inputs.jsonl @@ -0,0 +1,5 @@ +{"responses_create_params":{"input":[{"role":"system","content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."},{"role":"user","content":"Deal me in."}],"max_output_tokens":1024,"temperature":0.0},"agent_ref":{"name":"blackjack_agent"},"_ng_task_index":0,"_ng_rollout_index":0} +{"responses_create_params":{"input":[{"role":"system","content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."},{"role":"user","content":"Deal me in."}],"max_output_tokens":1024,"temperature":0.0},"agent_ref":{"name":"blackjack_agent"},"_ng_task_index":0,"_ng_rollout_index":1} +{"responses_create_params":{"input":[{"role":"system","content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."},{"role":"user","content":"Deal me in."}],"max_output_tokens":1024,"temperature":0.0},"agent_ref":{"name":"blackjack_agent"},"_ng_task_index":0,"_ng_rollout_index":2} +{"responses_create_params":{"input":[{"role":"system","content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."},{"role":"user","content":"Deal me in."}],"max_output_tokens":1024,"temperature":0.0},"agent_ref":{"name":"blackjack_agent"},"_ng_task_index":0,"_ng_rollout_index":3} +{"responses_create_params":{"input":[{"role":"system","content":"You are playing Blackjack. After seeing your hand, respond with hit or stand. Think briefly, then give your action tag."},{"role":"user","content":"Deal me in."}],"max_output_tokens":1024,"temperature":0.0},"agent_ref":{"name":"blackjack_agent"},"_ng_task_index":0,"_ng_rollout_index":4} diff --git a/resources_servers/blackjack/requirements.txt b/resources_servers/blackjack/requirements.txt new file mode 100644 index 000000000..00ed83213 --- /dev/null +++ b/resources_servers/blackjack/requirements.txt @@ -0,0 +1 @@ +-e nemo-gym[dev] @ ../../ diff --git a/resources_servers/blackjack/tests/__init__.py b/resources_servers/blackjack/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/resources_servers/blackjack/tests/test_app.py b/resources_servers/blackjack/tests/test_app.py new file mode 100644 index 000000000..0c0b2e6c6 --- /dev/null +++ b/resources_servers/blackjack/tests/test_app.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock + +import pytest + +from nemo_gym.base_resources_server import BaseResourcesServerConfig +from nemo_gym.openai_utils import NeMoGymResponse, NeMoGymResponseOutputMessage, NeMoGymResponseOutputText +from nemo_gym.server_utils import ServerClient +from resources_servers.blackjack.app import BlackjackEnv, _hand_value + + +def _make_env(): + config = BaseResourcesServerConfig(host="", port=0, entrypoint="", name="") + return BlackjackEnv(config=config, server_client=MagicMock(spec=ServerClient)) + + +def _response(text: str) -> NeMoGymResponse: + return NeMoGymResponse( + id="r", + created_at=0.0, + model="m", + object="response", + output=[ + NeMoGymResponseOutputMessage( + id="msg", + content=[NeMoGymResponseOutputText(annotations=[], text=text, type="output_text")], + role="assistant", + status="completed", + type="message", + ) + ], + parallel_tool_calls=True, + tool_choice="auto", + tools=[], + ) + + +class TestHandValue: + def test_basic(self): + assert _hand_value(["5", "7"]) == 12 + + def test_face_cards(self): + assert _hand_value(["K", "Q"]) == 20 + + def test_ace_low_when_busting(self): + assert _hand_value(["A", "K", "5"]) == 16 + + def test_ace_high_when_safe(self): + assert _hand_value(["A", "9"]) == 20 + + def test_two_aces(self): + assert _hand_value(["A", "A"]) == 12 + + +class TestReset: + @pytest.mark.asyncio + async def test_reset_populates_state(self): + env = _make_env() + obs, info = await env.reset({}, session_id="sid") + assert "sid" in env.session_state + state = env.session_state["sid"] + assert len(state["player"]) == 2 + assert len(state["dealer"]) == 2 + assert "rng" in state + assert "Your hand" in obs + assert "hit" in obs + + @pytest.mark.asyncio + async def test_per_session_rng_is_isolated(self): + # Two sessions get distinct RNG instances (not the module default). + env = _make_env() + await env.reset({}, session_id="a") + await env.reset({}, session_id="b") + assert env.session_state["a"]["rng"] is not env.session_state["b"]["rng"] + + +class TestStep: + @pytest.mark.asyncio + async def test_stand_finishes_game(self): + env = _make_env() + # Preload a known safe state so the result is deterministic. + env.session_state["sid"] = { + "player": ["K", "9"], # 19 + "dealer": ["10"], # dealer will draw until ≥ 17 + "rng": __import__("random").Random(0), + } + _, reward, term, trunc, info = await env.step(_response("stand"), {}, session_id="sid") + assert term is True + assert trunc is False + assert reward in (-1.0, 0.0, 1.0) + assert info["result"] in ("win", "draw", "loss") + + @pytest.mark.asyncio + async def test_hit_bust_ends_game(self): + env = _make_env() + # Force a bust on the next draw regardless of RNG: player already at 20, drawing any card except A busts. + # Use a seeded RNG that deals non-A. + rng = __import__("random").Random(0) + env.session_state["sid"] = {"player": ["K", "K"], "dealer": ["5"], "rng": rng} + _, reward, term, _, info = await env.step(_response("hit"), {}, session_id="sid") + # Player at 20 + any non-ace → bust. With rng seed 0, the draw is deterministic. + if term: + assert reward == -1.0 + assert info["result"] == "bust" + + @pytest.mark.asyncio + async def test_hit_continues_when_safe(self): + env = _make_env() + env.session_state["sid"] = { + "player": ["5", "3"], # 8 + "dealer": ["10"], + "rng": __import__("random").Random(0), + } + obs, reward, term, _, _ = await env.step(_response("hit"), {}, session_id="sid") + assert term is False + assert reward == 0.0 + assert "Your hand" in obs + + +class TestActionParser: + @pytest.mark.asyncio + async def _decide(self, text: str) -> str: + env = _make_env() + env.session_state["sid"] = { + "player": ["2", "2"], # 4, can't bust on hit + "dealer": ["10"], + "rng": __import__("random").Random(0), + } + obs, _, term, _, info = await env.step(_response(text), {}, session_id="sid") + # Terminal → stood; non-terminal → hit. + return "stand" if term else "hit" + + @pytest.mark.asyncio + async def test_hit_tag(self): + assert await self._decide("hit") == "hit" + + @pytest.mark.asyncio + async def test_stand_tag(self): + assert await self._decide("stand") == "stand" + + @pytest.mark.asyncio + async def test_whitespace_tolerated(self): + assert await self._decide(" hit ") == "hit" + + @pytest.mark.asyncio + async def test_case_insensitive(self): + assert await self._decide("HIT") == "hit" + + @pytest.mark.asyncio + async def test_no_tag_defaults_stand(self): + # Previously the fallback did a substring match, so "don't hit" would parse as hit. + assert await self._decide("i don't know, maybe hit?") == "stand" + + @pytest.mark.asyncio + async def test_unknown_action_defaults_stand(self): + assert await self._decide("fold") == "stand" diff --git a/responses_api_agents/gymnasium_agent/README.md b/responses_api_agents/gymnasium_agent/README.md new file mode 100644 index 000000000..8b8409ed9 --- /dev/null +++ b/responses_api_agents/gymnasium_agent/README.md @@ -0,0 +1,5 @@ +# Gymnasium Agent + +Agent for Gymnasium-style environments based on `GymnasiumServer` resources servers. Drives the reset/step loop. + +See `docs/resources-server/gymnasium-api.md` for more details. \ No newline at end of file diff --git a/responses_api_agents/gymnasium_agent/__init__.py b/responses_api_agents/gymnasium_agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/responses_api_agents/gymnasium_agent/app.py b/responses_api_agents/gymnasium_agent/app.py new file mode 100644 index 000000000..0312f2fcf --- /dev/null +++ b/responses_api_agents/gymnasium_agent/app.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent for GymnasiumServer resources servers (resources_servers.base_gymnasium) which implements the Gymnasium API.""" + +from fastapi import Body, Request, Response +from pydantic import ConfigDict, Field + +from nemo_gym.base_resources_server import ( + BaseRunRequest, + BaseVerifyResponse, +) +from nemo_gym.base_responses_api_agent import BaseResponsesAPIAgentConfig, SimpleResponsesAPIAgent +from nemo_gym.config_types import AggregateMetrics, AggregateMetricsRequest, ModelServerRef, ResourcesServerRef +from nemo_gym.openai_utils import ( + NeMoGymEasyInputMessage, + NeMoGymFunctionCallOutput, + NeMoGymResponse, + NeMoGymResponseCreateParamsNonStreaming, +) +from nemo_gym.server_utils import get_response_json, raise_for_status +from resources_servers.base_gymnasium import EnvResetResponse, EnvStepResponse + + +class GymnasiumAgentConfig(BaseResponsesAPIAgentConfig): + resources_server: ResourcesServerRef + model_server: ModelServerRef + max_steps: int = Field(10, ge=1) + + +class GymnasiumAgentRunRequest(BaseRunRequest): + model_config = ConfigDict(extra="allow") + + +class GymnasiumRunResponse(BaseVerifyResponse): + model_config = ConfigDict(extra="allow") + terminated: bool = False + truncated: bool = False + info: dict = {} + + +class GymnasiumAgent(SimpleResponsesAPIAgent): + config: GymnasiumAgentConfig + + async def responses( + self, + request: Request, + response: Response, + body: NeMoGymResponseCreateParamsNonStreaming = Body(), + ) -> NeMoGymResponse: + model_resp = await self.server_client.post( + server_name=self.config.model_server.name, + url_path="/v1/responses", + json=body, + cookies=request.cookies, + ) + await raise_for_status(model_resp) + result = NeMoGymResponse.model_validate(await get_response_json(model_resp)) + for k, v in model_resp.cookies.items(): + response.set_cookie(k, v) + return result + + async def run(self, request: Request, body: GymnasiumAgentRunRequest) -> GymnasiumRunResponse: + env_cookies = request.cookies + + reset_resp = await self.server_client.post( + server_name=self.config.resources_server.name, + url_path="/reset", + json=body.model_dump(), + cookies=env_cookies, + ) + await raise_for_status(reset_resp) + reset_data = EnvResetResponse.model_validate(await get_response_json(reset_resp)) + env_cookies = reset_resp.cookies + + base_body = body.responses_create_params.model_copy(deep=True) + if isinstance(base_body.input, str): + base_body.input = [NeMoGymEasyInputMessage(role="user", content=base_body.input)] + if reset_data.observation: + base_body.input = list(base_body.input) + [ + NeMoGymEasyInputMessage(role="user", content=reset_data.observation) + ] + + new_outputs = [] + total_reward = 0.0 + usage = None + model_server_cookies = None + step_data = EnvStepResponse(terminated=False, truncated=True, reward=0.0) + last_model_response = None + finished = False + + for _ in range(self.config.max_steps): + new_body = base_body.model_copy(update={"input": base_body.input + new_outputs}) + + model_resp = await self.server_client.post( + server_name=self.config.model_server.name, + url_path="/v1/responses", + json=new_body, + cookies=model_server_cookies, + ) + await raise_for_status(model_resp) + model_response = NeMoGymResponse.model_validate(await get_response_json(model_resp)) + model_server_cookies = model_resp.cookies + last_model_response = model_response + + new_outputs.extend(model_response.output) + + if model_response.usage: + if usage is None: + usage = model_response.usage.model_copy(deep=True) + else: + usage.input_tokens += model_response.usage.input_tokens + usage.output_tokens += model_response.usage.output_tokens + usage.total_tokens += model_response.usage.total_tokens + usage.input_tokens_details.cached_tokens = 0 + usage.output_tokens_details.reasoning_tokens = 0 + + step_resp = await self.server_client.post( + server_name=self.config.resources_server.name, + url_path="/step", + json=body.model_dump() | {"response": model_response.model_dump()}, + cookies=env_cookies, + ) + await raise_for_status(step_resp) + step_data = EnvStepResponse.model_validate(await get_response_json(step_resp)) + total_reward += step_data.reward + env_cookies = step_resp.cookies + + if step_data.terminated or step_data.truncated: + finished = True + break + + for tool_output in (step_data.info or {}).get("tool_outputs", []): + new_outputs.append( + NeMoGymFunctionCallOutput( + type="function_call_output", + call_id=tool_output["call_id"], + output=tool_output["output"], + ) + ) + + if step_data.observation: + new_outputs.append(NeMoGymEasyInputMessage(role="user", content=step_data.observation)) + + if not finished: + step_data = step_data.model_copy(update={"truncated": True}) + + last_model_response.output = new_outputs + last_model_response.usage = usage + + return GymnasiumRunResponse( + responses_create_params=body.responses_create_params, + response=last_model_response, + reward=total_reward, + terminated=step_data.terminated, + truncated=step_data.truncated, + info=step_data.info, + ) + + async def aggregate_metrics(self, body: AggregateMetricsRequest = Body()) -> AggregateMetrics: + response = await self.server_client.post( + server_name=self.config.resources_server.name, + url_path="/aggregate_metrics", + json=body, + ) + await raise_for_status(response) + return AggregateMetrics.model_validate(await get_response_json(response)) + + +if __name__ == "__main__": + GymnasiumAgent.run_webserver() diff --git a/responses_api_agents/gymnasium_agent/configs/gymnasium_agent.yaml b/responses_api_agents/gymnasium_agent/configs/gymnasium_agent.yaml new file mode 100644 index 000000000..5a088281d --- /dev/null +++ b/responses_api_agents/gymnasium_agent/configs/gymnasium_agent.yaml @@ -0,0 +1,4 @@ +gymnasium_agent: + responses_api_agents: + gymnasium_agent: + entrypoint: app.py diff --git a/responses_api_agents/gymnasium_agent/requirements.txt b/responses_api_agents/gymnasium_agent/requirements.txt new file mode 100644 index 000000000..00ed83213 --- /dev/null +++ b/responses_api_agents/gymnasium_agent/requirements.txt @@ -0,0 +1 @@ +-e nemo-gym[dev] @ ../../ diff --git a/responses_api_agents/gymnasium_agent/tests/__init__.py b/responses_api_agents/gymnasium_agent/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/responses_api_agents/gymnasium_agent/tests/test_app.py b/responses_api_agents/gymnasium_agent/tests/test_app.py new file mode 100644 index 000000000..68ed67b4e --- /dev/null +++ b/responses_api_agents/gymnasium_agent/tests/test_app.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nemo_gym.config_types import ModelServerRef, ResourcesServerRef +from nemo_gym.server_utils import ServerClient +from responses_api_agents.gymnasium_agent.app import GymnasiumAgent, GymnasiumAgentConfig, GymnasiumAgentRunRequest + + +def _make_agent(max_steps=10): + config = GymnasiumAgentConfig( + host="", + port=0, + entrypoint="", + name="test_gymnasium_agent", + resources_server=ResourcesServerRef(type="resources_servers", name="my_env"), + model_server=ModelServerRef(type="responses_api_models", name="policy_model"), + max_steps=max_steps, + ) + return GymnasiumAgent(config=config, server_client=MagicMock(spec=ServerClient)) + + +def _model_response(text: str, input_toks=1, output_toks=1) -> dict: + return { + "id": "r", + "created_at": 0.0, + "model": "m", + "object": "response", + "output": [ + { + "id": "msg", + "content": [{"annotations": [], "text": text, "type": "output_text"}], + "role": "assistant", + "status": "completed", + "type": "message", + } + ], + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "usage": { + "input_tokens": input_toks, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens": output_toks, + "output_tokens_details": {"reasoning_tokens": 0}, + "total_tokens": input_toks + output_toks, + }, + } + + +class _FakeHttpResp: + def __init__(self, payload: dict): + self._payload = payload + self.cookies = {} + self.status = 200 + self.ok = True + + async def json(self): + return self._payload + + async def read(self): + return json.dumps(self._payload).encode() + + @property + def content(self): + class _Body: + async def read(inner): + return json.dumps(self._payload).encode() + + return _Body() + + def raise_for_status(self): + return None + + +def _wire_mock_client(agent, responses_per_url): + """Wire agent.server_client.post to return payloads keyed by url_path.""" + call_log = [] + + async def _post(server_name, url_path, json=None, cookies=None, **kw): + call_log.append((server_name, url_path, json)) + payload = responses_per_url[url_path].pop(0) + return _FakeHttpResp(payload) + + agent.server_client.post = AsyncMock(side_effect=_post) + return call_log + + +class TestRoutes: + def test_routes_registered(self): + app = _make_agent().setup_webserver() + routes = {r.path for r in app.routes} + assert {"/run", "/v1/responses", "/aggregate_metrics"}.issubset(routes) + + +class TestConfig: + def test_max_steps_validator_rejects_zero(self): + with pytest.raises(Exception): + GymnasiumAgentConfig( + host="", + port=0, + entrypoint="", + name="x", + resources_server=ResourcesServerRef(type="resources_servers", name="e"), + model_server=ModelServerRef(type="responses_api_models", name="m"), + max_steps=0, + ) + + def test_default_max_steps(self): + assert _make_agent().config.max_steps == 10 + + +class TestRun: + @pytest.mark.asyncio + async def test_terminates_on_first_step(self): + agent = _make_agent() + call_log = _wire_mock_client( + agent, + { + "/reset": [{"observation": "go", "info": {}}], + "/v1/responses": [_model_response("move A")], + "/step": [{"observation": None, "reward": 1.0, "terminated": True, "truncated": False, "info": {}}], + }, + ) + req = MagicMock() + req.cookies = {} + body = GymnasiumAgentRunRequest(responses_create_params={"input": [{"role": "user", "content": "play"}]}) + result = await agent.run(req, body) + assert result.terminated is True + assert result.reward == 1.0 + # reset + exactly 1 model call + 1 step + urls = [u for (_s, u, _) in call_log] + assert urls.count("/reset") == 1 + assert urls.count("/v1/responses") == 1 + assert urls.count("/step") == 1 + + @pytest.mark.asyncio + async def test_multi_step_preserves_output_items_in_history(self): + agent = _make_agent(max_steps=3) + call_log = _wire_mock_client( + agent, + { + "/reset": [{"observation": "start", "info": {}}], + "/v1/responses": [ + _model_response("turn-1", output_toks=10), + _model_response("turn-2", output_toks=20), + ], + "/step": [ + {"observation": "obs-1", "reward": 0.5, "terminated": False, "truncated": False, "info": {}}, + {"observation": None, "reward": 0.5, "terminated": True, "truncated": False, "info": {}}, + ], + }, + ) + req = MagicMock() + req.cookies = {} + body = GymnasiumAgentRunRequest(responses_create_params={"input": [{"role": "user", "content": "play"}]}) + result = await agent.run(req, body) + assert result.reward == 1.0 + assert result.terminated is True + # Inspect turn-2 model call body: its input must contain the full turn-1 output item, + # not a flattened string, and the obs-1 appended as user message. + turn2_body = [body for (s, u, body) in call_log if u == "/v1/responses"][1] + turn2_input = turn2_body.input + # turn-1 full output item preserved (with structured content list) + assistant_items = [m for m in turn2_input if getattr(m, "role", None) == "assistant"] + assert any( + isinstance(getattr(m, "content", None), list) + and any( + getattr(c, "type", None) == "output_text" and getattr(c, "text", "") == "turn-1" for c in m.content + ) + for m in assistant_items + ), f"turn-1 output not preserved in structured form: {assistant_items}" + # obs-1 appended as a user message after turn-1 + assert any(getattr(m, "role", None) == "user" and getattr(m, "content", "") == "obs-1" for m in turn2_input) + + @pytest.mark.asyncio + async def test_max_steps_sets_truncated(self): + agent = _make_agent(max_steps=2) + _wire_mock_client( + agent, + { + "/reset": [{"observation": None, "info": {}}], + "/v1/responses": [_model_response("a"), _model_response("b")], + "/step": [ + {"observation": "obs-1", "reward": 0.0, "terminated": False, "truncated": False, "info": {}}, + {"observation": "obs-2", "reward": 0.0, "terminated": False, "truncated": False, "info": {}}, + ], + }, + ) + req = MagicMock() + req.cookies = {} + body = GymnasiumAgentRunRequest(responses_create_params={"input": [{"role": "user", "content": "x"}]}) + result = await agent.run(req, body) + assert result.truncated is True + assert result.terminated is False + + @pytest.mark.asyncio + async def test_usage_accumulates_across_turns(self): + agent = _make_agent(max_steps=3) + _wire_mock_client( + agent, + { + "/reset": [{"observation": None, "info": {}}], + "/v1/responses": [ + _model_response("a", input_toks=5, output_toks=7), + _model_response("b", input_toks=11, output_toks=13), + ], + "/step": [ + {"observation": "o", "reward": 0.0, "terminated": False, "truncated": False, "info": {}}, + {"observation": None, "reward": 0.0, "terminated": True, "truncated": False, "info": {}}, + ], + }, + ) + req = MagicMock() + req.cookies = {} + body = GymnasiumAgentRunRequest(responses_create_params={"input": [{"role": "user", "content": "x"}]}) + result = await agent.run(req, body) + # usage summed across both turns + assert result.response.usage.input_tokens == 16 + assert result.response.usage.output_tokens == 20 + assert result.response.usage.total_tokens == 36