|
1 | 1 | from fastapi import APIRouter, HTTPException, Request |
2 | 2 | from fastapi.responses import JSONResponse, StreamingResponse |
| 3 | +import asyncio |
3 | 4 | import json |
4 | 5 | import logging |
5 | 6 | import time |
@@ -125,63 +126,82 @@ async def chat_completions(request: Request): |
125 | 126 |
|
126 | 127 | if standard_request.stream: |
127 | 128 | async def generate(): |
128 | | - async with app.state.session_locks.hold(session_key): |
129 | | - try: |
130 | | - update_request_context(stream_attempt=1) |
131 | | - translator = OpenAIStreamTranslator( |
132 | | - completion_id=completion_id, |
133 | | - created=created, |
134 | | - model_name=model_name, |
135 | | - client_profile=standard_request.client_profile, |
136 | | - build_final_directive=lambda answer_text: build_tool_directive( |
137 | | - standard_request, |
138 | | - RuntimeAttemptState(answer_text=answer_text), |
139 | | - ), |
140 | | - allowed_tool_names=standard_request.tool_names, |
141 | | - ) |
142 | | - |
143 | | - async def on_delta(evt: dict[str, Any], text_chunk: str | None, tool_calls: list[dict[str, Any]] | None) -> None: |
144 | | - translator.on_delta(evt, text_chunk, tool_calls) |
145 | | - |
146 | | - result = await run_retryable_completion_bridge( |
147 | | - client=client, |
148 | | - standard_request=standard_request, |
149 | | - prompt=prompt, |
150 | | - users_db=users_db, |
151 | | - token=token, |
152 | | - history_messages=history_messages, |
153 | | - max_attempts=request_max_attempts(standard_request), |
154 | | - usage_delta_factory=build_usage_delta_factory(prompt), |
155 | | - allow_after_visible_output=True, |
156 | | - capture_events=False, |
157 | | - on_delta=on_delta, |
158 | | - ) |
159 | | - execution = result.execution |
160 | | - directive = result.directive or build_tool_directive(standard_request, execution.state) |
161 | | - assistant_message = build_openai_assistant_history_message( |
162 | | - execution=execution, |
163 | | - request=standard_request, |
164 | | - directive=directive, |
165 | | - ) |
166 | | - await persist_session_turn( |
167 | | - app=app, |
168 | | - request=standard_request, |
169 | | - surface="openai", |
170 | | - execution=execution, |
171 | | - assistant_message=assistant_message, |
172 | | - ) |
173 | | - final_finish_reason = "tool_calls" if directive.stop_reason == "tool_use" else execution.state.finish_reason |
174 | | - for chunk in translator.finalize(final_finish_reason): |
175 | | - yield chunk |
176 | | - return |
177 | | - except HTTPException as he: |
178 | | - await clear_invalidated_session_chat(app=app, request=standard_request) |
179 | | - yield f"data: {json.dumps({'error': he.detail})}\n\n" |
180 | | - return |
181 | | - except Exception as e: |
182 | | - await clear_invalidated_session_chat(app=app, request=standard_request) |
183 | | - yield f"data: {json.dumps({'error': str(e)})}\n\n" |
184 | | - return |
| 129 | + queue: asyncio.Queue[str | None] = asyncio.Queue() |
| 130 | + |
| 131 | + async def producer() -> None: |
| 132 | + async with app.state.session_locks.hold(session_key): |
| 133 | + try: |
| 134 | + update_request_context(stream_attempt=1) |
| 135 | + translator = OpenAIStreamTranslator( |
| 136 | + completion_id=completion_id, |
| 137 | + created=created, |
| 138 | + model_name=model_name, |
| 139 | + client_profile=standard_request.client_profile, |
| 140 | + build_final_directive=lambda answer_text: build_tool_directive( |
| 141 | + standard_request, |
| 142 | + RuntimeAttemptState(answer_text=answer_text), |
| 143 | + ), |
| 144 | + allowed_tool_names=standard_request.tool_names, |
| 145 | + ) |
| 146 | + |
| 147 | + async def on_delta(evt: dict[str, Any], text_chunk: str | None, tool_calls: list[dict[str, Any]] | None) -> None: |
| 148 | + translator.on_delta(evt, text_chunk, tool_calls) |
| 149 | + while translator.pending_chunks: |
| 150 | + await queue.put(translator.pending_chunks.pop(0)) |
| 151 | + |
| 152 | + result = await run_retryable_completion_bridge( |
| 153 | + client=client, |
| 154 | + standard_request=standard_request, |
| 155 | + prompt=prompt, |
| 156 | + users_db=users_db, |
| 157 | + token=token, |
| 158 | + history_messages=history_messages, |
| 159 | + max_attempts=request_max_attempts(standard_request), |
| 160 | + usage_delta_factory=build_usage_delta_factory(prompt), |
| 161 | + allow_after_visible_output=True, |
| 162 | + capture_events=False, |
| 163 | + on_delta=on_delta, |
| 164 | + ) |
| 165 | + execution = result.execution |
| 166 | + directive = result.directive or build_tool_directive(standard_request, execution.state) |
| 167 | + assistant_message = build_openai_assistant_history_message( |
| 168 | + execution=execution, |
| 169 | + request=standard_request, |
| 170 | + directive=directive, |
| 171 | + ) |
| 172 | + await persist_session_turn( |
| 173 | + app=app, |
| 174 | + request=standard_request, |
| 175 | + surface="openai", |
| 176 | + execution=execution, |
| 177 | + assistant_message=assistant_message, |
| 178 | + ) |
| 179 | + final_finish_reason = "tool_calls" if directive.stop_reason == "tool_use" else (execution.state.finish_reason or "stop") |
| 180 | + for chunk in translator.finalize(final_finish_reason): |
| 181 | + await queue.put(chunk) |
| 182 | + except HTTPException as he: |
| 183 | + await clear_invalidated_session_chat(app=app, request=standard_request) |
| 184 | + await queue.put(f"data: {json.dumps({'error': he.detail})}\n\n") |
| 185 | + except Exception as e: |
| 186 | + await clear_invalidated_session_chat(app=app, request=standard_request) |
| 187 | + await queue.put(f"data: {json.dumps({'error': str(e)})}\n\n") |
| 188 | + finally: |
| 189 | + await queue.put(None) |
| 190 | + |
| 191 | + producer_task = asyncio.create_task(producer()) |
| 192 | + try: |
| 193 | + while True: |
| 194 | + chunk = await queue.get() |
| 195 | + if chunk is None: |
| 196 | + break |
| 197 | + yield chunk |
| 198 | + finally: |
| 199 | + if not producer_task.done(): |
| 200 | + producer_task.cancel() |
| 201 | + try: |
| 202 | + await producer_task |
| 203 | + except Exception: |
| 204 | + pass |
185 | 205 |
|
186 | 206 | return StreamingResponse( |
187 | 207 | generate(), |
|
0 commit comments