Skip to content

Commit 2a37cc0

Browse files
committed
FEAT: tool to launch provided sub-task managers
1 parent 228aa75 commit 2a37cc0

3 files changed

Lines changed: 319 additions & 3 deletions

File tree

packages/eaa-core/src/eaa_core/task_manager/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ class BaseTaskManager:
114114
----------
115115
llm_config : LLMConfig, optional
116116
Configuration used to build the chat model.
117+
name : str, optional
118+
Human-readable task manager name used when registering or launching
119+
task managers from other agents.
117120
memory_config : Optional[MemoryManagerConfig], optional
118121
Configuration for the long-term memory store.
119122
tools : list[BaseTool], optional
@@ -142,6 +145,7 @@ class BaseTaskManager:
142145
def __init__(
143146
self,
144147
llm_config: LLMConfig = None,
148+
name: Optional[str] = None,
145149
memory_config: Optional[MemoryManagerConfig] = None,
146150
tools: list[BaseTool] = (),
147151
skill_dirs: Optional[Sequence[str]] = None,
@@ -168,6 +172,7 @@ def __init__(
168172
"`transcript_db_path` for WebUI transcript persistence and "
169173
"`checkpoint_db_path` for LangGraph checkpoints."
170174
)
175+
self.name = self.resolve_name(name)
171176
self.chat_state = ChatGraphState()
172177
self.task_state = TaskManagerState()
173178
self.active_state: TaskManagerState = self.task_state
@@ -239,6 +244,19 @@ def __init__(
239244
if build:
240245
self.build()
241246

247+
@classmethod
248+
def get_default_name(cls) -> str:
249+
"""Return the default task manager name."""
250+
return BaseTool.camel_to_snake(cls.__name__)
251+
252+
@classmethod
253+
def resolve_name(cls, name: Optional[str]) -> str:
254+
"""Return a validated task manager name."""
255+
resolved_name = cls.get_default_name() if name is None else str(name).strip()
256+
if not resolved_name:
257+
raise ValueError("`name` must be a non-empty string when provided.")
258+
return resolved_name
259+
242260
def get_default_system_prompt(self) -> str:
243261
"""Return the default system prompt for the task manager."""
244262
prompt = render_prompt_template(

packages/eaa-core/src/eaa_core/tool/subagent.py

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import annotations
22

3+
import inspect
34
import re
4-
from typing import Annotated, Any
5+
from typing import Annotated, Any, ClassVar
56

67
from eaa_core.message_proc import extract_message_text
8+
from eaa_core.task_manager.persistence import SQLiteTranscriptStore
79
from eaa_core.tool.base import BaseTool, tool
810

911

1012
class SubagentTool(BaseTool):
1113
"""Launch subordinate task managers for delegated agent work."""
1214

1315
name: str = "subagent"
16+
registered_task_managers: ClassVar[dict[str, Any]] = {}
1417

1518
def __init__(self, task_manager: Any, *args: Any, **kwargs: Any) -> None:
1619
"""Initialize the subagent launcher.
@@ -27,6 +30,52 @@ def __init__(self, task_manager: Any, *args: Any, **kwargs: Any) -> None:
2730
self.task_manager = task_manager
2831
super().__init__(*args, **kwargs)
2932

33+
@classmethod
34+
def add_task_managers(cls, task_managers: Any | list[Any]) -> None:
35+
"""Register one or more task manager objects for subtask launches.
36+
37+
Parameters
38+
----------
39+
task_managers : Any or list[Any]
40+
Task manager object or objects with a unique ``name`` and callable
41+
``run`` method.
42+
"""
43+
if isinstance(task_managers, (list, tuple)):
44+
normalized_task_managers = list(task_managers)
45+
else:
46+
normalized_task_managers = [task_managers]
47+
for task_manager in normalized_task_managers:
48+
name = cls._get_task_manager_name(task_manager)
49+
run_method = getattr(task_manager, "run", None)
50+
if not callable(run_method):
51+
raise ValueError(
52+
f"Registered task manager {name!r} must define a callable "
53+
"`run` method."
54+
)
55+
existing_task_manager = cls.registered_task_managers.get(name)
56+
if (
57+
existing_task_manager is not None
58+
and existing_task_manager is not task_manager
59+
):
60+
raise ValueError(
61+
f"A task manager named {name!r} is already registered."
62+
)
63+
cls.registered_task_managers[name] = task_manager
64+
65+
@staticmethod
66+
def _get_task_manager_name(task_manager: Any) -> str:
67+
"""Return a validated task manager registry name."""
68+
name = str(getattr(task_manager, "name", "")).strip()
69+
if not name:
70+
raise ValueError("Registered task managers must have a non-empty `name`.")
71+
return name
72+
73+
@staticmethod
74+
def _transcript_table_name_for_conversation(conversation_id: str) -> str:
75+
"""Return the transcript table name for a subordinate conversation."""
76+
table_suffix = re.sub(r"[^A-Za-z0-9_]", "_", conversation_id)
77+
return f"transcript_messages_{table_suffix}"
78+
3079
@tool(name="subagent_tool.launch_subagent")
3180
def launch_subagent(
3281
self,
@@ -56,8 +105,9 @@ def launch_subagent(
56105
conversation_id = str(conversation["id"])
57106
transcript_table_name = "transcript_messages"
58107
if conversation_id is not None:
59-
table_suffix = re.sub(r"[^A-Za-z0-9_]", "_", conversation_id)
60-
transcript_table_name = f"transcript_messages_{table_suffix}"
108+
transcript_table_name = self._transcript_table_name_for_conversation(
109+
conversation_id
110+
)
61111
inherited_tools = [
62112
tool_object
63113
for tool_object in self.task_manager.tool_manager
@@ -93,3 +143,125 @@ def launch_subagent(
93143
conversation_id,
94144
message="Subagent terminated",
95145
)
146+
147+
@tool(name="subagent_tool.list_registered_task_managers")
148+
def list_registered_task_managers(self) -> list[dict[str, str]]:
149+
"""Return task managers available for ``launch_subtask_manager``.
150+
151+
Returns
152+
-------
153+
list[dict[str, str]]
154+
Registered task manager specs.
155+
"""
156+
specs: list[dict[str, str]] = []
157+
for name, task_manager in self.registered_task_managers.items():
158+
run_method = getattr(task_manager, "run", None)
159+
specs.append(
160+
{
161+
"name": name,
162+
"task_class_name": type(task_manager).__name__,
163+
"run_method_docstring": inspect.getdoc(run_method) or "",
164+
}
165+
)
166+
return specs
167+
168+
@tool(name="subagent_tool.launch_subtask_manager")
169+
def launch_subtask_manager(
170+
self,
171+
task_manager_name: Annotated[
172+
str,
173+
"Name of the registered task manager to launch.",
174+
],
175+
task_manager_kwargs: Annotated[
176+
dict,
177+
"Keyword arguments to pass to the selected task manager's run method.",
178+
],
179+
) -> dict[str, Any]:
180+
"""Run a registered task manager and return its run result.
181+
182+
Parameters
183+
----------
184+
task_manager_name : str
185+
Name of the registered task manager to launch.
186+
task_manager_kwargs : dict
187+
Keyword arguments passed to ``matched_task_manager.run``.
188+
189+
Returns
190+
-------
191+
dict[str, Any]
192+
The selected task manager's ``run`` return value.
193+
"""
194+
if not isinstance(task_manager_kwargs, dict):
195+
raise ValueError("`task_manager_kwargs` must be a dictionary.")
196+
task_manager_name = str(task_manager_name).strip()
197+
matched_task_manager = self.registered_task_managers.get(task_manager_name)
198+
if matched_task_manager is None:
199+
available_names = ", ".join(self.registered_task_managers) or "none"
200+
raise ValueError(
201+
f"No registered task manager named {task_manager_name!r}. "
202+
f"Available task managers: {available_names}."
203+
)
204+
205+
runtime_controller = getattr(self.task_manager, "runtime_controller", None)
206+
conversation_id = self._get_task_manager_name(matched_task_manager)
207+
if runtime_controller is not None:
208+
conversation = runtime_controller.create_conversation(
209+
label=conversation_id,
210+
kind="subagent",
211+
)
212+
conversation_id = str(conversation["id"])
213+
transcript_table_name = self._transcript_table_name_for_conversation(
214+
conversation_id
215+
)
216+
217+
saved_state = {
218+
"checkpoint_db_path": getattr(
219+
matched_task_manager,
220+
"checkpoint_db_path",
221+
None,
222+
),
223+
"transcript_db_path": getattr(
224+
matched_task_manager,
225+
"transcript_db_path",
226+
None,
227+
),
228+
"transcript_table_name": getattr(
229+
matched_task_manager,
230+
"transcript_table_name",
231+
None,
232+
),
233+
"transcript_store": getattr(matched_task_manager, "transcript_store", None),
234+
"use_webui": getattr(matched_task_manager, "use_webui", None),
235+
"runtime_controller": getattr(
236+
matched_task_manager,
237+
"runtime_controller",
238+
None,
239+
),
240+
"runtime_conversation_id": getattr(
241+
matched_task_manager,
242+
"runtime_conversation_id",
243+
None,
244+
),
245+
}
246+
matched_task_manager.checkpoint_db_path = None
247+
matched_task_manager.transcript_db_path = self.task_manager.transcript_db_path
248+
matched_task_manager.transcript_table_name = transcript_table_name
249+
matched_task_manager.transcript_store = SQLiteTranscriptStore(
250+
self.task_manager.transcript_db_path,
251+
table_name=transcript_table_name,
252+
)
253+
matched_task_manager.use_webui = self.task_manager.use_webui
254+
matched_task_manager.runtime_controller = runtime_controller
255+
matched_task_manager.runtime_conversation_id = conversation_id
256+
257+
try:
258+
result = matched_task_manager.run(**task_manager_kwargs)
259+
return {"result": result}
260+
finally:
261+
for attribute, value in saved_state.items():
262+
setattr(matched_task_manager, attribute, value)
263+
if runtime_controller is not None:
264+
runtime_controller.terminate_conversation(
265+
conversation_id,
266+
message="Subtask manager terminated",
267+
)

0 commit comments

Comments
 (0)