Skip to content

Commit ca1cf82

Browse files
committed
add interchange api
1 parent 311fea5 commit ca1cf82

File tree

28 files changed

+838
-200
lines changed

28 files changed

+838
-200
lines changed

ajet/backbone/main_trinity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def patched_trainer_get_actor(cls, config: Config):
4040
Trainer.get_actor = classmethod(patched_trainer_get_actor)
4141

4242

43+
44+
4345
if __name__ == "__main__":
4446
patch_runtime_env_to_get_actor()
4547
main()

ajet/backbone/main_verl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ def run(self, config):
247247

248248
from ajet.backbone.trainer_verl import AjetRayPPOTrainer
249249

250+
if config.ajet.enable_experimental_reverse_proxy:
251+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
252+
start_interchange_server()
253+
250254
# Initialize the PPO trainer.
251255
trainer = AjetRayPPOTrainer(
252256
config=config,

ajet/backbone/trainer_trinity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from trinity.utils.monitor import MONITOR, Monitor
2020

2121
from ajet.backbone.warm_up import warm_up_process
22-
from ajet.context_tracker.agentscope_tracker.multiagent_tracking import (
22+
from ajet.context_tracker.multiagent_tracking import (
2323
MultiAgentContextTracker,
2424
)
2525
from ajet.schema.trajectory import Sample
@@ -116,6 +116,10 @@ def __init__(
116116

117117
async def run_async(self):
118118
ajet_config = get_ajet_config_from_trinity_side()
119+
if ajet_config.ajet.enable_experimental_reverse_proxy:
120+
raise NotImplementedError(
121+
"The experimental reverse proxy is not supported in Trinity backbone yet."
122+
)
119123
warm_up_process(ajet_config)
120124
tracker = await TrinityRolloutManager(
121125
is_eval=self.is_eval,

ajet/context_tracker/agentscope_tracker/multiagent_tracking.py renamed to ajet/context_tracker/multiagent_tracking.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from loguru import logger
99
from transformers.tokenization_utils import PreTrainedTokenizer
1010

11-
from ajet.context_tracker.agentscope_tracker.timeline_merging import (
11+
from ajet.context_tracker.timeline_merging.timeline_merging import (
1212
merge_tracker_timelines, is_timeline_mergeable
1313
)
1414
from ajet.context_tracker.basic_tracker import (
@@ -36,6 +36,8 @@ class ContextTrackerConfig:
3636
detect_timeline_snap: bool = False
3737

3838

39+
40+
3941
class MultiAgentContextTracker(BaseContextTracker):
4042
"""
4143
Context tracker is responsible to monitor and process LLM IO.
@@ -44,22 +46,22 @@ class MultiAgentContextTracker(BaseContextTracker):
4446

4547
def __init__(
4648
self,
47-
llm_inference_fn,
4849
tokenizer: PreTrainedTokenizer,
4950
config,
5051
should_interrupt_fn,
5152
generated_token_callback_fn,
53+
episode_uuid: str,
5254
**kwargs,
5355
):
5456
super().__init__(config, tokenizer, **kwargs)
55-
self.llm_inference_fn = llm_inference_fn
5657
self.tokenizer = tokenizer
5758
self.should_interrupt_fn = should_interrupt_fn
5859
self.generated_token_callback_fn = generated_token_callback_fn
5960
self.context_overflow = False
6061
self.output_kwargs = {}
6162
self.input_kwargs = {}
6263
self.timeline_cache = {}
64+
self.episode_uuid = episode_uuid
6365

6466

6567
def step_prepare(self, messages: List[dict], tools: List = [], timeline_uuid: str = ""):

ajet/context_tracker/agentscope_tracker/timeline_merging.py renamed to ajet/context_tracker/timeline_merging/timeline_merging.py

File renamed without changes.

ajet/default_config/ajet_default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ ajet:
66
backbone: debug # `debug` or `trinity` or `verl`
77

88

9+
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
10+
enable_experimental_reverse_proxy: True
11+
912
model:
1013
# which model should be trained
1114
path: /path/to/model/such/as/Qwen/Qwen2___5-14B-Instruct

ajet/schema/convertion.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11

2+
import time
23
from openai.types.chat.chat_completion import ChatCompletion, Choice
34
from openai.types.chat.chat_completion_message import ChatCompletionMessage
5+
from agentscope.model import ChatResponse as AgentScopeChatResponse
46
from openai.types.completion_usage import CompletionUsage
5-
import time
7+
from typing import Any, Callable, Dict, List, Literal, Type, Union
8+
from agentscope.message import TextBlock, ToolUseBlock
9+
from agentscope._utils._common import _json_loads_with_repair
10+
from pydantic import BaseModel
11+
from agentscope.model import ChatResponse
612

713

814
def convert_llm_proxy_response_to_oai_response(llm_proxy_response):
@@ -40,6 +46,66 @@ def convert_llm_proxy_response_to_oai_response(llm_proxy_response):
4046
usage=usage,
4147
)
4248

49+
# copied from AgentScope's DashScopeChatModule
50+
def convert_llm_proxy_response_to_agentscope_response(
51+
message,
52+
structured_model: Type[BaseModel] | None = None,
53+
) -> AgentScopeChatResponse: # type: ignore
54+
content_blocks: List[TextBlock | ToolUseBlock] = []
55+
content = message.get("content")
56+
metadata: dict | None = None
57+
58+
if content not in [
59+
None,
60+
"",
61+
[],
62+
]:
63+
if isinstance(content, list):
64+
for item in content:
65+
if isinstance(item, dict) and "text" in item:
66+
content_blocks.append(
67+
TextBlock(
68+
type="text",
69+
text=item["text"],
70+
),
71+
)
72+
else:
73+
content_blocks.append(
74+
TextBlock(
75+
type="text",
76+
text=content,
77+
),
78+
)
79+
80+
if message.get("tool_calls"):
81+
for tool_call in message["tool_calls"]:
82+
input_ = _json_loads_with_repair(
83+
tool_call["function"].get(
84+
"arguments",
85+
"{}",
86+
)
87+
or "{}",
88+
)
89+
content_blocks.append(
90+
ToolUseBlock(
91+
type="tool_use",
92+
name=tool_call["function"]["name"],
93+
input=input_, # type: ignore
94+
id=tool_call["id"],
95+
),
96+
)
97+
98+
if structured_model:
99+
metadata = input_ # type: ignore
100+
101+
parsed_response = AgentScopeChatResponse(
102+
content=content_blocks,
103+
metadata=metadata,
104+
)
105+
106+
return parsed_response
107+
108+
43109

44110
def test_convert_llm_proxy_response_to_oai_response():
45111
"""Test the conversion from llm_proxy_response to OpenAI ChatCompletion format."""

ajet/schema/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class WorkflowTask(BaseModel):
2626
task_thread_index: int = Field(default=0)
2727
task_batch_index: int = Field(default=0)
2828
task_tag: str = Field(default="")
29-
task_env_uuid: str = Field(default="")
29+
episode_uuid: str = Field(default="")
3030
observation_window: dict = Field(default={})
3131
llm_inference_fn: Any = Field(default=None)
3232
tokenizer: Any = Field(default=None)

ajet/task_judge/env_service_as_judge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO
1010
raw_reward = 0
1111

1212
env = workflow_task.gym_env
13-
raw_reward = env.evaluate(workflow_task.task_env_uuid, params={"sparse": False})
13+
raw_reward = env.evaluate(workflow_task.episode_uuid, params={"sparse": False})
1414
if raw_reward >= 1:
1515
is_success = True
1616
else:

0 commit comments

Comments
 (0)