@@ -61,8 +61,36 @@ def _make_provider(store=None):
6161 ), store
6262
6363
64+ def _make_context (messages = None , response_text = None ):
65+ """Create a mock SessionContext for before_run/after_run calls."""
66+ ctx = MagicMock ()
67+ ctx .get_messages = MagicMock (return_value = messages or [])
68+ ctx .extend_instructions = MagicMock ()
69+ if response_text is not None :
70+ ctx .response = MagicMock ()
71+ ctx .response .text = response_text
72+ else :
73+ ctx .response = None
74+ return ctx
75+
76+
77+ async def _call_before_run (provider , messages ):
78+ """Helper to call before_run and return the instructions that were injected."""
79+ ctx = _make_context (messages = messages )
80+ await provider .before_run (agent = MagicMock (), session = MagicMock (), context = ctx , state = {})
81+ if ctx .extend_instructions .called :
82+ return ctx .extend_instructions .call_args [0 ][1 ] # second positional arg = instructions
83+ return None
84+
85+
86+ async def _call_after_run (provider , response_text ):
87+ """Helper to call after_run with a response."""
88+ ctx = _make_context (response_text = response_text )
89+ await provider .after_run (agent = MagicMock (), session = MagicMock (), context = ctx , state = {})
90+
91+
6492# ---------------------------------------------------------------------------
65- # invoking () — Pre-LLM memory injection
93+ # before_run () — Pre-LLM memory injection
6694# ---------------------------------------------------------------------------
6795
6896
@@ -75,11 +103,11 @@ async def _run():
75103 ]
76104 messages = [_make_chat_message ("How should we handle storage configuration?" )]
77105
78- context = await provider . invoking ( messages )
106+ instructions = await _call_before_run ( provider , messages )
79107
80- assert context . instructions is not None
81- assert "GKE Filestore CSI" in context . instructions
82- assert "Azure Files for AKS" in context . instructions
108+ assert instructions is not None
109+ assert "GKE Filestore CSI" in instructions
110+ assert "Azure Files for AKS" in instructions
83111 store .search .assert_called_once ()
84112
85113 asyncio .run (_run ())
@@ -88,9 +116,8 @@ async def _run():
88116def test_invoking_empty_messages_returns_empty ():
89117 async def _run ():
90118 provider , _ = _make_provider ()
91- context = await provider .invoking ([])
92- assert context .instructions is None
93- assert getattr (context , "messages" , []) == []
119+ instructions = await _call_before_run (provider , [])
120+ assert instructions is None
94121
95122 asyncio .run (_run ())
96123
@@ -101,8 +128,8 @@ async def _run():
101128 store .search .return_value = []
102129 messages = [_make_chat_message ("What is the overall migration plan for AKS?" )]
103130
104- context = await provider . invoking ( messages )
105- assert context . instructions is None
131+ instructions = await _call_before_run ( provider , messages )
132+ assert instructions is None
106133
107134 asyncio .run (_run ())
108135
@@ -113,8 +140,8 @@ async def _run():
113140 store .search .side_effect = Exception ("search failed" )
114141 messages = [_make_chat_message ("What is the networking plan for AKS?" )]
115142
116- context = await provider . invoking ( messages )
117- assert context . instructions is None
143+ instructions = await _call_before_run ( provider , messages )
144+ assert instructions is None
118145
119146 asyncio .run (_run ())
120147
@@ -125,7 +152,7 @@ async def _run():
125152 long_text = "x" * 5000
126153 messages = [_make_chat_message (long_text )]
127154
128- await provider . invoking ( messages )
155+ await _call_before_run ( provider , messages )
129156
130157 query = store .search .call_args .kwargs ["query" ]
131158 assert len (query ) <= 2000
@@ -142,7 +169,7 @@ async def _run():
142169 _make_chat_message ("Latest question about storage" ),
143170 ]
144171
145- await provider . invoking ( messages )
172+ await _call_before_run ( provider , messages )
146173
147174 query = store .search .call_args .kwargs ["query" ]
148175 assert "Latest question about storage" in query
@@ -159,10 +186,10 @@ async def _run():
159186 store .search .return_value = large_memories
160187 messages = [_make_chat_message ("What storage configuration should we use for persistent volumes?" )]
161188
162- context = await provider . invoking ( messages )
189+ instructions = await _call_before_run ( provider , messages )
163190
164- assert context . instructions is not None
165- assert len (context . instructions ) <= MAX_MEMORY_CONTEXT_CHARS + 200
191+ assert instructions is not None
192+ assert len (instructions ) <= MAX_MEMORY_CONTEXT_CHARS + 200
166193
167194 asyncio .run (_run ())
168195
@@ -175,10 +202,10 @@ async def _run():
175202 ]
176203 messages = [_make_chat_message ("What storage class should we choose for the cluster?" )]
177204
178- context = await provider . invoking ( messages )
205+ instructions = await _call_before_run ( provider , messages )
179206
180- assert "Chief Architect" in context . instructions
181- assert "design" in context . instructions
207+ assert "Chief Architect" in instructions
208+ assert "design" in instructions
182209
183210 asyncio .run (_run ())
184211
@@ -189,26 +216,25 @@ async def _run():
189216 store .search .return_value = [_make_memory_entry ("some memory" )]
190217 single = _make_chat_message ("What about networking configuration for AKS?" )
191218
192- context = await provider . invoking ( single )
219+ instructions = await _call_before_run ( provider , [ single ] )
193220
194- assert context . instructions is not None
221+ assert instructions is not None
195222 store .search .assert_called_once ()
196223
197224 asyncio .run (_run ())
198225
199226
200227# ---------------------------------------------------------------------------
201- # invoked () — Post-LLM memory storage
228+ # after_run () — Post-LLM memory storage
202229# ---------------------------------------------------------------------------
203230
204231
205232def test_invoked_stores_response ():
206233 async def _run ():
207234 provider , store = _make_provider ()
208- request = [_make_chat_message ("What is the networking plan for AKS?" )]
209- response = [_make_chat_message ("We should use Azure CNI for networking configuration in the AKS cluster" )]
235+ response_text = "We should use Azure CNI for networking configuration in the AKS cluster"
210236
211- await provider . invoked ( request , response )
237+ await _call_after_run ( provider , response_text )
212238 await provider .flush ()
213239
214240 store .add .assert_called_once ()
@@ -222,10 +248,9 @@ async def _run():
222248def test_invoked_skips_on_exception ():
223249 async def _run ():
224250 provider , store = _make_provider ()
225- request = [_make_chat_message ("Q" )]
226- response = [_make_chat_message ("A" * 100 )]
227-
228- await provider .invoked (request , response , invoke_exception = Exception ("fail" ))
251+ # after_run with no response simulates exception path
252+ ctx = _make_context (response_text = None )
253+ await provider .after_run (agent = MagicMock (), session = MagicMock (), context = ctx , state = {})
229254 store .add .assert_not_called ()
230255
231256 asyncio .run (_run ())
@@ -234,9 +259,8 @@ async def _run():
234259def test_invoked_skips_none_response ():
235260 async def _run ():
236261 provider , store = _make_provider ()
237- request = [_make_chat_message ("Q" )]
238-
239- await provider .invoked (request , None )
262+ ctx = _make_context (response_text = None )
263+ await provider .after_run (agent = MagicMock (), session = MagicMock (), context = ctx , state = {})
240264 store .add .assert_not_called ()
241265
242266 asyncio .run (_run ())
@@ -245,10 +269,8 @@ async def _run():
245269def test_invoked_skips_short_response ():
246270 async def _run ():
247271 provider , store = _make_provider ()
248- request = [_make_chat_message ("Q" )]
249- short = [_make_chat_message ("x" * (MIN_CONTENT_LENGTH_TO_STORE - 1 ))]
250-
251- await provider .invoked (request , short )
272+ short_text = "x" * (MIN_CONTENT_LENGTH_TO_STORE - 1 )
273+ await _call_after_run (provider , short_text )
252274 store .add .assert_not_called ()
253275
254276 asyncio .run (_run ())
@@ -257,10 +279,8 @@ async def _run():
257279def test_invoked_stores_long_response ():
258280 async def _run ():
259281 provider , store = _make_provider ()
260- request = [_make_chat_message ("Q" )]
261- long_resp = [_make_chat_message ("x" * (MIN_CONTENT_LENGTH_TO_STORE + 1 ))]
262-
263- await provider .invoked (request , long_resp )
282+ long_text = "x" * (MIN_CONTENT_LENGTH_TO_STORE + 1 )
283+ await _call_after_run (provider , long_text )
264284 await provider .flush ()
265285 store .add .assert_called_once ()
266286
@@ -270,11 +290,10 @@ async def _run():
270290def test_invoked_increments_turn_counter ():
271291 async def _run ():
272292 provider , store = _make_provider ()
273- request = [_make_chat_message ("Q" )]
274- response = [_make_chat_message ("A" * 100 )]
293+ response_text = "A" * 100
275294
276- await provider . invoked ( request , response )
277- await provider . invoked ( request , response )
295+ await _call_after_run ( provider , response_text )
296+ await _call_after_run ( provider , response_text )
278297 assert provider ._turn_counter == 2
279298
280299 asyncio .run (_run ())
@@ -284,10 +303,9 @@ def test_invoked_store_failure_does_not_raise():
284303 async def _run ():
285304 provider , store = _make_provider ()
286305 store .add .side_effect = Exception ("store failed" )
287- request = [_make_chat_message ("Q" )]
288- response = [_make_chat_message ("A" * 100 )]
306+ response_text = "A" * 100
289307
290- await provider . invoked ( request , response )
308+ await _call_after_run ( provider , response_text )
291309 await provider .flush () # Should not raise
292310
293311 asyncio .run (_run ())
@@ -296,10 +314,9 @@ async def _run():
296314def test_invoked_with_single_message ():
297315 async def _run ():
298316 provider , store = _make_provider ()
299- request = _make_chat_message ("What is the question about networking?" )
300- response = _make_chat_message ("We should use Azure CNI Overlay for the networking configuration in AKS" )
317+ response_text = "We should use Azure CNI Overlay for the networking configuration in AKS"
301318
302- await provider . invoked ( request , response )
319+ await _call_after_run ( provider , response_text )
303320 await provider .flush ()
304321 store .add .assert_called_once ()
305322
0 commit comments