55from typing import Any , Awaitable , Callable , Dict , List , Mapping , Optional , Sequence , Union , cast
66
77from autogen_core import AgentRuntime , Component , ComponentModel
8- from autogen_core .models import AssistantMessage , ChatCompletionClient , ModelFamily , SystemMessage , UserMessage
8+ from autogen_core .models import (
9+ AssistantMessage ,
10+ ChatCompletionClient ,
11+ CreateResult ,
12+ ModelFamily ,
13+ SystemMessage ,
14+ UserMessage ,
15+ )
916from pydantic import BaseModel
1017from typing_extensions import Self
1118
1623 BaseAgentEvent ,
1724 BaseChatMessage ,
1825 MessageFactory ,
26+ ModelClientStreamingChunkEvent ,
27+ SelectorEvent ,
1928)
2029from ...state import SelectorManagerState
2130from ._base_group_chat import BaseGroupChat
@@ -56,6 +65,7 @@ def __init__(
5665 max_selector_attempts : int ,
5766 candidate_func : Optional [CandidateFuncType ],
5867 emit_team_events : bool ,
68+ model_client_streaming : bool = False ,
5969 ) -> None :
6070 super ().__init__ (
6171 name ,
@@ -79,6 +89,7 @@ def __init__(
7989 self ._max_selector_attempts = max_selector_attempts
8090 self ._candidate_func = candidate_func
8191 self ._is_candidate_func_async = iscoroutinefunction (self ._candidate_func )
92+ self ._model_client_streaming = model_client_streaming
8293
8394 async def validate_group_state (self , messages : List [BaseChatMessage ] | None ) -> None :
8495 pass
@@ -194,7 +205,26 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st
194205 num_attempts = 0
195206 while num_attempts < max_attempts :
196207 num_attempts += 1
197- response = await self ._model_client .create (messages = select_speaker_messages )
208+ if self ._model_client_streaming :
209+ chunk : CreateResult | str = ""
210+ async for _chunk in self ._model_client .create_stream (messages = select_speaker_messages ):
211+ chunk = _chunk
212+ if self ._emit_team_events :
213+ if isinstance (chunk , str ):
214+ await self ._output_message_queue .put (
215+ ModelClientStreamingChunkEvent (content = cast (str , _chunk ), source = self ._name )
216+ )
217+ else :
218+ assert isinstance (chunk , CreateResult )
219+ assert isinstance (chunk .content , str )
220+ await self ._output_message_queue .put (
221+ SelectorEvent (content = chunk .content , source = self ._name )
222+ )
223+ # The last chunk must be CreateResult.
224+ assert isinstance (chunk , CreateResult )
225+ response = chunk
226+ else :
227+ response = await self ._model_client .create (messages = select_speaker_messages )
198228 assert isinstance (response .content , str )
199229 select_speaker_messages .append (AssistantMessage (content = response .content , source = "selector" ))
200230 # NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
@@ -281,6 +311,7 @@ class SelectorGroupChatConfig(BaseModel):
281311 # selector_func: ComponentModel | None
282312 max_selector_attempts : int = 3
283313 emit_team_events : bool = False
314+ model_client_streaming : bool = False
284315
285316
286317class SelectorGroupChat (BaseGroupChat , Component [SelectorGroupChatConfig ]):
@@ -311,6 +342,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
311342 selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
312343 This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
313344 emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
345+ model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
314346
315347 Raises:
316348 ValueError: If the number of participants is less than two or if the selector prompt is invalid.
@@ -453,6 +485,7 @@ def __init__(
453485 candidate_func : Optional [CandidateFuncType ] = None ,
454486 custom_message_types : List [type [BaseAgentEvent | BaseChatMessage ]] | None = None ,
455487 emit_team_events : bool = False ,
488+ model_client_streaming : bool = False ,
456489 ):
457490 super ().__init__ (
458491 participants ,
@@ -473,6 +506,7 @@ def __init__(
473506 self ._selector_func = selector_func
474507 self ._max_selector_attempts = max_selector_attempts
475508 self ._candidate_func = candidate_func
509+ self ._model_client_streaming = model_client_streaming
476510
477511 def _create_group_chat_manager_factory (
478512 self ,
@@ -505,6 +539,7 @@ def _create_group_chat_manager_factory(
505539 self ._max_selector_attempts ,
506540 self ._candidate_func ,
507541 self ._emit_team_events ,
542+ self ._model_client_streaming ,
508543 )
509544
510545 def _to_config (self ) -> SelectorGroupChatConfig :
@@ -518,6 +553,7 @@ def _to_config(self) -> SelectorGroupChatConfig:
518553 max_selector_attempts = self ._max_selector_attempts ,
519554 # selector_func=self._selector_func.dump_component() if self._selector_func else None,
520555 emit_team_events = self ._emit_team_events ,
556+ model_client_streaming = self ._model_client_streaming ,
521557 )
522558
523559 @classmethod
@@ -536,4 +572,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
536572 # if config.selector_func
537573 # else None,
538574 emit_team_events = config .emit_team_events ,
575+ model_client_streaming = config .model_client_streaming ,
539576 )
0 commit comments