Skip to content

Commit 5103ff8

Browse files
authored
Merge pull request #423 from ZQlQZ/fix/ark_rl_tools
fix tool in ark rl
2 parents eb77b3d + 8890df9 commit 5103ff8

File tree

1 file changed

+2
-33
lines changed

1 file changed

+2
-33
lines changed

veadk/cli/templates/rl/ark/plugins/raw_async_veadk_rollout.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import json
1616
import asyncio
1717
from typing import Optional, Any, cast
18-
from loguru import logger
1918
from ark_sdk.resources.pipeline_plugin import rollout
2019
from ark_sdk.types.pipeline_plugin import PluginInstance, Runtime
2120
from ark_sdk.types.pipeline_plugin.pipeline_plugin import PluginContext
@@ -25,17 +24,16 @@
2524
ChatCompletionResponse,
2625
RolloutInferenceProxy,
2726
RolloutResult,
28-
PluginStatus,
2927
)
3028
from veadk.agent import Agent
3129
from veadk.memory.short_term_memory import ShortTermMemory
3230
from veadk.runner import Runner
3331
from veadk.tracing.telemetry.opentelemetry_tracer import OpentelemetryTracer
3432
from veadk.tracing.telemetry.exporters.cozeloop_exporter import CozeloopExporter
3533
from veadk.tracing.telemetry.exporters.cozeloop_exporter import CozeloopExporterConfig
34+
from veadk.tools.demo_tools import get_city_weather
3635
from google.adk.models.lite_llm import LiteLLMClient, LiteLlm
3736
from litellm import ModelResponse
38-
from cozeloop.decorator import observe
3937

4038
# BASE_MODEL 格式 : "{model_provider}/{model_name}"
4139
BASE_MODEL = "openai/doubao-seed-1-6-flash-250615"
@@ -50,13 +48,6 @@
5048
tracer = OpentelemetryTracer(exporters=cast(Any, exporters))
5149

5250

53-
@observe()
54-
def get_current_weather(location: str, unit="摄氏度"):
55-
# 实际调用天气查询 API 的逻辑
56-
# 此处为示例,返回模拟的天气数据
57-
return f"{location}今天天气晴朗,温度 25 {unit}。"
58-
59-
6051
class RecordingLiteLlm(LiteLlm):
6152
"""
6253
在调用 LiteLlm 的 completion/acompletion 时,拦截并记录原始 ModelResponse。
@@ -171,6 +162,7 @@ async def demo_veadk_rollout(
171162
model_provider="openai",
172163
model_api_key=proxy.jwt_token,
173164
tracers=[tracer],
165+
tools=[get_city_weather],
174166
model=model_instance,
175167
)
176168

@@ -224,29 +216,6 @@ async def demo_veadk_rollout(
224216
if model_response.choices[0].finish_reason != "tool_calls":
225217
# 模型最终总结,没有调用工具意愿
226218
break
227-
tool_calls = model_response.choices[0].message.tool_calls
228-
for tool_call in tool_calls or []:
229-
tool_name = tool_call.function.name
230-
if tool_name == "get_current_weather":
231-
try:
232-
args = json.loads(tool_call.function.arguments)
233-
tool_result = get_current_weather(**args)
234-
except Exception as e:
235-
logger.error(f"get_current_weather error: {e}")
236-
return RolloutResult(
237-
status=PluginStatus.SUCCESS,
238-
extra={
239-
"reward": -1,
240-
},
241-
)
242-
# 将工具结果加入消息列表
243-
messages.append(
244-
{
245-
"role": "tool",
246-
"content": tool_result,
247-
"tool_call_id": tool_call.id,
248-
}
249-
)
250219
# 默认return None则视为rollout成功
251220
return None
252221

0 commit comments

Comments
 (0)