|
28 | 28 | TokenLimitMiddleware, |
29 | 29 | ) |
30 | 30 | from splunklib.ai.messages import AIMessage, AgentResponse |
31 | | -from splunklib.ai.middleware import AgentMiddleware, AgentRequest, AgentState, ModelRequest, ModelResponse |
| 31 | +from splunklib.ai.middleware import ( |
| 32 | + AgentMiddleware, |
| 33 | + AgentRequest, |
| 34 | + AgentState, |
| 35 | + ModelRequest, |
| 36 | + ModelResponse, |
| 37 | +) |
32 | 38 | from splunklib.ai.model import OpenAIModel |
33 | 39 | from splunklib.client import Service |
34 | 40 |
|
@@ -125,22 +131,22 @@ async def test_deadline_reset_on_each_invoke(self) -> None: |
125 | 131 | request = _make_agent_request() |
126 | 132 |
|
127 | 133 | await mw.agent_middleware(request, _noop_agent_handler) |
128 | | - first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage] |
| 134 | + first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage] |
129 | 135 |
|
130 | 136 | await mw.agent_middleware(request, _noop_agent_handler) |
131 | | - second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage] |
| 137 | + second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage] |
132 | 138 |
|
133 | 139 | assert first_deadline is not None |
134 | 140 | assert second_deadline is not None |
135 | 141 | assert second_deadline >= first_deadline |
136 | 142 |
|
137 | 143 | async def test_deadline_is_none_before_first_invoke(self) -> None: |
138 | 144 | mw = TimeoutLimitMiddleware(60.0) |
139 | | - assert mw._deadline is None # pyright: ignore[reportPrivateUsage] |
| 145 | + assert mw._deadline_per_thread_id["foo"] is None # pyright: ignore[reportPrivateUsage] |
140 | 146 |
|
141 | 147 | async def test_timeout_fires_when_deadline_exceeded(self) -> None: |
142 | 148 | mw = TimeoutLimitMiddleware(60.0) |
143 | | - mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past |
| 149 | + mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past |
144 | 150 |
|
145 | 151 | state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo") |
146 | 152 | request = ModelRequest(system_message="", state=state) |
|
0 commit comments