Skip to content

Commit 311fea5

Browse files
committed
discard old tuner and embrace new tuner
1 parent aac1463 commit 311fea5

File tree

26 files changed

+351
-360
lines changed

26 files changed

+351
-360
lines changed

ajet/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from ajet.copilot.job import AgentJetJob
22
from ajet.schema.task import WorkflowOutput, WorkflowTask
3-
from ajet.tuner import ModelTuner
3+
from ajet.tuner import AjetTuner
44
from ajet.workflow import Workflow
55
from ajet.utils.vsdb import vscode_conditional_breakpoint as bp
66

77
__all__ = [
88
"Workflow",
99
"WorkflowTask",
1010
"WorkflowOutput",
11-
"ModelTuner",
11+
"AjetTuner",
1212
"AgentJetJob",
1313
"bp",
1414
]

ajet/task_runner/general_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22

3-
from ajet.tuner_v2 import TunerV2 as ModelTuner
3+
from ajet import AjetTuner
44
from ajet import Workflow, WorkflowOutput
55
from ajet.context_tracker.agentscope_tracker.multiagent_tracking import (
66
MultiAgentContextTracker,
@@ -38,7 +38,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
3838
task_id=task_id,
3939
**hooks,
4040
)
41-
m_tuner = ModelTuner(
41+
m_tuner = AjetTuner(
4242
context_tracker=context_tracker,
4343
llm_inference_fn=self.llm_inference_fn,
4444
tokenizer=self.tokenizer,

ajet/tuner.py

Lines changed: 77 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,17 @@
1-
from typing import TYPE_CHECKING, Any, Literal, Type
2-
3-
from agentscope._utils._common import _create_tool_from_base_model
4-
from agentscope.model import ChatModelBase, ChatResponse, DashScopeChatModel
5-
from loguru import logger
6-
from pydantic import BaseModel
1+
from typing import TYPE_CHECKING, Any, Literal, Type, Union
72

83
from ajet.context_tracker.agentscope_tracker.multiagent_tracking import (
94
MultiAgentContextTracker,
105
)
11-
from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker
126

7+
from ajet.tuner_lib.weight_tuner import AgentScopeModelTuner
8+
from ajet.tuner_lib.weight_tuner import OpenaiClientModelTuner
139
if TYPE_CHECKING:
1410
from ajet import Workflow
1511

12+
TunerTypeUnion = Union[AgentScopeModelTuner, OpenaiClientModelTuner]
1613

17-
class Agent2Proxy(DashScopeChatModel):
18-
"""
19-
Handler for **NAMED** agent trainning targets.
20-
It stores the target name, and a reference to the ModelTuner.
21-
When request comes, it switches between default model (dashscope or openai models) and ModelTuner
22-
"""
23-
24-
def __init__(self, name: str, tuner: "ModelTuner", default_model: ChatModelBase):
25-
self.name = name
26-
self.tuner = tuner
27-
self.default_model = default_model
28-
super().__init__(
29-
model_name="ajet",
30-
api_key="dummy-api-key",
31-
stream=False,
32-
)
33-
34-
def __call__(self, *args, **kwargs):
35-
if not self.tuner.is_trainable(self.name):
36-
# [DO-NOT-TRAIN] if `trainable_targets` is non-empty,
37-
# and self.name is not in it, use default model
38-
return self.default_model(*args, **kwargs)
39-
else:
40-
# [TRAIN]
41-
return self.tuner(*args, **kwargs)
42-
43-
44-
class ModelTuner(DashScopeChatModel):
45-
"""
46-
ModelTuner for Agentscope workflow.
47-
It keeps record of all registered agent types (by their target names),
48-
And when request comes, it calls `self.llm_proxy` to handle the request.
49-
"""
14+
class AjetTuner(object):
5015

5116
def __init__(
5217
self,
@@ -56,19 +21,72 @@ def __init__(
5621
**kwargs,
5722
) -> None:
5823
self.config = config
24+
self.workflow = user_workflow
5925
self.context_tracker = context_tracker
60-
self.user_workflow = user_workflow
61-
self.target2proxy_registry: dict[str, Agent2Proxy] = {}
62-
self.llm_proxy = OpenaiLlmProxyWithTracker(
63-
context_tracker=context_tracker, config=config, **kwargs
26+
self.target2proxy_registry: dict[str, dict[str,TunerTypeUnion]] = {}
27+
self.kwargs = kwargs
28+
29+
30+
def as_agentscope_model(
31+
self,
32+
agent_name="default_agent_name",
33+
target_tag="default_target_tag",
34+
debug_model=None
35+
) -> "AgentScopeModelTuner":
36+
"""Convert to ModelTuner instance for Agentscope workflow.
37+
Returns:
38+
ModelTuner:
39+
The ModelTuner instance for Agentscope workflow.
40+
"""
41+
explicit_tuner_as_modelscope_model = AgentScopeModelTuner(
42+
config=self.config,
43+
context_tracker=self.context_tracker,
44+
user_workflow=self.workflow,
45+
agent_name=agent_name,
46+
debug_model=debug_model,
47+
use_debug_model=(not self._is_target_trainable(target_tag)),
48+
**self.kwargs,
6449
)
65-
super().__init__(
66-
model_name="ajet",
67-
api_key="dummy-api-key",
68-
stream=False,
50+
self._register(target_tag, agent_name, explicit_tuner_as_modelscope_model)
51+
return explicit_tuner_as_modelscope_model
52+
53+
54+
def as_raw_openai_sdk_client(
55+
self,
56+
agent_name="default_agent_name",
57+
target_tag="default_target_tag",
58+
debug_model='gpt-4o',
59+
) -> OpenaiClientModelTuner:
60+
"""Convert to raw OpenAI SDK client for advanced usage.
61+
Returns:
62+
Any:
63+
The raw OpenAI SDK client.
64+
"""
65+
explicit_tuner_as_oai_client = OpenaiClientModelTuner(
66+
config=self.config,
67+
context_tracker=self.context_tracker,
68+
workflow=self.workflow,
69+
agent_name=agent_name,
70+
debug_model=debug_model,
71+
use_debug_model=(not self._is_target_trainable(target_tag)),
72+
**self.kwargs,
6973
)
74+
self._register(target_tag, agent_name, explicit_tuner_as_oai_client)
75+
return explicit_tuner_as_oai_client
76+
77+
78+
def __call__(self, **kwargs):
79+
"""This method is **deprecated**.
80+
The current behavior of this method is pretend as a agentscope model
81+
"""
82+
raise RuntimeError("This method is deprecated. Please use `as_agentscope_model` / `as_raw_openai_sdk_client` first.")
83+
84+
85+
# ------------------------------------------------------------------------
86+
# other helper methods
87+
# ------------------------------------------------------------------------
7088

71-
def register_model(self, target_name: str, default_model: ChatModelBase) -> Agent2Proxy:
89+
def _register(self, target_name: str, agent_name: str, explicit_tuner: TunerTypeUnion) -> TunerTypeUnion:
7290
"""Register an agent type.
7391
Args:
7492
target_name (`str`):
@@ -79,116 +97,29 @@ def register_model(self, target_name: str, default_model: ChatModelBase) -> Agen
7997
Agent2Proxy:
8098
The agent type instance corresponding to the provided name.
8199
"""
82-
if target_name in self.target2proxy_registry:
83-
if (
84-
default_model.model_name
85-
!= self.target2proxy_registry[target_name].default_model.model_name
86-
):
87-
raise ValueError(
88-
f"Agent proxy `{target_name}` is already registered with a different model_name.\nWAS [{self.target2proxy_registry[target_name].default_model.model_name}]\nNOW [{default_model.model_name}]."
89-
)
90-
self.target2proxy_registry[target_name] = Agent2Proxy(target_name, self, default_model)
91-
return self.get_model(target_name)
92-
93-
def get_model(self, target_name: str) -> Agent2Proxy:
94-
"""Get the proxy instance by target_name.
95-
Args:
96-
target_name (`str`):
97-
The name of the agent proxy to retrieve.
98-
Returns:
99-
Agent2Proxy:
100-
The agent proxy corresponding to the provided target_name.
101-
"""
102100
if target_name not in self.target2proxy_registry:
103-
raise ValueError(f"Agent proxy '{target_name}' is not registered.")
104-
else:
105-
return self.target2proxy_registry[target_name]
101+
self.target2proxy_registry[target_name] = {}
102+
self.target2proxy_registry[target_name][agent_name] = explicit_tuner
103+
return explicit_tuner
106104

107-
async def __call__(
108-
self,
109-
messages: list[dict[str, Any]],
110-
tools: list[dict] | None = None,
111-
tool_choice: Literal["auto", "none", "any", "required"] | str | None = None,
112-
structured_model: Type[BaseModel] | None = None,
113-
**kwargs: Any,
114-
) -> ChatResponse:
115-
# For qvq and qwen-vl models, the content field cannot be `None` or
116-
# `[{"text": None}]`, so we need to convert it to an empty list.
117-
if self.model_name.startswith("qvq") or "-vl" in self.model_name:
118-
raise NotImplementedError("Not implemented for qvq and qwen-vl models yet.")
119-
120-
kwargs = {
121-
"messages": messages,
122-
"model": self.model_name,
123-
"stream": self.stream,
124-
**self.generate_kwargs,
125-
**kwargs,
126-
"result_format": "message",
127-
# In agentscope, the `incremental_output` must be `True` when
128-
# `self.stream` is True
129-
"incremental_output": self.stream,
130-
}
131-
132-
if tools:
133-
kwargs["tools"] = self._format_tools_json_schemas(tools)
134-
135-
if tool_choice:
136-
self._validate_tool_choice(tool_choice, tools)
137-
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
138-
139-
if self.enable_thinking is not None and "enable_thinking" not in kwargs:
140-
kwargs["enable_thinking"] = self.enable_thinking
141-
142-
if structured_model:
143-
if tools or tool_choice:
144-
logger.warning(
145-
"structured_model is provided. Both 'tools' and "
146-
"'tool_choice' parameters will be overridden and "
147-
"ignored. The model will only perform structured output "
148-
"generation without calling any other tools.",
149-
)
150-
format_tool = _create_tool_from_base_model(structured_model)
151-
kwargs["tools"] = self._format_tools_json_schemas(
152-
[format_tool],
153-
)
154-
kwargs["tool_choice"] = self._format_tool_choice(
155-
format_tool["function"]["name"],
156-
)
157-
158-
# call llm model ✨
159-
response_gen = await self.llm_proxy(
160-
api_key=self.api_key,
161-
structured_model=structured_model,
162-
**kwargs,
163-
)
164-
165-
# Return the AsyncGenerator directly
166-
return response_gen
167-
168-
def is_trainable(self, target_name) -> bool:
169-
if self.user_workflow.trainable_targets is None:
105+
def _is_target_trainable(self, target_name) -> bool:
106+
"""Determine whether user have used `trainable_targets` to explicitly control training targets.
107+
"""
108+
if self.workflow.trainable_targets is None:
170109
# always assume trainable when user has never changed trainable_targets
171110
return True
172-
if not self.user_workflow.trainable_targets:
111+
if not self.workflow.trainable_targets:
173112
# always assume trainable when trainable_targets is []
174113
return True
175-
if target_name in self.user_workflow.trainable_targets:
114+
if target_name in self.workflow.trainable_targets:
176115
return True
177116
else:
178117
return False
179118

180-
def get_llm_proxy(self) -> OpenaiLlmProxyWithTracker:
181-
"""Get the LlmProxyForAgentScope instance.
182-
Returns:
183-
LlmProxyForAgentScope:
184-
The LlmProxyForAgentScope instance used by the ModelTuner.
185-
"""
186-
return self.llm_proxy
187-
188119
def get_context_tracker(self) -> MultiAgentContextTracker:
189120
"""Get the context tracker instance.
190121
Returns:
191122
LlmProxyForAgentScope:
192123
The context tracker instance used by the ModelTuner.
193124
"""
194-
return self.context_tracker
125+
return self.context_tracker

0 commit comments

Comments
 (0)