diff --git a/tests/rl/test_camel_agent_loop.py b/tests/rl/test_camel_agent_loop.py
new file mode 100644
index 000000000..8beb595fc
--- /dev/null
+++ b/tests/rl/test_camel_agent_loop.py
@@ -0,0 +1,233 @@
+import json
+import os
+import tempfile
+import time
+import unittest
+from types import SimpleNamespace
+from unittest.mock import patch
+from urllib import error as urllib_error
+from urllib import request as urllib_request
+
+os.environ.setdefault("RAY_ENABLE_UV_RUN_RUNTIME_ENV", "0")
+
+import ray
+import torch
+from camel.toolkits import SearchToolkit
+
+from xtuner.v1.data_proto import RolloutState, SampleParams, Status
+from xtuner.v1.rl.agent_loop import CamelAgentLoop, CamelAgentLoopConfig
+from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout.chat_adapter.collector import append_current_trace_rollout_state
+from xtuner.v1.rl.rollout.worker import RolloutConfig
+from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
+
+
+MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "")
+RESOURCE_MAP = {
+ "npu": "NPU",
+ "cuda": "GPU",
+}
+
+
+class DummyTokenizer:
+ def encode(self, text, add_special_tokens=False):
+ return [ord(char) for char in text]
+
+ def decode(self, token_ids):
+ return "".join(chr(token_id) for token_id in token_ids)
+
+ def __call__(self, text, add_special_tokens=False):
+ return {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)}
+
+ def apply_chat_template(self, messages, tools=None, add_generation_prompt=True, tokenize=False):
+ rendered = "\n".join(json.dumps(message, ensure_ascii=False, sort_keys=True) for message in messages)
+ if tools:
+ rendered += "\nTOOLS:" + json.dumps(tools, ensure_ascii=False, sort_keys=True)
+ if add_generation_prompt:
+ rendered += "\nassistant:"
+ if tokenize:
+ return self.encode(rendered, add_special_tokens=False)
+ return rendered
+
+
+class TestCamelAgentLoopUnit(unittest.IsolatedAsyncioTestCase):
+ def _build_loop(self, agent, sample_params=None, tools=None, tool_choice=None):
+ tokenizer = DummyTokenizer()
+ with patch("xtuner.v1.rl.agent_loop.agent_loop.load_tokenizer", return_value=tokenizer), patch(
+ "xtuner.v1.rl.agent_loop.agent_loop.load_processor", return_value=None
+ ), patch.object(CamelAgentLoop, "init_agent", return_value=agent):
+ cfg = CamelAgentLoopConfig(
+ hf_checkpoint="dummy",
+ sample_params=sample_params or SampleParams(max_tokens=256, temperature=0.0),
+ tools=tools,
+ tool_choice=tool_choice,
+ )
+ loop = cfg.build(rollout_controller=SimpleNamespace())
+ return loop, tokenizer
+
+ async def test_camel_generate_sample_returns_gateway_rollout_states(self):
+ class GatewayFakeAgent:
+ async def astep(self, content):
+ first_turn = RolloutState(
+ uid=101,
+ message=[{"role": "user", "content": content}],
+ prompt_ids=[11, 12],
+ tokens=[11, 12],
+ response="tool-call",
+ response_ids=[21, 22],
+ logprobs=[-0.1, -0.2],
+ response_mask=[1, 1],
+ finish_reason="tool_calls",
+ status=Status.COMPLETED,
+ extra_fields={
+ "tool_calls": [
+ {
+ "id": "call_search",
+ "type": "function",
+ "function": {"name": "search", "arguments": "{\"q\":\"camel\"}"},
+ }
+ ]
+ },
+ )
+ second_turn = RolloutState(
+ uid=102,
+ message=[{"role": "user", "content": content}],
+ prompt_ids=[11, 12, 21, 22],
+ tokens=[11, 12, 21, 22],
+ response="final-answer",
+ response_ids=[31, 32, 33],
+ logprobs=[-0.3, -0.4, -0.5],
+ response_mask=[1, 1, 1],
+ finish_reason="stop",
+ status=Status.COMPLETED,
+ )
+ append_current_trace_rollout_state(first_turn)
+ append_current_trace_rollout_state(second_turn)
+ return SimpleNamespace(info={"termination_reasons": ["stop"]})
+
+ loop, _ = self._build_loop(agent=GatewayFakeAgent())
+ state = RolloutState(
+ message=[{"role": "user", "content": "Where is CAMEL on GitHub?"}],
+ sample_params=SampleParams(max_tokens=128, temperature=0.0),
+ )
+
+ result = await loop.generate_sample(state)
+
+ self.assertEqual(len(result), 2)
+ self.assertEqual(result[0].status, Status.COMPLETED)
+ self.assertEqual(result[0].finish_reason, "tool_calls")
+ self.assertEqual(result[0].extra_fields["gateway_rollout_index"], 0)
+ self.assertEqual(result[1].finish_reason, "stop")
+ self.assertEqual(result[1].prompt_ids, [11, 12, 21, 22])
+ self.assertEqual(result[1].tokens, [11, 12, 21, 22])
+ self.assertEqual(result[1].response_ids, [31, 32, 33])
+ self.assertEqual(result[1].response_mask, [1, 1, 1])
+ self.assertEqual(result[1].logprobs, [-0.3, -0.4, -0.5])
+ self.assertEqual(result[1].response, "final-answer")
+ self.assertEqual(result[1].message, state.message)
+ self.assertEqual(len(result[1].extra_fields["gateway_trace_records"]), 2)
+ self.assertEqual(result[1].extra_fields["gateway_trace_records"][0]["request_id"], "101")
+ self.assertEqual(result[1].extra_fields["gateway_trace_records"][1]["finish_reason"], "stop")
+
+
+@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0" or not MODEL_PATH, "lmdeploy backend is not enabled")
+class TestCamelAgentLoopIntegration(unittest.IsolatedAsyncioTestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ os.environ["XTUNER_USE_FA3"] = "1"
+ os.environ["LMD_SKIP_WARMUP"] = "1"
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ del os.environ["XTUNER_USE_FA3"]
+ del os.environ["LMD_SKIP_WARMUP"]
+
+ def setUp(self):
+ os.environ.pop("RAY_ADDRESS", None)
+ ray.init(address="local", num_cpus=80, ignore_reinit_error=True)
+ self.temp_dir = tempfile.TemporaryDirectory()
+ self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
+ self.resources_cfg = AcceleratorResourcesConfig(
+ accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type],
+ num_workers=1,
+ num_cpus_per_worker=8,
+ cpu_memory_per_worker=16 * 1024**3,
+ )
+ self.context_length = 1024
+
+ def tearDown(self):
+ ray.shutdown()
+ self.temp_dir.cleanup()
+
+ def _wait_until_ready(self, base_url: str):
+ deadline = time.time() + 1800
+ last_error = None
+ while time.time() < deadline:
+ try:
+ with urllib_request.urlopen(f"{base_url}/healthz", timeout=10.0) as response:
+ if response.status == 200:
+ return
+ last_error = response.read().decode("utf-8", errors="ignore")
+ except urllib_error.URLError as exc:
+ last_error = repr(exc)
+ except Exception as exc:
+ last_error = repr(exc)
+ time.sleep(5)
+ raise RuntimeError(f"API server at {base_url} did not become ready in time: {last_error}")
+
+ def _build_controller(self, port: int):
+ rollout_config = RolloutConfig(
+ env=f"test_camel_agent_loop_{port}",
+ model_path=MODEL_PATH,
+ model_name=os.path.basename(MODEL_PATH).lower(),
+ tokenizer_path=MODEL_PATH,
+ context_length=self.context_length,
+ worker_log_dir=self.worker_log_dir,
+ api_host="127.0.0.1",
+ api_port=port,
+ )
+ pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg, name=f"camel_pg_{port}")
+ rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
+ return rollout_controller, pg
+
+ async def test_camel_tool_call_integration_returns_gateway_rollout_batch(self):
+ rollout_controller, pg = self._build_controller(port=28602)
+ try:
+ metadata = await rollout_controller.get_rollout_metadata.remote()
+ self._wait_until_ready(metadata["api_server_url"])
+ search_tool = SearchToolkit().search_duckduckgo
+
+ cfg = CamelAgentLoopConfig(
+ hf_checkpoint=MODEL_PATH,
+ sample_params=SampleParams(max_tokens=256, temperature=0.0),
+ system_message="You are a helpful assistant to do search task.",
+ tools=[search_tool],
+ tool_choice={"type": "function", "function": {"name": "search_duckduckgo"}},
+ )
+ loop = cfg.build(rollout_controller=rollout_controller)
+ state = RolloutState(
+ message=[{"role": "user", "content": "What is the Github link to CAMEL framework?"}],
+ sample_params=SampleParams(max_tokens=256, temperature=0.0),
+ )
+
+ result = await loop.generate_sample(state)
+ print(f"CAMEL_INTEGRATION_RESULT: {result}")
+ if not any((state.extra_fields or {}).get("tool_calls") for state in result):
+ self.skipTest("Current backend/model did not emit tool calls for the integration prompt.")
+
+ self.assertGreaterEqual(len(result), 1)
+ self.assertEqual(result[-1].status, Status.COMPLETED)
+ self.assertIsNotNone(result[-1].response_ids)
+ self.assertIsNotNone(result[-1].response_mask)
+ self.assertIsNotNone(result[-1].response)
+ self.assertTrue(any(state.finish_reason == "tool_calls" for state in result))
+
+ finally:
+ try:
+ await rollout_controller.shutdown.remote()
+ finally:
+ ray.util.remove_placement_group(pg)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/rl/test_evaluator.py b/tests/rl/test_evaluator.py
deleted file mode 100644
index ffd52930a..000000000
--- a/tests/rl/test_evaluator.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import os
-import unittest
-import ray
-import tempfile
-from transformers import AutoTokenizer
-
-from xtuner.v1.rl.rollout.worker import RolloutConfig
-try:
- from xtuner.v1.ray.judger.controller import JudgerConfig
-except Exception:
- class JudgerConfig:
- def __init__(self, *args, **kwargs):
- self.__dict__.update(kwargs)
-from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
-try:
- from xtuner.v1.ray.environment import SingleTurnEnvironment
-except Exception:
- SingleTurnEnvironment = None
-from xtuner.v1.rl.evaluator import Evaluator, EvaluatorConfig
-from xtuner.v1.data_proto.rl_data import SampleParams
-from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, OpenaiTokenizeFunctionConfig
-
-
-MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
-TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"]
-
-
-@unittest.skipIf(SingleTurnEnvironment is None, "ray environment unavailable")
-class TestEvaluator(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls) -> None:
- os.environ["XTUNER_USE_FA3"] = "1"
- os.environ["LMD_SKIP_WARMUP"] = "1"
-
- @classmethod
- def tearDownClass(cls) -> None:
- del os.environ["XTUNER_USE_FA3"]
- del os.environ["LMD_SKIP_WARMUP"]
-
- def init_config(self):
- self.resources_cfg = AcceleratorResourcesConfig(
- accelerator="GPU",
- num_workers=8,
- num_cpus_per_worker=8,
- cpu_memory_per_worker=16 * 1024**3, # 16 GB
- )
- self.max_prompt_length = 512
- self.max_response_length = 1024
- self.rollout_cfg = RolloutConfig(
- env="test_rollout",
- model_path=MODEL_PATH,
- model_name=os.path.basename(MODEL_PATH).lower(),
- tokenizer_path=MODEL_PATH,
- tensor_parallel_size=8,
- context_length=self.max_prompt_length + self.max_response_length,
- worker_log_dir=self.worker_log_dir
- )
- from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig
- gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router")
- self.judger_cfg = JudgerConfig(
- reward_judger_configs=[gsm8k_judger_config],
- worker_log_dir=self.worker_log_dir
- )
- self.eval_dataset_cfg = [
- {
- "dataset": DatasetConfig(name="gsm8k",
- anno_path=TEST_DATA_PATH,
- sample_ratio=1.0),
- "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length)
- },
- ]
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
- self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
- self.test_env = SingleTurnEnvironment.remote(
- "test_env",
- self.pg,
- self.rollout_cfg,
- None,
- self.judger_cfg
- )
- self.sample_params = SampleParams(
- top_p=1.0,
- temperature=0.0,
- max_tokens=self.max_response_length,
- top_k=1
- )
-
- def setUp(self):
- ray.init(num_cpus=80)
- self.model_path = MODEL_PATH
- self.temp_dir = tempfile.TemporaryDirectory()
- self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
- self.init_config()
-
- def tearDown(self):
- ray.shutdown()
- self.temp_dir.cleanup()
-
- @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
- def test_lmdeploy_evaluator(self):
- def custom_compute_metric(samples):
- return {"custom_accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)}
-
- evaluator_cfg = EvaluatorConfig(
- dataset_cfg=self.eval_dataset_cfg,
- tokenizer=self.tokenizer,
- max_concurrent=16,
- eval_sample_ratio=0.004, # generate 5 samples
- compute_metric_func=custom_compute_metric,
- sample_params=self.sample_params,
- worker_log_dir=self.worker_log_dir
- )
- evaluator = Evaluator.remote(evaluator_cfg, self.test_env)
- try:
- ray.get(evaluator.run.remote())
- except Exception as e:
- self.fail(f"evaluator.run.remote() raised an exception: {e}")
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/rl/test_rl_train_with_sft.py b/tests/rl/test_rl_train_with_sft.py
deleted file mode 100644
index e0476de71..000000000
--- a/tests/rl/test_rl_train_with_sft.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import os
-import unittest
-from transformers import AutoTokenizer
-import shutil
-import tempfile
-import json
-import torch
-from xtuner.v1.data_proto.sequence_context import SequenceContext
-import ray
-from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
-from xtuner.v1.config import (
- AdamWConfig,
- FSDPConfig,
- LRConfig,
-)
-from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker
-from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig
-from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig
-from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
-from xtuner.v1.loss import CELossConfig
-from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
-from xtuner.v1.train.trainer import LoadCheckpointConfig
-
-QWEN3_PATH = os.environ["QWEN3_PATH"]
-ALPACA_PATH = os.environ["ALPACA_PATH"]
-
-
-class TestRLTrainWithSFT(unittest.TestCase):
- def setUp(self):
- ray.init(num_cpus=80, ignore_reinit_error=True)
-
- resources = AcceleratorResourcesConfig(
- accelerator="GPU",
- num_accelerators_per_worker=1,
- num_cpus_per_worker=8,
- num_workers=8,
- cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB
- )
-
- pg = AutoAcceleratorWorkers.build_placement_group(resources)
- self.pg = pg
-
- self.temp_dir = tempfile.mkdtemp()
- tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True)
- self.tokenizer = tokenizer
- self.prompt_repeat_k = 8
- file = './tests/ray/rollout_output.jsonl'
- with open(file, 'r') as f:
- data = [json.loads(line) for line in f]
- data_groups = [data[i:i + self.prompt_repeat_k] for i in range(0, len(data), self.prompt_repeat_k)]
- data_groups = data_groups[:8]
- data_batches = []
- for group in data_groups:
- prompt_ids = tokenizer(group[0]['prompt'], return_tensors='pt')['input_ids'].flatten().tolist()
- rewards = [item['reward'] for item in group]
- rewards = torch.tensor(rewards, dtype=torch.float32)
- advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8)
-
- for i in range(self.prompt_repeat_k):
- item = group[i]
- response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist()
- input_ids = prompt_ids + response_ids
- shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100]
- input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0)
- shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0)
- data_batches.append(
- dict(
- seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"),
- shifted_labels=shifted_labels,
- advantage=advantages[i].item(),
- )
- )
- self.data_batches = data_batches
-
- def tearDown(self):
- shutil.rmtree(self.temp_dir)
- ray.shutdown()
-
- def build_train_controller(self):
- model_cfg = Qwen3Dense8BConfig()
- optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False)
- fsdp_cfg: FSDPConfig = FSDPConfig(
- torch_compile=True,
- cpu_offload=False,
- ep_size=1,
- )
- lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7)
-
- dataset_config = []
- _data_cfg = {"dataset": DatasetConfig(name='apach',
- anno_path=ALPACA_PATH),
- "tokenize_fn": OpenaiTokenizeFunctionConfig(
- chat_template='qwen3',
- max_length=32768
- )
- }
- dataset_config.append(_data_cfg)
-
- sft_dataloader_cfg = DataloaderConfig(
- dataset_config_list=dataset_config,
- pack_max_length=32768,
- pack_to_max_length=True,
- num_workers=0,
- )
- sft_global_batch_size = 8
- loss_reduction = "square"
- sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction=loss_reduction)
-
- worker_cfg: WorkerConfig = WorkerConfig(
- sft_dataloader_cfg=sft_dataloader_cfg,
- sft_global_batch_size=sft_global_batch_size,
- sft_loss_cfg=sft_loss_cfg,
- seed=42,
- model_cfg=model_cfg,
- optim_cfg=optim_cfg,
- loss_cfg=LossConfig(
- policy_loss_cfg=dict(
- cliprange_high=0.28,
- cliprange_low=0.2,
- loss_type="vanilla",
- ),
- ignore_idx=-100,
- use_kl_loss=True,
- kl_loss_coef=0.001,
- kl_loss_type="low_var_kl",
- mode="eager"),
- lr_cfg=lr_cfg,
- fsdp_cfg=fsdp_cfg,
- load_from=QWEN3_PATH,
- sp_size=1,
- pack_max_length=8192,
- )
-
- TrainingWorker = ray.remote(
- runtime_env={
- "env_vars": {
- "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
- "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1",
- }
- },
- )(BaseTrainingWorker)
- train_workers, _ = AutoAcceleratorWorkers.from_placement_group(
- TrainingWorker, worker_cfg, self.pg
- )
- futures = [worker.test_all_reduce.remote() for worker in train_workers]
- print(ray.get(futures))
- train_controller = TrainingController.remote(
- workers=train_workers,
- )
- ray.get(train_controller.__ray_ready__.remote())
- return train_controller
-
- def test_rl_train_with_sft(self):
- train_controller = self.build_train_controller()
-
- ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0))
- ray.get(train_controller.save.remote(os.path.join(self.temp_dir, "save_test"), no_save_optimizer=True))
-
- log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1))
- efficient_attn_ratio_list = []
- for log_info in log_infos:
- efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio'])
- assert all([efficient_attn_ratio > 0 for efficient_attn_ratio in efficient_attn_ratio_list])
-
- ray.kill(train_controller)
- train_controller = self.build_train_controller()
- load_checkpoint_cfg = LoadCheckpointConfig(checkpoint_path=os.path.join(self.temp_dir, "save_test"),
- load_optimizer_states=False,
- load_optimizer_args=False
- )
- ray.get(train_controller.resume.remote(load_checkpoint_cfg))
-
- log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1))
- new_efficient_attn_ratio_list = []
- for log_info in log_infos:
- new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio'])
-
- efficient_attn_ratio_list.sort()
- new_efficient_attn_ratio_list.sort()
- self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list)
diff --git a/tests/rl/test_rl_trainer.py b/tests/rl/test_rl_trainer.py
deleted file mode 100644
index 82688151d..000000000
--- a/tests/rl/test_rl_trainer.py
+++ /dev/null
@@ -1,266 +0,0 @@
-import os
-import tempfile
-import unittest
-from pathlib import Path
-
-import ray
-import torch
-
-from transformers import AutoTokenizer
-from xtuner.v1.config import (
- AdamWConfig,
- FSDPConfig,
- LRConfig,
-)
-from xtuner.v1.data_proto.rl_data import SampleParams
-from xtuner.v1.datasets import RLTokenizeFnConfig
-from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
-from xtuner.v1.model import get_model_config_from_hf
-from xtuner.v1.rl.utils import AcceleratorResourcesConfig, CPUResourcesConfig
-from xtuner.v1.rl.rollout.worker import RolloutConfig
-try:
- from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
-except Exception:
- class DataFlowConfig:
- def __init__(self, *args, **kwargs):
- self.__dict__.update(kwargs)
-
- class ReplayBufferConfig:
- def __init__(self, *args, **kwargs):
- self.__dict__.update(kwargs)
-try:
- from xtuner.v1.ray.judger.controller import JudgerConfig
-except Exception:
- class JudgerConfig:
- def __init__(self, *args, **kwargs):
- self.__dict__.update(kwargs)
-from xtuner.v1.rl.trainer.worker import WorkerConfig
-from xtuner.v1.rl.loss import GRPOLossConfig
-from xtuner.v1.train.rl_trainer import RLTrainer, RLTrainerConfig
-
-
-MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
-TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
-TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"]
-resource_map = {
- "npu": "NPU",
- "cuda": "GPU",
-}
-
-
-class TestRLTrainer(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- os.environ["XTUNER_USE_FA3"] = "1"
- os.environ["LMD_SKIP_WARMUP"] = "1"
-
- @classmethod
- def tearDownClass(cls):
- del os.environ["XTUNER_USE_FA3"]
- del os.environ["LMD_SKIP_WARMUP"]
-
- def init_traine_worker_config(self, train_optimizer_steps, pack_max_length):
- model_cfg = get_model_config_from_hf(Path(MODEL_PATH))
- optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False)
- loss_cfg = GRPOLossConfig(
- policy_loss_cfg=dict(
- cliprange_high=0.28,
- cliprange_low=0.2,
- loss_type="vanilla",
- clip_ratio_c=10.0,
- log_prob_diff_min=-20.0,
- log_prob_diff_max=20.0,
- ),
- ignore_idx=-100,
- use_kl_loss=False,
- kl_loss_coef=0.0,
- kl_loss_type="low_var_kl",
- mode="chunk",
- chunk_size=512,
- )
- lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
- fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1)
- train_worker_cfg: WorkerConfig = WorkerConfig(
- model_cfg=model_cfg,
- load_from=MODEL_PATH,
- optim_cfg=optim_cfg,
- loss_cfg=loss_cfg,
- lr_cfg=lr_cfg,
- fsdp_cfg=fsdp_cfg,
- sp_size=1,
- optimizer_steps=train_optimizer_steps,
- pack_max_length=pack_max_length,
- )
- return train_worker_cfg
-
- def init_replay_buffer_config(self, max_prompt_length):
- train_dataset_cfg = [
- {
- "dataset": DatasetConfig(name="gsm8k", anno_path=TRAIN_DATA_PATH, sample_ratio=1.0),
- "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length),
- },
- ]
- dataloader_cfg = DataloaderConfig(
- collator="fake_collator",
- pack_level="none",
- group_by_length=False,
- )
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
- replay_buffer_cfg = ReplayBufferConfig(
- dataset_cfg=train_dataset_cfg,
- dataloader_cfg=dataloader_cfg,
- tokenizer=tokenizer,
- worker_log_dir=self.worker_log_dir,
- )
- return replay_buffer_cfg
-
- def init_resources_config(self, num_workers, num_cpus_per_worker, cpu_memory_per_worker):
- resources = AcceleratorResourcesConfig(
- accelerator=resource_map[torch.accelerator.current_accelerator().type],
- num_workers=num_workers,
- num_cpus_per_worker=num_cpus_per_worker,
- cpu_memory_per_worker=cpu_memory_per_worker,
- )
- return resources
-
- def init_cpu_resources_config(self, num_cpus_per_worker, cpu_memory_per_worker):
- cpu_resources = CPUResourcesConfig(
- num_cpus_per_worker=num_cpus_per_worker,
- cpu_memory_per_worker=cpu_memory_per_worker,
- )
- return cpu_resources
-
- def init_rollout_config(self, max_prompt_length, max_response_length):
- rollout_config = RolloutConfig(
- env="test_rl_trainer",
- model_path=MODEL_PATH,
- worker_log_dir=self.worker_log_dir,
- rollout_max_batch_size_per_instance=1024,
- context_length=max_response_length + max_prompt_length,
- )
- return rollout_config
-
- def init_dataflow_config(self, max_response_length, global_batch_size, prompt_repeat_k, enable_partial_rollout):
- sample_params = SampleParams(
- max_tokens=max_response_length,
- )
- dataflow_config = DataFlowConfig(
- env="test_rl_trainer",
- global_batch_size=global_batch_size,
- prompt_repeat_k=prompt_repeat_k,
- worker_log_dir=self.worker_log_dir,
- sample_params=sample_params,
- enable_partial_rollout=enable_partial_rollout,
- max_concurrent=1024,
- )
- return dataflow_config
-
- def init_judger_config(self):
- from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig
-
- gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router")
- judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config], worker_log_dir=self.worker_log_dir)
- return judger_cfg
-
- def init_multi_judger_config(self):
- from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig
-
- # 支持一个GSM8KJudgerConfig创建多个实例
- gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1", judger_type="router")
- gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2", judger_type="router")
- judger_cfg = JudgerConfig(
- reward_judger_configs=[gsm8k_judger_config_1, gsm8k_judger_config_2],
- worker_log_dir=self.worker_log_dir,
- )
- return judger_cfg
-
- def setUp(self):
- ray.init(num_cpus=80, ignore_reinit_error=True)
- self.temp_dir = tempfile.TemporaryDirectory()
- self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
-
- train_optimizer_steps = 2
- pack_max_length = 32768
- max_prompt_length = 2048
- max_response_length = 1024
- global_batch_size = 4
- prompt_repeat_k = 4
- enable_partial_rollout = False
-
- self.train_worker_cfg = self.init_traine_worker_config(train_optimizer_steps, pack_max_length)
- self.replay_buffer_cfg = self.init_replay_buffer_config(max_prompt_length)
- self.resources_cfg = self.init_resources_config(
- num_workers=8, num_cpus_per_worker=8, cpu_memory_per_worker=8 * 1024**3
- )
- self.cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3)
- self.rollout_config = self.init_rollout_config(
- max_response_length=max_response_length, max_prompt_length=max_prompt_length
- )
- self.dataflow_config = self.init_dataflow_config(
- max_response_length=max_response_length,
- global_batch_size=global_batch_size,
- prompt_repeat_k=prompt_repeat_k,
- enable_partial_rollout=enable_partial_rollout,
- )
- self.judger_config = self.init_judger_config()
-
- def tearDown(self):
- self.temp_dir.cleanup()
- ray.shutdown()
-
- def test_rl_trainer(self):
- multi_judger_config = self.init_multi_judger_config()
- cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=2, cpu_memory_per_worker=2 * 1024**3)
- trainer_config = RLTrainerConfig(
- load_from=MODEL_PATH,
- resources=self.resources_cfg,
- cpu_resources=cpu_resources,
- rollout_config=self.rollout_config,
- dataflow_config=self.dataflow_config,
- judger_config=multi_judger_config,
- replay_buffer_config=self.replay_buffer_cfg,
- train_worker_config=self.train_worker_cfg,
- work_dir=self.worker_log_dir,
- tokenizer_path=MODEL_PATH,
- total_epochs=1,
- rollout_steps=1,
- )
- trainer = RLTrainer.from_config(trainer_config)
- self.assertIsNotNone(trainer, "Trainer should be created successfully")
- try:
- trainer.fit()
- except Exception as e:
- self.fail(f"trainer.fit() raised unexpected exception: {e}")
- # assure all writers are closed before checking log files
- del trainer
- log_files = list(Path(self.worker_log_dir).rglob("*.log"))
- self.assertGreater(len(log_files), 0, "Should generate log files")
- trajectory_files = list(Path(self.worker_log_dir).rglob("*_trajectory.jsonl"))
- self.assertGreater(len(trajectory_files), 0, "Should generate trajectory files")
-
- def test_judger_cpu_pg_creation_with_error(self):
- """Test RLTrainer judger_cpu_pg creation."""
- multi_judger_config = self.init_multi_judger_config()
- # error resource with multi-judger
- cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3)
- trainer_config = RLTrainerConfig(
- load_from=MODEL_PATH,
- resources=self.resources_cfg,
- cpu_resources=cpu_resources,
- rollout_config=self.rollout_config,
- dataflow_config=self.dataflow_config,
- judger_config=multi_judger_config,
- replay_buffer_config=self.replay_buffer_cfg,
- train_worker_config=self.train_worker_cfg,
- work_dir=self.worker_log_dir,
- tokenizer_path=MODEL_PATH,
- total_epochs=1,
- rollout_steps=1,
- )
- with self.assertRaises(AssertionError) as cm:
- trainer = RLTrainer.from_config(trainer_config)
-
- print(f"Expected AssertionError caught: {cm.exception}")
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/rl/test_rollout_api_server.py b/tests/rl/test_rollout_api_server.py
new file mode 100644
index 000000000..b197f25d9
--- /dev/null
+++ b/tests/rl/test_rollout_api_server.py
@@ -0,0 +1,314 @@
+import os
+import subprocess
+import tempfile
+import time
+import unittest
+
+import httpx
+import ray
+import torch
+from transformers import AutoTokenizer
+
+from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams
+from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout.worker import RolloutConfig
+from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
+
+
+TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}]
+MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
+MOE_MODEL_PATH = os.environ.get("QWEN3_MOE_PATH") or os.environ["QWEN30B_MODEL_PATH"]
+RESOURCE_MAP = {
+ "npu": "NPU",
+ "cuda": "GPU",
+}
+
+
+@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
+class TestRolloutAPIServer(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ os.environ["XTUNER_USE_FA3"] = "1"
+ os.environ["LMD_SKIP_WARMUP"] = "1"
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ del os.environ["XTUNER_USE_FA3"]
+ del os.environ["LMD_SKIP_WARMUP"]
+
+ def init_config(self):
+ self.resources_cfg = AcceleratorResourcesConfig(
+ accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type],
+ num_workers=4,
+ num_cpus_per_worker=8,
+ cpu_memory_per_worker=16 * 1024**3,
+ )
+ self.max_prompt_length = 512
+ self.max_response_length = 1024
+ self.context_length = self.max_prompt_length + self.max_response_length
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
+
+ def setUp(self):
+ ray.init(num_cpus=80, ignore_reinit_error=True)
+ self.temp_dir = tempfile.TemporaryDirectory()
+ self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
+ self.init_config()
+
+ def tearDown(self):
+ ray.shutdown()
+ self._cleanup_lmdeploy_ray_worker_wrapper()
+ self.temp_dir.cleanup()
+
+ def _cleanup_lmdeploy_ray_worker_wrapper(self):
+ try:
+ result = subprocess.run(
+ ["pkill", "-f", "ray::RayWorkerWrapper*"],
+ capture_output=True,
+ text=True,
+ timeout=10,
+ check=False,
+ )
+ if result.returncode != 0:
+ print(
+ f"pkill command failed with return code {result.returncode}: {result.stderr}."
+ " Maybe no lmdeploy ray::RayWorkerWrapper processes found."
+ )
+ except Exception as exc:
+ print(f"Error stopping ray::RayWorkerWrapper cluster: {exc}")
+
+ def _wait_until_ready(self, base_url: str):
+ deadline = time.time() + 1800
+ last_error = None
+ while time.time() < deadline:
+ try:
+ response = httpx.get(f"{base_url}/healthz", timeout=10.0)
+ if response.status_code == 200:
+ return
+ last_error = f"healthz returned {response.status_code}: {response.text}"
+ except httpx.HTTPError as exc:
+ last_error = repr(exc)
+ time.sleep(5)
+ raise RuntimeError(f"API server at {base_url} did not become ready in time: {last_error}")
+
+ def test_dense_model(self):
+ resource_config = AcceleratorResourcesConfig(
+ accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type],
+ num_workers=4,
+ num_cpus_per_worker=16,
+ cpu_memory_per_worker=8 * 1024**3,
+ )
+ pg = AutoAcceleratorWorkers.build_placement_group(resource_config, name="dense_api_pg")
+ dense_worker_log_dir = os.path.join(self.worker_log_dir, "dense")
+ rollout_config = RolloutConfig(
+ env="test_rollout_api_server_dense",
+ model_path=MODEL_PATH,
+ model_name=os.path.basename(MODEL_PATH).lower(),
+ tokenizer_path=MODEL_PATH,
+ tensor_parallel_size=4,
+ expert_parallel_size=1,
+ context_length=self.context_length,
+ worker_log_dir=dense_worker_log_dir,
+ dist_port_base=38000,
+ api_host="127.0.0.1",
+ api_port=28000,
+ )
+ rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
+ try:
+ metadata = ray.get(rollout_controller.get_rollout_metadata.remote(), timeout=1800)
+ base_url = metadata["api_server_url"]
+ self._wait_until_ready(base_url)
+
+ text_prompt = self.tokenizer.apply_chat_template(
+ TEST_TEXT_MESSAGES,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ test_input_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]
+
+ with httpx.Client(timeout=300.0) as client:
+ generate = client.post(
+ f"{base_url}/generate",
+ json={
+ "message": TEST_TEXT_MESSAGES,
+ "tokens": test_input_ids,
+ "sample_params": {
+ "return_token_ids": True,
+ "temperature": 0.0,
+ "top_k": 1,
+ "max_tokens": 16,
+ },
+ },
+ )
+ self.assertEqual(generate.status_code, 200, generate.text)
+ generate_body = generate.json()
+ self.assertEqual(generate_body["status"], "completed")
+ self.assertIn(generate_body["finish_reason"], {"stop", "length"})
+ self.assertTrue(generate_body["extra_fields"]["request_id"])
+ self.assertGreater(len(generate_body["response_ids"]), 0)
+ self.assertIsInstance(generate_body["response"], str)
+
+ chat = client.post(
+ f"{base_url}/v1/chat/completions",
+ json={
+ "model": rollout_config.model_name,
+ "messages": TEST_TEXT_MESSAGES,
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "max_tokens": 16,
+ },
+ )
+ self.assertEqual(chat.status_code, 200, chat.text)
+ chat_body = chat.json()
+ print("chat_body: ", chat_body)
+ self.assertEqual(chat_body["object"], "chat.completion")
+ self.assertEqual(chat_body["model"], rollout_config.model_name)
+ self.assertTrue(chat_body["id"].startswith("chatcmpl-"))
+ self.assertEqual(chat_body["choices"][0]["message"]["role"], "assistant")
+ self.assertTrue(chat_body["choices"][0]["message"]["content"])
+ self.assertIn(chat_body["choices"][0]["finish_reason"], {"stop", "length"})
+ self.assertGreater(chat_body["usage"]["prompt_tokens"], 0)
+ self.assertGreater(chat_body["usage"]["total_tokens"], chat_body["usage"]["completion_tokens"])
+
+ anthropic = client.post(
+ f"{base_url}/v1/messages",
+ json={
+ "model": rollout_config.model_name,
+ "system": "You are helpful.",
+ "messages": TEST_TEXT_MESSAGES,
+ "max_tokens": 16,
+ "temperature": 0.0,
+ "top_p": 1.0,
+ },
+ )
+ self.assertEqual(anthropic.status_code, 200, anthropic.text)
+ anthropic_body = anthropic.json()
+ self.assertEqual(anthropic_body["type"], "message")
+ self.assertEqual(anthropic_body["role"], "assistant")
+ self.assertEqual(anthropic_body["model"], rollout_config.model_name)
+ self.assertTrue(anthropic_body["id"].startswith("msg_"))
+ self.assertTrue(anthropic_body["content"][0]["text"])
+ self.assertIn(anthropic_body["stop_reason"], {"stop", "length"})
+ self.assertGreater(anthropic_body["usage"]["input_tokens"], 0)
+ self.assertGreaterEqual(anthropic_body["usage"]["output_tokens"], 1)
+
+ invalid_block = client.post(
+ f"{base_url}/v1/messages",
+ json={
+ "model": rollout_config.model_name,
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "look"},
+ {"type": "image", "text": ""},
+ ],
+ }
+ ],
+ "max_tokens": 8,
+ },
+ timeout=30.0,
+ )
+ self.assertEqual(invalid_block.status_code, 400, invalid_block.text)
+ self.assertEqual(invalid_block.json()["type"], "error")
+ self.assertEqual(invalid_block.json()["error"]["type"], "invalid_request_error")
+
+ health = client.get(f"{base_url}/healthz", timeout=30.0)
+ meta = client.get(f"{base_url}/metadata", timeout=30.0)
+ self.assertEqual(health.status_code, 200, health.text)
+ self.assertEqual(health.json()["status"], "ok")
+ self.assertGreaterEqual(health.json()["active_workers"], 1)
+ self.assertEqual(meta.status_code, 200, meta.text)
+ self.assertEqual(meta.json()["api_server_url"], base_url)
+ self.assertEqual(metadata["api_server_url"], base_url)
+ self.assertEqual(meta.json()["api_server_url"].rsplit(":", 1)[-1], str(rollout_config.api_port))
+ self.assertTrue(all(meta.json()["worker_server_urls_status"].values()))
+
+ offload = client.post(f"{base_url}/offload", timeout=120.0)
+ self.assertEqual(offload.status_code, 200, offload.text)
+ self.assertEqual(offload.json()["action"], "offload")
+
+ onload = client.post(f"{base_url}/onload", timeout=120.0)
+ self.assertEqual(onload.status_code, 200, onload.text)
+ self.assertEqual(onload.json()["action"], "onload")
+
+ regenerated = client.post(
+ f"{base_url}/generate",
+ json={
+ "message": TEST_TEXT_MESSAGES,
+ "sample_params": {
+ "return_token_ids": True,
+ "temperature": 0.0,
+ "top_k": 1,
+ "max_tokens": 8,
+ },
+ },
+ )
+ self.assertEqual(regenerated.status_code, 200, regenerated.text)
+ self.assertEqual(regenerated.json()["status"], "completed")
+ finally:
+ try:
+ ray.get(rollout_controller.shutdown.remote(), timeout=300)
+ finally:
+ ray.util.remove_placement_group(pg)
+
+ def test_moe_model(self):
+ resource_config = AcceleratorResourcesConfig(
+ accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type],
+ num_workers=4,
+ num_cpus_per_worker=16,
+ cpu_memory_per_worker=8 * 1024**3,
+ )
+ pg = AutoAcceleratorWorkers.build_placement_group(resource_config, name="moe_api_pg")
+ moe_worker_log_dir = os.path.join(self.worker_log_dir, "moe")
+ rollout_config = RolloutConfig(
+ env="test_rollout_api_server_moe",
+ model_path=MOE_MODEL_PATH,
+ model_name=os.path.basename(MOE_MODEL_PATH).lower(),
+ tokenizer_path=MOE_MODEL_PATH,
+ tensor_parallel_size=1,
+ expert_parallel_size=4,
+ context_length=self.context_length,
+ worker_log_dir=moe_worker_log_dir,
+ dist_port_base=38000 + 1024 * 4,
+ api_host="127.0.0.1",
+ api_port=29000,
+ enable_return_routed_experts=True,
+ )
+ rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
+ try:
+ metadata = ray.get(rollout_controller.get_rollout_metadata.remote(), timeout=1800)
+ base_url = metadata["api_server_url"]
+ self._wait_until_ready(base_url)
+
+ request = RolloutState(
+ message=[{"role": "user", "content": "Briefly explain what mixture of experts means."}],
+ sample_params=SampleParams(
+ return_token_ids=True,
+ return_logprob=False,
+ temperature=0.0,
+ top_k=1,
+ max_tokens=32,
+ ),
+ )
+ with httpx.Client(timeout=300.0) as client:
+ response = client.post(
+ f"{base_url}/generate",
+ json=request.model_dump(mode="json"),
+ )
+ meta = client.get(f"{base_url}/metadata", timeout=30.0)
+
+ self.assertEqual(response.status_code, 200, response.text)
+ rollout_state = RolloutState.model_validate_json(response.text)
+ self.assertIsNotNone(rollout_state.routed_experts)
+ self.assertEqual(meta.status_code, 200, meta.text)
+ self.assertEqual(meta.json()["api_server_url"], base_url)
+ self.assertEqual(metadata["api_server_url"], base_url)
+ self.assertEqual(meta.json()["api_server_url"].rsplit(":", 1)[-1], str(rollout_config.api_port))
+ finally:
+ try:
+ ray.get(rollout_controller.shutdown.remote(), timeout=300)
+ finally:
+ ray.util.remove_placement_group(pg)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py
index e8009eb34..0762b9744 100644
--- a/xtuner/v1/data_proto/rl_data.py
+++ b/xtuner/v1/data_proto/rl_data.py
@@ -1,10 +1,11 @@
from __future__ import annotations
+import base64
from enum import Enum
from typing import TYPE_CHECKING, Any, TypeAlias
import torch
-from pydantic import BaseModel, ConfigDict, field_serializer
+from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
from typing_extensions import NotRequired, TypedDict
# ====================================
@@ -76,11 +77,11 @@ class RolloutState(CacheObj, BaseModel):
reward_model: dict[str, Any] | None = None
num_tokens: int | None = None # 用于 cache 管理
- # --- InferEngine 输入 ---
+ # --- InferEngine 输入 ---å
session_uid: int | None = None
tokens: list[int] | None = None # 每一次推理引擎的实际输入
tools: list | None = None
- tool_choice: str | None = None
+ tool_choice: str | dict[str, Any] | None = None
sample_params: SampleParams = SampleParams()
# --- InferEngine 输出 ---
@@ -104,20 +105,44 @@ class RolloutState(CacheObj, BaseModel):
extra_fields: dict[str, Any] = {}
@field_serializer("routed_experts")
- def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | None:
- """Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。"""
+ def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | str | None:
+ """序列化 routed_experts 字段:
+
+ - None -> None
+ - list[int] -> list[int](原样保留)
+ - RayObjectRef -> base64 编码的字符串(通过 ray.cloudpickle 序列化)
+ """
+ import ray
+
if value is None:
return None
- try:
- import ray
-
- if isinstance(value, ray.ObjectRef):
- return None
- except ImportError:
- pass
- if type(value).__name__ == "ObjectRef" and "ray" in getattr(type(value), "__module__", ""):
+ if isinstance(value, ray.ObjectRef):
+ data = ray.cloudpickle.dumps(value)
+ return base64.b64encode(data).decode("utf-8")
+ return value
+
+ @field_validator("routed_experts", mode="before")
+ @classmethod
+ def _deserialize_routed_experts(cls, value: Any) -> list[int] | RayObjectRef | None:
+ """反序列化 routed_experts 字段:
+
+ - None -> None
+ - list[int] -> list[int](原样保留)
+ - str(base64 编码)-> RayObjectRef(通过 ray.cloudpickle 反序列化)
+ - RayObjectRef -> RayObjectRef(原样保留)
+ """
+ import ray
+
+ if value is None:
return None
- return value # list[int]
+ if isinstance(value, ray.ObjectRef):
+ return value
+ if isinstance(value, str):
+ data = base64.b64decode(value)
+ return ray.cloudpickle.loads(data)
+ if isinstance(value, list):
+ return value
+ return value
def update_status_from_finish_reason(finish_reason: str | None) -> Status:
diff --git a/xtuner/v1/rl/agent_loop/camel_agent_loop.py b/xtuner/v1/rl/agent_loop/camel_agent_loop.py
new file mode 100644
index 000000000..4abd88c47
--- /dev/null
+++ b/xtuner/v1/rl/agent_loop/camel_agent_loop.py
@@ -0,0 +1,190 @@
+import copy
+import os
+from typing import Any, Literal
+
+import ray
+from camel.agents import ChatAgent
+from camel.models import OpenAICompatibleModel
+from camel.utils import BaseTokenCounter
+from openai import AsyncOpenAI
+from pydantic import ConfigDict
+
+from xtuner.v1.data_proto import RolloutState, SampleParams, Status
+from xtuner.v1.rl.rollout import RolloutController
+from xtuner.v1.rl.rollout.chat_adapter.collector import reset_current_trace_collector, set_current_trace_collector
+from xtuner.v1.rl.rollout.utils import ROLLOUT_RAY_GET_TIMEOUT
+
+from .agent_loop import AgentLoop, AgentLoopConfig
+
+
+class XTunerCamelTokenCounter(BaseTokenCounter):
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+
+ def count_tokens_from_messages(self, messages: list[dict]) -> int:
+ return len(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=True))
+
+ def encode(self, text: str) -> list[int]:
+ return list(self.tokenizer(text, add_special_tokens=False)["input_ids"])
+
+ def decode(self, token_ids: list[int]) -> str:
+ return self.tokenizer.decode(token_ids)
+
+
+class CamelAgentLoopConfig(AgentLoopConfig):
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
+
+ mode: Literal["single_turn"] = "single_turn"
+ context_length: int | None = None
+ system_message: str | None = None
+ tools: list[Any] | None = None
+ tool_choice: str | dict[str, Any] | None = None
+
+ def build(self, rollout_controller, judger=None, logger=None) -> "CamelAgentLoop":
+ return CamelAgentLoop(
+ context_length=self.context_length,
+ system_message=self.system_message,
+ tools=self.tools,
+ tool_choice=self.tool_choice,
+ rollout_ctl=rollout_controller,
+ hf_checkpoint=self.hf_checkpoint,
+ sample_params=self.sample_params,
+ judger=judger,
+ logger=logger,
+ )
+
+
+class CamelAgentLoop(AgentLoop):
+ def __init__(
+ self,
+ context_length: int | None,
+ system_message: str | None,
+ tools: list[Any] | None,
+ tool_choice: str | dict[str, Any] | None,
+ rollout_ctl: RolloutController,
+ hf_checkpoint: str,
+ sample_params: SampleParams,
+ judger=None,
+ logger=None,
+ ) -> None:
+ super().__init__(
+ rollout_ctl=rollout_ctl,
+ hf_checkpoint=hf_checkpoint,
+ sample_params=sample_params,
+ judger=judger,
+ logger=logger,
+ )
+ self.context_length = context_length
+ self.system_message = system_message
+ self.tools = tools
+ self.tool_choice = tool_choice
+ self._api_server_url: str | None = None
+ self._model_name: str | None = None
+
+ def init_agent(self):
+ metadata = ray.get(self.rollout_ctl.get_rollout_metadata.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT)
+ self._api_server_url = metadata["api_server_url"]
+ self._model_name = getattr(metadata.get("rollout_config"), "model_name", None) or "rollout-controller"
+
+ client = AsyncOpenAI(
+ base_url=f"{self._api_server_url.rstrip('/')}/v1",
+ api_key=os.environ.get("OPENAI_API_KEY", "EMPTY"),
+ timeout=180.0,
+ max_retries=3,
+ )
+ model = OpenAICompatibleModel(
+ model_type=self._model_name,
+ client=client,
+ async_client=client,
+ model_config_dict=self._build_camel_model_config(copy.deepcopy(self.sample_params)),
+ token_counter=XTunerCamelTokenCounter(self.tokenizer),
+ )
+ return ChatAgent(
+ system_message=self.system_message,
+ model=model,
+ token_limit=self.context_length,
+ step_timeout=180.0,
+ tools=self.tools,
+ )
+
+ async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> list[RolloutState]:
+ try:
+ agent = self.init_agent()
+ messages = copy.deepcopy(rollout_state.message)
+ turn_rollout_states: list[RolloutState] = []
+ collector_token = set_current_trace_collector(turn_rollout_states)
+ try:
+ await agent.astep(messages[-1]["content"])
+ finally:
+ reset_current_trace_collector(collector_token)
+
+ if not turn_rollout_states:
+ raise RuntimeError("no rollout states captured from gateway")
+
+ normalized_rollout_states: list[RolloutState] = []
+ trace_records: list[dict[str, Any]] = []
+ for turn_index, turn_state in enumerate(turn_rollout_states):
+ if turn_state.prompt_ids is None:
+ raise RuntimeError(f"captured Camel rollout turn {turn_index} is missing prompt_ids")
+ if turn_state.response_ids is None:
+ raise RuntimeError(f"captured Camel rollout turn {turn_index} is missing response_ids")
+ normalized_turn_state = turn_state.model_copy(deep=True)
+ normalized_turn_state.message_uid = rollout_state.message_uid
+ normalized_turn_state.message = copy.deepcopy(rollout_state.message)
+ normalized_turn_state.data_source = copy.deepcopy(rollout_state.data_source)
+ normalized_turn_state.mm_info = copy.deepcopy(rollout_state.mm_info)
+ normalized_turn_state.reward_model = copy.deepcopy(rollout_state.reward_model)
+ normalized_turn_state.sample_params = copy.deepcopy(rollout_state.sample_params)
+ normalized_turn_state.task_name = rollout_state.task_name
+ normalized_turn_state.seq_staleness = rollout_state.seq_staleness
+ normalized_turn_state.extra_fields = {
+ **copy.deepcopy(rollout_state.extra_fields),
+ **copy.deepcopy(normalized_turn_state.extra_fields),
+ "gateway_rollout_index": turn_index,
+ }
+ normalized_rollout_states.append(normalized_turn_state)
+ trace_records.append(
+ {
+ "request_id": str(normalized_turn_state.uid)
+ if normalized_turn_state.uid is not None
+ else None,
+ "prompt_ids": list(normalized_turn_state.prompt_ids),
+ "response_ids": list(normalized_turn_state.response_ids),
+ "logprobs": None
+ if normalized_turn_state.logprobs is None
+ else list(normalized_turn_state.logprobs),
+ "routed_experts": normalized_turn_state.routed_experts,
+ "finish_reason": normalized_turn_state.finish_reason,
+ "status": normalized_turn_state.status,
+ }
+ )
+ for normalized_turn_state in normalized_rollout_states:
+ normalized_turn_state.extra_fields["gateway_trace_records"] = copy.deepcopy(trace_records)
+ return normalized_rollout_states
+ except Exception as exc:
+ rollout_state.status = Status.FAILED
+ rollout_state.error_msg = f"Camel agent loop failed: {exc}"
+ return [rollout_state]
+
+ def _build_camel_model_config(self, sample_params: SampleParams) -> dict:
+ model_config = {
+ "temperature": sample_params.temperature,
+ "top_p": sample_params.top_p,
+ "max_tokens": sample_params.max_tokens,
+ "stream": False,
+ }
+ if self.tool_choice is not None:
+ model_config["tool_choice"] = copy.deepcopy(self.tool_choice)
+ if sample_params.presence_penalty:
+ model_config["presence_penalty"] = sample_params.presence_penalty
+ if sample_params.frequency_penalty:
+ model_config["frequency_penalty"] = sample_params.frequency_penalty
+ if sample_params.stops:
+ model_config["stop"] = sample_params.stops
+ return model_config
+
+ def _extract_finish_reason(self, response) -> str:
+ reasons = response.info.get("termination_reasons", []) if getattr(response, "info", None) else []
+ if reasons and reasons[-1] in {"stop", "length", "tool_calls"}:
+ return reasons[-1]
+ return "stop"
diff --git a/xtuner/v1/rl/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py
index 349cd2fad..57e7cf13e 100644
--- a/xtuner/v1/rl/rollout/__init__.py
+++ b/xtuner/v1/rl/rollout/__init__.py
@@ -1,5 +1,6 @@
import os
+from .chat_adapter import AnthropicChatAdapter, OpenAIChatAdapter
from .controller import RolloutController
from .worker import RolloutWorker
diff --git a/xtuner/v1/rl/rollout/api_server.py b/xtuner/v1/rl/rollout/api_server.py
new file mode 100644
index 000000000..5a387b069
--- /dev/null
+++ b/xtuner/v1/rl/rollout/api_server.py
@@ -0,0 +1,419 @@
+from __future__ import annotations
+
+import json
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, Any
+
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import JSONResponse, StreamingResponse
+
+from xtuner.v1.data_proto.rl_data import RolloutState, Status
+
+from .chat_adapter import (
+ AnthropicChatAdapterError,
+ AnthropicCountTokensRequest,
+ AnthropicCountTokensResponse,
+ AnthropicMessagesRequest,
+ AnthropicMessagesResponse,
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ OpenAIChatAdapterError,
+ ResponsesRequest,
+ ResponsesResponse,
+)
+from .utils import ensure_rollout_request_id
+
+
+if TYPE_CHECKING:
+ from .controller import RolloutControllerProxy
+
+
+def _build_error_response(
+ status_code: int,
+ message: str,
+ error_type: str,
+ code: str | None = None,
+ request_id: str | None = None,
+ protocol: str = "openai",
+) -> JSONResponse:
+ if protocol == "anthropic":
+ payload = {
+ "type": "error",
+ "error": {
+ "type": error_type,
+ "message": message,
+ },
+ }
+ if request_id is not None:
+ payload["request_id"] = request_id
+ else:
+ payload = {
+ "error": {
+ "message": message,
+ "type": error_type,
+ "code": code,
+ "request_id": request_id,
+ }
+ }
+ return JSONResponse(status_code=status_code, content=payload)
+
+
+def create_rollout_api_app(
+ rollout_controller: RolloutControllerProxy,
+ logger: Any,
+) -> FastAPI:
+ """Build the rollout API app around the provided rollout controller."""
+ app = FastAPI()
+
+ @app.exception_handler(HTTPException)
+ async def handle_http_exception(request: Request, exc: HTTPException) -> JSONResponse:
+ request_id = request.headers.get("X-Request-Id")
+ if isinstance(exc.detail, dict) and "error" in exc.detail:
+ return JSONResponse(status_code=exc.status_code, content=exc.detail)
+ return _build_error_response(
+ status_code=exc.status_code,
+ message=str(exc.detail),
+ error_type="invalid_request_error" if exc.status_code < 500 else "server_error",
+ code="http_error",
+ request_id=request_id,
+ )
+
+ @app.post("/generate")
+ async def generate(request: RolloutState) -> RolloutState:
+ request_id = ensure_rollout_request_id(request)
+ try:
+ response = await rollout_controller.generate(request)
+ if not response.extra_fields.get("request_id"):
+ response.extra_fields["request_id"] = request_id
+ return response
+ except Exception as exc:
+ logger.error(f"Generate failed in API server for request_id={request_id}: {exc}")
+ request.status = Status.FAILED
+ request.error_msg = f"Generate failed in API server with error: {str(exc)}"
+ return request
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(request: ChatCompletionRequest, http_request: Request) -> ChatCompletionResponse:
+ try:
+ return await rollout_controller.chat(request)
+ except OpenAIChatAdapterError as exc:
+ status_code = 400 if exc.error_type == "invalid_request_error" else 500
+ raise HTTPException(
+ status_code=status_code,
+ detail={
+ "error": {
+ "message": exc.message,
+ "type": exc.error_type,
+ "code": exc.code,
+ "request_id": exc.request_id,
+ }
+ },
+ )
+
+ @app.post("/v1/responses", response_model=None)
+ async def responses(request: ResponsesRequest, http_request: Request):
+ try:
+ if request.stream:
+ non_stream_request = request.model_copy(update={"stream": False})
+ response = await rollout_controller.responses(non_stream_request)
+ return StreamingResponse(
+ _iter_openai_responses_sse_events(response),
+ media_type="text/event-stream",
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
+ )
+ return await rollout_controller.responses(request)
+ except OpenAIChatAdapterError as exc:
+ status_code = 400 if exc.error_type == "invalid_request_error" else 500
+ raise HTTPException(
+ status_code=status_code,
+ detail={
+ "error": {
+ "message": exc.message,
+ "type": exc.error_type,
+ "code": exc.code,
+ "request_id": exc.request_id,
+ }
+ },
+ )
+
+ @app.post("/v1/messages", response_model=None)
+ async def anthropic_messages(
+ request: AnthropicMessagesRequest, http_request: Request
+ ):
+ try:
+ if request.stream:
+ non_stream_request = request.model_copy(update={"stream": False})
+ response = await rollout_controller.anthropic_messages(non_stream_request)
+ return StreamingResponse(
+ _iter_anthropic_sse_events(response),
+ media_type="text/event-stream",
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
+ )
+ return await rollout_controller.anthropic_messages(request)
+ except AnthropicChatAdapterError as exc:
+ status_code = 400 if exc.error_type == "invalid_request_error" else 500
+ return _build_error_response(
+ status_code=status_code,
+ message=exc.message,
+ error_type=exc.error_type,
+ request_id=exc.request_id,
+ protocol="anthropic",
+ )
+
+ @app.post("/v1/messages/count_tokens")
+ async def anthropic_count_tokens(request: AnthropicCountTokensRequest) -> AnthropicCountTokensResponse:
+ return await rollout_controller.anthropic_count_tokens(request)
+
+ @app.get("/healthz")
+ async def healthz():
+ is_ready, payload = rollout_controller.get_ready_status()
+ if is_ready:
+ return {"status": "ok", **payload}
+ return JSONResponse(status_code=503, content={"status": "not_ready", **payload})
+
+ @app.get("/metadata")
+ async def metadata():
+ return rollout_controller.get_rollout_metadata()
+
+ @app.post("/pause")
+ async def pause():
+ rollout_controller.pause_generation()
+ return {"status": "ok", "action": "pause"}
+
+ @app.post("/continue")
+ async def continue_generation():
+ rollout_controller.continue_generation()
+ return {"status": "ok", "action": "continue"}
+
+ @app.post("/offload")
+ async def offload():
+ rollout_controller.offload()
+ return {"status": "ok", "action": "offload"}
+
+ @app.post("/onload")
+ async def onload():
+ rollout_controller.onload()
+ return {"status": "ok", "action": "onload"}
+
+ @app.post("/shutdown")
+ async def shutdown():
+ rollout_controller.shutdown()
+ return {"status": "ok", "action": "shutdown"}
+
+ return app
+
+
+def _iter_anthropic_sse_events(response: AnthropicMessagesResponse) -> Iterator[str]:
+ output_tokens = response.usage.output_tokens
+ message_start = {
+ "type": "message_start",
+ "message": {
+ "id": response.id,
+ "type": response.type,
+ "role": response.role,
+ "content": [],
+ "model": response.model,
+ "stop_reason": None,
+ "stop_sequence": None,
+ "usage": {
+ "input_tokens": response.usage.input_tokens,
+ "output_tokens": 1 if output_tokens > 0 else 0,
+ },
+ },
+ }
+ yield _format_sse("message_start", message_start)
+
+ chunk_size = 64
+ for index, block in enumerate(response.content):
+ block_type = block.get("type")
+ if block_type == "text":
+ yield _format_sse(
+ "content_block_start",
+ {"type": "content_block_start", "index": index, "content_block": {"type": "text", "text": ""}},
+ )
+ text = str(block.get("text", ""))
+ for offset in range(0, len(text), chunk_size):
+ chunk = text[offset : offset + chunk_size]
+ yield _format_sse(
+ "content_block_delta",
+ {"type": "content_block_delta", "index": index, "delta": {"type": "text_delta", "text": chunk}},
+ )
+ yield _format_sse("content_block_stop", {"type": "content_block_stop", "index": index})
+ elif block_type == "tool_use":
+ yield _format_sse(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": index,
+ "content_block": {
+ "type": "tool_use",
+ "id": block["id"],
+ "name": block["name"],
+ "input": {},
+ },
+ },
+ )
+ input_json = json.dumps(block.get("input", {}), ensure_ascii=False)
+ for offset in range(0, len(input_json), chunk_size):
+ chunk = input_json[offset : offset + chunk_size]
+ yield _format_sse(
+ "content_block_delta",
+ {
+ "type": "content_block_delta",
+ "index": index,
+ "delta": {"type": "input_json_delta", "partial_json": chunk},
+ },
+ )
+ yield _format_sse("content_block_stop", {"type": "content_block_stop", "index": index})
+
+ yield _format_sse(
+ "message_delta",
+ {
+ "type": "message_delta",
+ "delta": {
+ "stop_reason": response.stop_reason,
+ "stop_sequence": response.stop_sequence,
+ },
+ "usage": {
+ "output_tokens": response.usage.output_tokens,
+ },
+ },
+ )
+ yield _format_sse("message_stop", {"type": "message_stop"})
+
+
+def _format_sse(event: str, data: dict[str, Any]) -> str:
+ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
+
+
+def _iter_openai_responses_sse_events(response: ResponsesResponse) -> Iterator[str]:
+ sequence_number = 0
+ response_snapshot = response.model_dump(mode="python")
+ in_progress_response = {**response_snapshot, "status": "in_progress"}
+ yield _format_openai_response_sse({"type": "response.created", "sequence_number": sequence_number, "response": in_progress_response})
+ sequence_number += 1
+
+ for output_index, item in enumerate(response.output):
+ item_type = item.get("type")
+ if item_type == "message":
+ yield _format_openai_response_sse(
+ {
+ "type": "response.output_item.added",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item": item,
+ }
+ )
+ sequence_number += 1
+ for content_index, part in enumerate(item.get("content", [])):
+ if part.get("type") != "output_text":
+ continue
+ yield _format_openai_response_sse(
+ {
+ "type": "response.content_part.added",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item_id": item["id"],
+ "content_index": content_index,
+ "part": {"type": "output_text", "text": "", "annotations": []},
+ }
+ )
+ sequence_number += 1
+ text = str(part.get("text", ""))
+ chunk_size = 64
+ for offset in range(0, len(text), chunk_size):
+ yield _format_openai_response_sse(
+ {
+ "type": "response.output_text.delta",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item_id": item["id"],
+ "content_index": content_index,
+ "delta": text[offset : offset + chunk_size],
+ }
+ )
+ sequence_number += 1
+ yield _format_openai_response_sse(
+ {
+ "type": "response.output_text.done",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item_id": item["id"],
+ "content_index": content_index,
+ "text": text,
+ }
+ )
+ sequence_number += 1
+ yield _format_openai_response_sse(
+ {
+ "type": "response.content_part.done",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item_id": item["id"],
+ "content_index": content_index,
+ "part": {"type": "output_text", "text": text, "annotations": part.get("annotations", [])},
+ }
+ )
+ sequence_number += 1
+ yield _format_openai_response_sse(
+ {
+ "type": "response.output_item.done",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item": item,
+ }
+ )
+ sequence_number += 1
+ elif item_type == "function_call":
+ added_item = {**item, "arguments": "", "status": "in_progress"}
+ yield _format_openai_response_sse(
+ {
+ "type": "response.output_item.added",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item": added_item,
+ }
+ )
+ sequence_number += 1
+ arguments = str(item.get("arguments", ""))
+ chunk_size = 64
+ for offset in range(0, len(arguments), chunk_size):
+ yield _format_openai_response_sse(
+ {
+ "type": "response.function_call_arguments.delta",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item_id": item["id"],
+ "delta": arguments[offset : offset + chunk_size],
+ }
+ )
+ sequence_number += 1
+ yield _format_openai_response_sse(
+ {
+ "type": "response.function_call_arguments.done",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item_id": item["id"],
+ "arguments": arguments,
+ "name": item.get("name"),
+ }
+ )
+ sequence_number += 1
+ yield _format_openai_response_sse(
+ {
+ "type": "response.output_item.done",
+ "sequence_number": sequence_number,
+ "output_index": output_index,
+ "item": item,
+ }
+ )
+ sequence_number += 1
+
+ yield _format_openai_response_sse(
+ {"type": "response.completed", "sequence_number": sequence_number, "response": response_snapshot}
+ )
+ yield "data: [DONE]\n\n"
+
+
+def _format_openai_response_sse(data: dict[str, Any]) -> str:
+ return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
diff --git a/xtuner/v1/rl/rollout/chat_adapter/__init__.py b/xtuner/v1/rl/rollout/chat_adapter/__init__.py
new file mode 100644
index 000000000..d1bd0189d
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/__init__.py
@@ -0,0 +1,41 @@
+from .anthropic import (
+ AnthropicChatAdapter,
+ AnthropicChatAdapterError,
+ AnthropicCountTokensRequest,
+ AnthropicCountTokensResponse,
+ AnthropicMessagesRequest,
+ AnthropicMessagesResponse,
+ bind_anthropic_chat_interface,
+)
+from .base import BaseChatAPIAdapter
+from .openai import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ OpenAIChatAdapter,
+ OpenAIChatAdapterError,
+ bind_openai_chat_interface,
+)
+from .responses import ResponsesRequest, ResponsesResponse, bind_openai_responses_interface
+from .trace import ChatTraceRecord, ChatTraceStore
+
+
+__all__ = [
+ "AnthropicChatAdapter",
+ "AnthropicChatAdapterError",
+ "AnthropicCountTokensRequest",
+ "AnthropicCountTokensResponse",
+ "AnthropicMessagesRequest",
+ "AnthropicMessagesResponse",
+ "ChatCompletionRequest",
+ "ChatCompletionResponse",
+ "OpenAIChatAdapter",
+ "OpenAIChatAdapterError",
+ "ResponsesRequest",
+ "ResponsesResponse",
+ "BaseChatAPIAdapter",
+ "ChatTraceRecord",
+ "ChatTraceStore",
+ "bind_anthropic_chat_interface",
+ "bind_openai_chat_interface",
+ "bind_openai_responses_interface",
+]
diff --git a/xtuner/v1/rl/rollout/chat_adapter/anthropic.py b/xtuner/v1/rl/rollout/chat_adapter/anthropic.py
new file mode 100644
index 000000000..28989138e
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/anthropic.py
@@ -0,0 +1,514 @@
+import json
+import re
+from typing import Any, Literal
+from uuid import uuid4
+
+from pydantic import BaseModel, ConfigDict
+
+from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
+from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
+
+from .base import BaseChatAPIAdapter
+from .trace import normalize_trace_payload
+
+
+class AnthropicTextContent(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ type: str = "text"
+ text: str
+
+
+AnthropicContentBlock = dict[str, Any]
+
+
+class AnthropicMessage(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ role: Literal["user", "assistant"]
+ content: str | list[AnthropicContentBlock]
+
+
+class AnthropicMessagesRequest(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ session_uid: int | None = None
+ model: str | None = None
+ system: str | list[dict[str, Any]] | None = None
+ messages: list[AnthropicMessage]
+ max_tokens: int
+ stream: bool = False
+ temperature: float | None = None
+ top_p: float | None = None
+ stop_sequences: list[str] | None = None
+ tools: list[dict[str, Any]] | None = None
+ tool_choice: str | dict[str, Any] | None = None
+
+
+class AnthropicCountTokensRequest(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ model: str | None = None
+ system: str | list[dict[str, Any]] | None = None
+ messages: list[AnthropicMessage]
+ tools: list[dict[str, Any]] | None = None
+
+
+class AnthropicCountTokensResponse(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ input_tokens: int
+
+
+class AnthropicUsage(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ input_tokens: int
+ output_tokens: int
+
+
+class AnthropicMessagesResponse(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ id: str
+ type: Literal["message"] = "message"
+ role: Literal["assistant"] = "assistant"
+ content: list[dict[str, Any]]
+ model: str
+ stop_reason: str | None = None
+ stop_sequence: str | None = None
+ usage: AnthropicUsage
+
+
+class AnthropicChatAdapterError(RuntimeError):
+ def __init__(self, message: str, error_type: str, request_id: str | None = None):
+ super().__init__(message)
+ self.message = message
+ self.error_type = error_type
+ self.request_id = request_id
+
+
+class AnthropicChatAdapter(BaseChatAPIAdapter[AnthropicMessagesRequest, AnthropicMessagesResponse]):
+ _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL)
+ _qwen_function_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL)
+ _qwen_parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL)
+
+ def __init__(
+ self,
+ generate_handler,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None,
+ default_model_name: str | None = None,
+ context_length: int | None = None,
+ capture_path: str | None = None,
+ ):
+ if isinstance(tokenizer, str):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
+ super().__init__(generate_handler, tokenizer=tokenizer, capture_path=capture_path)
+ self._default_model_name = default_model_name
+ self._context_length = context_length
+
+ async def messages(self, request: AnthropicMessagesRequest) -> AnthropicMessagesResponse:
+ return await self.handle_request(request)
+
+ async def count_tokens(self, request: AnthropicCountTokensRequest) -> AnthropicCountTokensResponse:
+ internal_messages = self._build_internal_messages(request)
+ rollout_state = RolloutState(message=internal_messages)
+ tokenizer_tools = self._normalize_tools_for_backend(request.tools)
+ if self._tokenizer is not None:
+ raw_prompt_ids = self._tokenizer.apply_chat_template(
+ internal_messages,
+ tools=tokenizer_tools,
+ tokenize=True,
+ add_generation_prompt=True,
+ )
+ rollout_state.prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids)
+ rollout_state.tokens = rollout_state.prompt_ids
+ return AnthropicCountTokensResponse(input_tokens=self._count_prompt_tokens(rollout_state))
+
+ def validate_request(self, request: AnthropicMessagesRequest) -> None:
+ if request.stream:
+ raise AnthropicChatAdapterError(
+ "stream=true is not supported yet",
+ "invalid_request_error",
+ )
+
+ def request_to_rollout_state(self, request: AnthropicMessagesRequest) -> RolloutState:
+ internal_messages = self._build_internal_messages(request)
+ tokenizer_tools = self._normalize_tools_for_backend(request.tools)
+ normalized_tool_choice = self._normalize_tool_choice_for_backend(request.tool_choice)
+ prompt_ids = None
+ if self._tokenizer is not None:
+ raw_prompt_ids = self._tokenizer.apply_chat_template(
+ internal_messages,
+ tools=tokenizer_tools,
+ tokenize=True,
+ add_generation_prompt=True,
+ )
+ prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids)
+ max_tokens = self._fit_max_tokens_to_context(prompt_ids=prompt_ids, requested_max_tokens=request.max_tokens)
+ return RolloutState(
+ uid=uuid4().int,
+ message=internal_messages,
+ prompt_ids=prompt_ids,
+ tokens=prompt_ids,
+ session_uid=request.session_uid,
+ tools=tokenizer_tools,
+ tool_choice=normalized_tool_choice,
+ sample_params=self._build_sample_params(request, max_tokens=max_tokens),
+ )
+
+ def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None:
+ if response.status == Status.FAILED:
+ raise AnthropicChatAdapterError(
+ response.error_msg or "Rollout generation failed",
+ "api_error",
+ request_id,
+ )
+
+ def normalize_request(self, request: AnthropicMessagesRequest) -> dict[str, Any]:
+ return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True))
+
+ def normalize_response(self, response: AnthropicMessagesResponse) -> dict[str, Any]:
+ return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True))
+
+ def rollout_state_to_response(
+ self,
+ rollout_state: RolloutState,
+ request: AnthropicMessagesRequest,
+ ) -> AnthropicMessagesResponse:
+ assert rollout_state.uid is not None, "uid should not be None when generating response"
+ request_id = str(rollout_state.uid)
+ model_name = request.model or self._default_model_name or "rollout-controller"
+ prompt_tokens = self._count_prompt_tokens(rollout_state)
+ completion_tokens = self._count_completion_tokens(rollout_state)
+ content_blocks = self._build_response_content_blocks(rollout_state)
+ stop_reason = "tool_use" if any(block.get("type") == "tool_use" for block in content_blocks) else rollout_state.finish_reason
+
+ return AnthropicMessagesResponse(
+ id=f"msg_{request_id}",
+ content=content_blocks,
+ model=model_name,
+ stop_reason=stop_reason,
+ usage=AnthropicUsage(
+ input_tokens=prompt_tokens,
+ output_tokens=completion_tokens,
+ ),
+ )
+
+ def _build_internal_messages(self, request: AnthropicMessagesRequest) -> list[dict[str, str]]:
+ messages: list[dict[str, Any]] = []
+
+ if request.system:
+ if isinstance(request.system, str):
+ system_text = request.system
+ else:
+ system_text = self._join_text_blocks(request.system, context="system")
+ messages.append({"role": "system", "content": system_text})
+
+ for message in request.messages:
+ if isinstance(message.content, str):
+ messages.append({"role": message.role, "content": message.content})
+ else:
+ messages.extend(self._convert_content_blocks_to_backend_messages(message.role, message.content))
+
+ return messages
+
+ def _join_text_blocks(self, blocks: list[dict[str, Any]], context: str) -> str:
+ unsupported_types = [block.get("type") for block in blocks if block.get("type") != "text"]
+ if unsupported_types:
+ unsupported_str = ", ".join(sorted(set(unsupported_types)))
+ raise AnthropicChatAdapterError(
+ f"Unsupported Anthropic content block type(s) in {context}: {unsupported_str}",
+ "invalid_request_error",
+ )
+ return "\n".join(str(block.get("text", "")) for block in blocks)
+
+ def _convert_content_blocks_to_backend_messages(self, role: str, blocks: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ backend_messages: list[dict[str, Any]] = []
+ text_chunks: list[str] = []
+ tool_calls: list[dict[str, Any]] = []
+
+ def flush_text_chunks():
+ if text_chunks:
+ backend_messages.append({"role": role, "content": "\n".join(text_chunks)})
+ text_chunks.clear()
+
+ for block in blocks:
+ block_type = block.get("type")
+ if block_type == "text":
+ text_value = str(block.get("text", ""))
+ if role == "assistant":
+ text_value = self._sanitize_assistant_text(text_value)
+ text_chunks.append(text_value)
+ elif block_type == "tool_use":
+ tool_calls.append(
+ {
+ "id": block.get("id") or f"toolu_{uuid4().hex}",
+ "type": "function",
+ "function": {
+ "name": str(block.get("name", "")),
+ "arguments": normalize_trace_payload(block.get("input", {})),
+ },
+ }
+ )
+ elif block_type == "tool_result":
+ flush_text_chunks()
+ backend_messages.append(
+ {
+ "role": "tool",
+ "content": self._serialize_tool_result_content(block.get("content")),
+ "tool_call_id": block.get("tool_use_id"),
+ }
+ )
+ else:
+ raise AnthropicChatAdapterError(
+ f"Unsupported Anthropic content block type in messages[{role}]: {block_type}",
+ "invalid_request_error",
+ )
+
+ if tool_calls:
+ backend_messages.append(
+ {
+ "role": role,
+ "content": "\n".join(text_chunks),
+ "tool_calls": tool_calls,
+ }
+ )
+ text_chunks.clear()
+ flush_text_chunks()
+ return backend_messages
+
+ def _serialize_tool_result_content(self, content: Any) -> str:
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ if all(isinstance(item, dict) and item.get("type") == "text" for item in content):
+ return "\n".join(str(item.get("text", "")) for item in content)
+ return json.dumps(content, ensure_ascii=False)
+ if isinstance(content, dict):
+ return json.dumps(content, ensure_ascii=False)
+ return str(content)
+
+ def _normalize_tools_for_backend(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
+ if not tools:
+ return None
+ normalized_tools = []
+ for tool in tools:
+ if tool.get("type") == "function":
+ normalized_tools.append(normalize_trace_payload(tool))
+ else:
+ normalized_tools.append(
+ {
+ "type": "function",
+ "function": {
+ "name": tool["name"],
+ "description": tool.get("description", ""),
+ "parameters": tool["input_schema"],
+ },
+ }
+ )
+ return normalize_trace_payload(normalized_tools)
+
+ def _normalize_tool_choice_for_backend(self, tool_choice: str | dict[str, Any] | None) -> str | dict[str, Any] | None:
+ if tool_choice is None:
+ return None
+ if isinstance(tool_choice, str):
+ return tool_choice
+ choice_type = tool_choice.get("type")
+ if choice_type == "auto":
+ return "auto"
+ if choice_type == "none":
+ return "none"
+ if choice_type == "any":
+ return "required"
+ if choice_type == "tool":
+ return {
+ "type": "function",
+ "function": {
+ "name": tool_choice.get("name"),
+ },
+ }
+ return normalize_trace_payload(tool_choice)
+
+ def _build_response_content_blocks(self, rollout_state: RolloutState) -> list[dict[str, Any]]:
+ raw_response = rollout_state.response or ""
+ tool_calls = rollout_state.extra_fields.get("tool_calls")
+ if not tool_calls:
+ text_blocks, parsed_tool_calls = self._parse_textual_tool_calls(raw_response)
+ if parsed_tool_calls:
+ tool_calls = parsed_tool_calls
+ rollout_state.extra_fields["tool_calls"] = parsed_tool_calls
+ response_text = "".join(block["text"] for block in text_blocks if block["type"] == "text")
+ rollout_state.response = self._sanitize_assistant_text(response_text)
+ else:
+ rollout_state.response = self._sanitize_assistant_text(raw_response)
+
+ if not tool_calls:
+ return [{"type": "text", "text": rollout_state.response or ""}]
+
+ content_blocks: list[dict[str, Any]] = []
+ if rollout_state.response:
+ content_blocks.append({"type": "text", "text": rollout_state.response})
+ for tool_call in tool_calls:
+ content_blocks.append(
+ {
+ "type": "tool_use",
+ "id": tool_call.get("id") or f"toolu_{uuid4().hex}",
+ "name": tool_call["function"]["name"],
+ "input": self._parse_tool_arguments(tool_call["function"].get("arguments")),
+ }
+ )
+ return content_blocks
+
+ def _parse_textual_tool_calls(self, text: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
+ if not text:
+ return [], []
+ content_blocks: list[dict[str, Any]] = []
+ tool_calls: list[dict[str, Any]] = []
+ last_end = 0
+ for match in self._tool_call_pattern.finditer(text):
+ if match.start() > last_end:
+ content_blocks.append({"type": "text", "text": text[last_end : match.start()]})
+ raw_payload = match.group(1).strip()
+ parsed_tool_call = self._parse_single_textual_tool_call(raw_payload)
+ if parsed_tool_call is not None:
+ tool_calls.append(parsed_tool_call)
+ else:
+ content_blocks.append({"type": "text", "text": match.group(0)})
+ last_end = match.end()
+ if last_end < len(text):
+ content_blocks.append({"type": "text", "text": text[last_end:]})
+ return content_blocks, tool_calls
+
+ def _parse_single_textual_tool_call(self, raw_payload: str) -> dict[str, Any] | None:
+ try:
+ parsed = json.loads(raw_payload)
+ return {
+ "id": f"call_{uuid4().hex}",
+ "type": "function",
+ "function": {
+ "name": parsed["name"],
+ "arguments": json.dumps(parsed.get("arguments", {}), ensure_ascii=False),
+ },
+ }
+ except Exception:
+ pass
+
+ function_match = self._qwen_function_pattern.search(raw_payload)
+ if function_match is None:
+ return None
+ function_name = function_match.group(1).strip()
+ function_body = function_match.group(2)
+ arguments: dict[str, Any] = {}
+ for parameter_match in self._qwen_parameter_pattern.finditer(function_body):
+ param_name = parameter_match.group(1).strip()
+ param_value = parameter_match.group(2).strip()
+ arguments[param_name] = self._coerce_parameter_value(param_value)
+ return {
+ "id": f"call_{uuid4().hex}",
+ "type": "function",
+ "function": {
+ "name": function_name,
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ },
+ }
+
+ def _parse_tool_arguments(self, arguments: Any) -> Any:
+ if isinstance(arguments, str):
+ try:
+ return json.loads(arguments)
+ except Exception:
+ return {"raw": arguments}
+ return arguments
+
+ def _coerce_parameter_value(self, value: str) -> Any:
+ stripped = value.strip()
+ if not stripped:
+ return ""
+ try:
+ return json.loads(stripped)
+ except Exception:
+ return stripped
+
+ def _sanitize_assistant_text(self, text: str) -> str:
+ cleaned = text.replace("<|im_end|>", "")
+ cleaned = cleaned.replace("", "")
+ cleaned = cleaned.replace("", "")
+ return cleaned.strip()
+
+ def _build_sample_params(self, request: AnthropicMessagesRequest, max_tokens: int | None = None) -> SampleParams:
+ kwargs = {
+ "return_token_ids": True,
+ "return_logprob": False,
+ "stream": request.stream,
+ "max_tokens": max_tokens if max_tokens is not None else request.max_tokens,
+ "stops": request.stop_sequences or [],
+ }
+ if request.temperature is not None:
+ kwargs["temperature"] = request.temperature
+ if request.top_p is not None:
+ kwargs["top_p"] = request.top_p
+ return SampleParams(**kwargs)
+
+ def _fit_max_tokens_to_context(self, prompt_ids: list[int] | None, requested_max_tokens: int) -> int:
+ if self._context_length is None or prompt_ids is None:
+ return requested_max_tokens
+ prompt_tokens = len(prompt_ids)
+ available_completion_tokens = self._context_length - prompt_tokens
+ if available_completion_tokens <= 0:
+ raise AnthropicChatAdapterError(
+ (
+ f"Input is too long for this model deployment: prompt_tokens={prompt_tokens}, "
+ f"context_length={self._context_length}."
+ ),
+ "invalid_request_error",
+ )
+ return min(requested_max_tokens, available_completion_tokens)
+
+ def _count_prompt_tokens(self, rollout_state: RolloutState) -> int:
+ if rollout_state.tokens is not None:
+ return len(rollout_state.tokens)
+ if rollout_state.prompt_ids is not None:
+ return len(rollout_state.prompt_ids)
+ if self._tokenizer is not None and rollout_state.message:
+ text_prompt = self._tokenizer.apply_chat_template(
+ rollout_state.message,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ return len(self._tokenizer(text_prompt, add_special_tokens=False)["input_ids"])
+ return 0
+
+ def _count_completion_tokens(self, rollout_state: RolloutState) -> int:
+ if rollout_state.response_ids is not None:
+ return len(rollout_state.response_ids)
+ if self._tokenizer is not None and rollout_state.response:
+ return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"])
+ return 0
+
+ def build_output_message_list(
+ self,
+ rollout_state: RolloutState,
+ request: AnthropicMessagesRequest,
+ ) -> list[dict[str, Any]]:
+ return [{"role": "assistant", "content": self._build_response_content_blocks(rollout_state)}]
+
+
+def bind_anthropic_chat_interface(
+ rollout_controller: Any,
+ default_model_name: str | None = None,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None = None,
+) -> Any:
+ if getattr(rollout_controller, "anthropic_chat_adapter", None) is None:
+ rollout_controller.anthropic_chat_adapter = AnthropicChatAdapter(
+ rollout_controller.generate,
+ default_model_name=default_model_name,
+ tokenizer=tokenizer,
+ context_length=getattr(rollout_controller.config, "context_length", None),
+ capture_path=str(getattr(rollout_controller.config, "worker_log_dir", ".")) + "/gateway_capture.jsonl",
+ )
+ rollout_controller.anthropic_messages = rollout_controller.anthropic_chat_adapter.messages
+ rollout_controller.anthropic_count_tokens = rollout_controller.anthropic_chat_adapter.count_tokens
+ return rollout_controller
diff --git a/xtuner/v1/rl/rollout/chat_adapter/base.py b/xtuner/v1/rl/rollout/chat_adapter/base.py
new file mode 100644
index 000000000..826210474
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/base.py
@@ -0,0 +1,175 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from collections.abc import Awaitable, Callable
+from typing import Any, Generic, TypeVar
+
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+from xtuner.v1.data_proto.rl_data import RolloutState
+
+from .capture import append_gateway_capture_record, render_blocks_as_text
+from .collector import append_current_trace_rollout_state
+from .trace import ChatTraceRecord, ChatTraceStore, snapshot_routed_experts
+
+
+GenerateHandler = Callable[[RolloutState], Awaitable[RolloutState]]
+RequestT = TypeVar("RequestT")
+ResponseT = TypeVar("ResponseT")
+
+
+class BaseChatAPIAdapter(ABC, Generic[RequestT, ResponseT]):
+ def __init__(
+ self,
+ generate_handler: GenerateHandler,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None,
+ *,
+ capture_path: str | None = None,
+ trace_store_max_entries: int = 10000,
+ ):
+ self._generate_handler = generate_handler
+ self._tokenizer = tokenizer
+ self._capture_path = capture_path
+ self._trace_store = ChatTraceStore(max_entries=trace_store_max_entries)
+
+ async def handle_request(self, request: RequestT) -> ResponseT:
+ self.validate_request(request)
+ rollout_state = self.request_to_rollout_state(request)
+ if rollout_state.uid is None:
+ raise ValueError("request_to_rollout_state must assign rollout_state.uid before generate is called.")
+ request_id = str(rollout_state.uid)
+ rollout_state = await self._generate_handler(rollout_state)
+ append_current_trace_rollout_state(rollout_state)
+
+ self.raise_for_failed_response(rollout_state, request_id)
+ response = self.rollout_state_to_response(rollout_state, request)
+ self._trace_store.put(self._build_trace_record(request, response, rollout_state, request_id))
+ self._write_capture_record(request=request, response=response, rollout_state=rollout_state, request_id=request_id)
+ return response
+
+ def get_trace_by_request_response(self, request: RequestT, response: ResponseT) -> ChatTraceRecord | None:
+ response_hash = self._trace_store.build_hash(
+ request_snapshot=self.normalize_request(request),
+ response_snapshot=self.normalize_response(response),
+ )
+ return self._trace_store.get(response_hash)
+
+ def get_trace_by_response_hash(self, response_hash: str) -> ChatTraceRecord | None:
+ return self._trace_store.get(response_hash)
+
+ def _build_trace_record(
+ self,
+ request: RequestT,
+ response: ResponseT,
+ rollout_state: RolloutState,
+ request_id: str,
+ ) -> ChatTraceRecord:
+ request_snapshot = self.normalize_request(request)
+ response_snapshot = self.normalize_response(response)
+ response_hash = self._trace_store.build_hash(
+ request_snapshot=request_snapshot,
+ response_snapshot=response_snapshot,
+ )
+ return ChatTraceRecord(
+ response_hash=response_hash,
+ request_snapshot=request_snapshot,
+ response_snapshot=response_snapshot,
+ prompt_ids=list(rollout_state.prompt_ids or []),
+ response_ids=list(rollout_state.response_ids or []),
+ logprobs=None if rollout_state.logprobs is None else list(rollout_state.logprobs),
+ routed_experts=snapshot_routed_experts(rollout_state.routed_experts),
+ finish_reason=rollout_state.finish_reason,
+ status=rollout_state.status,
+ request_id=request_id,
+ )
+
+ def _write_capture_record(
+ self,
+ request: RequestT,
+ response: ResponseT,
+ rollout_state: RolloutState,
+ request_id: str,
+ ) -> None:
+ if self._capture_path is None:
+ return
+ try:
+ response_snapshot = self.normalize_response(response)
+ response_finish_reason = (
+ response_snapshot.get("stop_reason")
+ or response_snapshot.get("finish_reason")
+ or rollout_state.finish_reason
+ )
+ output_messages = self.build_output_message_list(rollout_state, request)
+ append_gateway_capture_record(
+ self._capture_path,
+ {
+ "protocol": self.__class__.__name__,
+ "request_id": request_id,
+ "session_uid": rollout_state.session_uid,
+ "status": rollout_state.status.value,
+ "finish_reason": response_finish_reason,
+ "rollout_finish_reason": rollout_state.finish_reason,
+ "prompt_tokens": len(rollout_state.prompt_ids or []),
+ "completion_tokens": len(rollout_state.response_ids or []),
+ "request": self.normalize_request(request),
+ "response": response_snapshot,
+ "internal_messages": rollout_state.message,
+ "output_messages": output_messages,
+ "input_text": self._render_prompt_text(rollout_state),
+ "output_text": render_blocks_as_text(output_messages),
+ },
+ )
+ except Exception:
+ # Capturing should never block serving.
+ return
+
+ def _render_prompt_text(self, rollout_state: RolloutState) -> str:
+ if self._tokenizer is None:
+ return ""
+ try:
+ rendered = self._tokenizer.apply_chat_template(
+ rollout_state.message,
+ tools=rollout_state.tools,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ if isinstance(rendered, str):
+ return rendered
+ return render_blocks_as_text(rendered)
+ except Exception:
+ return ""
+
+ @abstractmethod
+ def validate_request(self, request: RequestT) -> None:
+ return None
+
+ @abstractmethod
+ def request_to_rollout_state(self, request: RequestT) -> RolloutState:
+ raise NotImplementedError
+
+ @abstractmethod
+ def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def normalize_request(self, request: RequestT) -> dict[str, Any]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def normalize_response(self, response: ResponseT) -> dict[str, Any]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def rollout_state_to_response(
+ self,
+ rollout_state: RolloutState,
+ request: RequestT,
+ ) -> ResponseT:
+ raise NotImplementedError
+
+ @abstractmethod
+ def build_output_message_list(
+ self,
+ rollout_state: RolloutState,
+ request: RequestT,
+ ) -> list[dict[str, Any]]:
+ raise NotImplementedError
diff --git a/xtuner/v1/rl/rollout/chat_adapter/capture.py b/xtuner/v1/rl/rollout/chat_adapter/capture.py
new file mode 100644
index 000000000..0d6847382
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/capture.py
@@ -0,0 +1,49 @@
+from __future__ import annotations
+
+import json
+import threading
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+
+_CAPTURE_LOCK = threading.RLock()
+
+
+def append_gateway_capture_record(path: str | Path, record: dict[str, Any]) -> None:
+ capture_path = Path(path)
+ capture_path.parent.mkdir(parents=True, exist_ok=True)
+ payload = {
+ "type": "gateway_turn",
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ **record,
+ }
+ with _CAPTURE_LOCK:
+ with capture_path.open("a", encoding="utf-8") as f:
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
+
+
+def render_blocks_as_text(value: Any) -> str:
+ if value is None:
+ return ""
+ if isinstance(value, str):
+ return value
+ if isinstance(value, list):
+ rendered_parts = [render_blocks_as_text(item) for item in value]
+ return "\n".join(part for part in rendered_parts if part)
+ if isinstance(value, dict):
+ block_type = value.get("type")
+ if block_type == "text":
+ return str(value.get("text", ""))
+ if block_type == "tool_use":
+ name = value.get("name", "")
+ input_payload = json.dumps(value.get("input", {}), ensure_ascii=False, sort_keys=True)
+ return f"{input_payload}"
+ if block_type == "tool_result":
+ tool_use_id = value.get("tool_use_id", "")
+ content = render_blocks_as_text(value.get("content"))
+ return f"{content}"
+ if "content" in value:
+ return render_blocks_as_text(value["content"])
+ return json.dumps(value, ensure_ascii=False, sort_keys=True)
+ return str(value)
diff --git a/xtuner/v1/rl/rollout/chat_adapter/collector.py b/xtuner/v1/rl/rollout/chat_adapter/collector.py
new file mode 100644
index 000000000..ff3ef7caa
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/collector.py
@@ -0,0 +1,26 @@
+from __future__ import annotations
+
+from contextvars import ContextVar, Token
+
+from xtuner.v1.data_proto.rl_data import RolloutState
+
+
+_CURRENT_TRACE_COLLECTOR: ContextVar[list[RolloutState] | None] = ContextVar(
+ "xtuner_rollout_trace_collector",
+ default=None,
+)
+
+
+def set_current_trace_collector(collector: list[RolloutState]) -> Token:
+ return _CURRENT_TRACE_COLLECTOR.set(collector)
+
+
+def reset_current_trace_collector(token: Token) -> None:
+ _CURRENT_TRACE_COLLECTOR.reset(token)
+
+
+def append_current_trace_rollout_state(rollout_state: RolloutState) -> None:
+ collector = _CURRENT_TRACE_COLLECTOR.get()
+ if collector is None:
+ return
+ collector.append(rollout_state.model_copy(deep=True))
diff --git a/xtuner/v1/rl/rollout/chat_adapter/openai.py b/xtuner/v1/rl/rollout/chat_adapter/openai.py
new file mode 100644
index 000000000..70b081b84
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/openai.py
@@ -0,0 +1,227 @@
+import time
+from typing import Any
+from uuid import uuid4
+
+from lmdeploy.serve.openai.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ChatCompletionResponseChoice,
+ ChatMessage,
+ UsageInfo,
+)
+
+from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
+from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
+
+from .base import BaseChatAPIAdapter
+from .trace import normalize_trace_payload
+
+
+class OpenAIChatAdapterError(RuntimeError):
+ def __init__(
+ self,
+ message: str,
+ error_type: str,
+ code: str,
+ request_id: str | None = None,
+ ):
+ super().__init__(message)
+ self.message = message
+ self.error_type = error_type
+ self.code = code
+ self.request_id = request_id
+
+
+class OpenAIChatAdapter(BaseChatAPIAdapter[ChatCompletionRequest, ChatCompletionResponse]):
+ def __init__(
+ self,
+ generate_handler,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str,
+ default_model_name: str | None = None,
+ context_length: int | None = None,
+ capture_path: str | None = None,
+ trace_store_max_entries: int = 10000,
+ ):
+ if isinstance(tokenizer, str):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
+ super().__init__(
+ generate_handler,
+ tokenizer=tokenizer,
+ capture_path=capture_path,
+ trace_store_max_entries=trace_store_max_entries,
+ )
+ self._default_model_name = default_model_name
+ self._context_length = context_length
+
+ async def chat(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
+ return await self.handle_request(request)
+
+ def validate_request(self, request: ChatCompletionRequest) -> None:
+ if request.stream:
+ raise OpenAIChatAdapterError(
+ "stream=true is not supported yet",
+ "invalid_request_error",
+ "stream_not_supported",
+ )
+
+ def request_to_rollout_state(self, request: ChatCompletionRequest) -> RolloutState:
+ normalized_messages = normalize_trace_payload(request.messages)
+ tokenizer_tools = self._normalize_tools_for_tokenizer(request.tools)
+ normalized_tool_choice = normalize_trace_payload(request.tool_choice)
+ prompt_ids = None
+ if self._tokenizer:
+ raw_prompt_ids = self._tokenizer.apply_chat_template(
+ normalized_messages,
+ tools=tokenizer_tools,
+ tokenize=True,
+ add_generation_prompt=True,
+ )
+ if hasattr(raw_prompt_ids, "get"):
+ prompt_ids = raw_prompt_ids.get("input_ids")
+ else:
+ prompt_ids = list(raw_prompt_ids)
+ max_tokens = self._fit_max_tokens_to_context(prompt_ids=prompt_ids, requested_max_tokens=request.max_tokens)
+
+ return RolloutState(
+ uid=uuid4().int,
+ message=normalized_messages,
+ prompt_ids=prompt_ids,
+ tokens=prompt_ids,
+ session_uid=getattr(request, "session_uid", getattr(request, "session_id", None)),
+ tools=tokenizer_tools,
+ tool_choice=normalized_tool_choice,
+ sample_params=self._build_sample_params(request, max_tokens=max_tokens),
+ )
+
+ def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None:
+ if response.status == Status.FAILED:
+ raise OpenAIChatAdapterError(
+ response.error_msg or "Rollout generation failed",
+ "server_error",
+ "rollout_failed",
+ request_id,
+ )
+
+ def rollout_state_to_response(
+ self,
+ rollout_state: RolloutState,
+ request: ChatCompletionRequest,
+ ) -> ChatCompletionResponse:
+ request_id = str(rollout_state.uid)
+ model_name = request.model or self._default_model_name or "rollout-controller"
+ assert rollout_state.response_ids is not None, "response_ids should not be None when generating response"
+ assert rollout_state.tokens is not None, "tokens should not be None when generating response"
+ prompt_tokens = len(rollout_state.tokens)
+ completion_tokens = len(rollout_state.response_ids)
+ tool_calls = rollout_state.extra_fields.get("tool_calls")
+ response_message = ChatMessage(
+ role="assistant",
+ content=None if tool_calls else rollout_state.response,
+ tool_calls=tool_calls,
+ )
+ return ChatCompletionResponse(
+ id=request_id,
+ created=int(time.time()),
+ model=model_name,
+ choices=[
+ ChatCompletionResponseChoice(
+ index=0,
+ message=response_message,
+ finish_reason=rollout_state.finish_reason,
+ )
+ ],
+ usage=UsageInfo(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ )
+
+ def build_output_message_list(
+ self,
+ rollout_state: RolloutState,
+ request: ChatCompletionRequest,
+ ) -> list[dict[str, Any]]:
+ return [{"role": "assistant", "content": rollout_state.response or ""}]
+
+ def normalize_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
+ return normalize_trace_payload(
+ {
+ "messages": request.messages,
+ "tools": request.tools,
+ "tool_choice": request.tool_choice,
+ }
+ )
+
+ def normalize_response(self, response: ChatCompletionResponse) -> dict[str, Any]:
+ normalized_choices = []
+ for choice in response.choices:
+ normalized_choices.append(
+ {
+ "message": getattr(choice.message, "model_dump", lambda **_: choice.message)(
+ mode="python",
+ exclude_none=True,
+ )
+ if choice.message is not None
+ else None,
+ "finish_reason": choice.finish_reason,
+ }
+ )
+ return normalize_trace_payload({"choices": normalized_choices})
+
+ def _normalize_tools_for_tokenizer(self, tools: Any) -> Any:
+ if tools is None:
+ return None
+ return normalize_trace_payload(tools)
+
+ def _build_sample_params(self, request: ChatCompletionRequest, max_tokens: int | None = None) -> SampleParams:
+ stops = [] if request.stop is None else [request.stop] if isinstance(request.stop, str) else request.stop
+ kwargs = {
+ "stops": stops,
+ **{
+ key: value
+ for key, value in {
+ "temperature": request.temperature,
+ "top_p": request.top_p,
+ "n": request.n,
+ "max_tokens": max_tokens if max_tokens is not None else request.max_tokens,
+ "presence_penalty": request.presence_penalty,
+ "frequency_penalty": request.frequency_penalty,
+ }.items()
+ if value is not None
+ },
+ }
+ return SampleParams(**kwargs)
+
+ def _fit_max_tokens_to_context(self, prompt_ids: list[int] | None, requested_max_tokens: int | None) -> int | None:
+ if self._context_length is None or prompt_ids is None or requested_max_tokens is None:
+ return requested_max_tokens
+ prompt_tokens = len(prompt_ids)
+ available_completion_tokens = self._context_length - prompt_tokens
+ if available_completion_tokens <= 0:
+ raise OpenAIChatAdapterError(
+ (
+ f"Input is too long for this model deployment: prompt_tokens={prompt_tokens}, "
+ f"context_length={self._context_length}."
+ ),
+ "invalid_request_error",
+ "context_length_exceeded",
+ )
+ return min(requested_max_tokens, available_completion_tokens)
+
+
+def bind_openai_chat_interface(
+ rollout_controller: Any,
+ tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None,
+ default_model_name: str | None = None,
+) -> Any:
+ if getattr(rollout_controller, "openai_chat_adapter", None) is None:
+ rollout_controller.openai_chat_adapter = OpenAIChatAdapter(
+ rollout_controller.generate,
+ tokenizer=tokenizer,
+ default_model_name=default_model_name,
+ context_length=getattr(rollout_controller.config, "context_length", None),
+ capture_path=str(getattr(rollout_controller.config, "worker_log_dir", ".")) + "/gateway_capture.jsonl",
+ )
+ rollout_controller.chat = rollout_controller.openai_chat_adapter.chat
+ return rollout_controller
diff --git a/xtuner/v1/rl/rollout/chat_adapter/responses.py b/xtuner/v1/rl/rollout/chat_adapter/responses.py
new file mode 100644
index 000000000..dd0cd5fb4
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/responses.py
@@ -0,0 +1,551 @@
+from __future__ import annotations
+
+import json
+import re
+import shlex
+import time
+from typing import Any, Literal
+from uuid import uuid4
+
+from pydantic import BaseModel, ConfigDict
+from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
+
+from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
+
+from .base import BaseChatAPIAdapter
+from .openai import OpenAIChatAdapterError
+from .trace import normalize_trace_payload
+
+
+class ResponsesRequest(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ session_uid: int | None = None
+ model: str | None = None
+ instructions: str | None = None
+ input: str | list[dict[str, Any]] | None = None
+ tools: list[dict[str, Any]] | None = None
+ tool_choice: str | dict[str, Any] | None = None
+ stream: bool = False
+ store: bool = False
+ parallel_tool_calls: bool | None = None
+ include: list[Any] | None = None
+ reasoning: dict[str, Any] | None = None
+ max_output_tokens: int | None = None
+ temperature: float | None = None
+ top_p: float | None = None
+
+
+class ResponsesUsage(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ input_tokens: int
+ output_tokens: int
+ total_tokens: int
+
+
+class ResponsesResponse(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ id: str
+ object: Literal["response"] = "response"
+ created_at: int
+ status: Literal["completed"] = "completed"
+ model: str
+ output: list[dict[str, Any]]
+ output_text: str = ""
+ parallel_tool_calls: bool = False
+ store: bool = False
+ text: dict[str, Any] = {"format": {"type": "text"}}
+ usage: ResponsesUsage
+
+
+class OpenAIResponsesAdapter(BaseChatAPIAdapter[ResponsesRequest, ResponsesResponse]):
+ _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL)
+ _qwen_function_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL)
+ _qwen_parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL)
+ _xml_tag_pattern = re.compile(r"<([a-zA-Z_][^>\n/]*)>(.*?)\1>", re.DOTALL)
+ _disabled_tool_names = {
+ "list_mcp_resources",
+ "list_mcp_resource_templates",
+ "read_mcp_resource",
+ "request_user_input",
+ }
+ _tool_name_aliases = {
+ "read_file_dd": "exec_command",
+ "read_file": "exec_command",
+ "readfile": "exec_command",
+ }
+
+ def __init__(
+ self,
+ generate_handler,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None,
+ default_model_name: str | None = None,
+ context_length: int | None = None,
+ capture_path: str | None = None,
+ ):
+ if isinstance(tokenizer, str):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
+ super().__init__(generate_handler, tokenizer=tokenizer, capture_path=capture_path)
+ self._default_model_name = default_model_name
+ self._context_length = context_length
+
+ async def responses(self, request: ResponsesRequest) -> ResponsesResponse:
+ return await self.handle_request(request)
+
+ def validate_request(self, request: ResponsesRequest) -> None:
+ return None
+
+ def request_to_rollout_state(self, request: ResponsesRequest) -> RolloutState:
+ internal_messages = self._build_internal_messages(request)
+ tokenizer_tools = self._normalize_tools_for_backend(request.tools)
+ normalized_tool_choice = self._normalize_tool_choice_for_backend(request.tool_choice)
+
+ prompt_ids = None
+ if self._tokenizer is not None:
+ raw_prompt_ids = self._tokenizer.apply_chat_template(
+ internal_messages,
+ tools=tokenizer_tools,
+ tokenize=True,
+ add_generation_prompt=True,
+ )
+ prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids)
+
+ max_tokens = self._fit_max_tokens_to_context(
+ prompt_ids=prompt_ids, requested_max_tokens=request.max_output_tokens
+ )
+ return RolloutState(
+ uid=uuid4().int,
+ message=internal_messages,
+ prompt_ids=prompt_ids,
+ tokens=prompt_ids,
+ session_uid=request.session_uid,
+ tools=tokenizer_tools,
+ tool_choice=normalized_tool_choice,
+ sample_params=self._build_sample_params(request, max_tokens=max_tokens),
+ )
+
+ def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None:
+ if response.status == Status.FAILED:
+ raise OpenAIChatAdapterError(
+ response.error_msg or "Rollout generation failed",
+ "server_error",
+ "rollout_failed",
+ request_id,
+ )
+
+ def normalize_request(self, request: ResponsesRequest) -> dict[str, Any]:
+ return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True))
+
+ def normalize_response(self, response: ResponsesResponse) -> dict[str, Any]:
+ return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True))
+
+ def rollout_state_to_response(
+ self,
+ rollout_state: RolloutState,
+ request: ResponsesRequest,
+ ) -> ResponsesResponse:
+ request_id = str(rollout_state.uid)
+ model_name = request.model or self._default_model_name or "rollout-controller"
+ prompt_tokens = self._count_prompt_tokens(rollout_state)
+ completion_tokens = self._count_completion_tokens(rollout_state)
+ output_items, output_text = self._build_response_output_items(rollout_state)
+ return ResponsesResponse(
+ id=f"resp_{request_id}",
+ created_at=int(time.time()),
+ model=model_name,
+ output=output_items,
+ output_text=output_text,
+ parallel_tool_calls=bool(request.parallel_tool_calls),
+ store=bool(request.store),
+ usage=ResponsesUsage(
+ input_tokens=prompt_tokens,
+ output_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ )
+
+ def build_output_message_list(
+ self,
+ rollout_state: RolloutState,
+ request: ResponsesRequest,
+ ) -> list[dict[str, Any]]:
+ output_items, output_text = self._build_response_output_items(rollout_state)
+ content: list[dict[str, Any]] = []
+ if output_text:
+ content.append({"type": "text", "text": output_text})
+ for item in output_items:
+ if item.get("type") == "function_call":
+ try:
+ input_payload = json.loads(item.get("arguments", "{}"))
+ except Exception:
+ input_payload = {"raw": item.get("arguments", "")}
+ content.append({"type": "tool_use", "name": item.get("name", ""), "input": input_payload})
+ return [{"role": "assistant", "content": content}]
+
+ def _build_internal_messages(self, request: ResponsesRequest) -> list[dict[str, Any]]:
+ system_chunks: list[str] = []
+ messages: list[dict[str, Any]] = []
+ tool_name_by_call_id: dict[str, str] = {}
+ if request.instructions:
+ system_chunks.append(request.instructions)
+
+ if request.input is None:
+ return self._prepend_system_message(messages, system_chunks)
+ if isinstance(request.input, str):
+ messages.append({"role": "user", "content": request.input})
+ return self._prepend_system_message(messages, system_chunks)
+
+ for item in request.input:
+ item_type = item.get("type", "message")
+ if item_type == "message":
+ role = self._normalize_input_role(item.get("role"))
+ if role == "system":
+ system_text = self._extract_message_item_text(item.get("content"))
+ if system_text:
+ system_chunks.append(system_text)
+ else:
+ messages.extend(self._convert_message_item_to_backend_messages(role, item.get("content")))
+ elif item_type == "function_call":
+ call_id = item.get("call_id") or f"call_{uuid4().hex}"
+ tool_name_by_call_id[call_id] = str(item.get("name", ""))
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "id": call_id,
+ "type": "function",
+ "function": {
+ "name": str(item.get("name", "")),
+ "arguments": self._parse_json_string_or_mapping(item.get("arguments")),
+ },
+ }
+ ],
+ }
+ )
+ elif item_type == "function_call_output":
+ call_id = item.get("call_id")
+ messages.append(
+ {
+ "role": "tool",
+ "content": self._serialize_tool_output(
+ item.get("output"),
+ tool_name=tool_name_by_call_id.get(str(call_id or "")),
+ ),
+ "tool_call_id": call_id,
+ }
+ )
+ elif item_type == "reasoning":
+ continue
+ return self._prepend_system_message(messages, system_chunks)
+
+ def _prepend_system_message(
+ self,
+ messages: list[dict[str, Any]],
+ system_chunks: list[str],
+ ) -> list[dict[str, Any]]:
+ joined_system = "\n\n".join(chunk.strip() for chunk in system_chunks if chunk and chunk.strip())
+ if not joined_system:
+ return messages
+ return [{"role": "system", "content": joined_system}, *messages]
+
+ def _normalize_input_role(self, role: Any) -> str:
+ if role in {"developer", "system"}:
+ return "system"
+ if role in {"assistant", "tool"}:
+ return str(role)
+ return "user"
+
+ def _convert_message_item_to_backend_messages(self, role: str, content: Any) -> list[dict[str, Any]]:
+ text = self._extract_message_item_text(content)
+ if not text:
+ return []
+ if role == "assistant":
+ text = self._sanitize_assistant_text(text)
+ return [{"role": role, "content": text}]
+
+ def _extract_message_item_text(self, content: Any) -> str:
+ if isinstance(content, str):
+ return content
+ if not isinstance(content, list):
+ return str(content)
+
+ text_chunks: list[str] = []
+ for part in content:
+ part_type = part.get("type")
+ if part_type in {"input_text", "output_text", "text", "summary_text", "reasoning_text"}:
+ text_chunks.append(str(part.get("text", "")))
+ return "\n".join(chunk for chunk in text_chunks if chunk)
+
+ def _serialize_tool_output(self, output: Any, tool_name: str | None = None) -> str:
+ if output is None:
+ return ""
+ if isinstance(output, str):
+ return self._sanitize_tool_output_text(output, tool_name=tool_name)
+ if isinstance(output, list):
+ text_chunks = [str(part.get("text", "")) for part in output if isinstance(part, dict) and "text" in part]
+ if text_chunks:
+ return self._sanitize_tool_output_text("\n".join(text_chunks), tool_name=tool_name)
+ return json.dumps(output, ensure_ascii=False)
+ if isinstance(output, dict):
+ return json.dumps(output, ensure_ascii=False)
+ return str(output)
+
+ def _sanitize_tool_output_text(self, text: str, tool_name: str | None = None) -> str:
+ if tool_name not in {"exec_command", "write_stdin"}:
+ return text
+
+ marker = "\nOutput:\n"
+ if marker in text:
+ prefix, body = text.split(marker, 1)
+ exit_code = self._extract_exec_exit_code(prefix)
+ body = body.strip()
+ if exit_code is None:
+ return body
+ if body:
+ return f"[exit_code={exit_code}]\n{body}"
+ return f"[exit_code={exit_code}]"
+ return text
+
+ def _extract_exec_exit_code(self, text: str) -> int | None:
+ match = re.search(r"Process exited with code (\d+)", text)
+ if match is not None:
+ return int(match.group(1))
+ return None
+
+ def _normalize_tools_for_backend(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
+ if not tools:
+ return None
+ normalized_tools = []
+ for tool in tools:
+ tool_name = tool.get("name")
+ if tool_name in self._disabled_tool_names:
+ continue
+ if tool.get("type") != "function":
+ continue
+ normalized_tools.append(
+ {
+ "type": "function",
+ "function": {
+ "name": tool["name"],
+ "description": tool.get("description", ""),
+ "parameters": tool.get("parameters", {}),
+ },
+ }
+ )
+ return normalize_trace_payload(normalized_tools) or None
+
+ def _normalize_tool_choice_for_backend(self, tool_choice: str | dict[str, Any] | None) -> str | dict[str, Any] | None:
+ if tool_choice is None:
+ return None
+ if isinstance(tool_choice, str):
+ return tool_choice
+ if tool_choice.get("type") == "function":
+ return {"type": "function", "function": {"name": tool_choice.get("name")}}
+ return normalize_trace_payload(tool_choice)
+
+ def _build_response_output_items(self, rollout_state: RolloutState) -> tuple[list[dict[str, Any]], str]:
+ tool_calls = rollout_state.extra_fields.get("tool_calls")
+ response_text = rollout_state.response or ""
+ if not tool_calls:
+ text_blocks, parsed_tool_calls = self._parse_textual_tool_calls(response_text)
+ if parsed_tool_calls:
+ tool_calls = parsed_tool_calls
+ rollout_state.extra_fields["tool_calls"] = parsed_tool_calls
+ response_text = "".join(block["text"] for block in text_blocks if block["type"] == "text")
+
+ response_text = self._sanitize_assistant_text(response_text)
+
+ output_items: list[dict[str, Any]] = []
+ if response_text:
+ output_items.append(
+ {
+ "id": f"msg_{uuid4().hex}",
+ "type": "message",
+ "status": "completed",
+ "role": "assistant",
+ "content": [{"type": "output_text", "text": response_text, "annotations": []}],
+ }
+ )
+
+ for tool_call in tool_calls or []:
+ call_id = tool_call.get("id") or f"call_{uuid4().hex}"
+ output_items.append(
+ {
+ "id": f"fc_{uuid4().hex}",
+ "type": "function_call",
+ "status": "completed",
+ "call_id": call_id,
+ "name": tool_call["function"]["name"],
+ "arguments": self._stringify_tool_arguments(tool_call["function"].get("arguments")),
+ }
+ )
+ return output_items, response_text
+
+ def _parse_textual_tool_calls(self, text: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
+ if not text:
+ return [], []
+ content_blocks: list[dict[str, Any]] = []
+ tool_calls: list[dict[str, Any]] = []
+ last_end = 0
+ for match in self._tool_call_pattern.finditer(text):
+ if match.start() > last_end:
+ content_blocks.append({"type": "text", "text": text[last_end : match.start()]})
+ raw_payload = match.group(1).strip()
+ parsed_tool_call = self._parse_single_textual_tool_call(raw_payload)
+ if parsed_tool_call is not None:
+ tool_calls.append(parsed_tool_call)
+ else:
+ content_blocks.append({"type": "text", "text": match.group(0)})
+ last_end = match.end()
+ if last_end < len(text):
+ content_blocks.append({"type": "text", "text": text[last_end:]})
+ return content_blocks, tool_calls
+
+ def _parse_single_textual_tool_call(self, raw_payload: str) -> dict[str, Any] | None:
+ try:
+ parsed = json.loads(raw_payload)
+ return {
+ "id": f"call_{uuid4().hex}",
+ "type": "function",
+ "function": {
+ "name": parsed["name"],
+ "arguments": json.dumps(parsed.get("arguments", {}), ensure_ascii=False),
+ },
+ }
+ except Exception:
+ pass
+
+ function_match = self._qwen_function_pattern.search(raw_payload)
+ if function_match is None:
+ return None
+ function_name = function_match.group(1).strip()
+ function_body = function_match.group(2)
+ arguments: dict[str, Any] = {}
+ for parameter_match in self._qwen_parameter_pattern.finditer(function_body):
+ param_name = parameter_match.group(1).strip()
+ param_value = parameter_match.group(2).strip()
+ arguments[param_name] = self._coerce_parameter_value(param_value)
+ if not arguments:
+ for tag_match in self._xml_tag_pattern.finditer(function_body):
+ tag_name = tag_match.group(1).strip()
+ if tag_name.startswith("function="):
+ continue
+ tag_value = tag_match.group(2).strip()
+ if tag_name in {"path", "file_path"}:
+ arguments[tag_name] = tag_value
+ else:
+ arguments[tag_name] = self._coerce_parameter_value(tag_value)
+
+ function_name, arguments = self._normalize_tool_call(function_name, arguments)
+ return {
+ "id": f"call_{uuid4().hex}",
+ "type": "function",
+ "function": {
+ "name": function_name,
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ },
+ }
+
+ def _stringify_tool_arguments(self, arguments: Any) -> str:
+ if isinstance(arguments, str):
+ return arguments
+ return json.dumps(arguments or {}, ensure_ascii=False)
+
+ def _parse_json_string_or_mapping(self, value: Any) -> Any:
+ if isinstance(value, str):
+ try:
+ return json.loads(value)
+ except Exception:
+ return {"raw": value}
+ return value or {}
+
+ def _coerce_parameter_value(self, value: str) -> Any:
+ stripped = value.strip()
+ if not stripped:
+ return ""
+ try:
+ return json.loads(stripped)
+ except Exception:
+ return stripped
+
+ def _normalize_tool_call(self, function_name: str, arguments: dict[str, Any]) -> tuple[str, dict[str, Any]]:
+ normalized_name = self._tool_name_aliases.get(function_name, function_name)
+ normalized_arguments = dict(arguments)
+ if normalized_name == "exec_command" and function_name in self._tool_name_aliases:
+ path = normalized_arguments.pop("path", None) or normalized_arguments.pop("file_path", None)
+ if path:
+ quoted_path = shlex.quote(str(path))
+ normalized_arguments = {"cmd": f"cat {quoted_path}"}
+ return normalized_name, normalized_arguments
+
+ def _sanitize_assistant_text(self, text: str) -> str:
+ cleaned = text.replace("<|im_end|>", "")
+ cleaned = cleaned.replace("", "")
+ cleaned = cleaned.replace("", "")
+ return cleaned.strip()
+
+ def _build_sample_params(self, request: ResponsesRequest, max_tokens: int | None = None) -> SampleParams:
+ kwargs = {
+ "return_token_ids": True,
+ "return_logprob": False,
+ "stream": request.stream,
+ }
+ effective_max_tokens = max_tokens if max_tokens is not None else request.max_output_tokens
+ if effective_max_tokens is not None:
+ kwargs["max_tokens"] = effective_max_tokens
+ if request.temperature is not None:
+ kwargs["temperature"] = request.temperature
+ if request.top_p is not None:
+ kwargs["top_p"] = request.top_p
+ return SampleParams(**kwargs)
+
+ def _fit_max_tokens_to_context(self, prompt_ids: list[int] | None, requested_max_tokens: int | None) -> int | None:
+ if self._context_length is None or prompt_ids is None or requested_max_tokens is None:
+ return requested_max_tokens
+ prompt_tokens = len(prompt_ids)
+ available_completion_tokens = self._context_length - prompt_tokens
+ if available_completion_tokens <= 0:
+ raise OpenAIChatAdapterError(
+ (
+ f"Input is too long for this model deployment: prompt_tokens={prompt_tokens}, "
+ f"context_length={self._context_length}."
+ ),
+ "invalid_request_error",
+ "context_length_exceeded",
+ )
+ return min(requested_max_tokens, available_completion_tokens)
+
+ def _count_prompt_tokens(self, rollout_state: RolloutState) -> int:
+ if rollout_state.tokens is not None:
+ return len(rollout_state.tokens)
+ if rollout_state.prompt_ids is not None:
+ return len(rollout_state.prompt_ids)
+ return 0
+
+ def _count_completion_tokens(self, rollout_state: RolloutState) -> int:
+ if rollout_state.response_ids is not None:
+ return len(rollout_state.response_ids)
+ if self._tokenizer is not None and rollout_state.response:
+ return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"])
+ return 0
+
+
+def bind_openai_responses_interface(
+ rollout_controller: Any,
+ tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None,
+ default_model_name: str | None = None,
+) -> Any:
+ if getattr(rollout_controller, "openai_responses_adapter", None) is None:
+ rollout_controller.openai_responses_adapter = OpenAIResponsesAdapter(
+ rollout_controller.generate,
+ tokenizer=tokenizer,
+ default_model_name=default_model_name,
+ context_length=getattr(rollout_controller.config, "context_length", None),
+ capture_path=str(getattr(rollout_controller.config, "worker_log_dir", ".")) + "/gateway_capture.jsonl",
+ )
+ rollout_controller.responses = rollout_controller.openai_responses_adapter.responses
+ return rollout_controller
diff --git a/xtuner/v1/rl/rollout/chat_adapter/trace.py b/xtuner/v1/rl/rollout/chat_adapter/trace.py
new file mode 100644
index 000000000..23babfe0e
--- /dev/null
+++ b/xtuner/v1/rl/rollout/chat_adapter/trace.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+import hashlib
+import json
+import threading
+from collections import OrderedDict
+from collections.abc import Sequence
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Any
+
+from pydantic import BaseModel
+
+from xtuner.v1.data_proto.rl_data import Status
+
+
+def normalize_trace_payload(value: Any) -> Any:
+ if isinstance(value, BaseModel):
+ return normalize_trace_payload(value.model_dump(mode="python", exclude_none=True))
+ if isinstance(value, dict):
+ return {
+ str(key): normalize_trace_payload(val)
+ for key, val in sorted(value.items(), key=lambda item: str(item[0]))
+ if val is not None
+ }
+ if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
+ return [normalize_trace_payload(item) for item in value]
+ return value
+
+
+def build_trace_hash(request_snapshot: Any, response_snapshot: Any) -> str:
+ payload = {
+ "request": normalize_trace_payload(request_snapshot),
+ "response": normalize_trace_payload(response_snapshot),
+ }
+ encoded = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")).encode("utf-8")
+ return hashlib.sha256(encoded).hexdigest()
+
+
+def snapshot_routed_experts(routed_experts: Any) -> Any:
+ if routed_experts is None:
+ return None
+ try:
+ import ray
+
+ if isinstance(routed_experts, ray.ObjectRef):
+ return routed_experts
+ except Exception:
+ pass
+ return deepcopy(routed_experts)
+
+
+@dataclass
+class ChatTraceRecord:
+ response_hash: str
+ request_snapshot: dict[str, Any]
+ response_snapshot: dict[str, Any]
+ prompt_ids: list[int]
+ response_ids: list[int]
+ logprobs: list[float] | None
+ routed_experts: Any
+ finish_reason: str | None
+ status: Status
+ request_id: str | None = None
+
+
+class ChatTraceStore:
+ def __init__(self, max_entries: int = 10000):
+ self._max_entries = max_entries
+ self._records: OrderedDict[str, ChatTraceRecord] = OrderedDict()
+ self._lock = threading.RLock()
+
+ def put(self, record: ChatTraceRecord) -> None:
+ with self._lock:
+ self._records[record.response_hash] = record
+ self._records.move_to_end(record.response_hash)
+ while len(self._records) > self._max_entries:
+ self._records.popitem(last=False)
+
+ def get(self, response_hash: str) -> ChatTraceRecord | None:
+ with self._lock:
+ record = self._records.get(response_hash)
+ if record is None:
+ return None
+ self._records.move_to_end(response_hash)
+ return record
+
+ def build_hash(self, request_snapshot: Any, response_snapshot: Any) -> str:
+ return build_trace_hash(request_snapshot=request_snapshot, response_snapshot=response_snapshot)
diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py
index b82d6428b..1928e72fb 100644
--- a/xtuner/v1/rl/rollout/controller.py
+++ b/xtuner/v1/rl/rollout/controller.py
@@ -1,11 +1,13 @@
import asyncio
import os
+import socket
import threading
from dataclasses import dataclass
-from typing import Dict, List, Optional, TypeAlias, TypedDict
+from typing import Any, Dict, List, Optional, TypeAlias, TypedDict
from uuid import uuid4
import ray
+import uvicorn
from ray.actor import ActorProxy
from ray.util.placement_group import PlacementGroup
@@ -13,6 +15,8 @@
from xtuner.v1.rl.utils import AutoAcceleratorWorkers
from xtuner.v1.utils import get_logger
+from .api_server import create_rollout_api_app
+from .chat_adapter import bind_anthropic_chat_interface, bind_openai_chat_interface, bind_openai_responses_interface
from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter
from .worker import RolloutConfig, RolloutWorker
@@ -53,6 +57,9 @@ class RolloutWorkerMetadata(TypedDict):
# 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用
worker_server_urls_status: Dict[str, bool]
+ # Rollout Controller API 服务器的 URL 地址,
+ api_server_url: str
+
class RolloutController:
"""Controller for managing and coordinating multiple RolloutWorker
@@ -77,12 +84,25 @@ def __init__(
else self.config.tensor_parallel_size
)
self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController")
+ self.api_server_url = ""
self.engine_rank_mesh_array: List[List[int]] = []
self.worker_server_urls_map: dict[str, List[str]] = {}
self.rank2info: dict[int, WorkerInfo] = {}
self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group)
self.num_active_workers = len(self.rank2info)
self.worker_info_lock = threading.RLock()
+ bind_openai_chat_interface(
+ self, default_model_name=self.config.model_name, tokenizer=self.config.tokenizer_path
+ )
+ bind_openai_responses_interface(
+ self, default_model_name=self.config.model_name, tokenizer=self.config.tokenizer_path
+ )
+ bind_anthropic_chat_interface(
+ self,
+ default_model_name=self.config.model_name,
+ tokenizer=self.config.tokenizer_path,
+ )
+ self._start_api_server()
# The timeout for the environment to wait for the rollout controller's response.
# This should be longer than the controller's internal timeout (`rollout_timeout`)
# to account for potential queuing delays and other overheads.
@@ -109,9 +129,19 @@ def get_rollout_metadata(self) -> RolloutWorkerMetadata:
"server_url_dict": self.worker_server_urls_map,
"rollout_config": self.config,
"worker_server_urls_status": worker_server_urls_status,
+ "api_server_url": self.api_server_url,
}
return rollout_metadata
+ def get_ready_status(self) -> tuple[bool, dict[str, Any]]:
+ with self.worker_info_lock:
+ active_workers = sum(1 for info in self.rank2info.values() if info.is_active)
+ total_workers = len(self.rank2info)
+ return active_workers > 0, {
+ "active_workers": active_workers,
+ "total_workers": total_workers,
+ }
+
async def generate(self, rollout_state: RolloutState) -> RolloutState:
session_id = rollout_state.session_uid if rollout_state.session_uid else uuid4().int
worker = await self.router.get_worker(session_id)
@@ -322,6 +352,35 @@ def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_ser
active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval]
return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls))
+ @staticmethod
+ def _is_port_in_use(host: str, port: int) -> bool:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.settimeout(0.2)
+ return sock.connect_ex((host, port)) == 0
+
+ def _start_api_server(self, host: str | None = None, port: int | None = None):
+ """Starts the API server to expose the rollout functionality."""
+ host = host or self.config.api_host
+ port = self.config.api_port if self.config.api_port else (port or 8000)
+
+ original_port = port
+ while self._is_port_in_use(host, port):
+ self.logger.warning(f"Port {port} is in use, trying port {port + 1}")
+ port += 1
+
+ if original_port != port:
+ self.logger.info(f"API server will use port {port} instead of the originally configured {original_port}.")
+
+ app = create_rollout_api_app(self, self.logger)
+
+ config = uvicorn.Config(app, host=host, port=port)
+ server = uvicorn.Server(config)
+ server_thread = threading.Thread(target=server.run, daemon=True)
+ server_thread.start()
+ self.config.api_port = port
+ self.api_server_url = f"http://{host}:{port}"
+ self.logger.info(f"Rollout API server started at {self.api_server_url}")
+
def _init_workers(self, placement_group: PlacementGroup):
"""Initializes and configures the pool of RolloutWorker actors.
diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py
index 65f6fc3d8..c8453865f 100644
--- a/xtuner/v1/rl/rollout/lmdeploy.py
+++ b/xtuner/v1/rl/rollout/lmdeploy.py
@@ -93,12 +93,14 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict:
message = rollout_state.message
input_tokens = rollout_state.tokens
+ optional_fields: dict[str, object] = {}
+ if tools is not None:
+ optional_fields["tools"] = tools
+ if tool_choice is not None:
+ optional_fields["tool_choice"] = tool_choice
+
if sample_params.return_token_ids:
- payload = {
- "model": self.model_name,
- "tools": tools,
- "tool_choice": tool_choice,
- }
+ payload = {"model": self.model_name, **optional_fields}
if input_tokens is not None:
payload["input_ids"] = input_tokens
else:
@@ -107,19 +109,27 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict:
payload["input_ids"] = prompt_token_ids
sample_params.return_routed_experts = True if self.enable_return_routed_experts else False
lmdeploy_sample_params = self._transform_sample_params(sample_params)
- payload.update(sample_params)
+ payload.update(lmdeploy_sample_params)
else:
payload = {
"model": self.model_name,
"messages": rollout_state.message,
- "tools": tools,
- "tool_choice": tool_choice,
+ **optional_fields,
}
- lmdeploy_sample_params = self._transform_sample_params(sample_params)
- lmdeploy_sample_params.pop("no_stop_trim", None)
- lmdeploy_sample_params.pop("return_logprob", None)
- lmdeploy_sample_params.pop("stop_token_ids", None)
- lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens
+ lmdeploy_sample_params = {
+ "temperature": sample_params.temperature,
+ "top_p": sample_params.top_p,
+ "n": sample_params.n,
+ "stream": sample_params.stream,
+ "max_tokens": sample_params.max_tokens,
+ "repetition_penalty": sample_params.repetition_penalty,
+ "top_k": sample_params.top_k,
+ "skip_special_tokens": sample_params.skip_special_tokens,
+ }
+ if sample_params.stops:
+ lmdeploy_sample_params["stop"] = sample_params.stops
+ if sample_params.min_tokens > 0:
+ lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens
payload.update(lmdeploy_sample_params)
return payload
diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py
index e19bb2864..2afe87113 100644
--- a/xtuner/v1/rl/rollout/utils.py
+++ b/xtuner/v1/rl/rollout/utils.py
@@ -5,10 +5,12 @@
from collections import OrderedDict
from itertools import cycle
from typing import TYPE_CHECKING, Any, Optional
+from uuid import uuid4
import httpx
import ray
+from xtuner.v1.data_proto.rl_data import RolloutState
from xtuner.v1.rl.utils import asyncio_run
from xtuner.v1.utils import get_logger
@@ -277,3 +279,13 @@ async def check_worker_health(
f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}"
)
return False
+
+
+def ensure_rollout_request_id(rollout_state: RolloutState) -> str:
+ request_id = str(rollout_state.extra_fields.get("request_id", ""))
+ if request_id:
+ return request_id
+
+ request_id = uuid4().hex
+ rollout_state.extra_fields["request_id"] = request_id
+ return request_id
diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py
index 156c9a054..f5fecb1aa 100644
--- a/xtuner/v1/rl/rollout/worker.py
+++ b/xtuner/v1/rl/rollout/worker.py
@@ -117,6 +117,10 @@ class RolloutConfig(BaseModel):
int,
Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."),
] = 8000
+ api_host: Annotated[
+ str,
+ Parameter(group=infer_group, help="Host for the rollout API server."),
+ ] = "0.0.0.0"
gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8
dtype: Annotated[
str,
@@ -313,7 +317,7 @@ def model_post_init(self, __context: Any) -> None:
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
- s.bind(("localhost", port))
+ s.bind((self.api_host if self.api_host != "0.0.0.0" else "localhost", port))
break
except OSError:
port += 1
diff --git a/xtuner/v1/rl/utils/ray_utils.py b/xtuner/v1/rl/utils/ray_utils.py
index 987ba700f..14d94323d 100644
--- a/xtuner/v1/rl/utils/ray_utils.py
+++ b/xtuner/v1/rl/utils/ray_utils.py
@@ -180,6 +180,6 @@ def bind_train_rollout(
train_workers: A list of training worker actors.
rollout_controller: The rollout controller actor.
"""
- info_dict = ray.get(rollout_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
+ info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined]
ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined]
return