forked from TransferQueue/verl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent_loop.py
More file actions
99 lines (81 loc) · 4.37 KB
/
agent_loop.py
File metadata and controls
99 lines (81 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import numpy as np
import ray
from transfer_queue import BatchMeta
import verl.experimental.agent_loop.agent_loop as agent_loop
class AgentLoopManager(agent_loop.AgentLoopManager):
def generate_sequences(self, prompts: BatchMeta) -> BatchMeta:
"""Split input batch and dispatch to agent loop workers.
Args:
prompts (BatchMeta): Input batch.
Returns:
BatchMeta: Output batch metadata.
"""
if self.config.actor_rollout_ref.rollout.free_cache_engine:
self.wake_up()
if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
self.reward_model_manager.wake_up()
chunkes = prompts.chunk(len(self.agent_loop_workers))
outputs = ray.get(
[
worker.generate_sequences.remote(chunk)
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
]
)
output = BatchMeta.concat(outputs)
if self.config.actor_rollout_ref.rollout.free_cache_engine:
self.sleep()
if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
self.reward_model_manager.sleep()
# calculate performance metrics
metrics = [output.extra_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
timing = self._performance_metrics(metrics, output)
output.set_extra_info("timing", timing)
return output
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: BatchMeta) -> dict[str, float]:
timing = {}
t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk])
t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk])
timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min()
timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max()
timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean()
timing["agent_loop/tool_calls/min"] = t_tool_calls.min()
timing["agent_loop/tool_calls/max"] = t_tool_calls.max()
timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean()
# TODO (TQ): initialize tq during init when enable TQ switch is stable
tq_client = self._create_transferqueue_client()
# batch sequence generation is bounded by the slowest sample
slowest = np.argmax(t_generate_sequences + t_tool_calls)
attention_mask = asyncio.run(tq_client.async_get_data(output[slowest]))["attention_mask"]
prompt_length = output.samples[0].fields["prompts"].shape[0]
timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()
return timing
def create_transferqueue_client_for_workers(self):
# TODO (TQ): initialize tq during worker init when enable TQ switch is stable
ray.get([worker.create_transferqueue_client.remote() for worker in self.agent_loop_workers])
def _create_transferqueue_client(self):
"""Create a client for data system (TransferQueue)."""
from verl.single_controller.ray.base import get_random_string
from verl.utils.transferqueue_utils import create_transferqueue_client
client_name = get_random_string(length=6)
tq_client = create_transferqueue_client(
client_id=f"AgentLoopManager_{client_name}",
config=self.config.transfer_queue,
)
return tq_client