1- from typing import TYPE_CHECKING , Any , Literal , Type
2-
3- from agentscope ._utils ._common import _create_tool_from_base_model
4- from agentscope .model import ChatModelBase , ChatResponse , DashScopeChatModel
5- from loguru import logger
6- from pydantic import BaseModel
1+ from typing import TYPE_CHECKING , Any , Literal , Type , Union
72
83from ajet .context_tracker .agentscope_tracker .multiagent_tracking import (
94 MultiAgentContextTracker ,
105)
11- from ajet .task_rollout .async_llm_bridge import OpenaiLlmProxyWithTracker
126
7+ from ajet .tuner_lib .weight_tuner import AgentScopeModelTuner
8+ from ajet .tuner_lib .weight_tuner import OpenaiClientModelTuner
139if TYPE_CHECKING :
1410 from ajet import Workflow
1511
12+ TunerTypeUnion = Union [AgentScopeModelTuner , OpenaiClientModelTuner ]
1613
17- class Agent2Proxy (DashScopeChatModel ):
18- """
19- Handler for **NAMED** agent trainning targets.
20- It stores the target name, and a reference to the ModelTuner.
21- When request comes, it switches between default model (dashscope or openai models) and ModelTuner
22- """
23-
24- def __init__ (self , name : str , tuner : "ModelTuner" , default_model : ChatModelBase ):
25- self .name = name
26- self .tuner = tuner
27- self .default_model = default_model
28- super ().__init__ (
29- model_name = "ajet" ,
30- api_key = "dummy-api-key" ,
31- stream = False ,
32- )
33-
34- def __call__ (self , * args , ** kwargs ):
35- if not self .tuner .is_trainable (self .name ):
36- # [DO-NOT-TRAIN] if `trainable_targets` is non-empty,
37- # and self.name is not in it, use default model
38- return self .default_model (* args , ** kwargs )
39- else :
40- # [TRAIN]
41- return self .tuner (* args , ** kwargs )
42-
43-
44- class ModelTuner (DashScopeChatModel ):
45- """
46- ModelTuner for Agentscope workflow.
47- It keeps record of all registered agent types (by their target names),
48- And when request comes, it calls `self.llm_proxy` to handle the request.
49- """
14+ class AjetTuner (object ):
5015
5116 def __init__ (
5217 self ,
@@ -56,19 +21,72 @@ def __init__(
5621 ** kwargs ,
5722 ) -> None :
5823 self .config = config
24+ self .workflow = user_workflow
5925 self .context_tracker = context_tracker
60- self .user_workflow = user_workflow
61- self .target2proxy_registry : dict [str , Agent2Proxy ] = {}
62- self .llm_proxy = OpenaiLlmProxyWithTracker (
63- context_tracker = context_tracker , config = config , ** kwargs
26+ self .target2proxy_registry : dict [str , dict [str ,TunerTypeUnion ]] = {}
27+ self .kwargs = kwargs
28+
29+
30+ def as_agentscope_model (
31+ self ,
32+ agent_name = "default_agent_name" ,
33+ target_tag = "default_target_tag" ,
34+ debug_model = None
35+ ) -> "AgentScopeModelTuner" :
36+ """Convert to ModelTuner instance for Agentscope workflow.
37+ Returns:
38+ ModelTuner:
39+ The ModelTuner instance for Agentscope workflow.
40+ """
41+ explicit_tuner_as_modelscope_model = AgentScopeModelTuner (
42+ config = self .config ,
43+ context_tracker = self .context_tracker ,
44+ user_workflow = self .workflow ,
45+ agent_name = agent_name ,
46+ debug_model = debug_model ,
47+ use_debug_model = (not self ._is_target_trainable (target_tag )),
48+ ** self .kwargs ,
6449 )
65- super ().__init__ (
66- model_name = "ajet" ,
67- api_key = "dummy-api-key" ,
68- stream = False ,
50+ self ._register (target_tag , agent_name , explicit_tuner_as_modelscope_model )
51+ return explicit_tuner_as_modelscope_model
52+
53+
54+ def as_raw_openai_sdk_client (
55+ self ,
56+ agent_name = "default_agent_name" ,
57+ target_tag = "default_target_tag" ,
58+ debug_model = 'gpt-4o' ,
59+ ) -> OpenaiClientModelTuner :
60+ """Convert to raw OpenAI SDK client for advanced usage.
61+ Returns:
62+ Any:
63+ The raw OpenAI SDK client.
64+ """
65+ explicit_tuner_as_oai_client = OpenaiClientModelTuner (
66+ config = self .config ,
67+ context_tracker = self .context_tracker ,
68+ workflow = self .workflow ,
69+ agent_name = agent_name ,
70+ debug_model = debug_model ,
71+ use_debug_model = (not self ._is_target_trainable (target_tag )),
72+ ** self .kwargs ,
6973 )
74+ self ._register (target_tag , agent_name , explicit_tuner_as_oai_client )
75+ return explicit_tuner_as_oai_client
76+
77+
78+ def __call__ (self , ** kwargs ):
79+ """This method is **deprecated**.
80+ The current behavior of this method is pretend as a agentscope model
81+ """
82+ raise RuntimeError ("This method is deprecated. Please use `as_agentscope_model` / `as_raw_openai_sdk_client` first." )
83+
84+
85+ # ------------------------------------------------------------------------
86+ # other helper methods
87+ # ------------------------------------------------------------------------
7088
71- def register_model (self , target_name : str , default_model : ChatModelBase ) -> Agent2Proxy :
89+ def _register (self , target_name : str , agent_name : str , explicit_tuner : TunerTypeUnion ) -> TunerTypeUnion :
7290 """Register an agent type.
7391 Args:
7492 target_name (`str`):
@@ -79,116 +97,29 @@ def register_model(self, target_name: str, default_model: ChatModelBase) -> Agen
7997 Agent2Proxy:
8098 The agent type instance corresponding to the provided name.
8199 """
82- if target_name in self .target2proxy_registry :
83- if (
84- default_model .model_name
85- != self .target2proxy_registry [target_name ].default_model .model_name
86- ):
87- raise ValueError (
88- f"Agent proxy `{ target_name } ` is already registered with a different model_name.\n WAS [{ self .target2proxy_registry [target_name ].default_model .model_name } ]\n NOW [{ default_model .model_name } ]."
89- )
90- self .target2proxy_registry [target_name ] = Agent2Proxy (target_name , self , default_model )
91- return self .get_model (target_name )
92-
93- def get_model (self , target_name : str ) -> Agent2Proxy :
94- """Get the proxy instance by target_name.
95- Args:
96- target_name (`str`):
97- The name of the agent proxy to retrieve.
98- Returns:
99- Agent2Proxy:
100- The agent proxy corresponding to the provided target_name.
101- """
102100 if target_name not in self .target2proxy_registry :
103- raise ValueError ( f"Agent proxy ' { target_name } ' is not registered." )
104- else :
105- return self . target2proxy_registry [ target_name ]
101+ self . target2proxy_registry [ target_name ] = {}
102+ self . target2proxy_registry [ target_name ][ agent_name ] = explicit_tuner
103+ return explicit_tuner
106104
107- async def __call__ (
108- self ,
109- messages : list [dict [str , Any ]],
110- tools : list [dict ] | None = None ,
111- tool_choice : Literal ["auto" , "none" , "any" , "required" ] | str | None = None ,
112- structured_model : Type [BaseModel ] | None = None ,
113- ** kwargs : Any ,
114- ) -> ChatResponse :
115- # For qvq and qwen-vl models, the content field cannot be `None` or
116- # `[{"text": None}]`, so we need to convert it to an empty list.
117- if self .model_name .startswith ("qvq" ) or "-vl" in self .model_name :
118- raise NotImplementedError ("Not implemented for qvq and qwen-vl models yet." )
119-
120- kwargs = {
121- "messages" : messages ,
122- "model" : self .model_name ,
123- "stream" : self .stream ,
124- ** self .generate_kwargs ,
125- ** kwargs ,
126- "result_format" : "message" ,
127- # In agentscope, the `incremental_output` must be `True` when
128- # `self.stream` is True
129- "incremental_output" : self .stream ,
130- }
131-
132- if tools :
133- kwargs ["tools" ] = self ._format_tools_json_schemas (tools )
134-
135- if tool_choice :
136- self ._validate_tool_choice (tool_choice , tools )
137- kwargs ["tool_choice" ] = self ._format_tool_choice (tool_choice )
138-
139- if self .enable_thinking is not None and "enable_thinking" not in kwargs :
140- kwargs ["enable_thinking" ] = self .enable_thinking
141-
142- if structured_model :
143- if tools or tool_choice :
144- logger .warning (
145- "structured_model is provided. Both 'tools' and "
146- "'tool_choice' parameters will be overridden and "
147- "ignored. The model will only perform structured output "
148- "generation without calling any other tools." ,
149- )
150- format_tool = _create_tool_from_base_model (structured_model )
151- kwargs ["tools" ] = self ._format_tools_json_schemas (
152- [format_tool ],
153- )
154- kwargs ["tool_choice" ] = self ._format_tool_choice (
155- format_tool ["function" ]["name" ],
156- )
157-
158- # call llm model ✨
159- response_gen = await self .llm_proxy (
160- api_key = self .api_key ,
161- structured_model = structured_model ,
162- ** kwargs ,
163- )
164-
165- # Return the AsyncGenerator directly
166- return response_gen
167-
168- def is_trainable (self , target_name ) -> bool :
169- if self .user_workflow .trainable_targets is None :
105+ def _is_target_trainable (self , target_name ) -> bool :
106+ """Determine whether user have used `trainable_targets` to explicitly control training targets.
107+ """
108+ if self .workflow .trainable_targets is None :
170109 # always assume trainable when user has never changed trainable_targets
171110 return True
172- if not self .user_workflow .trainable_targets :
111+ if not self .workflow .trainable_targets :
173112 # always assume trainable when trainable_targets is []
174113 return True
175- if target_name in self .user_workflow .trainable_targets :
114+ if target_name in self .workflow .trainable_targets :
176115 return True
177116 else :
178117 return False
179118
180- def get_llm_proxy (self ) -> OpenaiLlmProxyWithTracker :
181- """Get the LlmProxyForAgentScope instance.
182- Returns:
183- LlmProxyForAgentScope:
184- The LlmProxyForAgentScope instance used by the ModelTuner.
185- """
186- return self .llm_proxy
187-
188119 def get_context_tracker (self ) -> MultiAgentContextTracker :
189120 """Get the context tracker instance.
190121 Returns:
191122 LlmProxyForAgentScope:
192123 The context tracker instance used by the ModelTuner.
193124 """
194- return self .context_tracker
125+ return self .context_tracker
0 commit comments