6060 'object' ,
6161)
6262
63+ _REFUSAL_PREFIX = '[[REFUSAL]]: '
64+
6365
6466class ApigeeLlm (Gemini ):
6567 """A BaseLlm implementation for calling Apigee proxy.
@@ -658,11 +660,14 @@ def _content_to_messages(
658660
659661 tool_calls = []
660662 content_parts = []
663+ refusals : list [str ] = []
661664
662665 function_responses = []
663666
664667 for part in content .parts or []:
665- self ._process_content_part (content , part , tool_calls , content_parts )
668+ self ._process_content_part (
669+ content , part , tool_calls , content_parts , refusals
670+ )
666671 if part .function_response :
667672 function_responses .append ({
668673 'role' : 'tool' ,
@@ -673,6 +678,8 @@ def _content_to_messages(
673678 return function_responses
674679
675680 message = {'role' : role }
681+ if refusals :
682+ message ['refusal' ] = '\n ' .join (refusals )
676683 if tool_calls :
677684 message ['tool_calls' ] = tool_calls
678685 if not content_parts :
@@ -691,6 +698,7 @@ def _process_content_part(
691698 part : types .Part ,
692699 tool_calls : list [dict [str , Any ]],
693700 content_parts : list [dict [str , Any ]],
701+ refusals : list [str ],
694702 ) -> None :
695703 """Processes a single Part and updates tool_calls or content_parts."""
696704 if content .role != 'user' and (
@@ -731,7 +739,14 @@ def _process_content_part(
731739 # Handled in the loop to return immediately
732740 pass
733741 elif part .text :
734- content_parts .append ({'type' : 'text' , 'text' : part .text })
742+ if part .text .startswith (_REFUSAL_PREFIX ):
743+ refusals .append (part .text .removeprefix (_REFUSAL_PREFIX ))
744+ else :
745+ before , sep , after = part .text .partition ('\n ' + _REFUSAL_PREFIX )
746+ if sep :
747+ refusals .append (after )
748+ if before :
749+ content_parts .append ({'type' : 'text' , 'text' : before })
735750 elif part .inline_data :
736751 mime_type = part .inline_data .mime_type
737752 data = base64 .b64encode (part .inline_data .data ).decode ('utf-8' )
@@ -843,6 +858,7 @@ def __init__(self):
843858 self .usage = {}
844859 self .logprobs = {}
845860 self .custom_metadata = {}
861+ self ._refusal_started = False
846862
847863 def process_response (self , response : dict [str , Any ]) -> LlmResponse :
848864 """Processes a complete non-streaming response."""
@@ -989,19 +1005,49 @@ def _accumulate_logprobs(self, logprobs_chunk: dict[str, Any]) -> None:
9891005 self .logprobs ['refusal' ] = []
9901006 self .logprobs ['refusal' ].extend (logprobs_chunk ['refusal' ])
9911007
992- def _append_content (self , content : str , refusal : str ) -> str :
993- if content and refusal :
994- content += '\n '
995- content += refusal
996- elif refusal :
997- content = refusal
1008+ def _accumulate_content (self , choice : dict [str , Any ]) -> str :
1009+ """Processes a message or delta chunk to accumulate content and refusals.
1010+
1011+ This method extracts 'content' and 'refusal' from the chunk, updates the
1012+ accumulated state (self.content_parts), and returns the text content for
1013+ this chunk (handling prefixes and newlines if it's a refusal).
1014+
1015+ Args:
1016+ choice: A dictionary representing a message choice or a streaming delta.
1017+
1018+ Returns:
1019+ The text content to be appended or yielded for this chunk.
1020+ """
1021+ content = choice .get ('content' , '' )
1022+ refusal = choice .get ('refusal' , '' )
1023+
1024+ if content and self ._refusal_started :
1025+ logging .warning (
1026+ 'Received content after refusal has started. Dropping content.'
1027+ )
1028+ content = ''
1029+
1030+ chunk_text = ''
9981031 if content :
999- self .content_parts += content
1000- return content
1032+ chunk_text += content
1033+
1034+ if refusal and not self ._refusal_started :
1035+ self ._refusal_started = True
1036+ if self .content_parts or chunk_text :
1037+ chunk_text += '\n '
1038+ chunk_text += _REFUSAL_PREFIX
1039+
1040+ if refusal :
1041+ chunk_text += refusal
1042+
1043+ if chunk_text :
1044+ self .content_parts += chunk_text
1045+
1046+ return chunk_text
10011047
10021048 def _add_chat_completion_chunk_delta (
10031049 self , delta : dict [str , Any ]
1004- ) -> ( list [types .Part ], str ) :
1050+ ) -> tuple [ list [types .Part ], str ] :
10051051 """Adds a chunk delta from a streaming chat completions response.
10061052
10071053 This method processes a single delta chunk from a streaming chat completions
@@ -1021,9 +1067,7 @@ def _add_chat_completion_chunk_delta(
10211067 for tool_call in delta .get ('tool_calls' , []):
10221068 chunk_part = self ._upsert_tool_call (tool_call )
10231069 parts .append (chunk_part )
1024- content = delta .get ('content' )
1025- refusal = delta .get ('refusal' )
1026- merged_content = self ._append_content (content , refusal )
1070+ merged_content = self ._accumulate_content (delta )
10271071 if merged_content :
10281072 parts .append (types .Part .from_text (text = merged_content ))
10291073
@@ -1057,9 +1101,7 @@ def _add_chat_completion_message(
10571101 'type' : 'function' ,
10581102 'function' : function_call ,
10591103 })
1060- content = message .get ('content' )
1061- refusal = message .get ('refusal' )
1062- self ._append_content (content , refusal )
1104+ self ._accumulate_content (message )
10631105
10641106 self ._get_or_create_role (message .get ('role' , 'model' ))
10651107 return self ._get_content_parts (), self .role
0 commit comments