1616
1717from abc import ABC
1818import asyncio
19+ import inspect
1920import logging
2021from typing import AsyncGenerator
2122from typing import cast
@@ -199,7 +200,7 @@ def get_author(llm_response):
199200 return "user"
200201 else :
201202 return invocation_context .agent .name
202-
203+
203204 assert invocation_context .live_request_queue
204205 try :
205206 while True :
@@ -447,7 +448,7 @@ async def _call_llm_async(
447448 model_response_event : Event ,
448449 ) -> AsyncGenerator [LlmResponse , None ]:
449450 # Runs before_model_callback if it exists.
450- if response := self ._handle_before_model_callback (
451+ if response := await self ._handle_before_model_callback (
451452 invocation_context , llm_request , model_response_event
452453 ):
453454 yield response
@@ -460,7 +461,7 @@ async def _call_llm_async(
460461 invocation_context .live_request_queue = LiveRequestQueue ()
461462 async for llm_response in self .run_live (invocation_context ):
462463 # Runs after_model_callback if it exists.
463- if altered_llm_response := self ._handle_after_model_callback (
464+ if altered_llm_response := await self ._handle_after_model_callback (
464465 invocation_context , llm_response , model_response_event
465466 ):
466467 llm_response = altered_llm_response
@@ -489,14 +490,14 @@ async def _call_llm_async(
489490 llm_response ,
490491 )
491492 # Runs after_model_callback if it exists.
492- if altered_llm_response := self ._handle_after_model_callback (
493+ if altered_llm_response := await self ._handle_after_model_callback (
493494 invocation_context , llm_response , model_response_event
494495 ):
495496 llm_response = altered_llm_response
496497
497498 yield llm_response
498499
499- def _handle_before_model_callback (
500+ async def _handle_before_model_callback (
500501 self ,
501502 invocation_context : InvocationContext ,
502503 llm_request : LlmRequest ,
@@ -514,11 +515,16 @@ def _handle_before_model_callback(
514515 callback_context = CallbackContext (
515516 invocation_context , event_actions = model_response_event .actions
516517 )
517- return agent .before_model_callback (
518+ before_model_callback_content = agent .before_model_callback (
518519 callback_context = callback_context , llm_request = llm_request
519520 )
520521
521- def _handle_after_model_callback (
522+ if inspect .isawaitable (before_model_callback_content ):
523+ before_model_callback_content = await before_model_callback_content
524+
525+ return before_model_callback_content
526+
527+ async def _handle_after_model_callback (
522528 self ,
523529 invocation_context : InvocationContext ,
524530 llm_response : LlmResponse ,
@@ -536,10 +542,15 @@ def _handle_after_model_callback(
536542 callback_context = CallbackContext (
537543 invocation_context , event_actions = model_response_event .actions
538544 )
539- return agent .after_model_callback (
545+ after_model_callback_content = agent .after_model_callback (
540546 callback_context = callback_context , llm_response = llm_response
541547 )
542548
549+ if inspect .isawaitable (after_model_callback_content ):
550+ after_model_callback_content = await after_model_callback_content
551+
552+ return after_model_callback_content
553+
543554 def _finalize_model_response_event (
544555 self ,
545556 llm_request : LlmRequest ,
0 commit comments