Skip to content

Commit b64e992

Browse files
committed
fix: Bedrock extended thinking configuration
1 parent e03ee3a commit b64e992

1 file changed

Lines changed: 105 additions & 9 deletions

File tree

go/adk/pkg/models/bedrock.go

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,10 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
243243
var finishReason genai.FinishReason
244244
var usageMetadata *genai.GenerateContentResponseUsageMetadata
245245

246-
// Track tool calls during streaming
247-
// Map of content block index -> tool call being built
246+
// Track tool calls and reasoning blocks during streaming.
247+
// Maps of content block index -> in-flight block being built.
248248
toolCalls := make(map[int32]*streamingToolCall)
249+
reasoningBlocks := make(map[int32]*streamingReasoningBlock)
249250
var completedToolCalls []*genai.Part
250251

251252
// Get the event stream and read events from the channel
@@ -254,19 +255,20 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
254255

255256
// Read events from the channel
256257
for event := range stream.Events() {
257-
// Handle content block start (tool use start)
258+
// Handle content block start (tool use or reasoning start)
258259
if start, ok := event.(*types.ConverseStreamOutputMemberContentBlockStart); ok {
260+
blockIdx := aws.ToInt32(start.Value.ContentBlockIndex)
259261
if toolStart, ok := start.Value.Start.(*types.ContentBlockStartMemberToolUse); ok {
260-
// A new tool use block is starting - initialize tracking
261-
blockIdx := aws.ToInt32(start.Value.ContentBlockIndex)
262262
toolCalls[blockIdx] = &streamingToolCall{
263263
ID: aws.ToString(toolStart.Value.ToolUseId),
264264
Name: aws.ToString(toolStart.Value.Name),
265265
}
266266
}
267+
// Reasoning blocks have no start payload; we initialize on first delta.
268+
_ = blockIdx
267269
}
268270

269-
// Handle content block delta (streaming text or tool input)
271+
// Handle content block delta (streaming text, tool input, or reasoning)
270272
if chunk, ok := event.(*types.ConverseStreamOutputMemberContentBlockDelta); ok {
271273
blockIdx := aws.ToInt32(chunk.Value.ContentBlockIndex)
272274

@@ -295,10 +297,25 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
295297
if tc, ok := toolCalls[blockIdx]; ok && delta.Value.Input != nil {
296298
tc.InputJSON += aws.ToString(delta.Value.Input)
297299
}
300+
301+
case *types.ContentBlockDeltaMemberReasoningContent:
302+
// Reasoning (thinking) delta — accumulate text and signature.
303+
if _, exists := reasoningBlocks[blockIdx]; !exists {
304+
reasoningBlocks[blockIdx] = &streamingReasoningBlock{}
305+
}
306+
rb := reasoningBlocks[blockIdx]
307+
switch rd := delta.Value.(type) {
308+
case *types.ReasoningContentBlockDeltaMemberText:
309+
rb.Text += rd.Value
310+
case *types.ReasoningContentBlockDeltaMemberSignature:
311+
rb.Signature = rd.Value
312+
case *types.ReasoningContentBlockDeltaMemberRedactedContent:
313+
rb.Redacted = rd.Value
314+
}
298315
}
299316
}
300317

301-
// Handle content block stop (tool use complete)
318+
// Handle content block stop (tool use or reasoning complete)
302319
if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok {
303320
blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex)
304321
if tc, ok := toolCalls[blockIdx]; ok {
@@ -316,8 +333,10 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
316333
Args: args,
317334
}
318335
completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall})
319-
delete(toolCalls, blockIdx) // Clean up
336+
delete(toolCalls, blockIdx)
320337
}
338+
// Reasoning blocks are finalized at message-stop (collected in reasoningBlocks map).
339+
_ = stop
321340
}
322341

323342
// Handle message stop (includes stop reason)
@@ -337,8 +356,15 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
337356
}
338357
}
339358

340-
// Build final response
359+
// Build final response — reasoning parts first so they precede toolUse blocks,
360+
// matching the order Bedrock requires when echoing them back in tool-result turns.
341361
finalParts := []*genai.Part{}
362+
for _, rb := range reasoningBlocks {
363+
part := rb.toPart()
364+
if part != nil {
365+
finalParts = append(finalParts, part)
366+
}
367+
}
342368
text := aggregatedText.String()
343369
if text != "" {
344370
finalParts = append(finalParts, &genai.Part{Text: text})
@@ -379,6 +405,27 @@ func (tc *streamingToolCall) parseArgs() map[string]any {
379405
return args
380406
}
381407

408+
// streamingReasoningBlock tracks a reasoning (thinking) block being built during streaming.
409+
type streamingReasoningBlock struct {
410+
Text string
411+
Signature string
412+
Redacted []byte
413+
}
414+
415+
func (rb *streamingReasoningBlock) toPart() *genai.Part {
416+
if len(rb.Redacted) > 0 {
417+
return &genai.Part{Thought: true, ThoughtSignature: rb.Redacted}
418+
}
419+
if rb.Text == "" && rb.Signature == "" {
420+
return nil
421+
}
422+
part := &genai.Part{Thought: true, Text: rb.Text}
423+
if rb.Signature != "" {
424+
part.ThoughtSignature = []byte(rb.Signature)
425+
}
426+
return part
427+
}
428+
382429
// generateNonStreaming handles non-streaming responses from Bedrock Converse.
383430
// reverseNameMap maps sanitized Bedrock tool names back to their original names.
384431
func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) {
@@ -403,6 +450,26 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string,
403450
parts := []*genai.Part{}
404451
if message, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
405452
for _, block := range message.Value.Content {
453+
// Handle reasoning (thinking) content — must be preserved and echoed back
454+
// in subsequent tool-result turns or Bedrock returns ValidationException.
455+
if reasoningBlock, ok := block.(*types.ContentBlockMemberReasoningContent); ok {
456+
if textBlock, ok := reasoningBlock.Value.(*types.ReasoningContentBlockMemberReasoningText); ok {
457+
part := &genai.Part{
458+
Thought: true,
459+
Text: aws.ToString(textBlock.Value.Text),
460+
}
461+
if textBlock.Value.Signature != nil {
462+
part.ThoughtSignature = []byte(aws.ToString(textBlock.Value.Signature))
463+
}
464+
parts = append(parts, part)
465+
} else if redacted, ok := reasoningBlock.Value.(*types.ReasoningContentBlockMemberRedactedContent); ok {
466+
parts = append(parts, &genai.Part{
467+
Thought: true,
468+
ThoughtSignature: redacted.Value,
469+
})
470+
}
471+
continue
472+
}
406473
// Handle text content
407474
if textBlock, ok := block.(*types.ContentBlockMemberText); ok {
408475
parts = append(parts, &genai.Part{Text: textBlock.Value})
@@ -517,6 +584,35 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma
517584
continue
518585
}
519586

587+
// Handle reasoning (thinking) parts — echo them back unmodified so Bedrock
588+
// can maintain reasoning continuity across tool-result turns.
589+
// AWS docs: "you must pass thinking blocks back to the API for the last
590+
// assistant message … include the complete unmodified block."
591+
if part.Thought {
592+
if len(part.ThoughtSignature) > 0 && part.Text == "" {
593+
// Redacted block
594+
contentBlocks = append(contentBlocks, &types.ContentBlockMemberReasoningContent{
595+
Value: &types.ReasoningContentBlockMemberRedactedContent{
596+
Value: part.ThoughtSignature,
597+
},
598+
})
599+
} else {
600+
textBlock := &types.ReasoningTextBlock{
601+
Text: aws.String(part.Text),
602+
}
603+
if len(part.ThoughtSignature) > 0 {
604+
sig := string(part.ThoughtSignature)
605+
textBlock.Signature = aws.String(sig)
606+
}
607+
contentBlocks = append(contentBlocks, &types.ContentBlockMemberReasoningContent{
608+
Value: &types.ReasoningContentBlockMemberReasoningText{
609+
Value: *textBlock,
610+
},
611+
})
612+
}
613+
continue
614+
}
615+
520616
// Handle function call (tool use in Bedrock terminology).
521617
// Use the sanitized name from nameMap so Bedrock can correlate the
522618
// tool call with the tool spec sent in the same request.

0 commit comments

Comments
 (0)