|
15 | 15 | import json |
16 | 16 | import asyncio |
17 | 17 | from typing import Optional, Any, cast |
18 | | -from loguru import logger |
19 | 18 | from ark_sdk.resources.pipeline_plugin import rollout |
20 | 19 | from ark_sdk.types.pipeline_plugin import PluginInstance, Runtime |
21 | 20 | from ark_sdk.types.pipeline_plugin.pipeline_plugin import PluginContext |
|
25 | 24 | ChatCompletionResponse, |
26 | 25 | RolloutInferenceProxy, |
27 | 26 | RolloutResult, |
28 | | - PluginStatus, |
29 | 27 | ) |
30 | 28 | from veadk.agent import Agent |
31 | 29 | from veadk.memory.short_term_memory import ShortTermMemory |
32 | 30 | from veadk.runner import Runner |
33 | 31 | from veadk.tracing.telemetry.opentelemetry_tracer import OpentelemetryTracer |
34 | 32 | from veadk.tracing.telemetry.exporters.cozeloop_exporter import CozeloopExporter |
35 | 33 | from veadk.tracing.telemetry.exporters.cozeloop_exporter import CozeloopExporterConfig |
| 34 | +from veadk.tools.demo_tools import get_city_weather |
36 | 35 | from google.adk.models.lite_llm import LiteLLMClient, LiteLlm |
37 | 36 | from litellm import ModelResponse |
38 | | -from cozeloop.decorator import observe |
39 | 37 |
|
40 | 38 | # BASE_MODEL 格式 : "{model_provider}/{model_name}" |
41 | 39 | BASE_MODEL = "openai/doubao-seed-1-6-flash-250615" |
|
50 | 48 | tracer = OpentelemetryTracer(exporters=cast(Any, exporters)) |
51 | 49 |
|
52 | 50 |
|
53 | | -@observe() |
54 | | -def get_current_weather(location: str, unit="摄氏度"): |
55 | | - # 实际调用天气查询 API 的逻辑 |
56 | | - # 此处为示例,返回模拟的天气数据 |
57 | | - return f"{location}今天天气晴朗,温度 25 {unit}。" |
58 | | - |
59 | | - |
60 | 51 | class RecordingLiteLlm(LiteLlm): |
61 | 52 | """ |
62 | 53 | 在调用 LiteLlm 的 completion/acompletion 时,拦截并记录原始 ModelResponse。 |
@@ -171,6 +162,7 @@ async def demo_veadk_rollout( |
171 | 162 | model_provider="openai", |
172 | 163 | model_api_key=proxy.jwt_token, |
173 | 164 | tracers=[tracer], |
| 165 | + tools=[get_city_weather], |
174 | 166 | model=model_instance, |
175 | 167 | ) |
176 | 168 |
|
@@ -224,29 +216,6 @@ async def demo_veadk_rollout( |
224 | 216 | if model_response.choices[0].finish_reason != "tool_calls": |
225 | 217 | # 模型最终总结,没有调用工具意愿 |
226 | 218 | break |
227 | | - tool_calls = model_response.choices[0].message.tool_calls |
228 | | - for tool_call in tool_calls or []: |
229 | | - tool_name = tool_call.function.name |
230 | | - if tool_name == "get_current_weather": |
231 | | - try: |
232 | | - args = json.loads(tool_call.function.arguments) |
233 | | - tool_result = get_current_weather(**args) |
234 | | - except Exception as e: |
235 | | - logger.error(f"get_current_weather error: {e}") |
236 | | - return RolloutResult( |
237 | | - status=PluginStatus.SUCCESS, |
238 | | - extra={ |
239 | | - "reward": -1, |
240 | | - }, |
241 | | - ) |
242 | | - # 将工具结果加入消息列表 |
243 | | - messages.append( |
244 | | - { |
245 | | - "role": "tool", |
246 | | - "content": tool_result, |
247 | | - "tool_call_id": tool_call.id, |
248 | | - } |
249 | | - ) |
250 | 219 | # 默认return None则视为rollout成功 |
251 | 220 | return None |
252 | 221 |
|
|
0 commit comments