@@ -691,46 +691,86 @@ def __init__(
691691 self ._retry_config = retry_config or RateLimitRetryConfig .from_env ()
692692 self ._context_trim_config = ContextTrimConfig .from_env ()
693693
694- async def _inner_get_response (
695- self ,
696- * ,
697- messages : MutableSequence [Any ],
698- options : Any | None = None ,
699- ** kwargs : Any ,
694+ def _inner_get_response (
695+ self , * , messages : MutableSequence [Any ], options : Any = None , stream : bool = False , ** kwargs : Any
700696 ) -> Any :
701- """Override that adds retry + context-trimming around the parent call."""
702- parent_inner_get_response = super (
703- AzureOpenAIChatClientWithRetry , self
704- )._inner_get_response
697+ """Override that adds retry + context-trimming around the parent call.
705698
706- effective_messages : MutableSequence [Any ] | list [Any ] = messages
707- if self ._context_trim_config .enabled :
708- approx_chars = sum (len (_estimate_message_text (m )) for m in messages )
709- if (
710- self ._context_trim_config .max_total_chars > 0
711- and approx_chars > self ._context_trim_config .max_total_chars
712- ):
713- effective_messages = _trim_messages (
714- messages , cfg = self ._context_trim_config
715- )
716- logger .warning (
717- "[AOAI_CTX_TRIM] pre-trimmed chat request messages: approx_chars=%s -> %s; count=%s -> %s" ,
718- approx_chars ,
719- sum (len (_estimate_message_text (m )) for m in effective_messages ),
720- len (messages ),
721- len (effective_messages ),
722- )
699+ Must remain a regular ``def`` (not ``async def``) because the parent
700+ returns different types depending on *stream*:
701+ - stream=False → Awaitable[ChatResponse]
702+ - stream=True → ResponseStream (AsyncIterable)
703+ """
704+ effective_messages = self ._maybe_trim_messages (messages )
723705
724706 if not effective_messages :
725707 logger .warning (
726708 "[AOAI_RETRY] empty messages list received; using original messages"
727709 )
728710 effective_messages = messages
729711
712+ if stream :
713+ # For streaming, delegate to the parent which returns a proper
714+ # ResponseStream. The framework checks isinstance(result, ResponseStream)
715+ # and async generators fail that check.
716+ parent_inner = super (
717+ AzureOpenAIChatClientWithRetry , self
718+ )._inner_get_response
719+ return parent_inner (
720+ messages = effective_messages , options = options , stream = True , ** kwargs
721+ )
722+ else :
723+ return self ._non_streaming_with_retry (
724+ effective_messages = effective_messages ,
725+ original_messages = messages ,
726+ options = options ,
727+ ** kwargs ,
728+ )
729+
730+ def _maybe_trim_messages (
731+ self , messages : MutableSequence [Any ]
732+ ) -> MutableSequence [Any ] | list [Any ]:
733+ """Apply pre-call context trimming if enabled and over budget."""
734+ if not self ._context_trim_config .enabled :
735+ return messages
736+ approx_chars = sum (len (_estimate_message_text (m )) for m in messages )
737+ if (
738+ self ._context_trim_config .max_total_chars > 0
739+ and approx_chars > self ._context_trim_config .max_total_chars
740+ ):
741+ trimmed = _trim_messages (messages , cfg = self ._context_trim_config )
742+ if not trimmed :
743+ logger .warning (
744+ "[AOAI_CTX_TRIM] trimming would remove all messages; keeping originals"
745+ )
746+ return messages
747+ logger .warning (
748+ "[AOAI_CTX_TRIM] pre-trimmed chat request messages: approx_chars=%s -> %s; count=%s -> %s" ,
749+ approx_chars ,
750+ sum (len (_estimate_message_text (m )) for m in trimmed ),
751+ len (messages ),
752+ len (trimmed ),
753+ )
754+ return trimmed
755+ return messages
756+
757+ async def _non_streaming_with_retry (
758+ self ,
759+ * ,
760+ effective_messages : MutableSequence [Any ] | list [Any ],
761+ original_messages : MutableSequence [Any ],
762+ options : Any = None ,
763+ ** kwargs : Any ,
764+ ) -> Any :
765+ """Non-streaming path: full retry + context-trim fallback."""
766+ parent_inner = super (
767+ AzureOpenAIChatClientWithRetry , self
768+ )._inner_get_response
769+
730770 try :
731771 return await _retry_call (
732- lambda : parent_inner_get_response (
733- messages = effective_messages , options = options , ** kwargs
772+ lambda : parent_inner (
773+ messages = effective_messages , options = options , stream = False , ** kwargs
734774 ),
735775 config = self ._retry_config ,
736776 )
@@ -742,20 +782,48 @@ async def _inner_get_response(
742782 ):
743783 raise
744784
745- trimmed = _trim_messages (messages , cfg = self ._context_trim_config )
785+ trimmed = _trim_messages (
786+ original_messages ,
787+ cfg = ContextTrimConfig (
788+ enabled = True ,
789+ max_total_chars = max (
790+ 50_000 , self ._context_trim_config .max_total_chars - 80_000
791+ ),
792+ max_message_chars = max (
793+ 3_000 , self ._context_trim_config .max_message_chars - 6_000
794+ ),
795+ keep_last_messages = max (
796+ 6 , self ._context_trim_config .keep_last_messages - 12
797+ ),
798+ keep_head_chars = max (
799+ 1_000 , self ._context_trim_config .keep_head_chars - 4_000
800+ ),
801+ keep_tail_chars = self ._context_trim_config .keep_tail_chars ,
802+ keep_system_messages = True ,
803+ retry_on_context_error = True ,
804+ ),
805+ )
746806 if not trimmed :
747807 logger .warning (
748- "[AOAI_CTX_TRIM] trim would remove all messages; re-raising original error"
808+ "[AOAI_CTX_TRIM] aggressive trim would remove all messages; re-raising original error"
749809 )
750810 raise
751811 logger .warning (
752812 "[AOAI_CTX_TRIM] retrying chat after context-length error; count=%s -> %s" ,
753- len (messages ),
813+ len (original_messages ),
754814 len (trimmed ),
755815 )
816+ trim_delay = min (
817+ self ._retry_config .base_delay_seconds ,
818+ self ._retry_config .max_delay_seconds ,
819+ )
820+ logger .info (
821+ "[AOAI_CTX_TRIM] sleeping %ss before retry" , round (trim_delay , 1 )
822+ )
823+ await asyncio .sleep (trim_delay )
756824 return await _retry_call (
757- lambda : parent_inner_get_response (
758- messages = trimmed , options = options , ** kwargs
825+ lambda : parent_inner (
826+ messages = trimmed , options = options , stream = False , ** kwargs
759827 ),
760828 config = self ._retry_config ,
761829 )
0 commit comments