|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import logging |
3 | 4 | import os |
4 | 5 | from typing import Any, Dict, List, Optional |
5 | 6 |
|
6 | | -import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401 |
| 7 | +# todo: need to support this for multi tool use, maybe upstream package has it fixed now. |
| 8 | +# commented out because it's not working with streams |
| 9 | +# import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401 |
7 | 10 | from abcs.llm import LLM |
8 | | -from abcs.models import PromptResponse, UsageStats |
| 11 | +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats |
9 | 12 | from openai import OpenAI |
10 | 13 | from tools.tool_manager import ToolManager |
11 | 14 |
|
@@ -188,3 +191,59 @@ def _translate_response(self, response) -> PromptResponse: |
188 | 191 | # logger.error("An error occurred while translating OpenAI response: %s", e, exc_info=True) |
189 | 192 | logger.exception(f"error: {e}\nresponse: {response}") |
190 | 193 | raise e |
| 194 | + |
| 195 | + # https://cookbook.openai.com/examples/how_to_stream_completions |
| 196 | + async def generate_text_stream( |
| 197 | + self, |
| 198 | + prompt: str, |
| 199 | + past_messages: List[Dict[str, str]], |
| 200 | + tools: Optional[List[Dict[str, Any]]] = None, |
| 201 | + **kwargs, |
| 202 | + ) -> StreamingPromptResponse: |
| 203 | + system_message = [{"role": "system", "content": self.system_prompt}] if self.system_prompt else [] |
| 204 | + combined_history = system_message + past_messages + [{"role": "user", "content": prompt}] |
| 205 | + |
| 206 | + try: |
| 207 | + stream = self.client.chat.completions.create( |
| 208 | + model=self.model, |
| 209 | + messages=combined_history, |
| 210 | + tools=tools, |
| 211 | + stream=True, |
| 212 | + ) |
| 213 | + |
| 214 | + async def content_generator(): |
| 215 | + for event in stream: |
| 216 | + # print("HERE\n"*30) |
| 217 | + # print(event) |
| 218 | + if event.choices[0].delta.content is not None: |
| 219 | + yield event.choices[0].delta.content |
| 220 | + # Small delay to allow for cooperative multitasking |
| 221 | + await asyncio.sleep(0) |
| 222 | + |
| 223 | + # # After the stream is complete, you might want to handle tool calls here |
| 224 | + # # This is a simplification and may need to be adjusted based on your needs |
| 225 | + # if tools and collected_content.strip().startswith('{"function":'): |
| 226 | + # # Handle tool calls (simplified example) |
| 227 | + # tool_response = await self.handle_tool_call(collected_content, combined_history, tools) |
| 228 | + # yield tool_response |
| 229 | + |
| 230 | + return StreamingPromptResponse( |
| 231 | + content=content_generator(), |
| 232 | + raw_response=stream, |
| 233 | + error={}, |
| 234 | + usage=UsageStats( |
| 235 | + input_tokens=0, # These will need to be updated after streaming |
| 236 | + output_tokens=0, |
| 237 | + extra={}, |
| 238 | + ), |
| 239 | + ) |
| 240 | + except Exception as e: |
| 241 | + logger.error("Error generating text stream: %s", e, exc_info=True) |
| 242 | + raise e |
| 243 | + |
| 244 | + async def handle_tool_call(self, collected_content, combined_history, tools): |
| 245 | + # This is a placeholder for handling tool calls in streaming context |
| 246 | + # You'll need to implement the logic to parse the tool call, execute it, |
| 247 | + # and generate a response based on the tool's output |
| 248 | + # This might involve breaking the streaming and making a new API call |
| 249 | + pass |
0 commit comments