Skip to content

Commit 914bc09

Browse files
committed
fix(proxy): classify token-endpoint 403 invalid_grant as auth-failure; audit detected protocol
1 parent 0176068 commit 914bc09

3 files changed

Lines changed: 98 additions & 28 deletions

File tree

internal/proxy/pool_failover.go

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,14 @@ func classifyFailover(statusCode int, body []byte, isTokenEndpoint bool) (class
8686
if bodyContainsAny(body, "insufficient_quota", "quota_exceeded", "quota exhausted", "rate_limit_exceeded") {
8787
return failoverRateLimited, ""
8888
}
89-
return failoverNone, ""
89+
// NOT a quota signal: do not early-return. A 403 is still a non-2xx
90+
// status, so a real token-endpoint body of invalid_grant/invalid_token
91+
// must classify as auth-failure (consistent with the 400/401 path).
92+
// The shared non-2xx token-endpoint check below handles it; a 403 from
93+
// a non-token-endpoint with an unrelated body still resolves to
94+
// failoverNone there (the body is only trusted on a real token URL).
9095
}
91-
// Non-4xx-status path. Only a real token-endpoint body may be classified
96+
// Non-2xx-status path. Only a real token-endpoint body may be classified
9297
// (invalid_grant/invalid_token), and only when the status is not a 2xx
9398
// success. A 2xx token response is a healthy refresh, never a failover.
9499
if isTokenEndpoint && (statusCode < 200 || statusCode > 299) {
@@ -133,21 +138,26 @@ type FailoverEvent struct {
133138
// poolForResponse maps a response's CONNECT destination back to a pooled
134139
// binding and returns the pool name + the member that was active for this
135140
// request. Returns ok=false when the destination is not bound to a pool.
136-
func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember string, pr *vault.PoolResolver, ok bool) {
141+
//
142+
// proto is the protocol detected for THIS request (the same value used for
143+
// the protocol-scoped binding lookup). The caller threads it into the
144+
// cred_failover audit event so the audit records the real protocol of the
145+
// pooled binding (grpc / http2 / etc.) instead of a hardcoded "https".
146+
func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember, proto string, pr *vault.PoolResolver, ok bool) {
137147
if a.poolResolver == nil || a.resolver == nil {
138-
return "", "", nil, false
148+
return "", "", "", nil, false
139149
}
140150
pr = a.poolResolver.Load()
141151
if pr == nil {
142-
return "", "", nil, false
152+
return "", "", "", nil, false
143153
}
144154
res := a.resolver.Load()
145155
if res == nil {
146-
return "", "", nil, false
156+
return "", "", "", nil, false
147157
}
148158
host, port := connectTargetForFlow(a, f)
149159
if host == "" {
150-
return "", "", nil, false
160+
return "", "", "", nil, false
151161
}
152162
// Finding 3: the failover binding lookup MUST use the same protocol the
153163
// request-side injection (injectHeaders / buildPhantomPairs) used, not a
@@ -157,7 +167,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str
157167
// detectRequestProtocol mirrors the injection path exactly (URL scheme
158168
// then header refinement); for the common unscoped-binding case the
159169
// result is still https-equivalent so behavior is unchanged.
160-
proto := a.detectRequestProtocol(f, port).String()
170+
proto = a.detectRequestProtocol(f, port).String()
161171
for _, boundName := range res.CredentialsForDestination(host, port, proto) {
162172
if !pr.IsPool(boundName) {
163173
continue
@@ -176,15 +186,15 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str
176186
// member of this pool (a membership change could have
177187
// raced); otherwise fall through to ResolveActive.
178188
if pr.PoolForMember(injected) == boundName {
179-
return boundName, injected, pr, true
189+
return boundName, injected, proto, pr, true
180190
}
181191
}
182192
}
183193
member, mok := pr.ResolveActive(boundName)
184194
if !mok || member == "" {
185195
continue
186196
}
187-
return boundName, member, pr, true
197+
return boundName, member, proto, pr, true
188198
}
189199

190200
// Token-endpoint path. An OAuth refresh hits the credential's token-URL
@@ -246,7 +256,7 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str
246256
realRefresh := extractRequestRefreshToken(f.Request.Body, reqCT)
247257
if owner, ok := a.refreshAttr.Peek(realRefresh); ok && owner != "" {
248258
if ownerPool := pr.PoolForMember(owner); ownerPool != "" {
249-
return ownerPool, owner, pr, true
259+
return ownerPool, owner, proto, pr, true
250260
}
251261
// owner is no longer in any pool (membership change
252262
// raced the failure); fall through to the active-member
@@ -261,20 +271,20 @@ func (a *SluiceAddon) poolForResponse(f *mitmproxy.Flow) (pool, activeMember str
261271
log.Printf("[POOL-FAILOVER] pool %q: could not attribute "+
262272
"token-endpoint failure via injected refresh token; "+
263273
"falling back to active member %q", pool, active)
264-
return pool, active, pr, true
274+
return pool, active, proto, pr, true
265275
}
266276
// Last resort: a pooled index match if any (preserves prior
267277
// behavior when even ResolveActive cannot decide; better than
268278
// no attribution at all).
269279
for _, c := range matches {
270280
if pr.PoolForMember(c) != "" {
271-
return pool, c, pr, true
281+
return pool, c, proto, pr, true
272282
}
273283
}
274-
return pool, matched, pr, true
284+
return pool, matched, proto, pr, true
275285
}
276286
}
277-
return "", "", nil, false
287+
return "", "", "", nil, false
278288
}
279289

280290
// handlePoolFailover is the Phase 2 entry point invoked from Response for
@@ -302,7 +312,7 @@ func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) {
302312
if f == nil || f.Response == nil || f.Request == nil {
303313
return
304314
}
305-
pool, from, pr, ok := a.poolForResponse(f)
315+
pool, from, proto, pr, ok := a.poolForResponse(f)
306316
if !ok {
307317
return
308318
}
@@ -352,11 +362,14 @@ func (a *SluiceAddon) handlePoolFailover(f *mitmproxy.Flow) {
352362
evt := audit.Event{
353363
Destination: host,
354364
Port: port,
355-
Protocol: "https",
356-
Verdict: "failover",
357-
Action: "cred_failover",
358-
Reason: fmt.Sprintf("%s:%s->%s:%s", pool, from, to, tag),
359-
Credential: from,
365+
// Same protocol used for the protocol-scoped binding lookup in
366+
// poolForResponse, NOT a hardcoded "https". For a grpc/http2
367+
// scoped pooled binding the audit must record the real protocol.
368+
Protocol: proto,
369+
Verdict: "failover",
370+
Action: "cred_failover",
371+
Reason: fmt.Sprintf("%s:%s->%s:%s", pool, from, to, tag),
372+
Credential: from,
360373
}
361374
if err := a.auditLog.Log(evt); err != nil {
362375
log.Printf("[POOL-FAILOVER] audit log error: %v", err)

internal/proxy/pool_failover_test.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ func TestClassifyFailover(t *testing.T) {
5757
{"403 insufficient_quota", 403, `{"error":"insufficient_quota"}`, false, failoverRateLimited, "403"},
5858
{"403 quota_exceeded", 403, `{"error":{"code":"quota_exceeded"}}`, false, failoverRateLimited, "403"},
5959
{"403 unrelated -> noop", 403, `{"error":"forbidden: bad scope"}`, false, failoverNone, ""},
60+
// Finding 1: a token-endpoint 403 carrying invalid_grant/invalid_token
61+
// is an auth failure (consistent with the 400/401 token-endpoint path).
62+
// The old code early-returned failoverNone in the 403 branch before the
63+
// token-endpoint body check ever ran.
64+
{"403 token-endpoint invalid_grant -> auth", 403, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"},
65+
{"403 token-endpoint invalid_token -> auth", 403, `{"error":"invalid_token"}`, true, failoverAuthFailure, "invalid_token"},
66+
// 403 + quota signal stays rate-limited (unchanged).
67+
{"403 insufficient_quota (tokenEP) stays rate-limited", 403, `{"error":"insufficient_quota"}`, true, failoverRateLimited, "403"},
68+
// 403 + invalid_grant but NOT a real token endpoint -> still noop
69+
// (the body is only trusted on a real token URL).
70+
{"403 invalid_grant but NOT token endpoint -> noop", 403, `{"error":"invalid_grant"}`, false, failoverNone, ""},
6071
{"401 auth failure", 401, "", false, failoverAuthFailure, "401"},
6172
{"token-endpoint invalid_grant", 400, `{"error":"invalid_grant"}`, true, failoverAuthFailure, "invalid_grant"},
6273
{"token-endpoint invalid_token", 400, `{"error":"invalid_token"}`, true, failoverAuthFailure, "invalid_token"},
@@ -307,7 +318,7 @@ func TestPoolForResponseResolvesActiveMember(t *testing.T) {
307318
client := setupAddonConn(addon, "auth.example.com:443")
308319
f := newPoolRespFlow(client, 429, nil)
309320

310-
pool, member, pr, ok := addon.poolForResponse(f)
321+
pool, member, _, pr, ok := addon.poolForResponse(f)
311322
if !ok {
312323
t.Fatal("poolForResponse: expected a pooled destination match")
313324
}
@@ -392,7 +403,7 @@ func TestTokenEndpointHostFailoverOnPooledMember(t *testing.T) {
392403
// (this is exactly the gap CRITICAL-2 describes). poolForResponse must
393404
// still succeed via the token-URL index path.
394405
f := newPoolRespFlow(client, 400, []byte(`{"error":"invalid_grant"}`))
395-
pool, member, _, ok := addon.poolForResponse(f)
406+
pool, member, _, _, ok := addon.poolForResponse(f)
396407
if !ok {
397408
t.Fatal("poolForResponse: token-endpoint response on a pooled member must be attributed (CRITICAL-2 fix); got ok=false")
398409
}
@@ -497,7 +508,7 @@ func TestTokenEndpointFailoverAttributesInjectedMemberNotFirstIndex(t *testing.T
497508
// poolForResponse must now attribute the failure to memB (the injected
498509
// member), NOT memA (the first index entry).
499510
f := newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`))
500-
pool, member, _, ok := addon.poolForResponse(f)
511+
pool, member, _, _, ok := addon.poolForResponse(f)
501512
if !ok {
502513
t.Fatal("poolForResponse: token-endpoint failure on a pooled member must be attributed")
503514
}
@@ -614,7 +625,7 @@ func TestTokenEndpointFailover3MemberAttributesMiddleMember(t *testing.T) {
614625
addon.refreshAttr.Tag("memB-refresh", "memB")
615626

616627
f := newPoolRespFlowBody(client, 401, "memB-refresh", []byte(`{"error":"invalid_token"}`))
617-
pool, member, _, ok := addon.poolForResponse(f)
628+
pool, member, _, _, ok := addon.poolForResponse(f)
618629
if !ok || pool != "codex_pool" || member != "memB" {
619630
t.Fatalf("poolForResponse got ok=%v pool=%q member=%q, want codex_pool/memB", ok, pool, member)
620631
}
@@ -650,7 +661,7 @@ func TestTokenEndpointFailoverFallsBackToActiveMember(t *testing.T) {
650661
}
651662

652663
f := newPoolRespFlowBody(client, 400, "untagged-refresh", []byte(`{"error":"invalid_grant"}`))
653-
pool, member, _, ok := addon.poolForResponse(f)
664+
pool, member, _, _, ok := addon.poolForResponse(f)
654665
if !ok {
655666
t.Fatal("poolForResponse: expected attribution via active-member fallback")
656667
}

internal/proxy/pool_splithost_test.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package proxy
22

33
import (
4+
"encoding/json"
5+
"os"
6+
"path/filepath"
47
"strings"
58
"sync/atomic"
69
"testing"
710
"time"
811

12+
"github.com/nemirovsky/sluice/internal/audit"
913
"github.com/nemirovsky/sluice/internal/store"
1014
"github.com/nemirovsky/sluice/internal/vault"
1115
)
@@ -241,7 +245,7 @@ func TestSplitHost_TokenEndpointFailoverWithPlainCredSortingFirst(t *testing.T)
241245
// poolForResponse MUST attribute the failure to memB (the injected
242246
// member), not return ok=false because the plain credential sorted first.
243247
f := newPoolRespFlowBody(client, 400, "B-refresh-old", []byte(`{"error":"invalid_grant"}`))
244-
pool, member, _, ok := addon.poolForResponse(f)
248+
pool, member, _, _, ok := addon.poolForResponse(f)
245249
if !ok {
246250
t.Fatal("Finding 2: token-endpoint failure on a pooled member must be attributed even when a plain cred sorts first; got ok=false")
247251
}
@@ -355,13 +359,27 @@ func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) {
355359
t.Fatal("precondition: a 'https' lookup must NOT match the grpc-scoped binding (this is the Finding 3 bug)")
356360
}
357361

358-
pool2, member, _, ok := addon.poolForResponse(f)
362+
pool2, member, detProto, _, ok := addon.poolForResponse(f)
359363
if !ok {
360364
t.Fatal("Finding 3: protocol-scoped (grpc) pooled binding must be recognized on the failover path; got ok=false")
361365
}
362366
if pool2 != poolName || member != "gA" {
363367
t.Fatalf("Finding 3: got pool=%q member=%q, want %s/gA", pool2, member, poolName)
364368
}
369+
if detProto != ProtoGRPC.String() {
370+
t.Fatalf("Finding 2: poolForResponse detected protocol = %q, want %q", detProto, ProtoGRPC.String())
371+
}
372+
373+
// Finding 2: the cred_failover audit event must record the SAME
374+
// protocol that drove the binding lookup (grpc here), not a hardcoded
375+
// "https". Wire a real audit logger and assert the persisted Protocol.
376+
dir := t.TempDir()
377+
logPath := filepath.Join(dir, "audit.log")
378+
logger, lerr := audit.NewFileLogger(logPath)
379+
if lerr != nil {
380+
t.Fatalf("NewFileLogger: %v", lerr)
381+
}
382+
addon.auditLog = logger
365383

366384
var got FailoverEvent
367385
gotCalled := make(chan struct{}, 1)
@@ -382,4 +400,32 @@ func TestFinding3_ProtocolScopedPooledBindingFailoverLookup(t *testing.T) {
382400
if got.From != "gA" || got.Pool != poolName || got.Reason != "429" {
383401
t.Fatalf("FailoverEvent = %+v, want from=gA pool=%s reason=429", got, poolName)
384402
}
403+
404+
if cerr := logger.Close(); cerr != nil {
405+
t.Fatalf("logger close: %v", cerr)
406+
}
407+
data, rerr := os.ReadFile(logPath)
408+
if rerr != nil {
409+
t.Fatalf("read audit log: %v", rerr)
410+
}
411+
var foundFailover bool
412+
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
413+
if line == "" {
414+
continue
415+
}
416+
var evt audit.Event
417+
if uerr := json.Unmarshal([]byte(line), &evt); uerr != nil {
418+
t.Fatalf("unmarshal audit line %q: %v", line, uerr)
419+
}
420+
if evt.Action != "cred_failover" {
421+
continue
422+
}
423+
foundFailover = true
424+
if evt.Protocol != ProtoGRPC.String() {
425+
t.Fatalf("Finding 2: cred_failover audit Protocol = %q, want %q (must match the detected request protocol, not hardcoded https)", evt.Protocol, ProtoGRPC.String())
426+
}
427+
}
428+
if !foundFailover {
429+
t.Fatalf("no cred_failover audit event found in:\n%s", data)
430+
}
385431
}

0 commit comments

Comments
 (0)