Skip to content

Commit 644fdc4

Browse files
authored
Implement Hooks as middlewares (#80)
This change: - Makes the hook decorator return an AgentMiddleware - Changes the hook function input params to match Middlewares - Removes the underlying code, that implements hooks, as these are now middlewares. - Removes use of LC middlewares in debug logging
1 parent 04dfeb6 commit 644fdc4

File tree

8 files changed

+196
-345
lines changed

8 files changed

+196
-345
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30354,24 +30354,6 @@
3035430354
}
3035530355
}
3035630356
],
30357-
"./tests/integration/ai/test_hooks.py": [
30358-
{
30359-
"code": "reportOptionalMemberAccess",
30360-
"range": {
30361-
"startColumn": 26,
30362-
"endColumn": 30,
30363-
"lineCount": 1
30364-
}
30365-
},
30366-
{
30367-
"code": "reportOptionalMemberAccess",
30368-
"range": {
30369-
"startColumn": 26,
30370-
"endColumn": 30,
30371-
"lineCount": 1
30372-
}
30373-
}
30374-
],
3037530357
"./tests/integration/ai/test_registry.py": [
3037630358
{
3037730359
"code": "reportUnknownArgumentType",

splunklib/ai/README.md

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -592,12 +592,15 @@ They differ by the point in the execution flow where they are invoked:
592592
- before_agent: once per agent invocation, before any model calls
593593
- after_agent: once per agent invocation, after all model calls
594594

595+
Hooks implement the same interface as middlewares, which allows them to be supplied
596+
directly as middleware instances in the Agent constructor.
597+
595598
Example hook that logs token usage after each model call:
596599

597600
```py
598601
from splunklib.ai import Agent, OpenAIModel
599602
from splunklib.ai.hooks import after_model
600-
from splunklib.ai.middleware import AgentState
603+
from splunklib.ai.middleware import ModelResponse
601604
from splunklib.client import connect
602605

603606
import logging
@@ -608,42 +611,15 @@ model = OpenAIModel(...)
608611
service = connect(...)
609612

610613
@after_model
611-
def log_token_usage(state: AgentState) -> None:
612-
logger.debug(f"Model used {state.token_count} tokens up to this point")
613-
614-
615-
async with Agent(
616-
model=model,
617-
service=service,
618-
system_prompt="..." ,
619-
hooks=[log_token_usage],
620-
) as agent: ...
621-
```
614+
def log_model_response(req: ModelResponse) -> None:
615+
logger.debug(f"Model response {req.message.content}")
622616

623-
The same hook can be defined as a class. It needs to provide the type and name attributes, and implement the `__call__` method:
624-
625-
```py
626-
from typing import final, override
627-
from splunklib.ai.hooks import AgentHook
628-
from splunklib.ai.middleware import AgentState
629-
import logging
630-
631-
logger = logging.getLogger(__name__)
632-
633-
@final
634-
class LoggingHook(AgentHook):
635-
type = "before_model"
636-
name = "test_hook"
637-
638-
@override
639-
def __call__(self, state: AgentState) -> None:
640-
logger.debug(f"Model used {state.token_count} tokens up to this point")
641617

642618
async with Agent(
643619
model=model,
644620
service=service,
645621
system_prompt="..." ,
646-
hooks=[LoggingHook()],
622+
middleware=[log_model_response],
647623
) as agent: ...
648624
```
649625

@@ -652,25 +628,25 @@ The logic of the hook can be more advanced and include multiple conditions, for
652628

653629
```py
654630
from splunklib.ai import Agent, OpenAIModel
655-
from splunklib.ai.hooks import before_model, AgentHook
656-
from splunklib.ai.middleware import AgentState
631+
from splunklib.ai.hooks import before_model
632+
from splunklib.ai.middleware import AgentMiddleware, ModelRequest
657633
from time import monotonic
658634

659-
def timeout_or_token_limit(seconds_limit: float, token_limit: float) -> AgentHook:
635+
def timeout_or_token_limit(seconds_limit: float, token_limit: float) -> AgentMiddleware:
660636
now = monotonic()
661637
timeout = now + seconds_limit
662638

663639
@before_model
664-
def _limit_hook(state: AgentState) -> None:
665-
if state.token_count > token_limit or monotonic() >= timeout:
640+
def _limit_hook(req: ModelRequest) -> None:
641+
if req.state.token_count > token_limit or monotonic() >= timeout:
666642
raise Exception("Stopping Agentic Loop")
667643

668644
return _limit_hook
669645

670646

671647
async with Agent(
672648
...,
673-
hooks=[timeout_or_token_limit(seconds_limit=10.0, token_limit=10000)],
649+
middleware=[timeout_or_token_limit(seconds_limit=10.0, token_limit=10000)],
674650
) as agent: ...
675651
```
676652

splunklib/ai/agent.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from splunklib.ai.base_agent import BaseAgent
2424
from splunklib.ai.core.backend import AgentImpl
2525
from splunklib.ai.core.backend_registry import get_backend
26-
from splunklib.ai.hooks import AgentHook
2726
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2827
from splunklib.ai.middleware import AgentMiddleware
2928
from splunklib.ai.model import PredefinedModel
@@ -94,11 +93,6 @@ class Agent(BaseAgent[OutputT]):
9493
used as a *subagent*. The supervisor agent uses this schema to
9594
understand how to call the subagent and how to format its inputs.
9695
97-
hooks:
98-
Optional sequence of `AgentHook`. Hooks are user-defined callback
99-
functions that can be registered to execute at specific points
100-
during the agent's operation.
101-
10296
name:
10397
Name of the agent when used as a subagent. This is
10498
surfaced to the supervisor and used to decide whether this agent
@@ -130,7 +124,6 @@ def __init__(
130124
agents: Sequence[BaseAgent[BaseModel | None]] | None = None,
131125
output_schema: type[OutputT] | None = None,
132126
input_schema: type[BaseModel] | None = None, # Only used by Subagents
133-
hooks: Sequence[AgentHook] | None = None,
134127
middleware: Sequence[AgentMiddleware] | None = None,
135128
name: str = "", # Only used by Subagents
136129
description: str = "", # Only used by Subagents
@@ -144,7 +137,6 @@ def __init__(
144137
agents=agents,
145138
input_schema=input_schema,
146139
output_schema=output_schema,
147-
hooks=hooks,
148140
middleware=middleware,
149141
logger=logger,
150142
)

splunklib/ai/base_agent.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from pydantic import BaseModel
2222

23-
from splunklib.ai.hooks import AgentHook
2423
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2524
from splunklib.ai.middleware import AgentMiddleware
2625
from splunklib.ai.model import PredefinedModel
@@ -36,7 +35,6 @@ class BaseAgent(Generic[OutputT], ABC):
3635
_description: str = ""
3736
_input_schema: type[BaseModel] | None = None
3837
_output_schema: type[OutputT] | None = None
39-
_hooks: Sequence[AgentHook] | None = None
4038
_middleware: Sequence[AgentMiddleware] | None = None
4139
_trace_id: str
4240
_logger: logging.Logger
@@ -51,7 +49,6 @@ def __init__(
5149
agents: Sequence["BaseAgent[BaseModel | None]"] | None = None,
5250
input_schema: type[BaseModel] | None = None,
5351
output_schema: type[OutputT] | None = None,
54-
hooks: Sequence[AgentHook] | None = None,
5552
middleware: Sequence[AgentMiddleware] | None = None,
5653
logger: logging.Logger | None = None,
5754
) -> None:
@@ -63,7 +60,6 @@ def __init__(
6360
self._agents = tuple(agents) if agents else ()
6461
self._input_schema = input_schema
6562
self._output_schema = output_schema
66-
self._hooks = tuple(hooks) if hooks else ()
6763
self._middleware = tuple(middleware) if middleware else ()
6864
self._trace_id = secrets.token_hex(16) # 32 Hex characters
6965

@@ -112,10 +108,6 @@ def input_schema(self) -> type[BaseModel] | None:
112108
def output_schema(self) -> type[OutputT] | None:
113109
return self._output_schema
114110

115-
@property
116-
def hooks(self) -> Sequence[AgentHook] | None:
117-
return self._hooks
118-
119111
@property
120112
def middleware(self) -> Sequence[AgentMiddleware] | None:
121113
return self._middleware

0 commit comments

Comments
 (0)