Skip to content

Commit fc40103

Browse files
authored
feat: add CacheReadInputTokens and CacheWriteInputTokens to TokenUsageRecord (#229)
* feat: add CacheReadInputTokens and CacheWriteInputTokens to TokenUsageRecord * remove cached token recording from ExtraTokenTypes * add back cached token values to ExtraTokenTypes
1 parent 519b082 commit fc40103

15 files changed

Lines changed: 224 additions & 101 deletions

File tree

example/recorder.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func NewSQLiteRecorder(db *sql.DB, logger slog.Logger) (*SQLiteRecorder, error)
4040
}
4141

4242
r.stmtInsertTokenUsage, err = db.Prepare(`
43-
INSERT INTO aibridge_token_usages (id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at)
44-
VALUES (?, ?, ?, ?, ?, ?, ?)`)
43+
INSERT INTO aibridge_token_usages (id, interception_id, provider_response_id, input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens, metadata, created_at)
44+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`)
4545
if err != nil {
4646
return nil, err
4747
}
@@ -103,7 +103,7 @@ func (r *SQLiteRecorder) RecordTokenUsage(ctx context.Context, req *aibridge.Tok
103103
metadata, _ := json.Marshal(merged)
104104

105105
_, err := r.stmtInsertTokenUsage.ExecContext(ctx,
106-
uuid.NewString(), req.InterceptionID, req.MsgID, req.Input, req.Output, string(metadata), req.CreatedAt,
106+
uuid.NewString(), req.InterceptionID, req.MsgID, req.Input, req.Output, req.CacheReadInputTokens, req.CacheWriteInputTokens, string(metadata), req.CreatedAt,
107107
)
108108
if err != nil {
109109
r.logger.Warn(ctx, "failed to record token usage", slog.Error(err))

example/schema.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ func initSchema(db *sql.DB) error {
2222
provider_response_id TEXT NOT NULL,
2323
input_tokens INTEGER NOT NULL,
2424
output_tokens INTEGER NOT NULL,
25+
cache_read_input_tokens INTEGER NOT NULL DEFAULT 0,
26+
cache_write_input_tokens INTEGER NOT NULL DEFAULT 0,
2527
metadata TEXT,
2628
created_at DATETIME NOT NULL,
2729
FOREIGN KEY (interception_id) REFERENCES aibridge_interceptions(id)

intercept/chatcompletions/base.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
216216
// calculateActualInputTokenUsage accounts for cached tokens which are included in [openai.CompletionUsage].PromptTokens.
217217
func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
218218
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
219-
// The original value can be reconstructed by referencing the "prompt_cached" field in metadata.
219+
// The original value can be reconstructed by adding CachedTokens back to Input.
220220
// See https://platform.openai.com/docs/api-reference/usage/completions_object#usage/completions_object-input_tokens.
221221
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
222222
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */

intercept/chatcompletions/blocking.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,14 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
114114
cumulativeUsage = sumUsage(cumulativeUsage, completion.Usage)
115115

116116
_ = i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{
117-
InterceptionID: i.ID().String(),
118-
MsgID: completion.ID,
119-
Input: calculateActualInputTokenUsage(lastUsage),
120-
Output: lastUsage.CompletionTokens,
117+
InterceptionID: i.ID().String(),
118+
MsgID: completion.ID,
119+
Input: calculateActualInputTokenUsage(lastUsage),
120+
Output: lastUsage.CompletionTokens,
121+
CacheReadInputTokens: lastUsage.PromptTokensDetails.CachedTokens,
121122
ExtraTokenTypes: map[string]int64{
122123
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,
123-
"prompt_cached": lastUsage.PromptTokensDetails.CachedTokens,
124+
"prompt_cached": lastUsage.PromptTokensDetails.CachedTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
124125
"completion_accepted_prediction": lastUsage.CompletionTokensDetails.AcceptedPredictionTokens,
125126
"completion_rejected_prediction": lastUsage.CompletionTokensDetails.RejectedPredictionTokens,
126127
"completion_audio": lastUsage.CompletionTokensDetails.AudioTokens,

intercept/chatcompletions/streaming.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,14 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
214214
// If the usage information is set, track it.
215215
// The API will send usage information when the response terminates, which will happen if a tool call is invoked.
216216
_ = i.recorder.RecordTokenUsage(streamCtx, &recorder.TokenUsageRecord{
217-
InterceptionID: i.ID().String(),
218-
MsgID: processor.getMsgID(),
219-
Input: calculateActualInputTokenUsage(lastUsage),
220-
Output: lastUsage.CompletionTokens,
217+
InterceptionID: i.ID().String(),
218+
MsgID: processor.getMsgID(),
219+
Input: calculateActualInputTokenUsage(lastUsage),
220+
Output: lastUsage.CompletionTokens,
221+
CacheReadInputTokens: lastUsage.PromptTokensDetails.CachedTokens,
221222
ExtraTokenTypes: map[string]int64{
222223
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,
223-
"prompt_cached": lastUsage.PromptTokensDetails.CachedTokens,
224+
"prompt_cached": lastUsage.PromptTokensDetails.CachedTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
224225
"completion_accepted_prediction": lastUsage.CompletionTokensDetails.AcceptedPredictionTokens,
225226
"completion_rejected_prediction": lastUsage.CompletionTokensDetails.RejectedPredictionTokens,
226227
"completion_audio": lastUsage.CompletionTokensDetails.AudioTokens,

intercept/messages/blocking.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,16 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
128128
}
129129

130130
_ = i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{
131-
InterceptionID: i.ID().String(),
132-
MsgID: resp.ID,
133-
Input: resp.Usage.InputTokens,
134-
Output: resp.Usage.OutputTokens,
131+
InterceptionID: i.ID().String(),
132+
MsgID: resp.ID,
133+
Input: resp.Usage.InputTokens,
134+
Output: resp.Usage.OutputTokens,
135+
CacheReadInputTokens: resp.Usage.CacheReadInputTokens,
136+
CacheWriteInputTokens: resp.Usage.CacheCreationInputTokens,
135137
ExtraTokenTypes: map[string]int64{
136138
"web_search_requests": resp.Usage.ServerToolUse.WebSearchRequests,
137-
"cache_creation_input": resp.Usage.CacheCreationInputTokens,
138-
"cache_read_input": resp.Usage.CacheReadInputTokens,
139+
"cache_creation_input": resp.Usage.CacheCreationInputTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
140+
"cache_read_input": resp.Usage.CacheReadInputTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
139141
"cache_ephemeral_1h_input": resp.Usage.CacheCreation.Ephemeral1hInputTokens,
140142
"cache_ephemeral_5m_input": resp.Usage.CacheCreation.Ephemeral5mInputTokens,
141143
},

intercept/messages/streaming.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ newStream:
205205
accumulateUsage(&cumulativeUsage, start.Message.Usage)
206206

207207
_ = i.recorder.RecordTokenUsage(streamCtx, &recorder.TokenUsageRecord{
208-
InterceptionID: i.ID().String(),
209-
MsgID: message.ID,
210-
Input: start.Message.Usage.InputTokens,
211-
Output: start.Message.Usage.OutputTokens,
208+
InterceptionID: i.ID().String(),
209+
MsgID: message.ID,
210+
Input: start.Message.Usage.InputTokens,
211+
Output: start.Message.Usage.OutputTokens,
212+
CacheReadInputTokens: start.Message.Usage.CacheReadInputTokens,
213+
CacheWriteInputTokens: start.Message.Usage.CacheCreationInputTokens,
212214
ExtraTokenTypes: map[string]int64{
213215
"web_search_requests": start.Message.Usage.ServerToolUse.WebSearchRequests,
214-
"cache_creation_input": start.Message.Usage.CacheCreationInputTokens,
215-
"cache_read_input": start.Message.Usage.CacheReadInputTokens,
216+
"cache_creation_input": start.Message.Usage.CacheCreationInputTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
217+
"cache_read_input": start.Message.Usage.CacheReadInputTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
216218
"cache_ephemeral_1h_input": start.Message.Usage.CacheCreation.Ephemeral1hInputTokens,
217219
"cache_ephemeral_5m_input": start.Message.Usage.CacheCreation.Ephemeral5mInputTokens,
218220
},

intercept/responses/base.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,13 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon
240240
inputNonCacheTokens := usage.InputTokens - usage.InputTokensDetails.CachedTokens
241241

242242
if err := i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{
243-
InterceptionID: i.ID().String(),
244-
MsgID: response.ID,
245-
Input: inputNonCacheTokens,
246-
Output: usage.OutputTokens,
243+
InterceptionID: i.ID().String(),
244+
MsgID: response.ID,
245+
Input: inputNonCacheTokens,
246+
Output: usage.OutputTokens,
247+
CacheReadInputTokens: usage.InputTokensDetails.CachedTokens,
247248
ExtraTokenTypes: map[string]int64{
248-
"input_cached": usage.InputTokensDetails.CachedTokens,
249+
"input_cached": usage.InputTokensDetails.CachedTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
249250
"output_reasoning": usage.OutputTokensDetails.ReasoningTokens,
250251
"total_tokens": usage.TotalTokens,
251252
},

intercept/responses/base_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,11 @@ func TestRecordTokenUsage(t *testing.T) {
297297
},
298298
},
299299
expected: &recorder.TokenUsageRecord{
300-
InterceptionID: id.String(),
301-
MsgID: "resp_full",
302-
Input: 5, // 10 input - 5 cached
303-
Output: 20,
300+
InterceptionID: id.String(),
301+
MsgID: "resp_full",
302+
Input: 5, // 10 input - 5 cached
303+
Output: 20,
304+
CacheReadInputTokens: 5,
304305
ExtraTokenTypes: map[string]int64{
305306
"input_cached": 5,
306307
"output_reasoning": 5,

internal/integrationtest/bridge_test.go

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,31 @@ func TestAnthropicMessages(t *testing.T) {
4444
t.Parallel()
4545

4646
cases := []struct {
47-
name string
48-
streaming bool
49-
expectedInputTokens int
50-
expectedOutputTokens int
51-
expectedToolCallID string
47+
name string
48+
streaming bool
49+
expectedInputTokens int
50+
expectedOutputTokens int
51+
expectedCacheReadInputTokens int
52+
expectedCacheWriteInputTokens int
53+
expectedToolCallID string
5254
}{
5355
{
54-
name: "streaming",
55-
streaming: true,
56-
expectedInputTokens: 2,
57-
expectedOutputTokens: 66,
58-
expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo",
56+
name: "streaming",
57+
streaming: true,
58+
expectedInputTokens: 2,
59+
expectedOutputTokens: 66,
60+
expectedCacheReadInputTokens: 13993,
61+
expectedCacheWriteInputTokens: 22,
62+
expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo",
5963
},
6064
{
61-
name: "non-streaming",
62-
streaming: false,
63-
expectedInputTokens: 5,
64-
expectedOutputTokens: 84,
65-
expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq",
65+
name: "non-streaming",
66+
streaming: false,
67+
expectedInputTokens: 5,
68+
expectedOutputTokens: 84,
69+
expectedCacheReadInputTokens: 23490,
70+
expectedCacheWriteInputTokens: 0,
71+
expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq",
6672
},
6773
}
6874

@@ -104,6 +110,8 @@ func TestAnthropicMessages(t *testing.T) {
104110

105111
assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated")
106112
assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated")
113+
assert.EqualValues(t, tc.expectedCacheReadInputTokens, bridgeServer.Recorder.TotalCacheReadInputTokens(), "cache read input tokens miscalculated")
114+
assert.EqualValues(t, tc.expectedCacheWriteInputTokens, bridgeServer.Recorder.TotalCacheWriteInputTokens(), "cache write input tokens miscalculated")
107115

108116
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
109117
require.Len(t, toolUsages, 1)

0 commit comments

Comments
 (0)