Skip to content

Commit 52f5127

Browse files
committed
fix(proxy): pass WithProtocol on non-TLS deferred ask checks and prevent port overwrite
All CheckAndConsume calls in byte-detection and direct-connect paths now pass WithProtocol so protocol-scoped rules (ssh, http, imap, smtp) are matched correctly. Prevent TlsEstablishedServer from overwriting a port already captured by ServerConnected with a recovered/defaulted value.
1 parent 07a8e24 commit 52f5127

2 files changed

Lines changed: 18 additions & 8 deletions

File tree

internal/proxy/addon.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,17 @@ func (a *SluiceAddon) storeConnectTarget(clientID uuid.UUID, host string, port i
336336
return
337337
}
338338
cs := v.(*connState)
339-
cs.connectHost = host
340-
cs.connectPort = port
339+
// Only update host/port if not already set by a prior callback
340+
// (ServerConnected fires before TlsEstablishedServer). This prevents
341+
// a host-only TLS callback from overwriting a correct port with a
342+
// recovered/defaulted value.
343+
if cs.connectHost == "" {
344+
cs.connectHost = host
345+
cs.connectPort = port
346+
} else if host != "" && host != cs.connectHost {
347+
cs.connectHost = host
348+
cs.connectPort = port
349+
}
341350

342351
// Consume any pending checker for this destination.
343352
dest := net.JoinHostPort(host, strconv.Itoa(port))

internal/proxy/server.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,8 @@ func (s *Server) dial(ctx context.Context, network, addr string) (net.Conn, erro
955955
host, _, _ := net.SplitHostPort(addr)
956956
fqdn = host
957957
}
958-
if v, _ := perReqChecker.CheckAndConsume(fqdn, port); v != policy.Allow {
958+
proto := DetectProtocol(port)
959+
if v, _ := perReqChecker.CheckAndConsume(fqdn, port, WithProtocol(proto.String())); v != policy.Allow {
959960
log.Printf("[DIAL-DENY] %s deferred ask denied", addr)
960961
return nil, fmt.Errorf("connection denied by policy")
961962
}
@@ -1063,7 +1064,7 @@ func (s *Server) handleWithDetection(
10631064
// Plain HTTP: no MITM handler. Check connection-level policy
10641065
// before relaying if a checker was deferred from Allow().
10651066
if checker != nil {
1066-
if v, _ := checker.CheckAndConsume(fqdn, port); v != policy.Allow {
1067+
if v, _ := checker.CheckAndConsume(fqdn, port, WithProtocol(proto.String())); v != policy.Allow {
10671068
log.Printf("[DETECT-DENY] %s:%d plain HTTP blocked by deferred policy", fqdn, port)
10681069
return
10691070
}
@@ -1072,7 +1073,7 @@ func (s *Server) handleWithDetection(
10721073
return
10731074
case ProtoSSH:
10741075
if checker != nil {
1075-
if v, _ := checker.CheckAndConsume(fqdn, port); v != policy.Allow {
1076+
if v, _ := checker.CheckAndConsume(fqdn, port, WithProtocol(proto.String())); v != policy.Allow {
10761077
log.Printf("[DETECT-DENY] %s:%d SSH blocked by deferred policy", fqdn, port)
10771078
return
10781079
}
@@ -1085,7 +1086,7 @@ func (s *Server) handleWithDetection(
10851086
}
10861087
case ProtoIMAP, ProtoSMTP:
10871088
if checker != nil {
1088-
if v, _ := checker.CheckAndConsume(fqdn, port); v != policy.Allow {
1089+
if v, _ := checker.CheckAndConsume(fqdn, port, WithProtocol(proto.String())); v != policy.Allow {
10891090
log.Printf("[DETECT-DENY] %s:%d %s blocked by deferred policy", fqdn, port, proto)
10901091
return
10911092
}
@@ -1106,7 +1107,7 @@ func (s *Server) handleWithDetection(
11061107
// without a binding to inject.
11071108
if n == 0 && binding != nil && s.mailProxy != nil {
11081109
if checker != nil {
1109-
if v, _ := checker.CheckAndConsume(fqdn, port); v != policy.Allow {
1110+
if v, _ := checker.CheckAndConsume(fqdn, port, WithProtocol(proto.String())); v != policy.Allow {
11101111
log.Printf("[DETECT-DENY] %s:%d server-first blocked by deferred policy", fqdn, port)
11111112
return
11121113
}
@@ -1117,7 +1118,7 @@ func (s *Server) handleWithDetection(
11171118

11181119
// Generic fallback: direct relay. Check deferred policy first.
11191120
if checker != nil {
1120-
if v, _ := checker.CheckAndConsume(fqdn, port); v != policy.Allow {
1121+
if v, _ := checker.CheckAndConsume(fqdn, port, WithProtocol(proto.String())); v != policy.Allow {
11211122
log.Printf("[DETECT-DENY] %s:%d blocked by deferred policy", fqdn, port)
11221123
return
11231124
}

0 commit comments

Comments
 (0)