diff --git a/tests/rl/test_camel_agent_loop.py b/tests/rl/test_camel_agent_loop.py new file mode 100644 index 000000000..8beb595fc --- /dev/null +++ b/tests/rl/test_camel_agent_loop.py @@ -0,0 +1,233 @@ +import json +import os +import tempfile +import time +import unittest +from types import SimpleNamespace +from unittest.mock import patch +from urllib import error as urllib_error +from urllib import request as urllib_request + +os.environ.setdefault("RAY_ENABLE_UV_RUN_RUNTIME_ENV", "0") + +import ray +import torch +from camel.toolkits import SearchToolkit + +from xtuner.v1.data_proto import RolloutState, SampleParams, Status +from xtuner.v1.rl.agent_loop import CamelAgentLoop, CamelAgentLoopConfig +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.chat_adapter.collector import append_current_trace_rollout_state +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + +MODEL_PATH = os.environ.get("ROLLOUT_MODEL_PATH", "") +RESOURCE_MAP = { + "npu": "NPU", + "cuda": "GPU", +} + + +class DummyTokenizer: + def encode(self, text, add_special_tokens=False): + return [ord(char) for char in text] + + def decode(self, token_ids): + return "".join(chr(token_id) for token_id in token_ids) + + def __call__(self, text, add_special_tokens=False): + return {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)} + + def apply_chat_template(self, messages, tools=None, add_generation_prompt=True, tokenize=False): + rendered = "\n".join(json.dumps(message, ensure_ascii=False, sort_keys=True) for message in messages) + if tools: + rendered += "\nTOOLS:" + json.dumps(tools, ensure_ascii=False, sort_keys=True) + if add_generation_prompt: + rendered += "\nassistant:" + if tokenize: + return self.encode(rendered, add_special_tokens=False) + return rendered + + +class TestCamelAgentLoopUnit(unittest.IsolatedAsyncioTestCase): + def _build_loop(self, agent, sample_params=None, tools=None, tool_choice=None): + tokenizer = DummyTokenizer() + with patch("xtuner.v1.rl.agent_loop.agent_loop.load_tokenizer", return_value=tokenizer), patch( + "xtuner.v1.rl.agent_loop.agent_loop.load_processor", return_value=None + ), patch.object(CamelAgentLoop, "init_agent", return_value=agent): + cfg = CamelAgentLoopConfig( + hf_checkpoint="dummy", + sample_params=sample_params or SampleParams(max_tokens=256, temperature=0.0), + tools=tools, + tool_choice=tool_choice, + ) + loop = cfg.build(rollout_controller=SimpleNamespace()) + return loop, tokenizer + + async def test_camel_generate_sample_returns_gateway_rollout_states(self): + class GatewayFakeAgent: + async def astep(self, content): + first_turn = RolloutState( + uid=101, + message=[{"role": "user", "content": content}], + prompt_ids=[11, 12], + tokens=[11, 12], + response="tool-call", + response_ids=[21, 22], + logprobs=[-0.1, -0.2], + response_mask=[1, 1], + finish_reason="tool_calls", + status=Status.COMPLETED, + extra_fields={ + "tool_calls": [ + { + "id": "call_search", + "type": "function", + "function": {"name": "search", "arguments": "{\"q\":\"camel\"}"}, + } + ] + }, + ) + second_turn = RolloutState( + uid=102, + message=[{"role": "user", "content": content}], + prompt_ids=[11, 12, 21, 22], + tokens=[11, 12, 21, 22], + response="final-answer", + response_ids=[31, 32, 33], + logprobs=[-0.3, -0.4, -0.5], + response_mask=[1, 1, 1], + finish_reason="stop", + status=Status.COMPLETED, + ) + append_current_trace_rollout_state(first_turn) + append_current_trace_rollout_state(second_turn) + return SimpleNamespace(info={"termination_reasons": ["stop"]}) + + loop, _ = self._build_loop(agent=GatewayFakeAgent()) + state = RolloutState( + message=[{"role": "user", "content": "Where is CAMEL on GitHub?"}], + sample_params=SampleParams(max_tokens=128, temperature=0.0), + ) + + result = await loop.generate_sample(state) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].status, Status.COMPLETED) + self.assertEqual(result[0].finish_reason, "tool_calls") + self.assertEqual(result[0].extra_fields["gateway_rollout_index"], 0) + self.assertEqual(result[1].finish_reason, "stop") + self.assertEqual(result[1].prompt_ids, [11, 12, 21, 22]) + self.assertEqual(result[1].tokens, [11, 12, 21, 22]) + self.assertEqual(result[1].response_ids, [31, 32, 33]) + self.assertEqual(result[1].response_mask, [1, 1, 1]) + self.assertEqual(result[1].logprobs, [-0.3, -0.4, -0.5]) + self.assertEqual(result[1].response, "final-answer") + self.assertEqual(result[1].message, state.message) + self.assertEqual(len(result[1].extra_fields["gateway_trace_records"]), 2) + self.assertEqual(result[1].extra_fields["gateway_trace_records"][0]["request_id"], "101") + self.assertEqual(result[1].extra_fields["gateway_trace_records"][1]["finish_reason"], "stop") + + +@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0" or not MODEL_PATH, "lmdeploy backend is not enabled") +class TestCamelAgentLoopIntegration(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def setUp(self): + os.environ.pop("RAY_ADDRESS", None) + ray.init(address="local", num_cpus=80, ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=1, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + self.context_length = 1024 + + def tearDown(self): + ray.shutdown() + self.temp_dir.cleanup() + + def _wait_until_ready(self, base_url: str): + deadline = time.time() + 1800 + last_error = None + while time.time() < deadline: + try: + with urllib_request.urlopen(f"{base_url}/healthz", timeout=10.0) as response: + if response.status == 200: + return + last_error = response.read().decode("utf-8", errors="ignore") + except urllib_error.URLError as exc: + last_error = repr(exc) + except Exception as exc: + last_error = repr(exc) + time.sleep(5) + raise RuntimeError(f"API server at {base_url} did not become ready in time: {last_error}") + + def _build_controller(self, port: int): + rollout_config = RolloutConfig( + env=f"test_camel_agent_loop_{port}", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + api_host="127.0.0.1", + api_port=port, + ) + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg, name=f"camel_pg_{port}") + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + return rollout_controller, pg + + async def test_camel_tool_call_integration_returns_gateway_rollout_batch(self): + rollout_controller, pg = self._build_controller(port=28602) + try: + metadata = await rollout_controller.get_rollout_metadata.remote() + self._wait_until_ready(metadata["api_server_url"]) + search_tool = SearchToolkit().search_duckduckgo + + cfg = CamelAgentLoopConfig( + hf_checkpoint=MODEL_PATH, + sample_params=SampleParams(max_tokens=256, temperature=0.0), + system_message="You are a helpful assistant to do search task.", + tools=[search_tool], + tool_choice={"type": "function", "function": {"name": "search_duckduckgo"}}, + ) + loop = cfg.build(rollout_controller=rollout_controller) + state = RolloutState( + message=[{"role": "user", "content": "What is the Github link to CAMEL framework?"}], + sample_params=SampleParams(max_tokens=256, temperature=0.0), + ) + + result = await loop.generate_sample(state) + print(f"CAMEL_INTEGRATION_RESULT: {result}") + if not any((state.extra_fields or {}).get("tool_calls") for state in result): + self.skipTest("Current backend/model did not emit tool calls for the integration prompt.") + + self.assertGreaterEqual(len(result), 1) + self.assertEqual(result[-1].status, Status.COMPLETED) + self.assertIsNotNone(result[-1].response_ids) + self.assertIsNotNone(result[-1].response_mask) + self.assertIsNotNone(result[-1].response) + self.assertTrue(any(state.finish_reason == "tool_calls" for state in result)) + + finally: + try: + await rollout_controller.shutdown.remote() + finally: + ray.util.remove_placement_group(pg) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rl/test_evaluator.py b/tests/rl/test_evaluator.py deleted file mode 100644 index ffd52930a..000000000 --- a/tests/rl/test_evaluator.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import unittest -import ray -import tempfile -from transformers import AutoTokenizer - -from xtuner.v1.rl.rollout.worker import RolloutConfig -try: - from xtuner.v1.ray.judger.controller import JudgerConfig -except Exception: - class JudgerConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers -try: - from xtuner.v1.ray.environment import SingleTurnEnvironment -except Exception: - SingleTurnEnvironment = None -from xtuner.v1.rl.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, OpenaiTokenizeFunctionConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] - - -@unittest.skipIf(SingleTurnEnvironment is None, "ray environment unavailable") -class TestEvaluator(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 512 - self.max_response_length = 1024 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - tensor_parallel_size=8, - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir - ) - from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") - self.judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - self.eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", - anno_path=TEST_DATA_PATH, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length) - }, - ] - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - self.rollout_cfg, - None, - self.judger_cfg - ) - self.sample_params = SampleParams( - top_p=1.0, - temperature=0.0, - max_tokens=self.max_response_length, - top_k=1 - ) - - def setUp(self): - ray.init(num_cpus=80) - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_evaluator(self): - def custom_compute_metric(samples): - return {"custom_accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)} - - evaluator_cfg = EvaluatorConfig( - dataset_cfg=self.eval_dataset_cfg, - tokenizer=self.tokenizer, - max_concurrent=16, - eval_sample_ratio=0.004, # generate 5 samples - compute_metric_func=custom_compute_metric, - sample_params=self.sample_params, - worker_log_dir=self.worker_log_dir - ) - evaluator = Evaluator.remote(evaluator_cfg, self.test_env) - try: - ray.get(evaluator.run.remote()) - except Exception as e: - self.fail(f"evaluator.run.remote() raised an exception: {e}") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/rl/test_rl_train_with_sft.py b/tests/rl/test_rl_train_with_sft.py deleted file mode 100644 index e0476de71..000000000 --- a/tests/rl/test_rl_train_with_sft.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -import unittest -from transformers import AutoTokenizer -import shutil -import tempfile -import json -import torch -from xtuner.v1.data_proto.sequence_context import SequenceContext -import ray -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig -from xtuner.v1.loss import CELossConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.train.trainer import LoadCheckpointConfig - -QWEN3_PATH = os.environ["QWEN3_PATH"] -ALPACA_PATH = os.environ["ALPACA_PATH"] - - -class TestRLTrainWithSFT(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - - resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_accelerators_per_worker=1, - num_cpus_per_worker=8, - num_workers=8, - cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB - ) - - pg = AutoAcceleratorWorkers.build_placement_group(resources) - self.pg = pg - - self.temp_dir = tempfile.mkdtemp() - tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) - self.tokenizer = tokenizer - self.prompt_repeat_k = 8 - file = './tests/ray/rollout_output.jsonl' - with open(file, 'r') as f: - data = [json.loads(line) for line in f] - data_groups = [data[i:i + self.prompt_repeat_k] for i in range(0, len(data), self.prompt_repeat_k)] - data_groups = data_groups[:8] - data_batches = [] - for group in data_groups: - prompt_ids = tokenizer(group[0]['prompt'], return_tensors='pt')['input_ids'].flatten().tolist() - rewards = [item['reward'] for item in group] - rewards = torch.tensor(rewards, dtype=torch.float32) - advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - - for i in range(self.prompt_repeat_k): - item = group[i] - response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist() - input_ids = prompt_ids + response_ids - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - data_batches.append( - dict( - seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), - shifted_labels=shifted_labels, - advantage=advantages[i].item(), - ) - ) - self.data_batches = data_batches - - def tearDown(self): - shutil.rmtree(self.temp_dir) - ray.shutdown() - - def build_train_controller(self): - model_cfg = Qwen3Dense8BConfig() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig( - torch_compile=True, - cpu_offload=False, - ep_size=1, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - - dataset_config = [] - _data_cfg = {"dataset": DatasetConfig(name='apach', - anno_path=ALPACA_PATH), - "tokenize_fn": OpenaiTokenizeFunctionConfig( - chat_template='qwen3', - max_length=32768 - ) - } - dataset_config.append(_data_cfg) - - sft_dataloader_cfg = DataloaderConfig( - dataset_config_list=dataset_config, - pack_max_length=32768, - pack_to_max_length=True, - num_workers=0, - ) - sft_global_batch_size = 8 - loss_reduction = "square" - sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction=loss_reduction) - - worker_cfg: WorkerConfig = WorkerConfig( - sft_dataloader_cfg=sft_dataloader_cfg, - sft_global_batch_size=sft_global_batch_size, - sft_loss_cfg=sft_loss_cfg, - seed=42, - model_cfg=model_cfg, - optim_cfg=optim_cfg, - loss_cfg=LossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="eager"), - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - load_from=QWEN3_PATH, - sp_size=1, - pack_max_length=8192, - ) - - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(BaseTrainingWorker) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, worker_cfg, self.pg - ) - futures = [worker.test_all_reduce.remote() for worker in train_workers] - print(ray.get(futures)) - train_controller = TrainingController.remote( - workers=train_workers, - ) - ray.get(train_controller.__ray_ready__.remote()) - return train_controller - - def test_rl_train_with_sft(self): - train_controller = self.build_train_controller() - - ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) - ray.get(train_controller.save.remote(os.path.join(self.temp_dir, "save_test"), no_save_optimizer=True)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - efficient_attn_ratio_list = [] - for log_info in log_infos: - efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - assert all([efficient_attn_ratio > 0 for efficient_attn_ratio in efficient_attn_ratio_list]) - - ray.kill(train_controller) - train_controller = self.build_train_controller() - load_checkpoint_cfg = LoadCheckpointConfig(checkpoint_path=os.path.join(self.temp_dir, "save_test"), - load_optimizer_states=False, - load_optimizer_args=False - ) - ray.get(train_controller.resume.remote(load_checkpoint_cfg)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - new_efficient_attn_ratio_list = [] - for log_info in log_infos: - new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - - efficient_attn_ratio_list.sort() - new_efficient_attn_ratio_list.sort() - self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list) diff --git a/tests/rl/test_rl_trainer.py b/tests/rl/test_rl_trainer.py deleted file mode 100644 index 82688151d..000000000 --- a/tests/rl/test_rl_trainer.py +++ /dev/null @@ -1,266 +0,0 @@ -import os -import tempfile -import unittest -from pathlib import Path - -import ray -import torch - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, CPUResourcesConfig -from xtuner.v1.rl.rollout.worker import RolloutConfig -try: - from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -except Exception: - class DataFlowConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) - - class ReplayBufferConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) -try: - from xtuner.v1.ray.judger.controller import JudgerConfig -except Exception: - class JudgerConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) -from xtuner.v1.rl.trainer.worker import WorkerConfig -from xtuner.v1.rl.loss import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainer, RLTrainerConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} - - -class TestRLTrainer(unittest.TestCase): - @classmethod - def setUpClass(cls): - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls): - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_traine_worker_config(self, train_optimizer_steps, pack_max_length): - model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) - optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) - loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) - fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) - train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=MODEL_PATH, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, - ) - return train_worker_cfg - - def init_replay_buffer_config(self, max_prompt_length): - train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", anno_path=TRAIN_DATA_PATH, sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length), - }, - ] - dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", - group_by_length=False, - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_cfg, - tokenizer=tokenizer, - worker_log_dir=self.worker_log_dir, - ) - return replay_buffer_cfg - - def init_resources_config(self, num_workers, num_cpus_per_worker, cpu_memory_per_worker): - resources = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=num_workers, - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return resources - - def init_cpu_resources_config(self, num_cpus_per_worker, cpu_memory_per_worker): - cpu_resources = CPUResourcesConfig( - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return cpu_resources - - def init_rollout_config(self, max_prompt_length, max_response_length): - rollout_config = RolloutConfig( - env="test_rl_trainer", - model_path=MODEL_PATH, - worker_log_dir=self.worker_log_dir, - rollout_max_batch_size_per_instance=1024, - context_length=max_response_length + max_prompt_length, - ) - return rollout_config - - def init_dataflow_config(self, max_response_length, global_batch_size, prompt_repeat_k, enable_partial_rollout): - sample_params = SampleParams( - max_tokens=max_response_length, - ) - dataflow_config = DataFlowConfig( - env="test_rl_trainer", - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - worker_log_dir=self.worker_log_dir, - sample_params=sample_params, - enable_partial_rollout=enable_partial_rollout, - max_concurrent=1024, - ) - return dataflow_config - - def init_judger_config(self): - from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig - - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") - judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config], worker_log_dir=self.worker_log_dir) - return judger_cfg - - def init_multi_judger_config(self): - from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig - - # 支持一个GSM8KJudgerConfig创建多个实例 - gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1", judger_type="router") - gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2", judger_type="router") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config_1, gsm8k_judger_config_2], - worker_log_dir=self.worker_log_dir, - ) - return judger_cfg - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - train_optimizer_steps = 2 - pack_max_length = 32768 - max_prompt_length = 2048 - max_response_length = 1024 - global_batch_size = 4 - prompt_repeat_k = 4 - enable_partial_rollout = False - - self.train_worker_cfg = self.init_traine_worker_config(train_optimizer_steps, pack_max_length) - self.replay_buffer_cfg = self.init_replay_buffer_config(max_prompt_length) - self.resources_cfg = self.init_resources_config( - num_workers=8, num_cpus_per_worker=8, cpu_memory_per_worker=8 * 1024**3 - ) - self.cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - self.rollout_config = self.init_rollout_config( - max_response_length=max_response_length, max_prompt_length=max_prompt_length - ) - self.dataflow_config = self.init_dataflow_config( - max_response_length=max_response_length, - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - enable_partial_rollout=enable_partial_rollout, - ) - self.judger_config = self.init_judger_config() - - def tearDown(self): - self.temp_dir.cleanup() - ray.shutdown() - - def test_rl_trainer(self): - multi_judger_config = self.init_multi_judger_config() - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=2, cpu_memory_per_worker=2 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - trainer = RLTrainer.from_config(trainer_config) - self.assertIsNotNone(trainer, "Trainer should be created successfully") - try: - trainer.fit() - except Exception as e: - self.fail(f"trainer.fit() raised unexpected exception: {e}") - # assure all writers are closed before checking log files - del trainer - log_files = list(Path(self.worker_log_dir).rglob("*.log")) - self.assertGreater(len(log_files), 0, "Should generate log files") - trajectory_files = list(Path(self.worker_log_dir).rglob("*_trajectory.jsonl")) - self.assertGreater(len(trajectory_files), 0, "Should generate trajectory files") - - def test_judger_cpu_pg_creation_with_error(self): - """Test RLTrainer judger_cpu_pg creation.""" - multi_judger_config = self.init_multi_judger_config() - # error resource with multi-judger - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - with self.assertRaises(AssertionError) as cm: - trainer = RLTrainer.from_config(trainer_config) - - print(f"Expected AssertionError caught: {cm.exception}") - -if __name__ == "__main__": - unittest.main() diff --git a/tests/rl/test_rollout_api_server.py b/tests/rl/test_rollout_api_server.py new file mode 100644 index 000000000..b197f25d9 --- /dev/null +++ b/tests/rl/test_rollout_api_server.py @@ -0,0 +1,314 @@ +import os +import subprocess +import tempfile +import time +import unittest + +import httpx +import ray +import torch +from transformers import AutoTokenizer + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + +TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +MOE_MODEL_PATH = os.environ.get("QWEN3_MOE_PATH") or os.environ["QWEN30B_MODEL_PATH"] +RESOURCE_MAP = { + "npu": "NPU", + "cuda": "GPU", +} + + +@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") +class TestRolloutAPIServer(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + self.max_prompt_length = 512 + self.max_response_length = 1024 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + + def tearDown(self): + ray.shutdown() + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run( + ["pkill", "-f", "ray::RayWorkerWrapper*"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + if result.returncode != 0: + print( + f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found." + ) + except Exception as exc: + print(f"Error stopping ray::RayWorkerWrapper cluster: {exc}") + + def _wait_until_ready(self, base_url: str): + deadline = time.time() + 1800 + last_error = None + while time.time() < deadline: + try: + response = httpx.get(f"{base_url}/healthz", timeout=10.0) + if response.status_code == 200: + return + last_error = f"healthz returned {response.status_code}: {response.text}" + except httpx.HTTPError as exc: + last_error = repr(exc) + time.sleep(5) + raise RuntimeError(f"API server at {base_url} did not become ready in time: {last_error}") + + def test_dense_model(self): + resource_config = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=16, + cpu_memory_per_worker=8 * 1024**3, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resource_config, name="dense_api_pg") + dense_worker_log_dir = os.path.join(self.worker_log_dir, "dense") + rollout_config = RolloutConfig( + env="test_rollout_api_server_dense", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=4, + expert_parallel_size=1, + context_length=self.context_length, + worker_log_dir=dense_worker_log_dir, + dist_port_base=38000, + api_host="127.0.0.1", + api_port=28000, + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + try: + metadata = ray.get(rollout_controller.get_rollout_metadata.remote(), timeout=1800) + base_url = metadata["api_server_url"] + self._wait_until_ready(base_url) + + text_prompt = self.tokenizer.apply_chat_template( + TEST_TEXT_MESSAGES, + tokenize=False, + add_generation_prompt=True, + ) + test_input_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + + with httpx.Client(timeout=300.0) as client: + generate = client.post( + f"{base_url}/generate", + json={ + "message": TEST_TEXT_MESSAGES, + "tokens": test_input_ids, + "sample_params": { + "return_token_ids": True, + "temperature": 0.0, + "top_k": 1, + "max_tokens": 16, + }, + }, + ) + self.assertEqual(generate.status_code, 200, generate.text) + generate_body = generate.json() + self.assertEqual(generate_body["status"], "completed") + self.assertIn(generate_body["finish_reason"], {"stop", "length"}) + self.assertTrue(generate_body["extra_fields"]["request_id"]) + self.assertGreater(len(generate_body["response_ids"]), 0) + self.assertIsInstance(generate_body["response"], str) + + chat = client.post( + f"{base_url}/v1/chat/completions", + json={ + "model": rollout_config.model_name, + "messages": TEST_TEXT_MESSAGES, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": 16, + }, + ) + self.assertEqual(chat.status_code, 200, chat.text) + chat_body = chat.json() + print("chat_body: ", chat_body) + self.assertEqual(chat_body["object"], "chat.completion") + self.assertEqual(chat_body["model"], rollout_config.model_name) + self.assertTrue(chat_body["id"].startswith("chatcmpl-")) + self.assertEqual(chat_body["choices"][0]["message"]["role"], "assistant") + self.assertTrue(chat_body["choices"][0]["message"]["content"]) + self.assertIn(chat_body["choices"][0]["finish_reason"], {"stop", "length"}) + self.assertGreater(chat_body["usage"]["prompt_tokens"], 0) + self.assertGreater(chat_body["usage"]["total_tokens"], chat_body["usage"]["completion_tokens"]) + + anthropic = client.post( + f"{base_url}/v1/messages", + json={ + "model": rollout_config.model_name, + "system": "You are helpful.", + "messages": TEST_TEXT_MESSAGES, + "max_tokens": 16, + "temperature": 0.0, + "top_p": 1.0, + }, + ) + self.assertEqual(anthropic.status_code, 200, anthropic.text) + anthropic_body = anthropic.json() + self.assertEqual(anthropic_body["type"], "message") + self.assertEqual(anthropic_body["role"], "assistant") + self.assertEqual(anthropic_body["model"], rollout_config.model_name) + self.assertTrue(anthropic_body["id"].startswith("msg_")) + self.assertTrue(anthropic_body["content"][0]["text"]) + self.assertIn(anthropic_body["stop_reason"], {"stop", "length"}) + self.assertGreater(anthropic_body["usage"]["input_tokens"], 0) + self.assertGreaterEqual(anthropic_body["usage"]["output_tokens"], 1) + + invalid_block = client.post( + f"{base_url}/v1/messages", + json={ + "model": rollout_config.model_name, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "look"}, + {"type": "image", "text": ""}, + ], + } + ], + "max_tokens": 8, + }, + timeout=30.0, + ) + self.assertEqual(invalid_block.status_code, 400, invalid_block.text) + self.assertEqual(invalid_block.json()["type"], "error") + self.assertEqual(invalid_block.json()["error"]["type"], "invalid_request_error") + + health = client.get(f"{base_url}/healthz", timeout=30.0) + meta = client.get(f"{base_url}/metadata", timeout=30.0) + self.assertEqual(health.status_code, 200, health.text) + self.assertEqual(health.json()["status"], "ok") + self.assertGreaterEqual(health.json()["active_workers"], 1) + self.assertEqual(meta.status_code, 200, meta.text) + self.assertEqual(meta.json()["api_server_url"], base_url) + self.assertEqual(metadata["api_server_url"], base_url) + self.assertEqual(meta.json()["api_server_url"].rsplit(":", 1)[-1], str(rollout_config.api_port)) + self.assertTrue(all(meta.json()["worker_server_urls_status"].values())) + + offload = client.post(f"{base_url}/offload", timeout=120.0) + self.assertEqual(offload.status_code, 200, offload.text) + self.assertEqual(offload.json()["action"], "offload") + + onload = client.post(f"{base_url}/onload", timeout=120.0) + self.assertEqual(onload.status_code, 200, onload.text) + self.assertEqual(onload.json()["action"], "onload") + + regenerated = client.post( + f"{base_url}/generate", + json={ + "message": TEST_TEXT_MESSAGES, + "sample_params": { + "return_token_ids": True, + "temperature": 0.0, + "top_k": 1, + "max_tokens": 8, + }, + }, + ) + self.assertEqual(regenerated.status_code, 200, regenerated.text) + self.assertEqual(regenerated.json()["status"], "completed") + finally: + try: + ray.get(rollout_controller.shutdown.remote(), timeout=300) + finally: + ray.util.remove_placement_group(pg) + + def test_moe_model(self): + resource_config = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=16, + cpu_memory_per_worker=8 * 1024**3, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resource_config, name="moe_api_pg") + moe_worker_log_dir = os.path.join(self.worker_log_dir, "moe") + rollout_config = RolloutConfig( + env="test_rollout_api_server_moe", + model_path=MOE_MODEL_PATH, + model_name=os.path.basename(MOE_MODEL_PATH).lower(), + tokenizer_path=MOE_MODEL_PATH, + tensor_parallel_size=1, + expert_parallel_size=4, + context_length=self.context_length, + worker_log_dir=moe_worker_log_dir, + dist_port_base=38000 + 1024 * 4, + api_host="127.0.0.1", + api_port=29000, + enable_return_routed_experts=True, + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + try: + metadata = ray.get(rollout_controller.get_rollout_metadata.remote(), timeout=1800) + base_url = metadata["api_server_url"] + self._wait_until_ready(base_url) + + request = RolloutState( + message=[{"role": "user", "content": "Briefly explain what mixture of experts means."}], + sample_params=SampleParams( + return_token_ids=True, + return_logprob=False, + temperature=0.0, + top_k=1, + max_tokens=32, + ), + ) + with httpx.Client(timeout=300.0) as client: + response = client.post( + f"{base_url}/generate", + json=request.model_dump(mode="json"), + ) + meta = client.get(f"{base_url}/metadata", timeout=30.0) + + self.assertEqual(response.status_code, 200, response.text) + rollout_state = RolloutState.model_validate_json(response.text) + self.assertIsNotNone(rollout_state.routed_experts) + self.assertEqual(meta.status_code, 200, meta.text) + self.assertEqual(meta.json()["api_server_url"], base_url) + self.assertEqual(metadata["api_server_url"], base_url) + self.assertEqual(meta.json()["api_server_url"].rsplit(":", 1)[-1], str(rollout_config.api_port)) + finally: + try: + ray.get(rollout_controller.shutdown.remote(), timeout=300) + finally: + ray.util.remove_placement_group(pg) + +if __name__ == "__main__": + unittest.main() diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index e8009eb34..0762b9744 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -1,10 +1,11 @@ from __future__ import annotations +import base64 from enum import Enum from typing import TYPE_CHECKING, Any, TypeAlias import torch -from pydantic import BaseModel, ConfigDict, field_serializer +from pydantic import BaseModel, ConfigDict, field_serializer, field_validator from typing_extensions import NotRequired, TypedDict # ==================================== @@ -76,11 +77,11 @@ class RolloutState(CacheObj, BaseModel): reward_model: dict[str, Any] | None = None num_tokens: int | None = None # 用于 cache 管理 - # --- InferEngine 输入 --- + # --- InferEngine 输入 ---å session_uid: int | None = None tokens: list[int] | None = None # 每一次推理引擎的实际输入 tools: list | None = None - tool_choice: str | None = None + tool_choice: str | dict[str, Any] | None = None sample_params: SampleParams = SampleParams() # --- InferEngine 输出 --- @@ -104,20 +105,44 @@ class RolloutState(CacheObj, BaseModel): extra_fields: dict[str, Any] = {} @field_serializer("routed_experts") - def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | None: - """Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。""" + def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | str | None: + """序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - RayObjectRef -> base64 编码的字符串(通过 ray.cloudpickle 序列化) + """ + import ray + if value is None: return None - try: - import ray - - if isinstance(value, ray.ObjectRef): - return None - except ImportError: - pass - if type(value).__name__ == "ObjectRef" and "ray" in getattr(type(value), "__module__", ""): + if isinstance(value, ray.ObjectRef): + data = ray.cloudpickle.dumps(value) + return base64.b64encode(data).decode("utf-8") + return value + + @field_validator("routed_experts", mode="before") + @classmethod + def _deserialize_routed_experts(cls, value: Any) -> list[int] | RayObjectRef | None: + """反序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - str(base64 编码)-> RayObjectRef(通过 ray.cloudpickle 反序列化) + - RayObjectRef -> RayObjectRef(原样保留) + """ + import ray + + if value is None: return None - return value # list[int] + if isinstance(value, ray.ObjectRef): + return value + if isinstance(value, str): + data = base64.b64decode(value) + return ray.cloudpickle.loads(data) + if isinstance(value, list): + return value + return value def update_status_from_finish_reason(finish_reason: str | None) -> Status: diff --git a/xtuner/v1/rl/agent_loop/camel_agent_loop.py b/xtuner/v1/rl/agent_loop/camel_agent_loop.py new file mode 100644 index 000000000..4abd88c47 --- /dev/null +++ b/xtuner/v1/rl/agent_loop/camel_agent_loop.py @@ -0,0 +1,190 @@ +import copy +import os +from typing import Any, Literal + +import ray +from camel.agents import ChatAgent +from camel.models import OpenAICompatibleModel +from camel.utils import BaseTokenCounter +from openai import AsyncOpenAI +from pydantic import ConfigDict + +from xtuner.v1.data_proto import RolloutState, SampleParams, Status +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.chat_adapter.collector import reset_current_trace_collector, set_current_trace_collector +from xtuner.v1.rl.rollout.utils import ROLLOUT_RAY_GET_TIMEOUT + +from .agent_loop import AgentLoop, AgentLoopConfig + + +class XTunerCamelTokenCounter(BaseTokenCounter): + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def count_tokens_from_messages(self, messages: list[dict]) -> int: + return len(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=True)) + + def encode(self, text: str) -> list[int]: + return list(self.tokenizer(text, add_special_tokens=False)["input_ids"]) + + def decode(self, token_ids: list[int]) -> str: + return self.tokenizer.decode(token_ids) + + +class CamelAgentLoopConfig(AgentLoopConfig): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + mode: Literal["single_turn"] = "single_turn" + context_length: int | None = None + system_message: str | None = None + tools: list[Any] | None = None + tool_choice: str | dict[str, Any] | None = None + + def build(self, rollout_controller, judger=None, logger=None) -> "CamelAgentLoop": + return CamelAgentLoop( + context_length=self.context_length, + system_message=self.system_message, + tools=self.tools, + tool_choice=self.tool_choice, + rollout_ctl=rollout_controller, + hf_checkpoint=self.hf_checkpoint, + sample_params=self.sample_params, + judger=judger, + logger=logger, + ) + + +class CamelAgentLoop(AgentLoop): + def __init__( + self, + context_length: int | None, + system_message: str | None, + tools: list[Any] | None, + tool_choice: str | dict[str, Any] | None, + rollout_ctl: RolloutController, + hf_checkpoint: str, + sample_params: SampleParams, + judger=None, + logger=None, + ) -> None: + super().__init__( + rollout_ctl=rollout_ctl, + hf_checkpoint=hf_checkpoint, + sample_params=sample_params, + judger=judger, + logger=logger, + ) + self.context_length = context_length + self.system_message = system_message + self.tools = tools + self.tool_choice = tool_choice + self._api_server_url: str | None = None + self._model_name: str | None = None + + def init_agent(self): + metadata = ray.get(self.rollout_ctl.get_rollout_metadata.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) + self._api_server_url = metadata["api_server_url"] + self._model_name = getattr(metadata.get("rollout_config"), "model_name", None) or "rollout-controller" + + client = AsyncOpenAI( + base_url=f"{self._api_server_url.rstrip('/')}/v1", + api_key=os.environ.get("OPENAI_API_KEY", "EMPTY"), + timeout=180.0, + max_retries=3, + ) + model = OpenAICompatibleModel( + model_type=self._model_name, + client=client, + async_client=client, + model_config_dict=self._build_camel_model_config(copy.deepcopy(self.sample_params)), + token_counter=XTunerCamelTokenCounter(self.tokenizer), + ) + return ChatAgent( + system_message=self.system_message, + model=model, + token_limit=self.context_length, + step_timeout=180.0, + tools=self.tools, + ) + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> list[RolloutState]: + try: + agent = self.init_agent() + messages = copy.deepcopy(rollout_state.message) + turn_rollout_states: list[RolloutState] = [] + collector_token = set_current_trace_collector(turn_rollout_states) + try: + await agent.astep(messages[-1]["content"]) + finally: + reset_current_trace_collector(collector_token) + + if not turn_rollout_states: + raise RuntimeError("no rollout states captured from gateway") + + normalized_rollout_states: list[RolloutState] = [] + trace_records: list[dict[str, Any]] = [] + for turn_index, turn_state in enumerate(turn_rollout_states): + if turn_state.prompt_ids is None: + raise RuntimeError(f"captured Camel rollout turn {turn_index} is missing prompt_ids") + if turn_state.response_ids is None: + raise RuntimeError(f"captured Camel rollout turn {turn_index} is missing response_ids") + normalized_turn_state = turn_state.model_copy(deep=True) + normalized_turn_state.message_uid = rollout_state.message_uid + normalized_turn_state.message = copy.deepcopy(rollout_state.message) + normalized_turn_state.data_source = copy.deepcopy(rollout_state.data_source) + normalized_turn_state.mm_info = copy.deepcopy(rollout_state.mm_info) + normalized_turn_state.reward_model = copy.deepcopy(rollout_state.reward_model) + normalized_turn_state.sample_params = copy.deepcopy(rollout_state.sample_params) + normalized_turn_state.task_name = rollout_state.task_name + normalized_turn_state.seq_staleness = rollout_state.seq_staleness + normalized_turn_state.extra_fields = { + **copy.deepcopy(rollout_state.extra_fields), + **copy.deepcopy(normalized_turn_state.extra_fields), + "gateway_rollout_index": turn_index, + } + normalized_rollout_states.append(normalized_turn_state) + trace_records.append( + { + "request_id": str(normalized_turn_state.uid) + if normalized_turn_state.uid is not None + else None, + "prompt_ids": list(normalized_turn_state.prompt_ids), + "response_ids": list(normalized_turn_state.response_ids), + "logprobs": None + if normalized_turn_state.logprobs is None + else list(normalized_turn_state.logprobs), + "routed_experts": normalized_turn_state.routed_experts, + "finish_reason": normalized_turn_state.finish_reason, + "status": normalized_turn_state.status, + } + ) + for normalized_turn_state in normalized_rollout_states: + normalized_turn_state.extra_fields["gateway_trace_records"] = copy.deepcopy(trace_records) + return normalized_rollout_states + except Exception as exc: + rollout_state.status = Status.FAILED + rollout_state.error_msg = f"Camel agent loop failed: {exc}" + return [rollout_state] + + def _build_camel_model_config(self, sample_params: SampleParams) -> dict: + model_config = { + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "max_tokens": sample_params.max_tokens, + "stream": False, + } + if self.tool_choice is not None: + model_config["tool_choice"] = copy.deepcopy(self.tool_choice) + if sample_params.presence_penalty: + model_config["presence_penalty"] = sample_params.presence_penalty + if sample_params.frequency_penalty: + model_config["frequency_penalty"] = sample_params.frequency_penalty + if sample_params.stops: + model_config["stop"] = sample_params.stops + return model_config + + def _extract_finish_reason(self, response) -> str: + reasons = response.info.get("termination_reasons", []) if getattr(response, "info", None) else [] + if reasons and reasons[-1] in {"stop", "length", "tool_calls"}: + return reasons[-1] + return "stop" diff --git a/xtuner/v1/rl/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py index 349cd2fad..57e7cf13e 100644 --- a/xtuner/v1/rl/rollout/__init__.py +++ b/xtuner/v1/rl/rollout/__init__.py @@ -1,5 +1,6 @@ import os +from .chat_adapter import AnthropicChatAdapter, OpenAIChatAdapter from .controller import RolloutController from .worker import RolloutWorker diff --git a/xtuner/v1/rl/rollout/api_server.py b/xtuner/v1/rl/rollout/api_server.py new file mode 100644 index 000000000..5a387b069 --- /dev/null +++ b/xtuner/v1/rl/rollout/api_server.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from xtuner.v1.data_proto.rl_data import RolloutState, Status + +from .chat_adapter import ( + AnthropicChatAdapterError, + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, + ChatCompletionRequest, + ChatCompletionResponse, + OpenAIChatAdapterError, + ResponsesRequest, + ResponsesResponse, +) +from .utils import ensure_rollout_request_id + + +if TYPE_CHECKING: + from .controller import RolloutControllerProxy + + +def _build_error_response( + status_code: int, + message: str, + error_type: str, + code: str | None = None, + request_id: str | None = None, + protocol: str = "openai", +) -> JSONResponse: + if protocol == "anthropic": + payload = { + "type": "error", + "error": { + "type": error_type, + "message": message, + }, + } + if request_id is not None: + payload["request_id"] = request_id + else: + payload = { + "error": { + "message": message, + "type": error_type, + "code": code, + "request_id": request_id, + } + } + return JSONResponse(status_code=status_code, content=payload) + + +def create_rollout_api_app( + rollout_controller: RolloutControllerProxy, + logger: Any, +) -> FastAPI: + """Build the rollout API app around the provided rollout controller.""" + app = FastAPI() + + @app.exception_handler(HTTPException) + async def handle_http_exception(request: Request, exc: HTTPException) -> JSONResponse: + request_id = request.headers.get("X-Request-Id") + if isinstance(exc.detail, dict) and "error" in exc.detail: + return JSONResponse(status_code=exc.status_code, content=exc.detail) + return _build_error_response( + status_code=exc.status_code, + message=str(exc.detail), + error_type="invalid_request_error" if exc.status_code < 500 else "server_error", + code="http_error", + request_id=request_id, + ) + + @app.post("/generate") + async def generate(request: RolloutState) -> RolloutState: + request_id = ensure_rollout_request_id(request) + try: + response = await rollout_controller.generate(request) + if not response.extra_fields.get("request_id"): + response.extra_fields["request_id"] = request_id + return response + except Exception as exc: + logger.error(f"Generate failed in API server for request_id={request_id}: {exc}") + request.status = Status.FAILED + request.error_msg = f"Generate failed in API server with error: {str(exc)}" + return request + + @app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest, http_request: Request) -> ChatCompletionResponse: + try: + return await rollout_controller.chat(request) + except OpenAIChatAdapterError as exc: + status_code = 400 if exc.error_type == "invalid_request_error" else 500 + raise HTTPException( + status_code=status_code, + detail={ + "error": { + "message": exc.message, + "type": exc.error_type, + "code": exc.code, + "request_id": exc.request_id, + } + }, + ) + + @app.post("/v1/responses", response_model=None) + async def responses(request: ResponsesRequest, http_request: Request): + try: + if request.stream: + non_stream_request = request.model_copy(update={"stream": False}) + response = await rollout_controller.responses(non_stream_request) + return StreamingResponse( + _iter_openai_responses_sse_events(response), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + return await rollout_controller.responses(request) + except OpenAIChatAdapterError as exc: + status_code = 400 if exc.error_type == "invalid_request_error" else 500 + raise HTTPException( + status_code=status_code, + detail={ + "error": { + "message": exc.message, + "type": exc.error_type, + "code": exc.code, + "request_id": exc.request_id, + } + }, + ) + + @app.post("/v1/messages", response_model=None) + async def anthropic_messages( + request: AnthropicMessagesRequest, http_request: Request + ): + try: + if request.stream: + non_stream_request = request.model_copy(update={"stream": False}) + response = await rollout_controller.anthropic_messages(non_stream_request) + return StreamingResponse( + _iter_anthropic_sse_events(response), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + return await rollout_controller.anthropic_messages(request) + except AnthropicChatAdapterError as exc: + status_code = 400 if exc.error_type == "invalid_request_error" else 500 + return _build_error_response( + status_code=status_code, + message=exc.message, + error_type=exc.error_type, + request_id=exc.request_id, + protocol="anthropic", + ) + + @app.post("/v1/messages/count_tokens") + async def anthropic_count_tokens(request: AnthropicCountTokensRequest) -> AnthropicCountTokensResponse: + return await rollout_controller.anthropic_count_tokens(request) + + @app.get("/healthz") + async def healthz(): + is_ready, payload = rollout_controller.get_ready_status() + if is_ready: + return {"status": "ok", **payload} + return JSONResponse(status_code=503, content={"status": "not_ready", **payload}) + + @app.get("/metadata") + async def metadata(): + return rollout_controller.get_rollout_metadata() + + @app.post("/pause") + async def pause(): + rollout_controller.pause_generation() + return {"status": "ok", "action": "pause"} + + @app.post("/continue") + async def continue_generation(): + rollout_controller.continue_generation() + return {"status": "ok", "action": "continue"} + + @app.post("/offload") + async def offload(): + rollout_controller.offload() + return {"status": "ok", "action": "offload"} + + @app.post("/onload") + async def onload(): + rollout_controller.onload() + return {"status": "ok", "action": "onload"} + + @app.post("/shutdown") + async def shutdown(): + rollout_controller.shutdown() + return {"status": "ok", "action": "shutdown"} + + return app + + +def _iter_anthropic_sse_events(response: AnthropicMessagesResponse) -> Iterator[str]: + output_tokens = response.usage.output_tokens + message_start = { + "type": "message_start", + "message": { + "id": response.id, + "type": response.type, + "role": response.role, + "content": [], + "model": response.model, + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": response.usage.input_tokens, + "output_tokens": 1 if output_tokens > 0 else 0, + }, + }, + } + yield _format_sse("message_start", message_start) + + chunk_size = 64 + for index, block in enumerate(response.content): + block_type = block.get("type") + if block_type == "text": + yield _format_sse( + "content_block_start", + {"type": "content_block_start", "index": index, "content_block": {"type": "text", "text": ""}}, + ) + text = str(block.get("text", "")) + for offset in range(0, len(text), chunk_size): + chunk = text[offset : offset + chunk_size] + yield _format_sse( + "content_block_delta", + {"type": "content_block_delta", "index": index, "delta": {"type": "text_delta", "text": chunk}}, + ) + yield _format_sse("content_block_stop", {"type": "content_block_stop", "index": index}) + elif block_type == "tool_use": + yield _format_sse( + "content_block_start", + { + "type": "content_block_start", + "index": index, + "content_block": { + "type": "tool_use", + "id": block["id"], + "name": block["name"], + "input": {}, + }, + }, + ) + input_json = json.dumps(block.get("input", {}), ensure_ascii=False) + for offset in range(0, len(input_json), chunk_size): + chunk = input_json[offset : offset + chunk_size] + yield _format_sse( + "content_block_delta", + { + "type": "content_block_delta", + "index": index, + "delta": {"type": "input_json_delta", "partial_json": chunk}, + }, + ) + yield _format_sse("content_block_stop", {"type": "content_block_stop", "index": index}) + + yield _format_sse( + "message_delta", + { + "type": "message_delta", + "delta": { + "stop_reason": response.stop_reason, + "stop_sequence": response.stop_sequence, + }, + "usage": { + "output_tokens": response.usage.output_tokens, + }, + }, + ) + yield _format_sse("message_stop", {"type": "message_stop"}) + + +def _format_sse(event: str, data: dict[str, Any]) -> str: + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def _iter_openai_responses_sse_events(response: ResponsesResponse) -> Iterator[str]: + sequence_number = 0 + response_snapshot = response.model_dump(mode="python") + in_progress_response = {**response_snapshot, "status": "in_progress"} + yield _format_openai_response_sse({"type": "response.created", "sequence_number": sequence_number, "response": in_progress_response}) + sequence_number += 1 + + for output_index, item in enumerate(response.output): + item_type = item.get("type") + if item_type == "message": + yield _format_openai_response_sse( + { + "type": "response.output_item.added", + "sequence_number": sequence_number, + "output_index": output_index, + "item": item, + } + ) + sequence_number += 1 + for content_index, part in enumerate(item.get("content", [])): + if part.get("type") != "output_text": + continue + yield _format_openai_response_sse( + { + "type": "response.content_part.added", + "sequence_number": sequence_number, + "output_index": output_index, + "item_id": item["id"], + "content_index": content_index, + "part": {"type": "output_text", "text": "", "annotations": []}, + } + ) + sequence_number += 1 + text = str(part.get("text", "")) + chunk_size = 64 + for offset in range(0, len(text), chunk_size): + yield _format_openai_response_sse( + { + "type": "response.output_text.delta", + "sequence_number": sequence_number, + "output_index": output_index, + "item_id": item["id"], + "content_index": content_index, + "delta": text[offset : offset + chunk_size], + } + ) + sequence_number += 1 + yield _format_openai_response_sse( + { + "type": "response.output_text.done", + "sequence_number": sequence_number, + "output_index": output_index, + "item_id": item["id"], + "content_index": content_index, + "text": text, + } + ) + sequence_number += 1 + yield _format_openai_response_sse( + { + "type": "response.content_part.done", + "sequence_number": sequence_number, + "output_index": output_index, + "item_id": item["id"], + "content_index": content_index, + "part": {"type": "output_text", "text": text, "annotations": part.get("annotations", [])}, + } + ) + sequence_number += 1 + yield _format_openai_response_sse( + { + "type": "response.output_item.done", + "sequence_number": sequence_number, + "output_index": output_index, + "item": item, + } + ) + sequence_number += 1 + elif item_type == "function_call": + added_item = {**item, "arguments": "", "status": "in_progress"} + yield _format_openai_response_sse( + { + "type": "response.output_item.added", + "sequence_number": sequence_number, + "output_index": output_index, + "item": added_item, + } + ) + sequence_number += 1 + arguments = str(item.get("arguments", "")) + chunk_size = 64 + for offset in range(0, len(arguments), chunk_size): + yield _format_openai_response_sse( + { + "type": "response.function_call_arguments.delta", + "sequence_number": sequence_number, + "output_index": output_index, + "item_id": item["id"], + "delta": arguments[offset : offset + chunk_size], + } + ) + sequence_number += 1 + yield _format_openai_response_sse( + { + "type": "response.function_call_arguments.done", + "sequence_number": sequence_number, + "output_index": output_index, + "item_id": item["id"], + "arguments": arguments, + "name": item.get("name"), + } + ) + sequence_number += 1 + yield _format_openai_response_sse( + { + "type": "response.output_item.done", + "sequence_number": sequence_number, + "output_index": output_index, + "item": item, + } + ) + sequence_number += 1 + + yield _format_openai_response_sse( + {"type": "response.completed", "sequence_number": sequence_number, "response": response_snapshot} + ) + yield "data: [DONE]\n\n" + + +def _format_openai_response_sse(data: dict[str, Any]) -> str: + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" diff --git a/xtuner/v1/rl/rollout/chat_adapter/__init__.py b/xtuner/v1/rl/rollout/chat_adapter/__init__.py new file mode 100644 index 000000000..d1bd0189d --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/__init__.py @@ -0,0 +1,41 @@ +from .anthropic import ( + AnthropicChatAdapter, + AnthropicChatAdapterError, + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, + bind_anthropic_chat_interface, +) +from .base import BaseChatAPIAdapter +from .openai import ( + ChatCompletionRequest, + ChatCompletionResponse, + OpenAIChatAdapter, + OpenAIChatAdapterError, + bind_openai_chat_interface, +) +from .responses import ResponsesRequest, ResponsesResponse, bind_openai_responses_interface +from .trace import ChatTraceRecord, ChatTraceStore + + +__all__ = [ + "AnthropicChatAdapter", + "AnthropicChatAdapterError", + "AnthropicCountTokensRequest", + "AnthropicCountTokensResponse", + "AnthropicMessagesRequest", + "AnthropicMessagesResponse", + "ChatCompletionRequest", + "ChatCompletionResponse", + "OpenAIChatAdapter", + "OpenAIChatAdapterError", + "ResponsesRequest", + "ResponsesResponse", + "BaseChatAPIAdapter", + "ChatTraceRecord", + "ChatTraceStore", + "bind_anthropic_chat_interface", + "bind_openai_chat_interface", + "bind_openai_responses_interface", +] diff --git a/xtuner/v1/rl/rollout/chat_adapter/anthropic.py b/xtuner/v1/rl/rollout/chat_adapter/anthropic.py new file mode 100644 index 000000000..28989138e --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/anthropic.py @@ -0,0 +1,514 @@ +import json +import re +from typing import Any, Literal +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status + +from .base import BaseChatAPIAdapter +from .trace import normalize_trace_payload + + +class AnthropicTextContent(BaseModel): + model_config = ConfigDict(extra="allow") + + type: str = "text" + text: str + + +AnthropicContentBlock = dict[str, Any] + + +class AnthropicMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicMessagesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | None = None + model: str | None = None + system: str | list[dict[str, Any]] | None = None + messages: list[AnthropicMessage] + max_tokens: int + stream: bool = False + temperature: float | None = None + top_p: float | None = None + stop_sequences: list[str] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + + +class AnthropicCountTokensRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + model: str | None = None + system: str | list[dict[str, Any]] | None = None + messages: list[AnthropicMessage] + tools: list[dict[str, Any]] | None = None + + +class AnthropicCountTokensResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + + +class AnthropicUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + + +class AnthropicMessagesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[dict[str, Any]] + model: str + stop_reason: str | None = None + stop_sequence: str | None = None + usage: AnthropicUsage + + +class AnthropicChatAdapterError(RuntimeError): + def __init__(self, message: str, error_type: str, request_id: str | None = None): + super().__init__(message) + self.message = message + self.error_type = error_type + self.request_id = request_id + + +class AnthropicChatAdapter(BaseChatAPIAdapter[AnthropicMessagesRequest, AnthropicMessagesResponse]): + _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + _qwen_function_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _qwen_parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None, + default_model_name: str | None = None, + context_length: int | None = None, + capture_path: str | None = None, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__(generate_handler, tokenizer=tokenizer, capture_path=capture_path) + self._default_model_name = default_model_name + self._context_length = context_length + + async def messages(self, request: AnthropicMessagesRequest) -> AnthropicMessagesResponse: + return await self.handle_request(request) + + async def count_tokens(self, request: AnthropicCountTokensRequest) -> AnthropicCountTokensResponse: + internal_messages = self._build_internal_messages(request) + rollout_state = RolloutState(message=internal_messages) + tokenizer_tools = self._normalize_tools_for_backend(request.tools) + if self._tokenizer is not None: + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=tokenizer_tools, + tokenize=True, + add_generation_prompt=True, + ) + rollout_state.prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids) + rollout_state.tokens = rollout_state.prompt_ids + return AnthropicCountTokensResponse(input_tokens=self._count_prompt_tokens(rollout_state)) + + def validate_request(self, request: AnthropicMessagesRequest) -> None: + if request.stream: + raise AnthropicChatAdapterError( + "stream=true is not supported yet", + "invalid_request_error", + ) + + def request_to_rollout_state(self, request: AnthropicMessagesRequest) -> RolloutState: + internal_messages = self._build_internal_messages(request) + tokenizer_tools = self._normalize_tools_for_backend(request.tools) + normalized_tool_choice = self._normalize_tool_choice_for_backend(request.tool_choice) + prompt_ids = None + if self._tokenizer is not None: + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=tokenizer_tools, + tokenize=True, + add_generation_prompt=True, + ) + prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids) + max_tokens = self._fit_max_tokens_to_context(prompt_ids=prompt_ids, requested_max_tokens=request.max_tokens) + return RolloutState( + uid=uuid4().int, + message=internal_messages, + prompt_ids=prompt_ids, + tokens=prompt_ids, + session_uid=request.session_uid, + tools=tokenizer_tools, + tool_choice=normalized_tool_choice, + sample_params=self._build_sample_params(request, max_tokens=max_tokens), + ) + + def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None: + if response.status == Status.FAILED: + raise AnthropicChatAdapterError( + response.error_msg or "Rollout generation failed", + "api_error", + request_id, + ) + + def normalize_request(self, request: AnthropicMessagesRequest) -> dict[str, Any]: + return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True)) + + def normalize_response(self, response: AnthropicMessagesResponse) -> dict[str, Any]: + return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True)) + + def rollout_state_to_response( + self, + rollout_state: RolloutState, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + assert rollout_state.uid is not None, "uid should not be None when generating response" + request_id = str(rollout_state.uid) + model_name = request.model or self._default_model_name or "rollout-controller" + prompt_tokens = self._count_prompt_tokens(rollout_state) + completion_tokens = self._count_completion_tokens(rollout_state) + content_blocks = self._build_response_content_blocks(rollout_state) + stop_reason = "tool_use" if any(block.get("type") == "tool_use" for block in content_blocks) else rollout_state.finish_reason + + return AnthropicMessagesResponse( + id=f"msg_{request_id}", + content=content_blocks, + model=model_name, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + ), + ) + + def _build_internal_messages(self, request: AnthropicMessagesRequest) -> list[dict[str, str]]: + messages: list[dict[str, Any]] = [] + + if request.system: + if isinstance(request.system, str): + system_text = request.system + else: + system_text = self._join_text_blocks(request.system, context="system") + messages.append({"role": "system", "content": system_text}) + + for message in request.messages: + if isinstance(message.content, str): + messages.append({"role": message.role, "content": message.content}) + else: + messages.extend(self._convert_content_blocks_to_backend_messages(message.role, message.content)) + + return messages + + def _join_text_blocks(self, blocks: list[dict[str, Any]], context: str) -> str: + unsupported_types = [block.get("type") for block in blocks if block.get("type") != "text"] + if unsupported_types: + unsupported_str = ", ".join(sorted(set(unsupported_types))) + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type(s) in {context}: {unsupported_str}", + "invalid_request_error", + ) + return "\n".join(str(block.get("text", "")) for block in blocks) + + def _convert_content_blocks_to_backend_messages(self, role: str, blocks: list[dict[str, Any]]) -> list[dict[str, Any]]: + backend_messages: list[dict[str, Any]] = [] + text_chunks: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + def flush_text_chunks(): + if text_chunks: + backend_messages.append({"role": role, "content": "\n".join(text_chunks)}) + text_chunks.clear() + + for block in blocks: + block_type = block.get("type") + if block_type == "text": + text_value = str(block.get("text", "")) + if role == "assistant": + text_value = self._sanitize_assistant_text(text_value) + text_chunks.append(text_value) + elif block_type == "tool_use": + tool_calls.append( + { + "id": block.get("id") or f"toolu_{uuid4().hex}", + "type": "function", + "function": { + "name": str(block.get("name", "")), + "arguments": normalize_trace_payload(block.get("input", {})), + }, + } + ) + elif block_type == "tool_result": + flush_text_chunks() + backend_messages.append( + { + "role": "tool", + "content": self._serialize_tool_result_content(block.get("content")), + "tool_call_id": block.get("tool_use_id"), + } + ) + else: + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type in messages[{role}]: {block_type}", + "invalid_request_error", + ) + + if tool_calls: + backend_messages.append( + { + "role": role, + "content": "\n".join(text_chunks), + "tool_calls": tool_calls, + } + ) + text_chunks.clear() + flush_text_chunks() + return backend_messages + + def _serialize_tool_result_content(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + if all(isinstance(item, dict) and item.get("type") == "text" for item in content): + return "\n".join(str(item.get("text", "")) for item in content) + return json.dumps(content, ensure_ascii=False) + if isinstance(content, dict): + return json.dumps(content, ensure_ascii=False) + return str(content) + + def _normalize_tools_for_backend(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + normalized_tools = [] + for tool in tools: + if tool.get("type") == "function": + normalized_tools.append(normalize_trace_payload(tool)) + else: + normalized_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool["input_schema"], + }, + } + ) + return normalize_trace_payload(normalized_tools) + + def _normalize_tool_choice_for_backend(self, tool_choice: str | dict[str, Any] | None) -> str | dict[str, Any] | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return tool_choice + choice_type = tool_choice.get("type") + if choice_type == "auto": + return "auto" + if choice_type == "none": + return "none" + if choice_type == "any": + return "required" + if choice_type == "tool": + return { + "type": "function", + "function": { + "name": tool_choice.get("name"), + }, + } + return normalize_trace_payload(tool_choice) + + def _build_response_content_blocks(self, rollout_state: RolloutState) -> list[dict[str, Any]]: + raw_response = rollout_state.response or "" + tool_calls = rollout_state.extra_fields.get("tool_calls") + if not tool_calls: + text_blocks, parsed_tool_calls = self._parse_textual_tool_calls(raw_response) + if parsed_tool_calls: + tool_calls = parsed_tool_calls + rollout_state.extra_fields["tool_calls"] = parsed_tool_calls + response_text = "".join(block["text"] for block in text_blocks if block["type"] == "text") + rollout_state.response = self._sanitize_assistant_text(response_text) + else: + rollout_state.response = self._sanitize_assistant_text(raw_response) + + if not tool_calls: + return [{"type": "text", "text": rollout_state.response or ""}] + + content_blocks: list[dict[str, Any]] = [] + if rollout_state.response: + content_blocks.append({"type": "text", "text": rollout_state.response}) + for tool_call in tool_calls: + content_blocks.append( + { + "type": "tool_use", + "id": tool_call.get("id") or f"toolu_{uuid4().hex}", + "name": tool_call["function"]["name"], + "input": self._parse_tool_arguments(tool_call["function"].get("arguments")), + } + ) + return content_blocks + + def _parse_textual_tool_calls(self, text: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + if not text: + return [], [] + content_blocks: list[dict[str, Any]] = [] + tool_calls: list[dict[str, Any]] = [] + last_end = 0 + for match in self._tool_call_pattern.finditer(text): + if match.start() > last_end: + content_blocks.append({"type": "text", "text": text[last_end : match.start()]}) + raw_payload = match.group(1).strip() + parsed_tool_call = self._parse_single_textual_tool_call(raw_payload) + if parsed_tool_call is not None: + tool_calls.append(parsed_tool_call) + else: + content_blocks.append({"type": "text", "text": match.group(0)}) + last_end = match.end() + if last_end < len(text): + content_blocks.append({"type": "text", "text": text[last_end:]}) + return content_blocks, tool_calls + + def _parse_single_textual_tool_call(self, raw_payload: str) -> dict[str, Any] | None: + try: + parsed = json.loads(raw_payload) + return { + "id": f"call_{uuid4().hex}", + "type": "function", + "function": { + "name": parsed["name"], + "arguments": json.dumps(parsed.get("arguments", {}), ensure_ascii=False), + }, + } + except Exception: + pass + + function_match = self._qwen_function_pattern.search(raw_payload) + if function_match is None: + return None + function_name = function_match.group(1).strip() + function_body = function_match.group(2) + arguments: dict[str, Any] = {} + for parameter_match in self._qwen_parameter_pattern.finditer(function_body): + param_name = parameter_match.group(1).strip() + param_value = parameter_match.group(2).strip() + arguments[param_name] = self._coerce_parameter_value(param_value) + return { + "id": f"call_{uuid4().hex}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(arguments, ensure_ascii=False), + }, + } + + def _parse_tool_arguments(self, arguments: Any) -> Any: + if isinstance(arguments, str): + try: + return json.loads(arguments) + except Exception: + return {"raw": arguments} + return arguments + + def _coerce_parameter_value(self, value: str) -> Any: + stripped = value.strip() + if not stripped: + return "" + try: + return json.loads(stripped) + except Exception: + return stripped + + def _sanitize_assistant_text(self, text: str) -> str: + cleaned = text.replace("<|im_end|>", "") + cleaned = cleaned.replace("", "") + cleaned = cleaned.replace("", "") + return cleaned.strip() + + def _build_sample_params(self, request: AnthropicMessagesRequest, max_tokens: int | None = None) -> SampleParams: + kwargs = { + "return_token_ids": True, + "return_logprob": False, + "stream": request.stream, + "max_tokens": max_tokens if max_tokens is not None else request.max_tokens, + "stops": request.stop_sequences or [], + } + if request.temperature is not None: + kwargs["temperature"] = request.temperature + if request.top_p is not None: + kwargs["top_p"] = request.top_p + return SampleParams(**kwargs) + + def _fit_max_tokens_to_context(self, prompt_ids: list[int] | None, requested_max_tokens: int) -> int: + if self._context_length is None or prompt_ids is None: + return requested_max_tokens + prompt_tokens = len(prompt_ids) + available_completion_tokens = self._context_length - prompt_tokens + if available_completion_tokens <= 0: + raise AnthropicChatAdapterError( + ( + f"Input is too long for this model deployment: prompt_tokens={prompt_tokens}, " + f"context_length={self._context_length}." + ), + "invalid_request_error", + ) + return min(requested_max_tokens, available_completion_tokens) + + def _count_prompt_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.tokens is not None: + return len(rollout_state.tokens) + if rollout_state.prompt_ids is not None: + return len(rollout_state.prompt_ids) + if self._tokenizer is not None and rollout_state.message: + text_prompt = self._tokenizer.apply_chat_template( + rollout_state.message, + tokenize=False, + add_generation_prompt=True, + ) + return len(self._tokenizer(text_prompt, add_special_tokens=False)["input_ids"]) + return 0 + + def _count_completion_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.response_ids is not None: + return len(rollout_state.response_ids) + if self._tokenizer is not None and rollout_state.response: + return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"]) + return 0 + + def build_output_message_list( + self, + rollout_state: RolloutState, + request: AnthropicMessagesRequest, + ) -> list[dict[str, Any]]: + return [{"role": "assistant", "content": self._build_response_content_blocks(rollout_state)}] + + +def bind_anthropic_chat_interface( + rollout_controller: Any, + default_model_name: str | None = None, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None = None, +) -> Any: + if getattr(rollout_controller, "anthropic_chat_adapter", None) is None: + rollout_controller.anthropic_chat_adapter = AnthropicChatAdapter( + rollout_controller.generate, + default_model_name=default_model_name, + tokenizer=tokenizer, + context_length=getattr(rollout_controller.config, "context_length", None), + capture_path=str(getattr(rollout_controller.config, "worker_log_dir", ".")) + "/gateway_capture.jsonl", + ) + rollout_controller.anthropic_messages = rollout_controller.anthropic_chat_adapter.messages + rollout_controller.anthropic_count_tokens = rollout_controller.anthropic_chat_adapter.count_tokens + return rollout_controller diff --git a/xtuner/v1/rl/rollout/chat_adapter/base.py b/xtuner/v1/rl/rollout/chat_adapter/base.py new file mode 100644 index 000000000..826210474 --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/base.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from typing import Any, Generic, TypeVar + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState + +from .capture import append_gateway_capture_record, render_blocks_as_text +from .collector import append_current_trace_rollout_state +from .trace import ChatTraceRecord, ChatTraceStore, snapshot_routed_experts + + +GenerateHandler = Callable[[RolloutState], Awaitable[RolloutState]] +RequestT = TypeVar("RequestT") +ResponseT = TypeVar("ResponseT") + + +class BaseChatAPIAdapter(ABC, Generic[RequestT, ResponseT]): + def __init__( + self, + generate_handler: GenerateHandler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None, + *, + capture_path: str | None = None, + trace_store_max_entries: int = 10000, + ): + self._generate_handler = generate_handler + self._tokenizer = tokenizer + self._capture_path = capture_path + self._trace_store = ChatTraceStore(max_entries=trace_store_max_entries) + + async def handle_request(self, request: RequestT) -> ResponseT: + self.validate_request(request) + rollout_state = self.request_to_rollout_state(request) + if rollout_state.uid is None: + raise ValueError("request_to_rollout_state must assign rollout_state.uid before generate is called.") + request_id = str(rollout_state.uid) + rollout_state = await self._generate_handler(rollout_state) + append_current_trace_rollout_state(rollout_state) + + self.raise_for_failed_response(rollout_state, request_id) + response = self.rollout_state_to_response(rollout_state, request) + self._trace_store.put(self._build_trace_record(request, response, rollout_state, request_id)) + self._write_capture_record(request=request, response=response, rollout_state=rollout_state, request_id=request_id) + return response + + def get_trace_by_request_response(self, request: RequestT, response: ResponseT) -> ChatTraceRecord | None: + response_hash = self._trace_store.build_hash( + request_snapshot=self.normalize_request(request), + response_snapshot=self.normalize_response(response), + ) + return self._trace_store.get(response_hash) + + def get_trace_by_response_hash(self, response_hash: str) -> ChatTraceRecord | None: + return self._trace_store.get(response_hash) + + def _build_trace_record( + self, + request: RequestT, + response: ResponseT, + rollout_state: RolloutState, + request_id: str, + ) -> ChatTraceRecord: + request_snapshot = self.normalize_request(request) + response_snapshot = self.normalize_response(response) + response_hash = self._trace_store.build_hash( + request_snapshot=request_snapshot, + response_snapshot=response_snapshot, + ) + return ChatTraceRecord( + response_hash=response_hash, + request_snapshot=request_snapshot, + response_snapshot=response_snapshot, + prompt_ids=list(rollout_state.prompt_ids or []), + response_ids=list(rollout_state.response_ids or []), + logprobs=None if rollout_state.logprobs is None else list(rollout_state.logprobs), + routed_experts=snapshot_routed_experts(rollout_state.routed_experts), + finish_reason=rollout_state.finish_reason, + status=rollout_state.status, + request_id=request_id, + ) + + def _write_capture_record( + self, + request: RequestT, + response: ResponseT, + rollout_state: RolloutState, + request_id: str, + ) -> None: + if self._capture_path is None: + return + try: + response_snapshot = self.normalize_response(response) + response_finish_reason = ( + response_snapshot.get("stop_reason") + or response_snapshot.get("finish_reason") + or rollout_state.finish_reason + ) + output_messages = self.build_output_message_list(rollout_state, request) + append_gateway_capture_record( + self._capture_path, + { + "protocol": self.__class__.__name__, + "request_id": request_id, + "session_uid": rollout_state.session_uid, + "status": rollout_state.status.value, + "finish_reason": response_finish_reason, + "rollout_finish_reason": rollout_state.finish_reason, + "prompt_tokens": len(rollout_state.prompt_ids or []), + "completion_tokens": len(rollout_state.response_ids or []), + "request": self.normalize_request(request), + "response": response_snapshot, + "internal_messages": rollout_state.message, + "output_messages": output_messages, + "input_text": self._render_prompt_text(rollout_state), + "output_text": render_blocks_as_text(output_messages), + }, + ) + except Exception: + # Capturing should never block serving. + return + + def _render_prompt_text(self, rollout_state: RolloutState) -> str: + if self._tokenizer is None: + return "" + try: + rendered = self._tokenizer.apply_chat_template( + rollout_state.message, + tools=rollout_state.tools, + tokenize=False, + add_generation_prompt=True, + ) + if isinstance(rendered, str): + return rendered + return render_blocks_as_text(rendered) + except Exception: + return "" + + @abstractmethod + def validate_request(self, request: RequestT) -> None: + return None + + @abstractmethod + def request_to_rollout_state(self, request: RequestT) -> RolloutState: + raise NotImplementedError + + @abstractmethod + def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None: + raise NotImplementedError + + @abstractmethod + def normalize_request(self, request: RequestT) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def normalize_response(self, response: ResponseT) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def rollout_state_to_response( + self, + rollout_state: RolloutState, + request: RequestT, + ) -> ResponseT: + raise NotImplementedError + + @abstractmethod + def build_output_message_list( + self, + rollout_state: RolloutState, + request: RequestT, + ) -> list[dict[str, Any]]: + raise NotImplementedError diff --git a/xtuner/v1/rl/rollout/chat_adapter/capture.py b/xtuner/v1/rl/rollout/chat_adapter/capture.py new file mode 100644 index 000000000..0d6847382 --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/capture.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import json +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +_CAPTURE_LOCK = threading.RLock() + + +def append_gateway_capture_record(path: str | Path, record: dict[str, Any]) -> None: + capture_path = Path(path) + capture_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "type": "gateway_turn", + "timestamp": datetime.now(timezone.utc).isoformat(), + **record, + } + with _CAPTURE_LOCK: + with capture_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def render_blocks_as_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, list): + rendered_parts = [render_blocks_as_text(item) for item in value] + return "\n".join(part for part in rendered_parts if part) + if isinstance(value, dict): + block_type = value.get("type") + if block_type == "text": + return str(value.get("text", "")) + if block_type == "tool_use": + name = value.get("name", "") + input_payload = json.dumps(value.get("input", {}), ensure_ascii=False, sort_keys=True) + return f"{input_payload}" + if block_type == "tool_result": + tool_use_id = value.get("tool_use_id", "") + content = render_blocks_as_text(value.get("content")) + return f"{content}" + if "content" in value: + return render_blocks_as_text(value["content"]) + return json.dumps(value, ensure_ascii=False, sort_keys=True) + return str(value) diff --git a/xtuner/v1/rl/rollout/chat_adapter/collector.py b/xtuner/v1/rl/rollout/chat_adapter/collector.py new file mode 100644 index 000000000..ff3ef7caa --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/collector.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from contextvars import ContextVar, Token + +from xtuner.v1.data_proto.rl_data import RolloutState + + +_CURRENT_TRACE_COLLECTOR: ContextVar[list[RolloutState] | None] = ContextVar( + "xtuner_rollout_trace_collector", + default=None, +) + + +def set_current_trace_collector(collector: list[RolloutState]) -> Token: + return _CURRENT_TRACE_COLLECTOR.set(collector) + + +def reset_current_trace_collector(token: Token) -> None: + _CURRENT_TRACE_COLLECTOR.reset(token) + + +def append_current_trace_rollout_state(rollout_state: RolloutState) -> None: + collector = _CURRENT_TRACE_COLLECTOR.get() + if collector is None: + return + collector.append(rollout_state.model_copy(deep=True)) diff --git a/xtuner/v1/rl/rollout/chat_adapter/openai.py b/xtuner/v1/rl/rollout/chat_adapter/openai.py new file mode 100644 index 000000000..70b081b84 --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/openai.py @@ -0,0 +1,227 @@ +import time +from typing import Any +from uuid import uuid4 + +from lmdeploy.serve.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status + +from .base import BaseChatAPIAdapter +from .trace import normalize_trace_payload + + +class OpenAIChatAdapterError(RuntimeError): + def __init__( + self, + message: str, + error_type: str, + code: str, + request_id: str | None = None, + ): + super().__init__(message) + self.message = message + self.error_type = error_type + self.code = code + self.request_id = request_id + + +class OpenAIChatAdapter(BaseChatAPIAdapter[ChatCompletionRequest, ChatCompletionResponse]): + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, + default_model_name: str | None = None, + context_length: int | None = None, + capture_path: str | None = None, + trace_store_max_entries: int = 10000, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__( + generate_handler, + tokenizer=tokenizer, + capture_path=capture_path, + trace_store_max_entries=trace_store_max_entries, + ) + self._default_model_name = default_model_name + self._context_length = context_length + + async def chat(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + return await self.handle_request(request) + + def validate_request(self, request: ChatCompletionRequest) -> None: + if request.stream: + raise OpenAIChatAdapterError( + "stream=true is not supported yet", + "invalid_request_error", + "stream_not_supported", + ) + + def request_to_rollout_state(self, request: ChatCompletionRequest) -> RolloutState: + normalized_messages = normalize_trace_payload(request.messages) + tokenizer_tools = self._normalize_tools_for_tokenizer(request.tools) + normalized_tool_choice = normalize_trace_payload(request.tool_choice) + prompt_ids = None + if self._tokenizer: + raw_prompt_ids = self._tokenizer.apply_chat_template( + normalized_messages, + tools=tokenizer_tools, + tokenize=True, + add_generation_prompt=True, + ) + if hasattr(raw_prompt_ids, "get"): + prompt_ids = raw_prompt_ids.get("input_ids") + else: + prompt_ids = list(raw_prompt_ids) + max_tokens = self._fit_max_tokens_to_context(prompt_ids=prompt_ids, requested_max_tokens=request.max_tokens) + + return RolloutState( + uid=uuid4().int, + message=normalized_messages, + prompt_ids=prompt_ids, + tokens=prompt_ids, + session_uid=getattr(request, "session_uid", getattr(request, "session_id", None)), + tools=tokenizer_tools, + tool_choice=normalized_tool_choice, + sample_params=self._build_sample_params(request, max_tokens=max_tokens), + ) + + def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None: + if response.status == Status.FAILED: + raise OpenAIChatAdapterError( + response.error_msg or "Rollout generation failed", + "server_error", + "rollout_failed", + request_id, + ) + + def rollout_state_to_response( + self, + rollout_state: RolloutState, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + request_id = str(rollout_state.uid) + model_name = request.model or self._default_model_name or "rollout-controller" + assert rollout_state.response_ids is not None, "response_ids should not be None when generating response" + assert rollout_state.tokens is not None, "tokens should not be None when generating response" + prompt_tokens = len(rollout_state.tokens) + completion_tokens = len(rollout_state.response_ids) + tool_calls = rollout_state.extra_fields.get("tool_calls") + response_message = ChatMessage( + role="assistant", + content=None if tool_calls else rollout_state.response, + tool_calls=tool_calls, + ) + return ChatCompletionResponse( + id=request_id, + created=int(time.time()), + model=model_name, + choices=[ + ChatCompletionResponseChoice( + index=0, + message=response_message, + finish_reason=rollout_state.finish_reason, + ) + ], + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + def build_output_message_list( + self, + rollout_state: RolloutState, + request: ChatCompletionRequest, + ) -> list[dict[str, Any]]: + return [{"role": "assistant", "content": rollout_state.response or ""}] + + def normalize_request(self, request: ChatCompletionRequest) -> dict[str, Any]: + return normalize_trace_payload( + { + "messages": request.messages, + "tools": request.tools, + "tool_choice": request.tool_choice, + } + ) + + def normalize_response(self, response: ChatCompletionResponse) -> dict[str, Any]: + normalized_choices = [] + for choice in response.choices: + normalized_choices.append( + { + "message": getattr(choice.message, "model_dump", lambda **_: choice.message)( + mode="python", + exclude_none=True, + ) + if choice.message is not None + else None, + "finish_reason": choice.finish_reason, + } + ) + return normalize_trace_payload({"choices": normalized_choices}) + + def _normalize_tools_for_tokenizer(self, tools: Any) -> Any: + if tools is None: + return None + return normalize_trace_payload(tools) + + def _build_sample_params(self, request: ChatCompletionRequest, max_tokens: int | None = None) -> SampleParams: + stops = [] if request.stop is None else [request.stop] if isinstance(request.stop, str) else request.stop + kwargs = { + "stops": stops, + **{ + key: value + for key, value in { + "temperature": request.temperature, + "top_p": request.top_p, + "n": request.n, + "max_tokens": max_tokens if max_tokens is not None else request.max_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + }.items() + if value is not None + }, + } + return SampleParams(**kwargs) + + def _fit_max_tokens_to_context(self, prompt_ids: list[int] | None, requested_max_tokens: int | None) -> int | None: + if self._context_length is None or prompt_ids is None or requested_max_tokens is None: + return requested_max_tokens + prompt_tokens = len(prompt_ids) + available_completion_tokens = self._context_length - prompt_tokens + if available_completion_tokens <= 0: + raise OpenAIChatAdapterError( + ( + f"Input is too long for this model deployment: prompt_tokens={prompt_tokens}, " + f"context_length={self._context_length}." + ), + "invalid_request_error", + "context_length_exceeded", + ) + return min(requested_max_tokens, available_completion_tokens) + + +def bind_openai_chat_interface( + rollout_controller: Any, + tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None, + default_model_name: str | None = None, +) -> Any: + if getattr(rollout_controller, "openai_chat_adapter", None) is None: + rollout_controller.openai_chat_adapter = OpenAIChatAdapter( + rollout_controller.generate, + tokenizer=tokenizer, + default_model_name=default_model_name, + context_length=getattr(rollout_controller.config, "context_length", None), + capture_path=str(getattr(rollout_controller.config, "worker_log_dir", ".")) + "/gateway_capture.jsonl", + ) + rollout_controller.chat = rollout_controller.openai_chat_adapter.chat + return rollout_controller diff --git a/xtuner/v1/rl/rollout/chat_adapter/responses.py b/xtuner/v1/rl/rollout/chat_adapter/responses.py new file mode 100644 index 000000000..dd0cd5fb4 --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/responses.py @@ -0,0 +1,551 @@ +from __future__ import annotations + +import json +import re +import shlex +import time +from typing import Any, Literal +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status + +from .base import BaseChatAPIAdapter +from .openai import OpenAIChatAdapterError +from .trace import normalize_trace_payload + + +class ResponsesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + session_uid: int | None = None + model: str | None = None + instructions: str | None = None + input: str | list[dict[str, Any]] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + stream: bool = False + store: bool = False + parallel_tool_calls: bool | None = None + include: list[Any] | None = None + reasoning: dict[str, Any] | None = None + max_output_tokens: int | None = None + temperature: float | None = None + top_p: float | None = None + + +class ResponsesUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + total_tokens: int + + +class ResponsesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + object: Literal["response"] = "response" + created_at: int + status: Literal["completed"] = "completed" + model: str + output: list[dict[str, Any]] + output_text: str = "" + parallel_tool_calls: bool = False + store: bool = False + text: dict[str, Any] = {"format": {"type": "text"}} + usage: ResponsesUsage + + +class OpenAIResponsesAdapter(BaseChatAPIAdapter[ResponsesRequest, ResponsesResponse]): + _tool_call_pattern = re.compile(r"\n*(.*?)", re.DOTALL) + _qwen_function_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _qwen_parameter_pattern = re.compile(r"\n]+)>(.*?)", re.DOTALL) + _xml_tag_pattern = re.compile(r"<([a-zA-Z_][^>\n/]*)>(.*?)", re.DOTALL) + _disabled_tool_names = { + "list_mcp_resources", + "list_mcp_resource_templates", + "read_mcp_resource", + "request_user_input", + } + _tool_name_aliases = { + "read_file_dd": "exec_command", + "read_file": "exec_command", + "readfile": "exec_command", + } + + def __init__( + self, + generate_handler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str | None, + default_model_name: str | None = None, + context_length: int | None = None, + capture_path: str | None = None, + ): + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + super().__init__(generate_handler, tokenizer=tokenizer, capture_path=capture_path) + self._default_model_name = default_model_name + self._context_length = context_length + + async def responses(self, request: ResponsesRequest) -> ResponsesResponse: + return await self.handle_request(request) + + def validate_request(self, request: ResponsesRequest) -> None: + return None + + def request_to_rollout_state(self, request: ResponsesRequest) -> RolloutState: + internal_messages = self._build_internal_messages(request) + tokenizer_tools = self._normalize_tools_for_backend(request.tools) + normalized_tool_choice = self._normalize_tool_choice_for_backend(request.tool_choice) + + prompt_ids = None + if self._tokenizer is not None: + raw_prompt_ids = self._tokenizer.apply_chat_template( + internal_messages, + tools=tokenizer_tools, + tokenize=True, + add_generation_prompt=True, + ) + prompt_ids = raw_prompt_ids.get("input_ids") if hasattr(raw_prompt_ids, "get") else list(raw_prompt_ids) + + max_tokens = self._fit_max_tokens_to_context( + prompt_ids=prompt_ids, requested_max_tokens=request.max_output_tokens + ) + return RolloutState( + uid=uuid4().int, + message=internal_messages, + prompt_ids=prompt_ids, + tokens=prompt_ids, + session_uid=request.session_uid, + tools=tokenizer_tools, + tool_choice=normalized_tool_choice, + sample_params=self._build_sample_params(request, max_tokens=max_tokens), + ) + + def raise_for_failed_response(self, response: RolloutState, request_id: str) -> None: + if response.status == Status.FAILED: + raise OpenAIChatAdapterError( + response.error_msg or "Rollout generation failed", + "server_error", + "rollout_failed", + request_id, + ) + + def normalize_request(self, request: ResponsesRequest) -> dict[str, Any]: + return normalize_trace_payload(request.model_dump(mode="python", exclude_none=True)) + + def normalize_response(self, response: ResponsesResponse) -> dict[str, Any]: + return normalize_trace_payload(response.model_dump(mode="python", exclude_none=True)) + + def rollout_state_to_response( + self, + rollout_state: RolloutState, + request: ResponsesRequest, + ) -> ResponsesResponse: + request_id = str(rollout_state.uid) + model_name = request.model or self._default_model_name or "rollout-controller" + prompt_tokens = self._count_prompt_tokens(rollout_state) + completion_tokens = self._count_completion_tokens(rollout_state) + output_items, output_text = self._build_response_output_items(rollout_state) + return ResponsesResponse( + id=f"resp_{request_id}", + created_at=int(time.time()), + model=model_name, + output=output_items, + output_text=output_text, + parallel_tool_calls=bool(request.parallel_tool_calls), + store=bool(request.store), + usage=ResponsesUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + def build_output_message_list( + self, + rollout_state: RolloutState, + request: ResponsesRequest, + ) -> list[dict[str, Any]]: + output_items, output_text = self._build_response_output_items(rollout_state) + content: list[dict[str, Any]] = [] + if output_text: + content.append({"type": "text", "text": output_text}) + for item in output_items: + if item.get("type") == "function_call": + try: + input_payload = json.loads(item.get("arguments", "{}")) + except Exception: + input_payload = {"raw": item.get("arguments", "")} + content.append({"type": "tool_use", "name": item.get("name", ""), "input": input_payload}) + return [{"role": "assistant", "content": content}] + + def _build_internal_messages(self, request: ResponsesRequest) -> list[dict[str, Any]]: + system_chunks: list[str] = [] + messages: list[dict[str, Any]] = [] + tool_name_by_call_id: dict[str, str] = {} + if request.instructions: + system_chunks.append(request.instructions) + + if request.input is None: + return self._prepend_system_message(messages, system_chunks) + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + return self._prepend_system_message(messages, system_chunks) + + for item in request.input: + item_type = item.get("type", "message") + if item_type == "message": + role = self._normalize_input_role(item.get("role")) + if role == "system": + system_text = self._extract_message_item_text(item.get("content")) + if system_text: + system_chunks.append(system_text) + else: + messages.extend(self._convert_message_item_to_backend_messages(role, item.get("content"))) + elif item_type == "function_call": + call_id = item.get("call_id") or f"call_{uuid4().hex}" + tool_name_by_call_id[call_id] = str(item.get("name", "")) + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": call_id, + "type": "function", + "function": { + "name": str(item.get("name", "")), + "arguments": self._parse_json_string_or_mapping(item.get("arguments")), + }, + } + ], + } + ) + elif item_type == "function_call_output": + call_id = item.get("call_id") + messages.append( + { + "role": "tool", + "content": self._serialize_tool_output( + item.get("output"), + tool_name=tool_name_by_call_id.get(str(call_id or "")), + ), + "tool_call_id": call_id, + } + ) + elif item_type == "reasoning": + continue + return self._prepend_system_message(messages, system_chunks) + + def _prepend_system_message( + self, + messages: list[dict[str, Any]], + system_chunks: list[str], + ) -> list[dict[str, Any]]: + joined_system = "\n\n".join(chunk.strip() for chunk in system_chunks if chunk and chunk.strip()) + if not joined_system: + return messages + return [{"role": "system", "content": joined_system}, *messages] + + def _normalize_input_role(self, role: Any) -> str: + if role in {"developer", "system"}: + return "system" + if role in {"assistant", "tool"}: + return str(role) + return "user" + + def _convert_message_item_to_backend_messages(self, role: str, content: Any) -> list[dict[str, Any]]: + text = self._extract_message_item_text(content) + if not text: + return [] + if role == "assistant": + text = self._sanitize_assistant_text(text) + return [{"role": role, "content": text}] + + def _extract_message_item_text(self, content: Any) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return str(content) + + text_chunks: list[str] = [] + for part in content: + part_type = part.get("type") + if part_type in {"input_text", "output_text", "text", "summary_text", "reasoning_text"}: + text_chunks.append(str(part.get("text", ""))) + return "\n".join(chunk for chunk in text_chunks if chunk) + + def _serialize_tool_output(self, output: Any, tool_name: str | None = None) -> str: + if output is None: + return "" + if isinstance(output, str): + return self._sanitize_tool_output_text(output, tool_name=tool_name) + if isinstance(output, list): + text_chunks = [str(part.get("text", "")) for part in output if isinstance(part, dict) and "text" in part] + if text_chunks: + return self._sanitize_tool_output_text("\n".join(text_chunks), tool_name=tool_name) + return json.dumps(output, ensure_ascii=False) + if isinstance(output, dict): + return json.dumps(output, ensure_ascii=False) + return str(output) + + def _sanitize_tool_output_text(self, text: str, tool_name: str | None = None) -> str: + if tool_name not in {"exec_command", "write_stdin"}: + return text + + marker = "\nOutput:\n" + if marker in text: + prefix, body = text.split(marker, 1) + exit_code = self._extract_exec_exit_code(prefix) + body = body.strip() + if exit_code is None: + return body + if body: + return f"[exit_code={exit_code}]\n{body}" + return f"[exit_code={exit_code}]" + return text + + def _extract_exec_exit_code(self, text: str) -> int | None: + match = re.search(r"Process exited with code (\d+)", text) + if match is not None: + return int(match.group(1)) + return None + + def _normalize_tools_for_backend(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + normalized_tools = [] + for tool in tools: + tool_name = tool.get("name") + if tool_name in self._disabled_tool_names: + continue + if tool.get("type") != "function": + continue + normalized_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + }, + } + ) + return normalize_trace_payload(normalized_tools) or None + + def _normalize_tool_choice_for_backend(self, tool_choice: str | dict[str, Any] | None) -> str | dict[str, Any] | None: + if tool_choice is None: + return None + if isinstance(tool_choice, str): + return tool_choice + if tool_choice.get("type") == "function": + return {"type": "function", "function": {"name": tool_choice.get("name")}} + return normalize_trace_payload(tool_choice) + + def _build_response_output_items(self, rollout_state: RolloutState) -> tuple[list[dict[str, Any]], str]: + tool_calls = rollout_state.extra_fields.get("tool_calls") + response_text = rollout_state.response or "" + if not tool_calls: + text_blocks, parsed_tool_calls = self._parse_textual_tool_calls(response_text) + if parsed_tool_calls: + tool_calls = parsed_tool_calls + rollout_state.extra_fields["tool_calls"] = parsed_tool_calls + response_text = "".join(block["text"] for block in text_blocks if block["type"] == "text") + + response_text = self._sanitize_assistant_text(response_text) + + output_items: list[dict[str, Any]] = [] + if response_text: + output_items.append( + { + "id": f"msg_{uuid4().hex}", + "type": "message", + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": response_text, "annotations": []}], + } + ) + + for tool_call in tool_calls or []: + call_id = tool_call.get("id") or f"call_{uuid4().hex}" + output_items.append( + { + "id": f"fc_{uuid4().hex}", + "type": "function_call", + "status": "completed", + "call_id": call_id, + "name": tool_call["function"]["name"], + "arguments": self._stringify_tool_arguments(tool_call["function"].get("arguments")), + } + ) + return output_items, response_text + + def _parse_textual_tool_calls(self, text: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + if not text: + return [], [] + content_blocks: list[dict[str, Any]] = [] + tool_calls: list[dict[str, Any]] = [] + last_end = 0 + for match in self._tool_call_pattern.finditer(text): + if match.start() > last_end: + content_blocks.append({"type": "text", "text": text[last_end : match.start()]}) + raw_payload = match.group(1).strip() + parsed_tool_call = self._parse_single_textual_tool_call(raw_payload) + if parsed_tool_call is not None: + tool_calls.append(parsed_tool_call) + else: + content_blocks.append({"type": "text", "text": match.group(0)}) + last_end = match.end() + if last_end < len(text): + content_blocks.append({"type": "text", "text": text[last_end:]}) + return content_blocks, tool_calls + + def _parse_single_textual_tool_call(self, raw_payload: str) -> dict[str, Any] | None: + try: + parsed = json.loads(raw_payload) + return { + "id": f"call_{uuid4().hex}", + "type": "function", + "function": { + "name": parsed["name"], + "arguments": json.dumps(parsed.get("arguments", {}), ensure_ascii=False), + }, + } + except Exception: + pass + + function_match = self._qwen_function_pattern.search(raw_payload) + if function_match is None: + return None + function_name = function_match.group(1).strip() + function_body = function_match.group(2) + arguments: dict[str, Any] = {} + for parameter_match in self._qwen_parameter_pattern.finditer(function_body): + param_name = parameter_match.group(1).strip() + param_value = parameter_match.group(2).strip() + arguments[param_name] = self._coerce_parameter_value(param_value) + if not arguments: + for tag_match in self._xml_tag_pattern.finditer(function_body): + tag_name = tag_match.group(1).strip() + if tag_name.startswith("function="): + continue + tag_value = tag_match.group(2).strip() + if tag_name in {"path", "file_path"}: + arguments[tag_name] = tag_value + else: + arguments[tag_name] = self._coerce_parameter_value(tag_value) + + function_name, arguments = self._normalize_tool_call(function_name, arguments) + return { + "id": f"call_{uuid4().hex}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(arguments, ensure_ascii=False), + }, + } + + def _stringify_tool_arguments(self, arguments: Any) -> str: + if isinstance(arguments, str): + return arguments + return json.dumps(arguments or {}, ensure_ascii=False) + + def _parse_json_string_or_mapping(self, value: Any) -> Any: + if isinstance(value, str): + try: + return json.loads(value) + except Exception: + return {"raw": value} + return value or {} + + def _coerce_parameter_value(self, value: str) -> Any: + stripped = value.strip() + if not stripped: + return "" + try: + return json.loads(stripped) + except Exception: + return stripped + + def _normalize_tool_call(self, function_name: str, arguments: dict[str, Any]) -> tuple[str, dict[str, Any]]: + normalized_name = self._tool_name_aliases.get(function_name, function_name) + normalized_arguments = dict(arguments) + if normalized_name == "exec_command" and function_name in self._tool_name_aliases: + path = normalized_arguments.pop("path", None) or normalized_arguments.pop("file_path", None) + if path: + quoted_path = shlex.quote(str(path)) + normalized_arguments = {"cmd": f"cat {quoted_path}"} + return normalized_name, normalized_arguments + + def _sanitize_assistant_text(self, text: str) -> str: + cleaned = text.replace("<|im_end|>", "") + cleaned = cleaned.replace("", "") + cleaned = cleaned.replace("", "") + return cleaned.strip() + + def _build_sample_params(self, request: ResponsesRequest, max_tokens: int | None = None) -> SampleParams: + kwargs = { + "return_token_ids": True, + "return_logprob": False, + "stream": request.stream, + } + effective_max_tokens = max_tokens if max_tokens is not None else request.max_output_tokens + if effective_max_tokens is not None: + kwargs["max_tokens"] = effective_max_tokens + if request.temperature is not None: + kwargs["temperature"] = request.temperature + if request.top_p is not None: + kwargs["top_p"] = request.top_p + return SampleParams(**kwargs) + + def _fit_max_tokens_to_context(self, prompt_ids: list[int] | None, requested_max_tokens: int | None) -> int | None: + if self._context_length is None or prompt_ids is None or requested_max_tokens is None: + return requested_max_tokens + prompt_tokens = len(prompt_ids) + available_completion_tokens = self._context_length - prompt_tokens + if available_completion_tokens <= 0: + raise OpenAIChatAdapterError( + ( + f"Input is too long for this model deployment: prompt_tokens={prompt_tokens}, " + f"context_length={self._context_length}." + ), + "invalid_request_error", + "context_length_exceeded", + ) + return min(requested_max_tokens, available_completion_tokens) + + def _count_prompt_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.tokens is not None: + return len(rollout_state.tokens) + if rollout_state.prompt_ids is not None: + return len(rollout_state.prompt_ids) + return 0 + + def _count_completion_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.response_ids is not None: + return len(rollout_state.response_ids) + if self._tokenizer is not None and rollout_state.response: + return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"]) + return 0 + + +def bind_openai_responses_interface( + rollout_controller: Any, + tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast | None, + default_model_name: str | None = None, +) -> Any: + if getattr(rollout_controller, "openai_responses_adapter", None) is None: + rollout_controller.openai_responses_adapter = OpenAIResponsesAdapter( + rollout_controller.generate, + tokenizer=tokenizer, + default_model_name=default_model_name, + context_length=getattr(rollout_controller.config, "context_length", None), + capture_path=str(getattr(rollout_controller.config, "worker_log_dir", ".")) + "/gateway_capture.jsonl", + ) + rollout_controller.responses = rollout_controller.openai_responses_adapter.responses + return rollout_controller diff --git a/xtuner/v1/rl/rollout/chat_adapter/trace.py b/xtuner/v1/rl/rollout/chat_adapter/trace.py new file mode 100644 index 000000000..23babfe0e --- /dev/null +++ b/xtuner/v1/rl/rollout/chat_adapter/trace.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import hashlib +import json +import threading +from collections import OrderedDict +from collections.abc import Sequence +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + +from xtuner.v1.data_proto.rl_data import Status + + +def normalize_trace_payload(value: Any) -> Any: + if isinstance(value, BaseModel): + return normalize_trace_payload(value.model_dump(mode="python", exclude_none=True)) + if isinstance(value, dict): + return { + str(key): normalize_trace_payload(val) + for key, val in sorted(value.items(), key=lambda item: str(item[0])) + if val is not None + } + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [normalize_trace_payload(item) for item in value] + return value + + +def build_trace_hash(request_snapshot: Any, response_snapshot: Any) -> str: + payload = { + "request": normalize_trace_payload(request_snapshot), + "response": normalize_trace_payload(response_snapshot), + } + encoded = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def snapshot_routed_experts(routed_experts: Any) -> Any: + if routed_experts is None: + return None + try: + import ray + + if isinstance(routed_experts, ray.ObjectRef): + return routed_experts + except Exception: + pass + return deepcopy(routed_experts) + + +@dataclass +class ChatTraceRecord: + response_hash: str + request_snapshot: dict[str, Any] + response_snapshot: dict[str, Any] + prompt_ids: list[int] + response_ids: list[int] + logprobs: list[float] | None + routed_experts: Any + finish_reason: str | None + status: Status + request_id: str | None = None + + +class ChatTraceStore: + def __init__(self, max_entries: int = 10000): + self._max_entries = max_entries + self._records: OrderedDict[str, ChatTraceRecord] = OrderedDict() + self._lock = threading.RLock() + + def put(self, record: ChatTraceRecord) -> None: + with self._lock: + self._records[record.response_hash] = record + self._records.move_to_end(record.response_hash) + while len(self._records) > self._max_entries: + self._records.popitem(last=False) + + def get(self, response_hash: str) -> ChatTraceRecord | None: + with self._lock: + record = self._records.get(response_hash) + if record is None: + return None + self._records.move_to_end(response_hash) + return record + + def build_hash(self, request_snapshot: Any, response_snapshot: Any) -> str: + return build_trace_hash(request_snapshot=request_snapshot, response_snapshot=response_snapshot) diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index b82d6428b..1928e72fb 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -1,11 +1,13 @@ import asyncio import os +import socket import threading from dataclasses import dataclass -from typing import Dict, List, Optional, TypeAlias, TypedDict +from typing import Any, Dict, List, Optional, TypeAlias, TypedDict from uuid import uuid4 import ray +import uvicorn from ray.actor import ActorProxy from ray.util.placement_group import PlacementGroup @@ -13,6 +15,8 @@ from xtuner.v1.rl.utils import AutoAcceleratorWorkers from xtuner.v1.utils import get_logger +from .api_server import create_rollout_api_app +from .chat_adapter import bind_anthropic_chat_interface, bind_openai_chat_interface, bind_openai_responses_interface from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter from .worker import RolloutConfig, RolloutWorker @@ -53,6 +57,9 @@ class RolloutWorkerMetadata(TypedDict): # 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用 worker_server_urls_status: Dict[str, bool] + # Rollout Controller API 服务器的 URL 地址, + api_server_url: str + class RolloutController: """Controller for managing and coordinating multiple RolloutWorker @@ -77,12 +84,25 @@ def __init__( else self.config.tensor_parallel_size ) self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") + self.api_server_url = "" self.engine_rank_mesh_array: List[List[int]] = [] self.worker_server_urls_map: dict[str, List[str]] = {} self.rank2info: dict[int, WorkerInfo] = {} self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group) self.num_active_workers = len(self.rank2info) self.worker_info_lock = threading.RLock() + bind_openai_chat_interface( + self, default_model_name=self.config.model_name, tokenizer=self.config.tokenizer_path + ) + bind_openai_responses_interface( + self, default_model_name=self.config.model_name, tokenizer=self.config.tokenizer_path + ) + bind_anthropic_chat_interface( + self, + default_model_name=self.config.model_name, + tokenizer=self.config.tokenizer_path, + ) + self._start_api_server() # The timeout for the environment to wait for the rollout controller's response. # This should be longer than the controller's internal timeout (`rollout_timeout`) # to account for potential queuing delays and other overheads. @@ -109,9 +129,19 @@ def get_rollout_metadata(self) -> RolloutWorkerMetadata: "server_url_dict": self.worker_server_urls_map, "rollout_config": self.config, "worker_server_urls_status": worker_server_urls_status, + "api_server_url": self.api_server_url, } return rollout_metadata + def get_ready_status(self) -> tuple[bool, dict[str, Any]]: + with self.worker_info_lock: + active_workers = sum(1 for info in self.rank2info.values() if info.is_active) + total_workers = len(self.rank2info) + return active_workers > 0, { + "active_workers": active_workers, + "total_workers": total_workers, + } + async def generate(self, rollout_state: RolloutState) -> RolloutState: session_id = rollout_state.session_uid if rollout_state.session_uid else uuid4().int worker = await self.router.get_worker(session_id) @@ -322,6 +352,35 @@ def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_ser active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) + @staticmethod + def _is_port_in_use(host: str, port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + return sock.connect_ex((host, port)) == 0 + + def _start_api_server(self, host: str | None = None, port: int | None = None): + """Starts the API server to expose the rollout functionality.""" + host = host or self.config.api_host + port = self.config.api_port if self.config.api_port else (port or 8000) + + original_port = port + while self._is_port_in_use(host, port): + self.logger.warning(f"Port {port} is in use, trying port {port + 1}") + port += 1 + + if original_port != port: + self.logger.info(f"API server will use port {port} instead of the originally configured {original_port}.") + + app = create_rollout_api_app(self, self.logger) + + config = uvicorn.Config(app, host=host, port=port) + server = uvicorn.Server(config) + server_thread = threading.Thread(target=server.run, daemon=True) + server_thread.start() + self.config.api_port = port + self.api_server_url = f"http://{host}:{port}" + self.logger.info(f"Rollout API server started at {self.api_server_url}") + def _init_workers(self, placement_group: PlacementGroup): """Initializes and configures the pool of RolloutWorker actors. diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 65f6fc3d8..c8453865f 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -93,12 +93,14 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict: message = rollout_state.message input_tokens = rollout_state.tokens + optional_fields: dict[str, object] = {} + if tools is not None: + optional_fields["tools"] = tools + if tool_choice is not None: + optional_fields["tool_choice"] = tool_choice + if sample_params.return_token_ids: - payload = { - "model": self.model_name, - "tools": tools, - "tool_choice": tool_choice, - } + payload = {"model": self.model_name, **optional_fields} if input_tokens is not None: payload["input_ids"] = input_tokens else: @@ -107,19 +109,27 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict: payload["input_ids"] = prompt_token_ids sample_params.return_routed_experts = True if self.enable_return_routed_experts else False lmdeploy_sample_params = self._transform_sample_params(sample_params) - payload.update(sample_params) + payload.update(lmdeploy_sample_params) else: payload = { "model": self.model_name, "messages": rollout_state.message, - "tools": tools, - "tool_choice": tool_choice, + **optional_fields, } - lmdeploy_sample_params = self._transform_sample_params(sample_params) - lmdeploy_sample_params.pop("no_stop_trim", None) - lmdeploy_sample_params.pop("return_logprob", None) - lmdeploy_sample_params.pop("stop_token_ids", None) - lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens + lmdeploy_sample_params = { + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "n": sample_params.n, + "stream": sample_params.stream, + "max_tokens": sample_params.max_tokens, + "repetition_penalty": sample_params.repetition_penalty, + "top_k": sample_params.top_k, + "skip_special_tokens": sample_params.skip_special_tokens, + } + if sample_params.stops: + lmdeploy_sample_params["stop"] = sample_params.stops + if sample_params.min_tokens > 0: + lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens payload.update(lmdeploy_sample_params) return payload diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index e19bb2864..2afe87113 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -5,10 +5,12 @@ from collections import OrderedDict from itertools import cycle from typing import TYPE_CHECKING, Any, Optional +from uuid import uuid4 import httpx import ray +from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.rl.utils import asyncio_run from xtuner.v1.utils import get_logger @@ -277,3 +279,13 @@ async def check_worker_health( f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" ) return False + + +def ensure_rollout_request_id(rollout_state: RolloutState) -> str: + request_id = str(rollout_state.extra_fields.get("request_id", "")) + if request_id: + return request_id + + request_id = uuid4().hex + rollout_state.extra_fields["request_id"] = request_id + return request_id diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 156c9a054..f5fecb1aa 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -117,6 +117,10 @@ class RolloutConfig(BaseModel): int, Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."), ] = 8000 + api_host: Annotated[ + str, + Parameter(group=infer_group, help="Host for the rollout API server."), + ] = "0.0.0.0" gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8 dtype: Annotated[ str, @@ -313,7 +317,7 @@ def model_post_init(self, __context: Any) -> None: while True: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: - s.bind(("localhost", port)) + s.bind((self.api_host if self.api_host != "0.0.0.0" else "localhost", port)) break except OSError: port += 1 diff --git a/xtuner/v1/rl/utils/ray_utils.py b/xtuner/v1/rl/utils/ray_utils.py index 987ba700f..14d94323d 100644 --- a/xtuner/v1/rl/utils/ray_utils.py +++ b/xtuner/v1/rl/utils/ray_utils.py @@ -180,6 +180,6 @@ def bind_train_rollout( train_workers: A list of training worker actors. rollout_controller: The rollout controller actor. """ - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) # type: ignore[attr-defined] + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] return