Skip to content

Commit ea4e644

Browse files
committed
add openai/anthropic wrapper for rollout controller
1 parent bf4f8a0 commit ea4e644

20 files changed

+1625
-597
lines changed

tests/rl/test_camel_agent_loop.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import sys
3+
import tempfile
4+
import time
5+
import unittest
6+
from urllib import error as urllib_error
7+
from urllib import request as urllib_request
8+
9+
os.environ.setdefault("RAY_ENABLE_UV_RUN_RUNTIME_ENV", "0")
10+
11+
import ray
12+
import torch
13+
14+
from xtuner.v1.data_proto import RolloutState, SampleParams, Status
15+
from xtuner.v1.rl.agent_loop import CamelAgentLoopConfig
16+
from xtuner.v1.rl.rollout import RolloutController
17+
from xtuner.v1.rl.rollout.worker import RolloutConfig
18+
from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
19+
20+
21+
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
22+
RESOURCE_MAP = {
23+
"npu": "NPU",
24+
"cuda": "GPU",
25+
}
26+
27+
28+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
29+
class TestCamelAgentLoop(unittest.IsolatedAsyncioTestCase):
30+
@classmethod
31+
def setUpClass(cls) -> None:
32+
os.environ["XTUNER_USE_FA3"] = "1"
33+
os.environ["LMD_SKIP_WARMUP"] = "1"
34+
35+
@classmethod
36+
def tearDownClass(cls) -> None:
37+
del os.environ["XTUNER_USE_FA3"]
38+
del os.environ["LMD_SKIP_WARMUP"]
39+
40+
def setUp(self):
41+
ray.init(num_cpus=80, ignore_reinit_error=True)
42+
self.temp_dir = tempfile.TemporaryDirectory()
43+
self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
44+
self.resources_cfg = AcceleratorResourcesConfig(
45+
accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type],
46+
num_workers=1,
47+
num_cpus_per_worker=8,
48+
cpu_memory_per_worker=16 * 1024**3,
49+
)
50+
self.context_length = 1024
51+
52+
def tearDown(self):
53+
ray.shutdown()
54+
self.temp_dir.cleanup()
55+
56+
def _wait_until_ready(self, base_url: str):
57+
deadline = time.time() + 1800
58+
last_error = None
59+
while time.time() < deadline:
60+
try:
61+
with urllib_request.urlopen(f"{base_url}/healthz", timeout=10.0) as response:
62+
if response.status == 200:
63+
return
64+
last_error = response.read().decode("utf-8", errors="ignore")
65+
except urllib_error.URLError as exc:
66+
last_error = repr(exc)
67+
except Exception as exc:
68+
last_error = repr(exc)
69+
time.sleep(5)
70+
raise RuntimeError(f"API server at {base_url} did not become ready in time: {last_error}")
71+
72+
def _build_controller(self, port: int):
73+
rollout_config = RolloutConfig(
74+
env=f"test_camel_agent_loop_{port}",
75+
model_path=MODEL_PATH,
76+
model_name=os.path.basename(MODEL_PATH).lower(),
77+
tokenizer_path=MODEL_PATH,
78+
context_length=self.context_length,
79+
worker_log_dir=self.worker_log_dir,
80+
api_host="127.0.0.1",
81+
api_port=port,
82+
)
83+
pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg, name=f"camel_pg_{port}")
84+
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
85+
return rollout_controller, pg
86+
87+
async def test_camel_single_turn_builds_chat_agent_with_official_async_openai(self):
88+
from openai import AsyncOpenAI
89+
90+
rollout_controller, pg = self._build_controller(port=28601)
91+
try:
92+
metadata = await rollout_controller.get_rollout_metadata.remote()
93+
self._wait_until_ready(metadata["api_server_url"])
94+
95+
cfg = CamelAgentLoopConfig(
96+
hf_checkpoint=MODEL_PATH,
97+
sample_params=SampleParams(max_tokens=64, temperature=0.0),
98+
)
99+
loop = cfg.build(rollout_controller=rollout_controller)
100+
state = RolloutState(
101+
message=[{"role": "user", "content": "Say hello in one short sentence."}],
102+
sample_params=SampleParams(max_tokens=64, temperature=0.0),
103+
)
104+
105+
result = await loop.generate_sample(state)
106+
trace = await rollout_controller.get_openai_chat_trace_by_messages.remote(
107+
result.extra_fields["camel_request"],
108+
result.extra_fields["camel_response"],
109+
result.extra_fields["camel_finish_reason"],
110+
)
111+
112+
print(f"CHAT_HISTORY_BEFORE: {result.extra_fields['camel_chat_history_before']}")
113+
print(f"CHAT_HISTORY_AFTER: {result.extra_fields['camel_chat_history_after']}")
114+
print(
115+
f"NEW_ENTRIES: "
116+
f"{result.extra_fields['camel_chat_history_after'][len(result.extra_fields['camel_chat_history_before']):]}"
117+
)
118+
print(f"EXTRACTED_REQUEST: {result.extra_fields['camel_request']}")
119+
print(f"EXTRACTED_RESPONSE: {result.extra_fields['camel_response']}")
120+
print(f"EXTRACTED_FINISH_REASON: {result.extra_fields['camel_finish_reason']}")
121+
print(f"TRACE_LOOKUP: {trace}")
122+
123+
self.assertEqual(result.status, Status.COMPLETED)
124+
self.assertIsNotNone(trace)
125+
self.assertEqual(result.prompt_ids, trace["prompt_ids"])
126+
self.assertEqual(result.response_ids, trace["response_ids"])
127+
self.assertEqual(result.finish_reason, trace["finish_reason"])
128+
self.assertEqual(result.response, loop.tokenizer.decode(trace["response_ids"]))
129+
finally:
130+
try:
131+
await rollout_controller.shutdown.remote()
132+
finally:
133+
ray.util.remove_placement_group(pg)
134+
135+
136+
if __name__ == "__main__":
137+
unittest.main()

tests/rl/test_evaluator.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

0 commit comments

Comments
 (0)