Skip to content

Commit 0d47609

Browse files
committed
fix: update token count types from int to int32 across Bedrock client implementations
1 parent 40bed58 commit 0d47609

8 files changed

Lines changed: 78 additions & 72 deletions

File tree

llms/bedrock/internal/bedrockclient/bedrockclient_converse.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,15 @@ func (c *ConverseClient) convertConverseResponse(response *bedrockruntime.Conver
495495

496496
// Add usage information
497497
if response.Usage != nil {
498-
choice.GenerationInfo["input_tokens"] = response.Usage.InputTokens
499-
choice.GenerationInfo["output_tokens"] = response.Usage.OutputTokens
500-
choice.GenerationInfo["total_tokens"] = response.Usage.TotalTokens
498+
if response.Usage.InputTokens != nil {
499+
choice.GenerationInfo["input_tokens"] = *response.Usage.InputTokens
500+
}
501+
if response.Usage.OutputTokens != nil {
502+
choice.GenerationInfo["output_tokens"] = *response.Usage.OutputTokens
503+
}
504+
if response.Usage.TotalTokens != nil {
505+
choice.GenerationInfo["total_tokens"] = *response.Usage.TotalTokens
506+
}
501507
}
502508

503509
return &llms.ContentResponse{

llms/bedrock/internal/bedrockclient/bedrockclient_integration_test.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ func TestClient_CreateCompletion(t *testing.T) {
118118
require.Len(t, resp.Choices, 1)
119119
assert.Equal(t, "Hello! How can I help you?", resp.Choices[0].Content)
120120
assert.Equal(t, Ai21CompletionReasonStop, resp.Choices[0].StopReason)
121-
assert.Equal(t, 5, resp.Choices[0].GenerationInfo["input_tokens"])
122-
assert.Equal(t, 7, resp.Choices[0].GenerationInfo["output_tokens"])
121+
assert.Equal(t, int32(5), resp.Choices[0].GenerationInfo["input_tokens"])
122+
assert.Equal(t, int32(7), resp.Choices[0].GenerationInfo["output_tokens"])
123123
},
124124
},
125125
{
@@ -135,7 +135,7 @@ func TestClient_CreateCompletion(t *testing.T) {
135135
mockResponse: amazonTextGenerationOutput{
136136
InputTextTokenCount: 4,
137137
Results: []struct {
138-
TokenCount int `json:"tokenCount"`
138+
TokenCount int32 `json:"tokenCount"`
139139
OutputText string `json:"outputText"`
140140
CompletionReason string `json:"completionReason"`
141141
}{
@@ -150,8 +150,8 @@ func TestClient_CreateCompletion(t *testing.T) {
150150
require.Len(t, resp.Choices, 1)
151151
assert.Equal(t, "Hello! I'm Amazon Titan.", resp.Choices[0].Content)
152152
assert.Equal(t, AmazonCompletionReasonFinish, resp.Choices[0].StopReason)
153-
assert.Equal(t, 4, resp.Choices[0].GenerationInfo["input_tokens"])
154-
assert.Equal(t, 8, resp.Choices[0].GenerationInfo["output_tokens"])
153+
assert.Equal(t, int32(4), resp.Choices[0].GenerationInfo["input_tokens"])
154+
assert.Equal(t, int32(8), resp.Choices[0].GenerationInfo["output_tokens"])
155155
},
156156
},
157157
{
@@ -176,8 +176,8 @@ func TestClient_CreateCompletion(t *testing.T) {
176176
},
177177
StopReason: AnthropicCompletionReasonEndTurn,
178178
Usage: struct {
179-
InputTokens int `json:"input_tokens"`
180-
OutputTokens int `json:"output_tokens"`
179+
InputTokens int32 `json:"input_tokens"`
180+
OutputTokens int32 `json:"output_tokens"`
181181
}{
182182
InputTokens: 10,
183183
OutputTokens: 5,
@@ -187,8 +187,8 @@ func TestClient_CreateCompletion(t *testing.T) {
187187
require.Len(t, resp.Choices, 1)
188188
assert.Equal(t, "Hello! I'm Claude.", resp.Choices[0].Content)
189189
assert.Equal(t, AnthropicCompletionReasonEndTurn, resp.Choices[0].StopReason)
190-
assert.Equal(t, 10, resp.Choices[0].GenerationInfo["input_tokens"])
191-
assert.Equal(t, 5, resp.Choices[0].GenerationInfo["output_tokens"])
190+
assert.Equal(t, int32(10), resp.Choices[0].GenerationInfo["input_tokens"])
191+
assert.Equal(t, int32(5), resp.Choices[0].GenerationInfo["output_tokens"])
192192
},
193193
},
194194
{
@@ -245,8 +245,8 @@ func TestClient_CreateCompletion(t *testing.T) {
245245
require.Len(t, resp.Choices, 1)
246246
assert.Equal(t, "Hello! I'm LLaMA 2.", resp.Choices[0].Content)
247247
assert.Equal(t, MetaCompletionReasonStop, resp.Choices[0].StopReason)
248-
assert.Equal(t, 3, resp.Choices[0].GenerationInfo["input_tokens"])
249-
assert.Equal(t, 6, resp.Choices[0].GenerationInfo["output_tokens"])
248+
assert.Equal(t, int32(3), resp.Choices[0].GenerationInfo["input_tokens"])
249+
assert.Equal(t, int32(6), resp.Choices[0].GenerationInfo["output_tokens"])
250250
},
251251
},
252252
{
@@ -441,16 +441,16 @@ func TestClient_CreateCompletion_Streaming(t *testing.T) {
441441
StopReason any `json:"stop_reason"`
442442
StopSequence any `json:"stop_sequence"`
443443
Usage struct {
444-
InputTokens int `json:"input_tokens"`
445-
OutputTokens int `json:"output_tokens"`
444+
InputTokens int32 `json:"input_tokens"`
445+
OutputTokens int32 `json:"output_tokens"`
446446
} `json:"usage"`
447447
}{
448448
ID: "msg-123",
449449
Type: "message",
450450
Role: "assistant",
451451
Usage: struct {
452-
InputTokens int `json:"input_tokens"`
453-
OutputTokens int `json:"output_tokens"`
452+
InputTokens int32 `json:"input_tokens"`
453+
OutputTokens int32 `json:"output_tokens"`
454454
}{
455455
InputTokens: 10,
456456
},
@@ -503,7 +503,7 @@ func TestClient_CreateCompletion_Streaming(t *testing.T) {
503503
StopReason: AnthropicCompletionReasonEndTurn,
504504
},
505505
Usage: struct {
506-
OutputTokens int `json:"output_tokens"`
506+
OutputTokens int32 `json:"output_tokens"`
507507
}{
508508
OutputTokens: 15,
509509
},
@@ -538,8 +538,8 @@ func TestClient_CreateCompletion_Streaming(t *testing.T) {
538538
require.Len(t, resp.Choices, 1)
539539
assert.Equal(t, "Once upon a time, there was a brave knight.", resp.Choices[0].Content)
540540
assert.Equal(t, AnthropicCompletionReasonEndTurn, resp.Choices[0].StopReason)
541-
assert.Equal(t, 10, resp.Choices[0].GenerationInfo["input_tokens"])
542-
assert.Equal(t, 15, resp.Choices[0].GenerationInfo["output_tokens"])
541+
assert.Equal(t, int32(10), resp.Choices[0].GenerationInfo["input_tokens"])
542+
assert.Equal(t, int32(15), resp.Choices[0].GenerationInfo["output_tokens"])
543543

544544
// Validate streamed content
545545
assert.Equal(t, []string{"Once upon a time, ", "there was a brave knight."}, streamedContent)
@@ -643,7 +643,7 @@ func TestClient_CreateCompletion_EdgeCases(t *testing.T) {
643643
},
644644
mockResponse: amazonTextGenerationOutput{
645645
Results: []struct {
646-
TokenCount int `json:"tokenCount"`
646+
TokenCount int32 `json:"tokenCount"`
647647
OutputText string `json:"outputText"`
648648
CompletionReason string `json:"completionReason"`
649649
}{},
@@ -879,8 +879,8 @@ func testCreateAi21CompletionWithMock(ctx context.Context, client *mockBedrockCl
879879
Content: c.Data.Text,
880880
StopReason: c.FinishReason.Reason,
881881
GenerationInfo: map[string]any{
882-
"input_tokens": len(output.Prompt.Tokens),
883-
"output_tokens": len(c.Data.Tokens),
882+
"input_tokens": int32(len(output.Prompt.Tokens)),
883+
"output_tokens": int32(len(c.Data.Tokens)),
884884
},
885885
}
886886
}

llms/bedrock/internal/bedrockclient/bedrockclient_test.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ func TestAmazonResponseParsing(t *testing.T) {
580580
output := amazonTextGenerationOutput{
581581
InputTextTokenCount: 5,
582582
Results: []struct {
583-
TokenCount int `json:"tokenCount"`
583+
TokenCount int32 `json:"tokenCount"`
584584
OutputText string `json:"outputText"`
585585
CompletionReason string `json:"completionReason"`
586586
}{
@@ -599,9 +599,9 @@ func TestAmazonResponseParsing(t *testing.T) {
599599
err = json.Unmarshal(data, &parsed)
600600
require.NoError(t, err)
601601

602-
require.Equal(t, 5, parsed.InputTextTokenCount)
602+
require.Equal(t, int32(5), parsed.InputTextTokenCount)
603603
require.Len(t, parsed.Results, 1)
604-
require.Equal(t, 15, parsed.Results[0].TokenCount)
604+
require.Equal(t, int32(15), parsed.Results[0].TokenCount)
605605
require.Equal(t, "AI is transforming the world.", parsed.Results[0].OutputText)
606606
require.Equal(t, AmazonCompletionReasonFinish, parsed.Results[0].CompletionReason)
607607
}
@@ -654,8 +654,8 @@ func TestMetaResponseParsing(t *testing.T) {
654654
require.NoError(t, err)
655655

656656
require.Equal(t, output.Generation, parsed.Generation)
657-
require.Equal(t, 7, parsed.PromptTokenCount)
658-
require.Equal(t, 12, parsed.GenerationTokenCount)
657+
require.Equal(t, int32(7), parsed.PromptTokenCount)
658+
require.Equal(t, int32(12), parsed.GenerationTokenCount)
659659
require.Equal(t, MetaCompletionReasonStop, parsed.StopReason)
660660
}
661661

@@ -672,8 +672,8 @@ func TestAnthropicResponseParsing(t *testing.T) {
672672
StopReason: AnthropicCompletionReasonEndTurn,
673673
StopSequence: "",
674674
Usage: struct {
675-
InputTokens int `json:"input_tokens"`
676-
OutputTokens int `json:"output_tokens"`
675+
InputTokens int32 `json:"input_tokens"`
676+
OutputTokens int32 `json:"output_tokens"`
677677
}{
678678
InputTokens: 10,
679679
OutputTokens: 15,
@@ -693,8 +693,8 @@ func TestAnthropicResponseParsing(t *testing.T) {
693693
require.Equal(t, "text", parsed.Content[0].Type)
694694
require.Equal(t, "Hello! I'm Claude, an AI assistant.", parsed.Content[0].Text)
695695
require.Equal(t, AnthropicCompletionReasonEndTurn, parsed.StopReason)
696-
require.Equal(t, 10, parsed.Usage.InputTokens)
697-
require.Equal(t, 15, parsed.Usage.OutputTokens)
696+
require.Equal(t, int32(10), parsed.Usage.InputTokens)
697+
require.Equal(t, int32(15), parsed.Usage.OutputTokens)
698698
}
699699

700700
// Edge case tests
@@ -703,7 +703,7 @@ func TestEmptyResponses(t *testing.T) {
703703
output := amazonTextGenerationOutput{
704704
InputTextTokenCount: 5,
705705
Results: []struct {
706-
TokenCount int `json:"tokenCount"`
706+
TokenCount int32 `json:"tokenCount"`
707707
OutputText string `json:"outputText"`
708708
CompletionReason string `json:"completionReason"`
709709
}{},
@@ -745,17 +745,17 @@ func TestAnthropicStreamingResponseChunk(t *testing.T) {
745745
StopReason any `json:"stop_reason"`
746746
StopSequence any `json:"stop_sequence"`
747747
Usage struct {
748-
InputTokens int `json:"input_tokens"`
749-
OutputTokens int `json:"output_tokens"`
748+
InputTokens int32 `json:"input_tokens"`
749+
OutputTokens int32 `json:"output_tokens"`
750750
} `json:"usage"`
751751
}{
752752
ID: "msg-123",
753753
Type: "message",
754754
Role: "assistant",
755755
Model: "claude-3",
756756
Usage: struct {
757-
InputTokens int `json:"input_tokens"`
758-
OutputTokens int `json:"output_tokens"`
757+
InputTokens int32 `json:"input_tokens"`
758+
OutputTokens int32 `json:"output_tokens"`
759759
}{
760760
InputTokens: 25,
761761
},
@@ -795,7 +795,7 @@ func TestAnthropicStreamingResponseChunk(t *testing.T) {
795795
StopReason: AnthropicCompletionReasonEndTurn,
796796
},
797797
Usage: struct {
798-
OutputTokens int `json:"output_tokens"`
798+
OutputTokens int32 `json:"output_tokens"`
799799
}{
800800
OutputTokens: 12,
801801
},

llms/bedrock/internal/bedrockclient/provider_ai21.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ type ai21JambaOutput struct {
7171
FinishReason string `json:"finish_reason"`
7272
} `json:"choices"`
7373
Usage struct {
74-
PromptTokens int `json:"prompt_tokens"`
75-
CompletionTokens int `json:"completion_tokens"`
76-
TotalTokens int `json:"total_tokens"`
74+
PromptTokens int32 `json:"prompt_tokens"`
75+
CompletionTokens int32 `json:"completion_tokens"`
76+
TotalTokens int32 `json:"total_tokens"`
7777
} `json:"usage"`
7878
Meta any `json:"meta"`
7979
Model string `json:"model"`
@@ -85,8 +85,8 @@ type ai21StreamingResponseChunk struct {
8585
FinishReason string `json:"finish_reason,omitempty"`
8686
Index int `json:"index,omitempty"`
8787
Usage struct {
88-
PromptTokens int `json:"prompt_tokens"`
89-
CompletionTokens int `json:"completion_tokens"`
88+
PromptTokens int32 `json:"prompt_tokens"`
89+
CompletionTokens int32 `json:"completion_tokens"`
9090
} `json:"usage,omitempty"`
9191
}
9292

@@ -218,8 +218,8 @@ func createAi21Completion(ctx context.Context, client *bedrockruntime.Client, mo
218218
StopReason: completion.FinishReason.Reason,
219219
GenerationInfo: map[string]any{
220220
"id": output.ID,
221-
"input_tokens": len(output.Prompt.Tokens),
222-
"output_tokens": len(completion.Data.Tokens),
221+
"input_tokens": int32(len(output.Prompt.Tokens)),
222+
"output_tokens": int32(len(completion.Data.Tokens)),
223223
},
224224
}
225225
}

llms/bedrock/internal/bedrockclient/provider_amazon.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ type amazonTextGenerationInput struct {
4040
// amazonTextGenerationOutput is the output for the text generation for Amazon Models.
4141
type amazonTextGenerationOutput struct {
4242
// The number of tokens in the prompt
43-
InputTextTokenCount int `json:"inputTextTokenCount"`
43+
InputTextTokenCount int32 `json:"inputTextTokenCount"`
4444
// The results of the request
4545
Results []struct {
4646
// The number of tokens in the response
47-
TokenCount int `json:"tokenCount"`
47+
TokenCount int32 `json:"tokenCount"`
4848
// The generated text
4949
OutputText string `json:"outputText"`
5050
// The reason for the completion of the generation
@@ -59,8 +59,8 @@ type amazonStreamingResponseChunk struct {
5959
Index int `json:"index"`
6060
TotalOutputTextCount int `json:"totalOutputTextCount"`
6161
CompletionReason string `json:"completionReason"`
62-
InputTextTokenCount int `json:"inputTextTokenCount"`
63-
OutputTextTokenCount int `json:"outputTextTokenCount"`
62+
InputTextTokenCount int32 `json:"inputTextTokenCount"`
63+
OutputTextTokenCount int32 `json:"outputTextTokenCount"`
6464
}
6565

6666
// Finish reason for the completion of the generation for Amazon Models.

llms/bedrock/internal/bedrockclient/provider_anthropic.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ type anthropicTextGenerationOutput struct {
114114
// Which custom stop sequence was matched, if any.
115115
StopSequence string `json:"stop_sequence"`
116116
Usage struct {
117-
InputTokens int `json:"input_tokens"`
118-
OutputTokens int `json:"output_tokens"`
117+
InputTokens int32 `json:"input_tokens"`
118+
OutputTokens int32 `json:"output_tokens"`
119119
} `json:"usage"`
120120
}
121121

@@ -302,13 +302,13 @@ type streamingCompletionResponseChunk struct {
302302
Input map[string]any `json:"input"`
303303
} `json:"content_block"`
304304
AmazonBedrockInvocationMetrics struct {
305-
InputTokenCount int `json:"inputTokenCount"`
306-
OutputTokenCount int `json:"outputTokenCount"`
307-
InvocationLatency int `json:"invocationLatency"`
308-
FirstByteLatency int `json:"firstByteLatency"`
305+
InputTokenCount int32 `json:"inputTokenCount"`
306+
OutputTokenCount int32 `json:"outputTokenCount"`
307+
InvocationLatency int32 `json:"invocationLatency"`
308+
FirstByteLatency int32 `json:"firstByteLatency"`
309309
} `json:"amazon-bedrock-invocationMetrics"`
310310
Usage struct {
311-
OutputTokens int `json:"output_tokens"`
311+
OutputTokens int32 `json:"output_tokens"`
312312
} `json:"usage"`
313313
Message struct {
314314
ID string `json:"id"`
@@ -319,8 +319,8 @@ type streamingCompletionResponseChunk struct {
319319
StopReason any `json:"stop_reason"`
320320
StopSequence any `json:"stop_sequence"`
321321
Usage struct {
322-
InputTokens int `json:"input_tokens"`
323-
OutputTokens int `json:"output_tokens"`
322+
InputTokens int32 `json:"input_tokens"`
323+
OutputTokens int32 `json:"output_tokens"`
324324
} `json:"usage"`
325325
} `json:"message"`
326326
}

llms/bedrock/internal/bedrockclient/provider_meta.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ type metaTextGenerationOutput struct {
3535
// The generated text.
3636
Generation string `json:"generation"`
3737
// The number of tokens in the prompt.
38-
PromptTokenCount int `json:"prompt_token_count"`
38+
PromptTokenCount int32 `json:"prompt_token_count"`
3939
// The number of tokens in the generated text.
40-
GenerationTokenCount int `json:"generation_token_count"`
40+
GenerationTokenCount int32 `json:"generation_token_count"`
4141
// The reason why the response stopped generating text.
4242
// One of: ["stop", "length"]
4343
StopReason string `json:"stop_reason"`
@@ -46,13 +46,13 @@ type metaTextGenerationOutput struct {
4646
// Meta streaming response structure
4747
type metaStreamingResponseChunk struct {
4848
Generation string `json:"generation"`
49-
PromptTokenCount int `json:"prompt_token_count,omitempty"`
50-
GenerationTokenCount int `json:"generation_token_count,omitempty"`
49+
PromptTokenCount int32 `json:"prompt_token_count,omitempty"`
50+
GenerationTokenCount int32 `json:"generation_token_count,omitempty"`
5151
StopReason string `json:"stop_reason,omitempty"`
5252
Amazon struct {
5353
BedrockInvocationMetrics struct {
54-
InputTokenCount int `json:"inputTokenCount"`
55-
OutputTokenCount int `json:"outputTokenCount"`
54+
InputTokenCount int32 `json:"inputTokenCount"`
55+
OutputTokenCount int32 `json:"outputTokenCount"`
5656
} `json:"bedrock-invocationMetrics"`
5757
} `json:"amazon-bedrock-invocationMetrics,omitempty"`
5858
}

0 commit comments

Comments
 (0)