Skip to content

Commit e559e5e

Browse files
authored
feat: add client session tracking (#198)
## Refactor client detection and add session tracking This change extracts client detection logic into a dedicated module and introduces session tracking capabilities for AI bridge interceptions. Closes #166 Required for coder/internal#1336 ### Changes - **Extract client detection**: Move `guessClient` function and client constants from `bridge.go` to new `client.go` file - **Add Client type**: Introduce `Client` type alias for better type safety and rename `ClientClaude` to `ClientClaudeCode` - **Implement session tracking**: Add `guessSessionID` function in `session.go` to extract session identifiers from requests - **Update interception records**: Add `ClientSessionID` field to `InterceptionRecord` and call session detection before creating interceptors - **Add comprehensive tests**: Include test coverage for both client detection and session ID extraction The session detection supports as many of our supported clients as possible.
1 parent bcc636a commit e559e5e

10 files changed

Lines changed: 568 additions & 155 deletions

File tree

bridge.go

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,6 @@ const (
2929
recordingTimeout = time.Second * 5
3030
)
3131

32-
const (
33-
// Possible values for the "client" field in interception records.
34-
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
35-
ClientClaude = "Claude Code"
36-
ClientCodex = "Codex"
37-
ClientCursor = "Cursor"
38-
ClientCopilotVSC = "GitHub Copilot (VS Code)"
39-
ClientCopilotCLI = "GitHub Copilot (CLI)"
40-
ClientKilo = "Kilo Code"
41-
ClientMux = "Mux"
42-
ClientRoo = "Roo Code"
43-
ClientZed = "Zed"
44-
ClientUnknown = "Unknown"
45-
)
46-
4732
// RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs;
4833
// specifically, OpenAI's & Anthropic's at present.
4934
// RequestBridge intercepts requests to - and responses from - these upstream services to provide
@@ -167,6 +152,11 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
167152
ctx, span := tracer.Start(r.Context(), "Intercept")
168153
defer span.End()
169154

155+
// We execute this before CreateInterceptor since the interceptors
156+
// read the request body and don't reset them.
157+
client := guessClient(r)
158+
sessionID := guessSessionID(client, r)
159+
170160
interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
171161
if err != nil {
172162
span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err))
@@ -203,13 +193,14 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
203193
interceptor.Setup(logger, asyncRecorder, mcpProxy)
204194

205195
if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
206-
Client: guessClient(r),
207196
ID: interceptor.ID().String(),
208197
InitiatorID: actor.ID,
209198
Metadata: actor.Metadata,
210199
Model: interceptor.Model(),
211200
Provider: p.Name(),
212201
UserAgent: r.UserAgent(),
202+
Client: string(client),
203+
ClientSessionID: sessionID,
213204
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
214205
}); err != nil {
215206
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
@@ -338,34 +329,3 @@ func mergeContexts(base, other context.Context) context.Context {
338329
}()
339330
return ctx
340331
}
341-
342-
// guessClient attempts to guess the client application from the request headers.
343-
// Not all clients set proper user agent headers, so this is a best-effort approach.
344-
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
345-
func guessClient(r *http.Request) string {
346-
userAgent := strings.ToLower(r.UserAgent())
347-
originator := r.Header.Get("originator")
348-
349-
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
350-
switch {
351-
case strings.HasPrefix(userAgent, "mux/"):
352-
return ClientMux
353-
case strings.HasPrefix(userAgent, "claude"):
354-
return ClientClaude
355-
case strings.HasPrefix(userAgent, "codex"):
356-
return ClientCodex
357-
case strings.HasPrefix(userAgent, "zed/"):
358-
return ClientZed
359-
case strings.HasPrefix(userAgent, "githubcopilotchat/"):
360-
return ClientCopilotVSC
361-
case strings.HasPrefix(userAgent, "copilot/"):
362-
return ClientCopilotCLI
363-
case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code":
364-
return ClientKilo
365-
case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code":
366-
return ClientRoo
367-
case r.Header.Get("x-cursor-client-version") != "":
368-
return ClientCursor
369-
}
370-
return ClientUnknown
371-
}

bridge_integration_test.go

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ func TestSimple(t *testing.T) {
560560
createRequest func(*testing.T, string, []byte) *http.Request
561561
expectedMsgID string
562562
userAgent string
563-
expectedClient string
563+
expectedClient aibridge.Client
564564
}{
565565
{
566566
name: config.ProviderAnthropic,
@@ -572,7 +572,7 @@ func TestSimple(t *testing.T) {
572572
createRequest: createAnthropicMessagesReq,
573573
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
574574
userAgent: "claude-cli/2.0.67 (external, cli)",
575-
expectedClient: aibridge.ClientClaude,
575+
expectedClient: aibridge.ClientClaudeCode,
576576
},
577577
{
578578
name: config.ProviderOpenAI,
@@ -682,7 +682,7 @@ func TestSimple(t *testing.T) {
682682
interceptions := recorderClient.RecordedInterceptions()
683683
require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions)
684684
assert.Equal(t, tc.userAgent, interceptions[0].UserAgent)
685-
assert.Equal(t, tc.expectedClient, interceptions[0].Client)
685+
assert.Equal(t, string(tc.expectedClient), interceptions[0].Client)
686686

687687
recorderClient.VerifyAllInterceptionsEnded(t)
688688
})
@@ -691,6 +691,127 @@ func TestSimple(t *testing.T) {
691691
}
692692
}
693693

694+
func TestSessionIDTracking(t *testing.T) {
695+
t.Parallel()
696+
697+
testCases := []struct {
698+
name string
699+
fixture []byte
700+
expectedClient aibridge.Client
701+
sessionID string
702+
configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error)
703+
createRequest func(t *testing.T, baseURL string, body []byte) *http.Request
704+
}{
705+
// Session in header.
706+
{
707+
name: "mux",
708+
fixture: fixtures.AntSimple,
709+
expectedClient: aibridge.ClientMux,
710+
sessionID: "mux-workspace-321",
711+
configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
712+
t.Helper()
713+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
714+
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
715+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
716+
},
717+
createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request {
718+
t.Helper()
719+
req := createAnthropicMessagesReq(t, baseURL, body)
720+
req.Header.Set("User-Agent", "mux/1.0.0")
721+
req.Header.Set("X-Mux-Workspace-Id", "mux-workspace-321")
722+
return req
723+
},
724+
},
725+
// Session in body.
726+
{
727+
name: "claude_code",
728+
fixture: fixtures.AntSimple,
729+
expectedClient: aibridge.ClientClaudeCode,
730+
sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479",
731+
configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
732+
t.Helper()
733+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
734+
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
735+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
736+
},
737+
createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request {
738+
t.Helper()
739+
// Claude Code embeds the session ID in metadata.user_id within the body.
740+
body, err := sjson.SetBytes(body, "metadata.user_id",
741+
"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479")
742+
require.NoError(t, err)
743+
req := createAnthropicMessagesReq(t, baseURL, body)
744+
req.Header.Set("User-Agent", "claude-cli/2.0.67 (external, cli)")
745+
return req
746+
},
747+
},
748+
// No session.
749+
{
750+
name: "zed",
751+
fixture: fixtures.AntSimple,
752+
expectedClient: aibridge.ClientZed,
753+
configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
754+
t.Helper()
755+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
756+
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
757+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
758+
},
759+
createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request {
760+
t.Helper()
761+
req := createAnthropicMessagesReq(t, baseURL, body)
762+
req.Header.Set("User-Agent", "Zed/0.219.4+stable.119.abc123 (macos; aarch64)")
763+
return req
764+
},
765+
},
766+
}
767+
768+
for _, tc := range testCases {
769+
t.Run(tc.name, func(t *testing.T) {
770+
t.Parallel()
771+
772+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
773+
t.Cleanup(cancel)
774+
775+
fix := fixtures.Parse(t, tc.fixture)
776+
upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix))
777+
778+
recorderClient := &testutil.MockRecorder{}
779+
780+
b, err := tc.configureFunc(t, upstream.URL, recorderClient)
781+
require.NoError(t, err)
782+
mockSrv := httptest.NewUnstartedServer(b)
783+
t.Cleanup(mockSrv.Close)
784+
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
785+
return aibcontext.AsActor(ctx, userID, nil)
786+
}
787+
mockSrv.Start()
788+
789+
req := tc.createRequest(t, mockSrv.URL, fix.Request())
790+
resp, err := http.DefaultClient.Do(req)
791+
require.NoError(t, err)
792+
require.Equal(t, http.StatusOK, resp.StatusCode)
793+
defer resp.Body.Close()
794+
795+
// Drain the body to let the stream complete.
796+
_, err = io.ReadAll(resp.Body)
797+
require.NoError(t, err)
798+
799+
interceptions := recorderClient.RecordedInterceptions()
800+
require.Len(t, interceptions, 1, "expected exactly one interception")
801+
assert.Equal(t, string(tc.expectedClient), interceptions[0].Client)
802+
803+
if tc.sessionID == "" {
804+
assert.Nil(t, interceptions[0].ClientSessionID, "expected nil session ID for %s", tc.name)
805+
} else {
806+
require.NotNil(t, interceptions[0].ClientSessionID, "expected non-nil session ID for %s", tc.name)
807+
assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID)
808+
}
809+
810+
recorderClient.VerifyAllInterceptionsEnded(t)
811+
})
812+
}
813+
}
814+
694815
func TestFallthrough(t *testing.T) {
695816
t.Parallel()
696817

bridge_test.go

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -104,103 +104,3 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
104104
})
105105
}
106106
}
107-
108-
func TestGuessClient(t *testing.T) {
109-
t.Parallel()
110-
111-
tests := []struct {
112-
name string
113-
userAgent string
114-
headers map[string]string
115-
wantClient string
116-
}{
117-
{
118-
name: "mux",
119-
userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22",
120-
wantClient: ClientMux,
121-
},
122-
{
123-
name: "claude_code",
124-
userAgent: "claude-cli/2.0.67 (external, cli)",
125-
wantClient: ClientClaude,
126-
},
127-
{
128-
name: "codex_cli",
129-
userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef",
130-
wantClient: ClientCodex,
131-
},
132-
{
133-
name: "zed",
134-
userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)",
135-
wantClient: ClientZed,
136-
},
137-
{
138-
name: "github_copilot_vsc",
139-
userAgent: "GitHubCopilotChat/0.37.2026011603",
140-
wantClient: ClientCopilotVSC,
141-
},
142-
{
143-
name: "github_copilot_cli",
144-
userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)",
145-
wantClient: ClientCopilotCLI,
146-
},
147-
{
148-
name: "kilo_code_user_agent",
149-
userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1",
150-
wantClient: ClientKilo,
151-
},
152-
{
153-
name: "kilo_code_originator",
154-
headers: map[string]string{"Originator": "kilo-code"},
155-
wantClient: ClientKilo,
156-
},
157-
{
158-
name: "roo_code_user_agent",
159-
userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1",
160-
wantClient: ClientRoo,
161-
},
162-
{
163-
name: "roo_code_originator",
164-
headers: map[string]string{"Originator": "roo-code"},
165-
wantClient: ClientRoo,
166-
},
167-
{
168-
name: "cursor_x_cursor_client_version",
169-
userAgent: "connect-es/1.6.1",
170-
headers: map[string]string{"X-Cursor-client-version": "0.50.0"},
171-
wantClient: ClientCursor,
172-
},
173-
{
174-
name: "cursor_x_cursor_some_other_header",
175-
headers: map[string]string{"x-cursor-client-version": "abc123"},
176-
wantClient: ClientCursor,
177-
},
178-
{
179-
name: "unknown_client",
180-
userAgent: "ccclaude-cli/calude-with-wrong-prefix",
181-
wantClient: ClientUnknown,
182-
},
183-
{
184-
name: "empty_user_agent",
185-
userAgent: "",
186-
wantClient: ClientUnknown,
187-
},
188-
}
189-
190-
for _, tt := range tests {
191-
t.Run(tt.name, func(t *testing.T) {
192-
t.Parallel()
193-
194-
req, err := http.NewRequest(http.MethodGet, "", nil)
195-
require.NoError(t, err)
196-
197-
req.Header.Set("User-Agent", tt.userAgent)
198-
for key, value := range tt.headers {
199-
req.Header.Set(key, value)
200-
}
201-
202-
got := guessClient(req)
203-
require.Equal(t, tt.wantClient, got)
204-
})
205-
}
206-
}

client.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package aibridge
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
)
7+
8+
type Client string
9+
10+
const (
11+
// Possible values for the "client" field in interception records.
12+
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
13+
ClientClaudeCode Client = "Claude Code"
14+
ClientCodex Client = "Codex"
15+
ClientZed Client = "Zed"
16+
ClientCopilotVSC Client = "GitHub Copilot (VS Code)"
17+
ClientCopilotCLI Client = "GitHub Copilot (CLI)"
18+
ClientKilo Client = "Kilo Code"
19+
ClientMux Client = "Mux"
20+
ClientRoo Client = "Roo Code"
21+
ClientCursor Client = "Cursor"
22+
ClientUnknown Client = "Unknown"
23+
)
24+
25+
// guessClient attempts to guess the client application from the request headers.
26+
// Not all clients set proper user agent headers, so this is a best-effort approach.
27+
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
28+
func guessClient(r *http.Request) Client {
29+
userAgent := strings.ToLower(r.UserAgent())
30+
originator := r.Header.Get("originator")
31+
32+
// Must be kept in sync with documentation: https://github.com/coder/coder/blob/90c11f3386578da053ec5cd9f1475835b980e7c7/docs/ai-coder/ai-bridge/monitoring.md?plain=1#L36-L44
33+
switch {
34+
case strings.HasPrefix(userAgent, "mux/"):
35+
return ClientMux
36+
case strings.HasPrefix(userAgent, "claude"):
37+
return ClientClaudeCode
38+
case strings.HasPrefix(userAgent, "codex"):
39+
return ClientCodex
40+
case strings.HasPrefix(userAgent, "zed/"):
41+
return ClientZed
42+
case strings.HasPrefix(userAgent, "githubcopilotchat/"):
43+
return ClientCopilotVSC
44+
case strings.HasPrefix(userAgent, "copilot/"):
45+
return ClientCopilotCLI
46+
case strings.HasPrefix(userAgent, "kilo-code/") || originator == "kilo-code":
47+
return ClientKilo
48+
case strings.HasPrefix(userAgent, "roo-code/") || originator == "roo-code":
49+
return ClientRoo
50+
case r.Header.Get("x-cursor-client-version") != "":
51+
return ClientCursor
52+
}
53+
return ClientUnknown
54+
}

0 commit comments

Comments
 (0)