|
| 1 | +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= |
| 14 | + |
| 15 | +import asyncio |
| 16 | +import contextvars |
| 17 | +import logging |
| 18 | +import uuid |
| 19 | +from threading import Lock |
| 20 | +from typing import Any, Callable |
| 21 | + |
| 22 | +from app.agent.listen_chat_agent import ListenChatAgent, logger |
| 23 | +from app.model.chat import AgentModelConfig, Chat |
| 24 | +from app.service.task import ActionCreateAgentData, Agents, get_task_lock |
| 25 | +from camel.messages import BaseMessage |
| 26 | +from camel.models import ModelFactory |
| 27 | +from camel.toolkits import FunctionTool, RegisteredAgentToolkit |
| 28 | +from camel.types import ModelPlatformType |
| 29 | + |
| 30 | +# Thread-safe reference to main event loop using contextvars |
| 31 | +# This ensures each request has its own event loop reference, |
| 32 | +# avoiding race conditions |
| 33 | +_main_event_loop_var: contextvars.ContextVar[asyncio.AbstractEventLoop |
| 34 | + | None] = contextvars.ContextVar( |
| 35 | + "_main_event_loop", |
| 36 | + default=None) |
| 37 | + |
| 38 | +# Global fallback for main event loop reference |
| 39 | +# Used when contextvars don't propagate to worker threads |
| 40 | +# (e.g., asyncio.to_thread) |
| 41 | +_GLOBAL_MAIN_LOOP: asyncio.AbstractEventLoop | None = None |
| 42 | +_GLOBAL_MAIN_LOOP_LOCK = Lock() |
| 43 | + |
| 44 | + |
| 45 | +def set_main_event_loop(loop: asyncio.AbstractEventLoop | None): |
| 46 | + """Set the main event loop reference for thread-safe task scheduling. |
| 47 | +
|
| 48 | + This should be called from the main async context before spawning threads |
| 49 | + that need to schedule async tasks. Uses both contextvars (for request |
| 50 | + isolation) and a global fallback (for thread pool workers where |
| 51 | + contextvars may not propagate). |
| 52 | + """ |
| 53 | + global _GLOBAL_MAIN_LOOP |
| 54 | + _main_event_loop_var.set(loop) |
| 55 | + with _GLOBAL_MAIN_LOOP_LOCK: |
| 56 | + _GLOBAL_MAIN_LOOP = loop |
| 57 | + |
| 58 | + |
| 59 | +def _schedule_async_task(coro): |
| 60 | + """Schedule an async coroutine as a task, thread-safe. |
| 61 | +
|
| 62 | + This function handles scheduling from both the main event loop thread |
| 63 | + and from worker threads (e.g., when using asyncio.to_thread). |
| 64 | + """ |
| 65 | + try: |
| 66 | + # Try to get the running loop (works in main event loop thread) |
| 67 | + loop = asyncio.get_running_loop() |
| 68 | + loop.create_task(coro) |
| 69 | + except RuntimeError: |
| 70 | + # No running loop in this thread (we're in a worker thread) |
| 71 | + # First try contextvars, then fallback to global reference |
| 72 | + main_loop = _main_event_loop_var.get() |
| 73 | + if main_loop is None: |
| 74 | + with _GLOBAL_MAIN_LOOP_LOCK: |
| 75 | + main_loop = _GLOBAL_MAIN_LOOP |
| 76 | + if main_loop is not None and main_loop.is_running(): |
| 77 | + asyncio.run_coroutine_threadsafe(coro, main_loop) |
| 78 | + else: |
| 79 | + # This should not happen in normal operation - log error and skip |
| 80 | + logging.error("No event loop available for async task " |
| 81 | + "scheduling, task skipped. Ensure " |
| 82 | + "set_main_event_loop() is called " |
| 83 | + "before parallel agent creation.") |
| 84 | + |
| 85 | + |
| 86 | +def agent_model( |
| 87 | + agent_name: str, |
| 88 | + system_message: str | BaseMessage, |
| 89 | + options: Chat, |
| 90 | + tools: list[FunctionTool | Callable] | None = None, |
| 91 | + prune_tool_calls_from_memory: bool = False, |
| 92 | + tool_names: list[str] | None = None, |
| 93 | + toolkits_to_register_agent: list[RegisteredAgentToolkit] | None = None, |
| 94 | + enable_snapshot_clean: bool = False, |
| 95 | + custom_model_config: AgentModelConfig | None = None, |
| 96 | +): |
| 97 | + task_lock = get_task_lock(options.project_id) |
| 98 | + agent_id = str(uuid.uuid4()) |
| 99 | + logger.info(f"Creating agent: {agent_name} with id: {agent_id} " |
| 100 | + f"for project: {options.project_id}") |
| 101 | + # Use thread-safe scheduling to support parallel agent creation |
| 102 | + _schedule_async_task( |
| 103 | + task_lock.put_queue( |
| 104 | + ActionCreateAgentData( |
| 105 | + data={ |
| 106 | + "agent_name": agent_name, |
| 107 | + "agent_id": agent_id, |
| 108 | + "tools": tool_names or [], |
| 109 | + }))) |
| 110 | + |
| 111 | + # Determine model configuration - use custom config if provided, |
| 112 | + # otherwise use task defaults |
| 113 | + config_attrs = ["model_platform", "model_type", "api_key", "api_url"] |
| 114 | + effective_config = {} |
| 115 | + |
| 116 | + if custom_model_config and custom_model_config.has_custom_config(): |
| 117 | + for attr in config_attrs: |
| 118 | + effective_config[attr] = getattr(custom_model_config, attr, |
| 119 | + None) or getattr(options, attr) |
| 120 | + extra_params = (custom_model_config.extra_params |
| 121 | + or options.extra_params or {}) |
| 122 | + logger.info(f"Agent {agent_name} using custom model config: " |
| 123 | + f"platform={effective_config['model_platform']}, " |
| 124 | + f"type={effective_config['model_type']}") |
| 125 | + else: |
| 126 | + for attr in config_attrs: |
| 127 | + effective_config[attr] = getattr(options, attr) |
| 128 | + extra_params = options.extra_params or {} |
| 129 | + init_param_keys = { |
| 130 | + "api_version", |
| 131 | + "azure_ad_token", |
| 132 | + "azure_ad_token_provider", |
| 133 | + "max_retries", |
| 134 | + "timeout", |
| 135 | + "client", |
| 136 | + "async_client", |
| 137 | + "azure_deployment_name", |
| 138 | + } |
| 139 | + |
| 140 | + init_params = {} |
| 141 | + model_config: dict[str, Any] = {} |
| 142 | + |
| 143 | + if options.is_cloud(): |
| 144 | + model_config["user"] = str(options.project_id) |
| 145 | + |
| 146 | + excluded_keys = {"model_platform", "model_type", "api_key", "url"} |
| 147 | + |
| 148 | + # Distribute extra_params between init_params and model_config |
| 149 | + for k, v in extra_params.items(): |
| 150 | + if k in excluded_keys: |
| 151 | + continue |
| 152 | + # Skip empty values |
| 153 | + if v is None or (isinstance(v, str) and not v.strip()): |
| 154 | + continue |
| 155 | + |
| 156 | + if k in init_param_keys: |
| 157 | + init_params[k] = v |
| 158 | + else: |
| 159 | + model_config[k] = v |
| 160 | + |
| 161 | + if agent_name == Agents.task_agent: |
| 162 | + model_config["stream"] = True |
| 163 | + if agent_name == Agents.browser_agent: |
| 164 | + try: |
| 165 | + model_platform_enum = ModelPlatformType( |
| 166 | + effective_config["model_platform"].lower()) |
| 167 | + if model_platform_enum in { |
| 168 | + ModelPlatformType.OPENAI, |
| 169 | + ModelPlatformType.AZURE, |
| 170 | + ModelPlatformType.OPENAI_COMPATIBLE_MODEL, |
| 171 | + ModelPlatformType.LITELLM, |
| 172 | + ModelPlatformType.OPENROUTER, |
| 173 | + }: |
| 174 | + model_config["parallel_tool_calls"] = False |
| 175 | + except (ValueError, AttributeError): |
| 176 | + logging.error( |
| 177 | + f"Invalid model platform for browser agent: " |
| 178 | + f"{effective_config['model_platform']}", |
| 179 | + exc_info=True, |
| 180 | + ) |
| 181 | + model_platform_enum = None |
| 182 | + |
| 183 | + model = ModelFactory.create( |
| 184 | + model_platform=effective_config["model_platform"], |
| 185 | + model_type=effective_config["model_type"], |
| 186 | + api_key=effective_config["api_key"], |
| 187 | + url=effective_config["api_url"], |
| 188 | + model_config_dict=model_config or None, |
| 189 | + timeout=600, # 10 minutes |
| 190 | + **init_params, |
| 191 | + ) |
| 192 | + |
| 193 | + return ListenChatAgent( |
| 194 | + options.project_id, |
| 195 | + agent_name, |
| 196 | + system_message, |
| 197 | + model=model, |
| 198 | + tools=tools, |
| 199 | + agent_id=agent_id, |
| 200 | + prune_tool_calls_from_memory=prune_tool_calls_from_memory, |
| 201 | + toolkits_to_register_agent=toolkits_to_register_agent, |
| 202 | + enable_snapshot_clean=enable_snapshot_clean, |
| 203 | + stream_accumulate=False, |
| 204 | + ) |
0 commit comments