Skip to content

Commit 4b27075

Browse files
authored
Merge pull request #8 from nnemirovsky/fix/dns-and-channel-gate
fix(proxy): forward DNS for all non-denied domains, remove ChannelGateMiddleware
2 parents 1edb258 + a586275 commit 4b27075

6 files changed

Lines changed: 54 additions & 73 deletions

File tree

cmd/sluice/main.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,11 +861,10 @@ func startAPIServer(addr string, apiSrv *api.Server, st *store.Store, mcpHandler
861861
}
862862
// oapi-codegen wraps handlers bottom-up: last middleware in the slice
863863
// becomes the outermost layer. List channel gate first, then auth, so
864-
// the request hits auth before the channel gate. This ensures bad tokens
865-
// get 401 before the channel gate reveals whether HTTP channel is enabled.
864+
// Bearer token auth protects all /api/* routes. The API is accessible
865+
// whenever SLUICE_API_TOKEN is set, regardless of which channels are enabled.
866866
apiHandler := api.HandlerWithOptions(apiSrv, api.ChiServerOptions{
867867
Middlewares: []api.MiddlewareFunc{
868-
api.ChannelGateMiddleware(st),
869868
api.BearerAuthMiddleware,
870869
},
871870
})

internal/api/server_test.go

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ func enableHTTPChannel(t *testing.T, st *store.Store) {
4545
// oapi-codegen wraps handlers bottom-up: last middleware in the slice becomes
4646
// outermost. Channel gate goes first (innermost), auth second (outermost),
4747
// so auth rejects before channel gate reveals channel state.
48-
func newTestHandler(t *testing.T, srv *api.Server, st *store.Store) http.Handler {
48+
func newTestHandler(t *testing.T, srv *api.Server, _ *store.Store) http.Handler {
4949
t.Helper()
5050
return api.HandlerWithOptions(srv, api.ChiServerOptions{
5151
Middlewares: []api.MiddlewareFunc{
52-
api.ChannelGateMiddleware(st),
5352
api.BearerAuthMiddleware,
5453
},
5554
})
@@ -210,53 +209,6 @@ func TestAuth_BadTokenBeforeChannelCheck(t *testing.T) {
210209

211210
// --- Channel gate middleware tests ---
212211

213-
func TestChannelGate_Disabled(t *testing.T) {
214-
st := newTestStore(t)
215-
// Default store has only Telegram channel (type=0), no HTTP channel
216-
srv := api.NewServer(st, nil, nil, "")
217-
218-
t.Setenv("SLUICE_API_TOKEN", "secret")
219-
handler := newTestHandler(t, srv, st)
220-
221-
req := httptest.NewRequest("GET", "/api/status", nil)
222-
req.Header.Set("Authorization", "Bearer secret")
223-
rec := httptest.NewRecorder()
224-
handler.ServeHTTP(rec, req)
225-
226-
if rec.Code != http.StatusForbidden {
227-
t.Errorf("expected 403, got %d", rec.Code)
228-
}
229-
230-
var resp api.ErrorResponse
231-
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
232-
t.Fatalf("decode: %v", err)
233-
}
234-
if resp.Error != "HTTP channel is not enabled" {
235-
t.Errorf("unexpected error: %q", resp.Error)
236-
}
237-
if resp.Code == nil || *resp.Code != "channel_disabled" {
238-
t.Errorf("unexpected code: %v", resp.Code)
239-
}
240-
}
241-
242-
func TestChannelGate_Enabled(t *testing.T) {
243-
st := newTestStore(t)
244-
enableHTTPChannel(t, st)
245-
srv := api.NewServer(st, nil, nil, "")
246-
247-
t.Setenv("SLUICE_API_TOKEN", "secret")
248-
handler := newTestHandler(t, srv, st)
249-
250-
req := httptest.NewRequest("GET", "/api/status", nil)
251-
req.Header.Set("Authorization", "Bearer secret")
252-
rec := httptest.NewRecorder()
253-
handler.ServeHTTP(rec, req)
254-
255-
if rec.Code != http.StatusOK {
256-
t.Errorf("expected 200, got %d", rec.Code)
257-
}
258-
}
259-
260212
// --- Approval endpoint tests ---
261213

262214
func TestGetApiApprovals_Empty(t *testing.T) {

internal/policy/engine.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,25 @@ func (e *Engine) IsDenied(dest string, port int) bool {
381381
return matchRules(e.compiled.denyRules, dest, port)
382382
}
383383

384+
// IsDeniedDomain checks if a domain matches any deny rule regardless of port
385+
// or protocol. Used by the DNS interceptor to block resolution for explicitly
386+
// denied domains while allowing all others through (policy enforcement happens
387+
// at the SOCKS5 CONNECT level).
388+
func (e *Engine) IsDeniedDomain(dest string) bool {
389+
dest = normalizeDestination(dest)
390+
e.mu.RLock()
391+
defer e.mu.RUnlock()
392+
if e.compiled == nil {
393+
return false
394+
}
395+
for _, r := range e.compiled.denyRules {
396+
if r.glob.Match(dest) {
397+
return true
398+
}
399+
}
400+
return false
401+
}
402+
384403
// IsRestricted checks whether a destination and port match any explicit deny
385404
// or ask rule. Unlike Evaluate, this does not fall back to the default verdict.
386405
// Used for DNS rebinding checks where the original FQDN was already allowed

internal/proxy/dns.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,18 @@ func (d *DNSInterceptor) HandleQuery(query []byte) ([]byte, error) {
239239
}
240240

241241
domain := questions[0].Name
242-
verdict := d.evaluate(domain)
242+
243+
// Only block DNS for explicitly denied domains. All other verdicts
244+
// (allow, ask, default) are forwarded. Policy enforcement for the
245+
// actual connection happens at the SOCKS5 CONNECT level, not DNS.
246+
// Blocking DNS for "ask" or default-deny domains would prevent the
247+
// connection from ever reaching the approval flow.
248+
eng := d.engine.Load()
249+
denied := eng.IsDeniedDomain(domain)
243250

244251
if d.audit != nil {
245252
verdictStr := "allow"
246-
if verdict != policy.Allow {
253+
if denied {
247254
verdictStr = "deny"
248255
}
249256
if logErr := d.audit.Log(audit.Event{
@@ -257,7 +264,7 @@ func (d *DNSInterceptor) HandleQuery(query []byte) ([]byte, error) {
257264
}
258265
}
259266

260-
if verdict != policy.Allow {
267+
if denied {
261268
return BuildNXDOMAIN(query)
262269
}
263270

@@ -275,17 +282,10 @@ func (d *DNSInterceptor) HandleQuery(query []byte) ([]byte, error) {
275282
return resp, nil
276283
}
277284

278-
// evaluate checks the DNS domain against the policy engine. Uses
279-
// EvaluateWithProtocol with protocol "dns" so dns-specific rules match.
280-
func (d *DNSInterceptor) evaluate(domain string) policy.Verdict {
281-
eng := d.engine.Load()
282-
// Use EvaluateWithProtocol with "dns" so protocol-scoped rules match.
283-
// DNS follows the same deny-then-allow-then-default semantics as
284-
// regular evaluation, not the UDP default-deny semantics, because
285-
// DNS queries are a known protocol with meaningful domain-level policy.
286-
v := eng.EvaluateWithProtocol(domain, 53, ProtoDNS.String())
287-
return v
288-
}
285+
// NOTE: the old evaluate() method has been removed. DNS resolution is now
286+
// allowed for all domains except those with explicit deny rules. Policy
287+
// enforcement (allow/ask/deny) happens at the SOCKS5 CONNECT level where
288+
// the approval broker can send Telegram notifications for "ask" verdicts.
289289

290290
// forwardToResolver sends the query to the upstream DNS resolver and returns
291291
// the response.

internal/proxy/dns_test.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ func TestDNSInterceptor_DeniedDomain(t *testing.T) {
447447
eng, err := policy.LoadFromBytes([]byte(`
448448
[policy]
449449
default = "deny"
450+
451+
[[deny]]
452+
destination = "denied.example.com"
450453
`))
451454
if err != nil {
452455
t.Fatal(err)
@@ -464,7 +467,7 @@ default = "deny"
464467
t.Fatalf("HandleQuery: %v", err)
465468
}
466469

467-
// Response should be NXDOMAIN.
470+
// Response should be NXDOMAIN (explicit deny rule).
468471
respID := binary.BigEndian.Uint16(resp[0:2])
469472
if respID != 0x2222 {
470473
t.Errorf("response ID = 0x%04x, want 0x2222", respID)
@@ -543,16 +546,18 @@ protocols = ["dns"]
543546
t.Errorf("evil.google.com RCODE = %d, want %d (NXDOMAIN)", rcode, dnsRcodeNXDomain)
544547
}
545548

546-
// Unmatched domain (default deny).
549+
// Unmatched domain (default deny). DNS should still resolve so the
550+
// connection reaches the SOCKS5 layer where the ask/deny flow runs.
551+
// Only explicitly denied domains get NXDOMAIN at the DNS level.
547552
query = buildDNSQuery(0x5555, "other.com", dnsTypeA)
548553
resp, err = interceptor.HandleQuery(query)
549554
if err != nil {
550555
t.Fatalf("HandleQuery(other.com): %v", err)
551556
}
552557
flags = binary.BigEndian.Uint16(resp[2:4])
553558
rcode = flags & 0x000F
554-
if rcode != dnsRcodeNXDomain {
555-
t.Errorf("other.com RCODE = %d, want %d (NXDOMAIN)", rcode, dnsRcodeNXDomain)
559+
if rcode == dnsRcodeNXDomain {
560+
t.Errorf("other.com RCODE = %d (NXDOMAIN), want forwarded (non-denied domains resolve via DNS)", rcode)
556561
}
557562
}
558563

@@ -877,10 +882,13 @@ func TestParseDNSNameEdgeCases(t *testing.T) {
877882
})
878883
}
879884

880-
// TestHandleQueryDenyRule tests that a deny rule causes NXDOMAIN.
885+
// TestHandleQueryDenyRule tests that an explicit deny rule causes NXDOMAIN.
881886
func TestHandleQueryDenyRule(t *testing.T) {
882887
eng, _ := policy.LoadFromBytes([]byte(`[policy]
883888
default = "deny"
889+
890+
[[deny]]
891+
destination = "blocked.example.com"
884892
`))
885893

886894
var enginePtr atomic.Pointer[policy.Engine]

internal/proxy/server_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,8 +1203,8 @@ protocols = ["dns"]
12031203
}
12041204

12051205
func TestUDPAssociateDNSInterceptionNXDOMAIN(t *testing.T) {
1206-
// DNS interceptor should return NXDOMAIN for denied domains without
1207-
// contacting the upstream resolver.
1206+
// DNS interceptor should return NXDOMAIN for explicitly denied domains
1207+
// without contacting the upstream resolver.
12081208
eng, err := policy.LoadFromBytes([]byte(`
12091209
[policy]
12101210
default = "deny"
@@ -1213,6 +1213,9 @@ default = "deny"
12131213
destination = "allowed.example.com"
12141214
ports = [53]
12151215
protocols = ["dns"]
1216+
1217+
[[deny]]
1218+
destination = "denied.example.com"
12161219
`))
12171220
if err != nil {
12181221
t.Fatal(err)

0 commit comments

Comments
 (0)