Skip to content

Commit b2a6d09

Browse files
authored
chore: drop redacted thinking support for AmazonBedrockChatGenerator (#2998)
* chore: simplify reasoning parsing logic and update integration models for AmazonBedrockChatGenerator * fix mypy * safer condition * simplify * fix format_reasoning_content * format * fix tests * fix parse_completion_response * format * apply feedback * fmt * apply feedback * fmt
1 parent 0dacd9e commit b2a6d09

3 files changed

Lines changed: 137 additions & 398 deletions

File tree

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py

Lines changed: 56 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _format_tool_call_message(tool_call_message: ChatMessage) -> dict[str, Any]:
173173

174174
# tool call messages can contain reasoning content
175175
if reasoning_content := tool_call_message.reasoning:
176-
content.extend(_format_reasoning_content(reasoning_content=reasoning_content))
176+
content.append(_format_reasoning_content(reasoning_content=reasoning_content))
177177

178178
# Tool call message can contain text
179179
if tool_call_message.text:
@@ -291,22 +291,27 @@ def _repair_tool_result_messages(bedrock_formatted_messages: list[dict[str, Any]
291291
return [msg for _, msg in repaired_bedrock_formatted_messages]
292292

293293

294-
def _format_reasoning_content(reasoning_content: ReasoningContent) -> list[dict[str, Any]]:
294+
def _format_reasoning_content(reasoning_content: ReasoningContent) -> dict[str, Any]:
295295
"""
296296
Format ReasoningContent to match Bedrock's expected structure.
297297
298298
:param reasoning_content: ReasoningContent object containing reasoning contents to format.
299-
:returns: List of formatted reasoning content dictionaries for Bedrock.
299+
:returns: Dictionary representing the formatted reasoning content for Bedrock.
300+
300301
"""
301-
formatted_contents = []
302-
for content in reasoning_content.extra.get("reasoning_contents", []):
303-
formatted_content = {"reasoningContent": content["reasoning_content"]}
304-
if reasoning_text := formatted_content["reasoningContent"].pop("reasoning_text", None):
305-
formatted_content["reasoningContent"]["reasoningText"] = reasoning_text
306-
if redacted_content := formatted_content["reasoningContent"].pop("redacted_content", None):
307-
formatted_content["reasoningContent"]["redactedContent"] = redacted_content
308-
formatted_contents.append(formatted_content)
309-
return formatted_contents
302+
formatted_content = {
303+
"reasoningContent": {
304+
"reasoningText": {
305+
"text": reasoning_content.reasoning_text,
306+
**(
307+
{"signature": reasoning_content.extra["signature"]}
308+
if reasoning_content.extra.get("signature")
309+
else {}
310+
),
311+
}
312+
}
313+
}
314+
return formatted_content
310315

311316

312317
def _format_user_message(message: ChatMessage) -> dict[str, Any]:
@@ -345,7 +350,7 @@ def _format_textual_assistant_message(message: ChatMessage) -> dict[str, Any]:
345350
bedrock_content_blocks: list[dict[str, Any]] = []
346351
# Add reasoning content if available as the first content block
347352
if message.reasoning:
348-
bedrock_content_blocks.extend(_format_reasoning_content(reasoning_content=message.reasoning))
353+
bedrock_content_blocks.append(_format_reasoning_content(reasoning_content=message.reasoning))
349354

350355
for part in content_parts:
351356
if isinstance(part, TextContent):
@@ -462,7 +467,7 @@ def _parse_completion_response(response_body: dict[str, Any], model: str) -> lis
462467
# Process all content blocks and combine them into a single message
463468
text_content = []
464469
tool_calls = []
465-
reasoning_contents = []
470+
reasoning_content = None
466471
for content_block in content_blocks:
467472
if "text" in content_block:
468473
text_content.append(content_block["text"])
@@ -477,12 +482,6 @@ def _parse_completion_response(response_body: dict[str, Any], model: str) -> lis
477482
tool_calls.append(tool_call)
478483
elif "reasoningContent" in content_block:
479484
reasoning_content = content_block["reasoningContent"]
480-
# If reasoningText is present, replace it with reasoning_text
481-
if "reasoningText" in reasoning_content:
482-
reasoning_content["reasoning_text"] = reasoning_content.pop("reasoningText")
483-
if "redactedContent" in reasoning_content:
484-
reasoning_content["redacted_content"] = reasoning_content.pop("redactedContent")
485-
reasoning_contents.append({"reasoning_content": reasoning_content})
486485
elif "citationsContent" in content_block:
487486
citations_content = content_block["citationsContent"]
488487
meta["citations"] = citations_content
@@ -492,23 +491,23 @@ def _parse_completion_response(response_body: dict[str, Any], model: str) -> lis
492491
if text.strip():
493492
text_content.append(text)
494493

494+
reasoning_extra = {}
495495
reasoning_text = ""
496-
for content in reasoning_contents:
497-
if "reasoning_text" in content["reasoning_content"]:
498-
reasoning_text += content["reasoning_content"]["reasoning_text"]["text"]
499-
elif "redacted_content" in content["reasoning_content"]:
500-
reasoning_text += "[REDACTED]"
496+
if reasoning_content:
497+
if "reasoningText" in reasoning_content:
498+
reasoning_text = reasoning_content["reasoningText"].get("text", "")
499+
signature = reasoning_content["reasoningText"].get("signature")
500+
if signature:
501+
reasoning_extra["signature"] = signature
501502

502503
# Create a single ChatMessage with combined text and tool calls
503504
replies.append(
504505
ChatMessage.from_assistant(
505506
"".join(text_content),
506507
tool_calls=tool_calls,
507508
meta=meta,
508-
reasoning=ReasoningContent(
509-
reasoning_text=reasoning_text, extra={"reasoning_contents": reasoning_contents}
510-
)
511-
if reasoning_contents
509+
reasoning=ReasoningContent(reasoning_text=reasoning_text, extra=reasoning_extra)
510+
if reasoning_text or reasoning_extra
512511
else None,
513512
)
514513
)
@@ -583,15 +582,18 @@ def _convert_event_to_streaming_chunk(
583582
# This is for accumulating reasoning content deltas
584583
elif "reasoningContent" in delta:
585584
reasoning_content = delta["reasoningContent"]
586-
if "redactedContent" in reasoning_content:
587-
reasoning_content["redacted_content"] = reasoning_content.pop("redactedContent")
588-
reasoning_text = reasoning_content.get("text", "")
585+
reasoning_text = ""
586+
extra = {}
587+
if "text" in reasoning_content:
588+
reasoning_text = reasoning_content["text"]
589+
if "signature" in reasoning_content:
590+
extra["signature"] = reasoning_content["signature"]
589591
streaming_chunk = StreamingChunk(
590592
content="",
591593
index=block_idx,
592594
reasoning=ReasoningContent(
593595
reasoning_text=reasoning_text,
594-
extra={"reasoning_contents": [{"index": block_idx, "reasoning_content": reasoning_content}]},
596+
extra=extra,
595597
),
596598
meta=base_meta,
597599
)
@@ -630,86 +632,6 @@ def _convert_event_to_streaming_chunk(
630632
return streaming_chunk
631633

632634

633-
def _process_reasoning_contents(chunks: list[StreamingChunk]) -> ReasoningContent | None:
634-
"""
635-
Process reasoning contents from a list of StreamingChunk objects into the Bedrock expected format.
636-
637-
:param chunks: List of StreamingChunk objects potentially containing reasoning contents.
638-
639-
:returns: List of Bedrock formatted reasoning content dictionaries
640-
"""
641-
formatted_reasoning_contents = []
642-
current_index = None
643-
reasoning_text = ""
644-
reasoning_signature = None
645-
redacted_content = None
646-
for chunk in chunks:
647-
if chunk.reasoning and chunk.reasoning.extra:
648-
reasoning_contents = chunk.reasoning.extra.get("reasoning_contents", [])
649-
else:
650-
reasoning_contents = []
651-
652-
for reasoning_content in reasoning_contents:
653-
content_block_index = reasoning_content["index"]
654-
655-
# Start new group when index changes
656-
if current_index is not None and content_block_index != current_index:
657-
# Finalize current group
658-
if reasoning_text:
659-
formatted_reasoning_contents.append(
660-
{
661-
"reasoning_content": {
662-
"reasoning_text": {"text": reasoning_text, "signature": reasoning_signature},
663-
}
664-
}
665-
)
666-
if redacted_content:
667-
formatted_reasoning_contents.append({"reasoning_content": {"redacted_content": redacted_content}})
668-
669-
# Reset accumulators for new group
670-
reasoning_text = ""
671-
reasoning_signature = None
672-
redacted_content = None
673-
674-
# Accumulate content for current index
675-
current_index = content_block_index
676-
reasoning_text += reasoning_content["reasoning_content"].get("text", "")
677-
if "redacted_content" in reasoning_content["reasoning_content"]:
678-
redacted_content = reasoning_content["reasoning_content"]["redacted_content"]
679-
if "signature" in reasoning_content["reasoning_content"]:
680-
reasoning_signature = reasoning_content["reasoning_content"]["signature"]
681-
682-
# Finalize the last group
683-
if current_index is not None:
684-
if reasoning_text:
685-
formatted_reasoning_contents.append(
686-
{
687-
"reasoning_content": {
688-
"reasoning_text": {"text": reasoning_text, "signature": reasoning_signature},
689-
}
690-
}
691-
)
692-
if redacted_content:
693-
formatted_reasoning_contents.append({"reasoning_content": {"redacted_content": redacted_content}})
694-
695-
# Combine all reasoning texts into a single string for the main reasoning_text field
696-
final_reasoning_text = ""
697-
for content in formatted_reasoning_contents:
698-
if "reasoning_text" in content["reasoning_content"]:
699-
# mypy somehow thinks that content["reasoning_content"]["reasoning_text"]["text"] can be of type None
700-
final_reasoning_text += content["reasoning_content"]["reasoning_text"]["text"] # type: ignore[operator]
701-
elif "redacted_content" in content["reasoning_content"]:
702-
final_reasoning_text += "[REDACTED]"
703-
704-
return (
705-
ReasoningContent(
706-
reasoning_text=final_reasoning_text, extra={"reasoning_contents": formatted_reasoning_contents}
707-
)
708-
if formatted_reasoning_contents
709-
else None
710-
)
711-
712-
713635
def _parse_streaming_response(
714636
response_stream: EventStream,
715637
streaming_callback: SyncStreamingCallbackT,
@@ -736,21 +658,26 @@ def _parse_streaming_response(
736658
streaming_callback(streaming_chunk)
737659
chunks.append(streaming_chunk)
738660

661+
replies = _convert_chunks_to_messages(chunks)
662+
return replies
663+
664+
665+
def _convert_chunks_to_messages(chunks: list[StreamingChunk]) -> list[ChatMessage]:
739666
reply = _convert_streaming_chunks_to_chat_message(chunks=chunks)
740667

741-
# both the reasoning content and the trace are ignored in _convert_streaming_chunks_to_chat_message
668+
# reasoning signatures are ignored in _convert_streaming_chunks_to_chat_message
742669
# so we need to process them separately
743-
reasoning_content = _process_reasoning_contents(chunks=chunks)
744-
if chunks[-1].meta and "trace" in chunks[-1].meta:
745-
reply.meta["trace"] = chunks[-1].meta["trace"]
746-
747-
reply = ChatMessage.from_assistant(
748-
text=reply.text,
749-
meta=reply.meta,
750-
name=reply.name,
751-
tool_calls=reply.tool_calls,
752-
reasoning=reasoning_content,
753-
)
670+
if reply.reasoning:
671+
for chunk in reversed(chunks):
672+
if chunk.reasoning and chunk.reasoning.extra and "signature" in chunk.reasoning.extra:
673+
reply.reasoning.extra["signature"] = chunk.reasoning.extra["signature"]
674+
break
675+
676+
# the trace are ignored in _convert_streaming_chunks_to_chat_message
677+
# so we need to process them separately
678+
last_chunk = chunks[-1] if chunks else None
679+
if last_chunk and last_chunk.meta and "trace" in last_chunk.meta:
680+
reply.meta["trace"] = last_chunk.meta["trace"]
754681

755682
return [reply]
756683

@@ -780,16 +707,10 @@ async def _parse_streaming_response_async(
780707
content_block_idxs.add(content_block_idx)
781708
await streaming_callback(streaming_chunk)
782709
chunks.append(streaming_chunk)
783-
reply = _convert_streaming_chunks_to_chat_message(chunks=chunks)
784-
reasoning_content = _process_reasoning_contents(chunks=chunks)
785-
reply = ChatMessage.from_assistant(
786-
text=reply.text,
787-
meta=reply.meta,
788-
name=reply.name,
789-
tool_calls=reply.tool_calls,
790-
reasoning=reasoning_content,
791-
)
792-
return [reply]
710+
711+
replies = _convert_chunks_to_messages(chunks)
712+
713+
return replies
793714

794715

795716
def _validate_guardrail_config(guardrail_config: dict[str, str] | None = None, streaming: bool = False) -> None:

0 commit comments

Comments
 (0)