Skip to content

Commit d135838

Browse files
feat: add callback system comparable to pytorch lightning
1 parent 4029647 commit d135838

9 files changed

Lines changed: 251 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ Ready to build your first agent? Check out our documentation:
123123
9. **[Reporting](docs/08_reporting.md)** - Obtain agent logs as execution reports and summaries as test reports
124124
10. **[Observability](docs/09_observability_telemetry_tracing.md)** - Monitor and debug agents
125125
11. **[Extracting Data](docs/10_extracting_data.md)** - Extracting structured data from screenshots and files
126+
12. **[Callbacks](docs/11_callbacks.md)** - Inject custom logic into the control loop
126127

127128
**Official documentation:** [docs.askui.com](https://docs.askui.com)
128129

docs/00_overview.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ Understand what data is collected and how to opt out.
8787
### 10 - Extracting Data
8888
**Topics**: Using `get()`, file support (PDF. Excel, Word, CSV), structured data extraction, response schemas
8989

90+
### 11 - Callbacks
91+
**Topics**: Inject custom logic at different positions of the control loop through callbacks
92+
9093
Extract information from screens and files using the `get()` method with Pydantic models.
9194

9295
## Additional Resources

docs/11_callbacks.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Callbacks
2+
3+
Callbacks provide hooks into the agent's conversation lifecycle, similar to PyTorch Lightning's callback system. Use them for logging, monitoring, custom metrics, or extending agent behavior.
4+
5+
## Usage
6+
7+
Subclass `ConversationCallback` and override the hooks you need:
8+
9+
```python
10+
from askui import ComputerAgent, ConversationCallback
11+
12+
class MetricsCallback(ConversationCallback):
13+
def on_step_start(self, conversation, step_index):
14+
print(f"Step {step_index} starting...")
15+
16+
def on_step_end(self, conversation, step_index, result):
17+
print(f"Step {step_index} finished: {result.status}")
18+
19+
with ComputerAgent(callbacks=[MetricsCallback()]) as agent:
20+
agent.act("Open the settings menu")
21+
```
22+
23+
## Available Hooks
24+
25+
| Hook | When Called | Parameters |
26+
|------|-------------|------------|
27+
| `on_conversation_start` | After setup, before control loop | `conversation` |
28+
| `on_conversation_end` | After control loop, before cleanup | `conversation` |
29+
| `on_control_loop_start` | Before the iteration loop begins | `conversation` |
30+
| `on_control_loop_end` | After the iteration loop ends | `conversation` |
31+
| `on_step_start` | Before each step execution | `conversation`, `step_index` |
32+
| `on_step_end` | After each step execution | `conversation`, `step_index`, `result` |
33+
| `on_tool_execution_start` | Before tools are executed | `conversation`, `tool_names` |
34+
| `on_tool_execution_end` | After tools are executed | `conversation`, `tool_names` |
35+
36+
### Parameters
37+
38+
- **`conversation`**: The `Conversation` instance with access to messages, settings, and state
39+
- **`step_index`**: Zero-based index of the current step
40+
- **`result`**: `SpeakerResult` containing `status`, `messages_to_add`, and `usage`
41+
- **`tool_names`**: List of tool names being executed
42+
43+
## Example: Timing Callback
44+
45+
```python
46+
import time
47+
from askui import ComputerAgent, ConversationCallback
48+
49+
class TimingCallback(ConversationCallback):
50+
def __init__(self):
51+
self.start_time = None
52+
self.step_times = []
53+
54+
def on_conversation_start(self, conversation):
55+
self.start_time = time.time()
56+
57+
def on_step_start(self, conversation, step_index):
58+
self._step_start = time.time()
59+
60+
def on_step_end(self, conversation, step_index, result):
61+
elapsed = time.time() - self._step_start
62+
self.step_times.append(elapsed)
63+
print(f"Step {step_index}: {elapsed:.2f}s")
64+
65+
def on_conversation_end(self, conversation):
66+
total = time.time() - self.start_time
67+
print(f"Total: {total:.2f}s across {len(self.step_times)} steps")
68+
69+
with ComputerAgent(callbacks=[TimingCallback()]) as agent:
70+
agent.act("Search for documents")
71+
```
72+
73+
## Multiple Callbacks
74+
75+
Pass multiple callbacks to combine behaviors:
76+
77+
```python
78+
with ComputerAgent(callbacks=[TimingCallback(), MetricsCallback()]) as agent:
79+
agent.act("Complete the form")
80+
```
81+
82+
Callbacks are called in the order they are provided.

src/askui/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ToolUseBlockParam,
3131
UrlImageSourceParam,
3232
)
33+
from .models.shared.conversation_callback import ConversationCallback
3334
from .models.shared.settings import (
3435
DEFAULT_GET_RESOLUTION,
3536
DEFAULT_LOCATE_RESOLUTION,
@@ -76,6 +77,7 @@
7677
"CitationPageLocationParam",
7778
"ConfigurableRetry",
7879
"ContentBlockParam",
80+
"ConversationCallback",
7981
"DEFAULT_GET_RESOLUTION",
8082
"DEFAULT_LOCATE_RESOLUTION",
8183
"GetSettings",

src/askui/agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from askui.container import telemetry
1010
from askui.locators.locators import Locator
1111
from askui.models.models import Point
12+
from askui.models.shared.conversation_callback import ConversationCallback
1213
from askui.models.shared.settings import ActSettings, LocateSettings, MessageSettings
1314
from askui.models.shared.tools import Tool
1415
from askui.prompts.act_prompts import (
@@ -67,7 +68,7 @@ class ComputerAgent(Agent):
6768
```
6869
"""
6970

70-
@telemetry.record_call(exclude={"reporters", "tools", "act_tools"})
71+
@telemetry.record_call(exclude={"reporters", "tools", "act_tools", "callbacks"})
7172
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
7273
def __init__(
7374
self,
@@ -77,6 +78,7 @@ def __init__(
7778
settings: AgentSettings | None = None,
7879
retry: Retry | None = None,
7980
act_tools: list[Tool] | None = None,
81+
callbacks: list[ConversationCallback] | None = None,
8082
) -> None:
8183
reporter = CompositeReporter(reporters=reporters)
8284
self.tools = tools or AgentToolbox(
@@ -109,6 +111,7 @@ def __init__(
109111
+ (act_tools or []),
110112
agent_os=self.tools.os,
111113
settings=settings,
114+
callbacks=callbacks,
112115
)
113116
self.act_agent_os_facade: ComputerAgentOsFacade = ComputerAgentOsFacade(
114117
self.tools.os

src/askui/agent_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from askui.locators.locators import Locator
1515
from askui.models.shared.agent_message_param import MessageParam
1616
from askui.models.shared.conversation import Conversation, Speakers
17+
from askui.models.shared.conversation_callback import ConversationCallback
1718
from askui.models.shared.settings import (
1819
ActSettings,
1920
CacheWritingSettings,
@@ -58,6 +59,7 @@ def __init__(
5859
tools: list[Tool] | None = None,
5960
agent_os: AgentOs | AndroidAgentOs | None = None,
6061
settings: AgentSettings | None = None,
62+
callbacks: list[ConversationCallback] | None = None,
6163
) -> None:
6264
load_dotenv()
6365
self._reporter: Reporter = reporter or CompositeReporter(reporters=None)
@@ -79,6 +81,7 @@ def __init__(
7981
image_qa_provider=self._image_qa_provider,
8082
detection_provider=self._detection_provider,
8183
reporter=self._reporter,
84+
callbacks=callbacks,
8285
)
8386

8487
# Provider-based tools

src/askui/android_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from askui.container import telemetry
1010
from askui.locators.locators import Locator
1111
from askui.models.models import Point
12+
from askui.models.shared.conversation_callback import ConversationCallback
1213
from askui.models.shared.settings import ActSettings, MessageSettings
1314
from askui.models.shared.tools import Tool
1415
from askui.prompts.act_prompts import create_android_agent_prompt
@@ -63,7 +64,7 @@ class AndroidAgent(Agent):
6364
```
6465
"""
6566

66-
@telemetry.record_call(exclude={"reporters", "tools", "act_tools"})
67+
@telemetry.record_call(exclude={"reporters", "tools", "act_tools", "callbacks"})
6768
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
6869
def __init__(
6970
self,
@@ -72,6 +73,7 @@ def __init__(
7273
settings: AgentSettings | None = None,
7374
retry: Retry | None = None,
7475
act_tools: list[Tool] | None = None,
76+
callbacks: list[ConversationCallback] | None = None,
7577
) -> None:
7678
reporter = CompositeReporter(reporters=reporters)
7779
self.os = PpadbAgentOs(device_identifier=device, reporter=reporter)
@@ -98,6 +100,7 @@ def __init__(
98100
+ (act_tools or []),
99101
agent_os=self.os,
100102
settings=settings,
103+
callbacks=callbacks,
101104
)
102105
self.act_tool_collection.add_agent_os(self.act_agent_os_facade)
103106
# Override default act settings with Android-specific settings

src/askui/models/shared/conversation.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from askui.speaker.speaker import SpeakerResult, Speakers
2626

2727
if TYPE_CHECKING:
28+
from askui.models.shared.conversation_callback import ConversationCallback
2829
from askui.utils.caching.cache_manager import CacheManager
2930

3031
logger = logging.getLogger(__name__)
@@ -58,6 +59,7 @@ class Conversation:
5859
reporter: Reporter for logging messages and actions
5960
cache_manager: Cache manager for recording/playback (optional)
6061
truncation_strategy_factory: Factory for creating truncation strategies
62+
callbacks: List of callbacks for conversation lifecycle hooks (optional)
6163
"""
6264

6365
def __init__(
@@ -69,6 +71,7 @@ def __init__(
6971
reporter: Reporter = NULL_REPORTER,
7072
cache_manager: "CacheManager | None" = None,
7173
truncation_strategy_factory: TruncationStrategyFactory | None = None,
74+
callbacks: "list[ConversationCallback] | None" = None,
7275
) -> None:
7376
"""Initialize conversation with speakers and model providers."""
7477
if not speakers:
@@ -92,18 +95,33 @@ def __init__(
9295
truncation_strategy_factory or SimpleTruncationStrategyFactory()
9396
)
9497
self._truncation_strategy: TruncationStrategy | None = None
98+
self._callbacks: "list[ConversationCallback]" = callbacks or []
9599

96100
# State for current execution (set in start())
97101
self.settings: ActSettings = ActSettings()
98102
self.tools: ToolCollection = ToolCollection()
99103
self._reporters: list[Reporter] = []
104+
self._step_index: int = 0
100105

101106
# Cache execution context (for communication between tools and CacheExecutor)
102107
self.cache_execution_context: dict[str, Any] = {}
103108

104109
# Track if cache execution was used (to prevent recording during playback)
105110
self._executed_from_cache: bool = False
106111

112+
def _call_callbacks(self, method_name: str, *args: Any, **kwargs: Any) -> None:
113+
"""Call a method on all registered callbacks.
114+
115+
Args:
116+
method_name: Name of the callback method to call
117+
*args: Positional arguments to pass to the callback
118+
**kwargs: Keyword arguments to pass to the callback
119+
"""
120+
for callback in self._callbacks:
121+
method = getattr(callback, method_name, None)
122+
if method and callable(method):
123+
method(self, *args, **kwargs)
124+
107125
@tracer.start_as_current_span("conversation")
108126
def execute_conversation(
109127
self,
@@ -119,7 +137,6 @@ def execute_conversation(
119137
120138
Args:
121139
messages: Initial message history
122-
on_message: Optional callback for each message
123140
tools: Available tools
124141
settings: Agent settings
125142
reporters: Optional list of additional reporters for this conversation
@@ -128,7 +145,11 @@ def execute_conversation(
128145
logger.info(msg)
129146

130147
self._setup_control_loop(messages, tools, settings, reporters)
148+
149+
self._call_callbacks("on_conversation_start")
131150
self._execute_control_loop()
151+
self._call_callbacks("on_conversation_end")
152+
132153
self._conclude_control_loop()
133154

134155
@tracer.start_as_current_span("setup_control_loop")
@@ -162,9 +183,12 @@ def _setup_control_loop(
162183

163184
@tracer.start_as_current_span("control_loop")
164185
def _execute_control_loop(self) -> None:
186+
self._call_callbacks("on_control_loop_start")
187+
self._step_index = 0
165188
continue_execution = True
166189
while continue_execution:
167190
continue_execution = self._execute_step()
191+
self._call_callbacks("on_control_loop_end")
168192

169193
@tracer.start_as_current_span("finish_control_loop")
170194
def _conclude_control_loop(self) -> None:
@@ -189,6 +213,7 @@ def _execute_step(self) -> bool:
189213
Returns:
190214
True if loop should continue, False if done
191215
"""
216+
self._call_callbacks("on_step_start", self._step_index)
192217

193218
# 1. Infer next speaker
194219
speaker = self.current_speaker
@@ -226,6 +251,9 @@ def _execute_step(self) -> bool:
226251
if result.usage:
227252
self._accumulate_usage(result.usage)
228253

254+
self._call_callbacks("on_step_end", self._step_index, result)
255+
self._step_index += 1
256+
229257
return continue_loop
230258

231259
@tracer.start_as_current_span("execute_tool_call")
@@ -255,8 +283,11 @@ def _execute_tools_if_present(self, message: MessageParam) -> MessageParam | Non
255283
return None
256284

257285
# Execute tools
286+
tool_names = [block.name for block in tool_use_blocks]
258287
logger.debug("Executing %d tool(s)", len(tool_use_blocks))
288+
self._call_callbacks("on_tool_execution_start", tool_names)
259289
tool_results = self.tools.run(tool_use_blocks)
290+
self._call_callbacks("on_tool_execution_end", tool_names)
260291

261292
if not tool_results:
262293
return None

0 commit comments

Comments
 (0)