Skip to content

Commit 8918757

Browse files
author
Murat Kaan Meral
committed
fix: fix bidi tests
1 parent 6c2cbf5 commit 8918757

4 files changed

Lines changed: 30 additions & 13 deletions

File tree

src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
282282

283283
# Queue tool requests for concurrent execution
284284
# Check for ToolUseStreamEvent (standard agent event)
285-
if "current_tool_use" in strands_event:
285+
if event_type == "tool_use_stream":
286286
tool_use = strands_event.get("current_tool_use")
287287
if tool_use:
288288
tool_name = tool_use.get("name")
@@ -297,9 +297,9 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
297297

298298
# Update Agent conversation history for user transcripts
299299
if event_type == "bidirectional_transcript_stream":
300-
source = strands_event.get("source")
300+
role = strands_event.get("role")
301301
text = strands_event.get("text", "")
302-
if source == "user" and text.strip():
302+
if role == "user" and text.strip():
303303
user_message = {"role": "user", "content": text}
304304
session.agent.messages.append(user_message)
305305
logger.debug("User transcript added to history")

src/strands/experimental/bidirectional_streaming/models/gemini_live.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,12 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic
219219
# Check if the transcription object has text content
220220
if hasattr(input_transcript, 'text') and input_transcript.text:
221221
transcription_text = input_transcript.text
222+
role = getattr(input_transcript, 'role', 'user')
222223
logger.debug(f"Input transcription detected: {transcription_text}")
223224
return BidiTranscriptStreamEvent(
224225
delta={"text": transcription_text},
225226
text=transcription_text,
226-
role="user",
227+
role=role.lower() if isinstance(role, str) else "user",
227228
is_final=True,
228229
current_transcript=transcription_text
229230
)
@@ -234,22 +235,24 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dic
234235
# Check if the transcription object has text content
235236
if hasattr(output_transcript, 'text') and output_transcript.text:
236237
transcription_text = output_transcript.text
238+
role = getattr(output_transcript, 'role', 'assistant')
237239
logger.debug(f"Output transcription detected: {transcription_text}")
238240
return BidiTranscriptStreamEvent(
239241
delta={"text": transcription_text},
240242
text=transcription_text,
241-
role="assistant",
243+
role=role.lower() if isinstance(role, str) else "assistant",
242244
is_final=True,
243245
current_transcript=transcription_text
244246
)
245247

246248
# Handle text output from model
247249
if message.text:
250+
role = getattr(message, 'role', 'assistant')
248251
logger.debug(f"Text output as transcript: {message.text}")
249252
return BidiTranscriptStreamEvent(
250253
delta={"text": message.text},
251254
text=message.text,
252-
role="assistant",
255+
role=role.lower() if isinstance(role, str) else "assistant",
253256
is_final=True,
254257
current_transcript=message.text
255258
)

src/strands/experimental/bidirectional_streaming/models/novasonic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None:
552552
elif "textOutput" in nova_event:
553553
text_content = nova_event["textOutput"]["content"]
554554
# Use stored role from contentStart event, fallback to event role
555-
role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant"))
555+
role = getattr(self, "_current_role", None) or nova_event["textOutput"].get("role", "assistant")
556556

557557
# Check for Nova Sonic interruption pattern
558558
if '{ "interrupted" : true }' in text_content:
@@ -562,7 +562,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> OutputEvent | None:
562562
return BidiTranscriptStreamEvent(
563563
delta={"text": text_content},
564564
text=text_content,
565-
role="user" if role == "USER" else "assistant",
565+
role=role.lower() if isinstance(role, str) else "assistant",
566566
is_final=True,
567567
current_transcript=text_content
568568
)

src/strands/experimental/bidirectional_streaming/models/openai.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,22 @@ def _require_active(self) -> bool:
174174
return self._active
175175

176176
def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent:
177-
"""Create standardized transcript event."""
177+
"""Create standardized transcript event.
178+
179+
Args:
180+
text: The transcript text
181+
role: The role (will be normalized to lowercase)
182+
is_final: Whether this is the final transcript
183+
"""
184+
# Normalize role to lowercase and ensure it's either "user" or "assistant"
185+
normalized_role = role.lower() if isinstance(role, str) else "assistant"
186+
if normalized_role not in ["user", "assistant"]:
187+
normalized_role = "assistant"
188+
178189
return BidiTranscriptStreamEvent(
179190
delta={"text": text},
180191
text=text,
181-
role="user" if role == "user" else "assistant",
192+
role=normalized_role,
182193
is_final=is_final,
183194
current_transcript=text if is_final else None
184195
)
@@ -326,20 +337,23 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> list[OutputEven
326337

327338
# Assistant text output events - combine multiple similar events
328339
elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]:
329-
return [self._create_text_event(openai_event["delta"], "assistant")]
340+
role = openai_event.get("role", "assistant")
341+
return [self._create_text_event(openai_event["delta"], role.lower() if isinstance(role, str) else "assistant")]
330342

331343
# User transcription events - combine multiple similar events
332344
elif event_type in ["conversation.item.input_audio_transcription.delta",
333345
"conversation.item.input_audio_transcription.completed"]:
334346
text_key = "delta" if "delta" in event_type else "transcript"
335347
text = openai_event.get(text_key, "")
348+
role = openai_event.get("role", "user")
336349
is_final = "completed" in event_type
337-
return [self._create_text_event(text, "user", is_final=is_final)] if text.strip() else None
350+
return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] if text.strip() else None
338351

339352
elif event_type == "conversation.item.input_audio_transcription.segment":
340353
segment_data = openai_event.get("segment", {})
341354
text = segment_data.get("text", "")
342-
return [self._create_text_event(text, "user")] if text.strip() else None
355+
role = segment_data.get("role", "user")
356+
return [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] if text.strip() else None
343357

344358
elif event_type == "conversation.item.input_audio_transcription.failed":
345359
error_info = openai_event.get("error", {})

0 commit comments

Comments
 (0)