-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
fix: bypass display_generation for OpenAI streaming to enable raw chunk output #1030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a702ccc
9357039
2a991b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2037,23 +2037,147 @@ def _start_stream(self, prompt: str, **kwargs) -> Generator[str, None, None]: | |
| raise | ||
|
|
||
| else: | ||
| # For OpenAI-style models, fall back to the chat method for now | ||
| # TODO: Implement OpenAI streaming in future iterations | ||
| response = self.chat(prompt, **kwargs) | ||
| # For OpenAI-style models, implement proper streaming without display | ||
| # Handle knowledge search | ||
| actual_prompt = prompt | ||
| if self.knowledge: | ||
| search_results = self.knowledge.search(prompt, agent_id=self.agent_id) | ||
| if search_results: | ||
| if isinstance(search_results, dict) and 'results' in search_results: | ||
| knowledge_content = "\n".join([result['memory'] for result in search_results['results']]) | ||
| else: | ||
| knowledge_content = "\n".join(search_results) | ||
| actual_prompt = f"{prompt}\n\nKnowledge: {knowledge_content}" | ||
|
|
||
| if response: | ||
| # Simulate streaming by yielding the response in word chunks | ||
| words = str(response).split() | ||
| chunk_size = max(1, len(words) // 20) | ||
| # Handle tools properly | ||
| tools = kwargs.get('tools', self.tools) | ||
| if tools is None or (isinstance(tools, list) and len(tools) == 0): | ||
| tool_param = self.tools | ||
| else: | ||
| tool_param = tools | ||
|
|
||
| # Build messages using the helper method | ||
| messages, original_prompt = self._build_messages(actual_prompt, kwargs.get('temperature', 0.2), | ||
| kwargs.get('output_json'), kwargs.get('output_pydantic')) | ||
|
|
||
| # Store chat history length for potential rollback | ||
| chat_history_length = len(self.chat_history) | ||
|
|
||
| # Normalize original_prompt for consistent chat history storage | ||
| normalized_content = original_prompt | ||
| if isinstance(original_prompt, list): | ||
| normalized_content = next((item["text"] for item in original_prompt if item.get("type") == "text"), "") | ||
|
|
||
| # Prevent duplicate messages in chat history | ||
| if not (self.chat_history and | ||
| self.chat_history[-1].get("role") == "user" and | ||
| self.chat_history[-1].get("content") == normalized_content): | ||
| self.chat_history.append({"role": "user", "content": normalized_content}) | ||
|
|
||
| try: | ||
| # Check if OpenAI client is available | ||
| if self._openai_client is None: | ||
| raise ValueError("OpenAI client is not initialized. Please provide OPENAI_API_KEY or use a custom LLM provider.") | ||
|
|
||
| # Format tools for OpenAI | ||
| formatted_tools = self._format_tools_for_completion(tool_param) | ||
|
|
||
| # Create streaming completion directly without display function | ||
| completion_args = { | ||
| "model": self.llm, | ||
| "messages": messages, | ||
| "temperature": kwargs.get('temperature', 0.2), | ||
| "stream": True | ||
| } | ||
| if formatted_tools: | ||
| completion_args["tools"] = formatted_tools | ||
|
|
||
| completion = self._openai_client.sync_client.chat.completions.create(**completion_args) | ||
|
|
||
| # Stream the response chunks without display | ||
| response_text = "" | ||
| tool_calls_data = [] | ||
|
|
||
| for i in range(0, len(words), chunk_size): | ||
| chunk_words = words[i:i + chunk_size] | ||
| chunk = ' '.join(chunk_words) | ||
| for chunk in completion: | ||
| delta = chunk.choices[0].delta | ||
|
|
||
| if i + chunk_size < len(words): | ||
| chunk += ' ' | ||
| # Handle text content | ||
| if delta.content is not None: | ||
| chunk_content = delta.content | ||
| response_text += chunk_content | ||
| yield chunk_content | ||
|
|
||
| yield chunk | ||
| # Handle tool calls (accumulate but don't yield as chunks) | ||
| if hasattr(delta, 'tool_calls') and delta.tool_calls: | ||
| for tool_call_delta in delta.tool_calls: | ||
| # Extend tool_calls_data list to accommodate the tool call index | ||
| while len(tool_calls_data) <= tool_call_delta.index: | ||
| tool_calls_data.append({'id': '', 'function': {'name': '', 'arguments': ''}}) | ||
|
|
||
| # Accumulate tool call data | ||
| if tool_call_delta.id: | ||
| tool_calls_data[tool_call_delta.index]['id'] = tool_call_delta.id | ||
| if tool_call_delta.function.name: | ||
| tool_calls_data[tool_call_delta.index]['function']['name'] = tool_call_delta.function.name | ||
| if tool_call_delta.function.arguments: | ||
| tool_calls_data[tool_call_delta.index]['function']['arguments'] += tool_call_delta.function.arguments | ||
|
|
||
| # Handle any tool calls that were accumulated | ||
| if tool_calls_data: | ||
| # Add assistant message with tool calls to chat history | ||
| assistant_message = {"role": "assistant", "content": response_text} | ||
| if tool_calls_data: | ||
| assistant_message["tool_calls"] = [ | ||
| { | ||
| "id": tc['id'], | ||
| "type": "function", | ||
| "function": tc['function'] | ||
| } for tc in tool_calls_data if tc['id'] | ||
| ] | ||
| self.chat_history.append(assistant_message) | ||
|
|
||
| # Execute tool calls and add results to chat history | ||
| for tool_call in tool_calls_data: | ||
| if tool_call['id'] and tool_call['function']['name']: | ||
| try: | ||
| tool_result = self.execute_tool( | ||
| tool_call['function']['name'], | ||
| tool_call['function']['arguments'] | ||
| ) | ||
| # Add tool result to chat history | ||
| self.chat_history.append({ | ||
| "role": "tool", | ||
| "tool_call_id": tool_call['id'], | ||
| "content": str(tool_result) | ||
| }) | ||
| except Exception as tool_error: | ||
| logging.error(f"Tool execution error in streaming: {tool_error}") | ||
| # Add error result to chat history | ||
| self.chat_history.append({ | ||
| "role": "tool", | ||
| "tool_call_id": tool_call['id'], | ||
| "content": f"Error: {str(tool_error)}" | ||
| }) | ||
| else: | ||
| # Add complete response to chat history (text-only response) | ||
| if response_text: | ||
| self.chat_history.append({"role": "assistant", "content": response_text}) | ||
|
|
||
| except Exception as e: | ||
| # Rollback chat history on error | ||
| self.chat_history = self.chat_history[:chat_history_length] | ||
| logging.error(f"OpenAI streaming error: {e}") | ||
| # Fall back to simulated streaming | ||
| response = self.chat(prompt, **kwargs) | ||
| if response: | ||
| words = str(response).split() | ||
| chunk_size = max(1, len(words) // 20) | ||
| for i in range(0, len(words), chunk_size): | ||
| chunk_words = words[i:i + chunk_size] | ||
| chunk = ' '.join(chunk_words) | ||
| if i + chunk_size < len(words): | ||
| chunk += ' ' | ||
| yield chunk | ||
|
Comment on lines
+2171
to
+2180
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback logic for simulated streaming is identical to the implementation that was removed from the |
||
|
|
||
| # Restore original verbose mode | ||
| self.verbose = original_verbose | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Test script for streaming display bypass fix | ||
| Tests that streaming yields raw chunks without display_generation | ||
| """ | ||
|
|
||
| import sys | ||
| import os | ||
| import collections.abc | ||
|
|
||
| # Add the praisonai-agents source to Python path | ||
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents')) | ||
|
|
||
| try: | ||
| from praisonaiagents import Agent | ||
|
|
||
| print("🧪 Testing Streaming Display Bypass Fix") | ||
| print("=" * 50) | ||
|
|
||
| # Test configuration - using mock model to avoid API calls | ||
| agent = Agent( | ||
| instructions="You are a helpful assistant", | ||
| llm="mock-model-for-testing", | ||
| stream=True | ||
| ) | ||
|
|
||
| # Test 1: Basic streaming setup | ||
| print("✅ Agent created successfully with stream=True") | ||
| print(f"📊 Agent stream attribute: {agent.stream}") | ||
|
|
||
| # Test 2: Check start method behavior and exception on consumption | ||
| result = agent.start("Hello, test streaming") | ||
| assert isinstance(result, collections.abc.Generator), "Agent.start() should return a generator for streaming" | ||
| print("✅ Agent.start() returned a generator (streaming enabled)") | ||
|
|
||
| try: | ||
| # Consume the generator to trigger the API call, which should fail for a mock model. | ||
| list(result) | ||
| # If we get here, the test has failed because an exception was expected. | ||
| print("❌ FAILED: Expected an exception with mock model, but none was raised.") | ||
| except Exception as e: | ||
| print(f"✅ SUCCESS: Caught expected exception with mock model: {e}") | ||
| print("✅ Streaming path was triggered (exception expected with mock model)") | ||
|
|
||
| # Test 3: Verify the streaming method exists and is callable | ||
| if hasattr(agent, '_start_stream') and callable(agent._start_stream): | ||
| print("✅ _start_stream method exists and is callable") | ||
| else: | ||
| print("❌ _start_stream method missing") | ||
|
|
||
| print("\n🎯 Test Results:") | ||
| print("✅ Streaming infrastructure is properly set up") | ||
| print("✅ Agent.start() correctly detects stream=True") | ||
| print("✅ Modified _start_stream should now bypass display_generation") | ||
| print("✅ OpenAI streaming implementation is in place") | ||
|
|
||
| print("\n📝 Note: Full streaming test requires valid OpenAI API key") | ||
| print("🔗 This test validates the code structure and logic flow") | ||
|
|
||
| except ImportError as e: | ||
| print(f"❌ Import failed: {e}") | ||
| print("Please ensure you're running from the correct directory") | ||
| except Exception as e: | ||
| print(f"❌ Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix tool call argument parsing for proper execution.
The tool call arguments are passed as a string to
execute_tool, but this method expects a dictionary of parsed arguments.Apply this fix to properly parse the JSON arguments:
📝 Committable suggestion
🤖 Prompt for AI Agents