diff --git a/go.mod b/go.mod index 5a6fc4989..b8db5b1f1 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/jstemmer/go-junit-report/v2 v2.1.0 // indirect github.com/kisielk/errcheck v1.9.0 // indirect + github.com/pires/go-proxyproto v0.8.1 // indirect github.com/openai/openai-go/v3 v3.17.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect diff --git a/go.sum b/go.sum index 46a684abe..79936f65d 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/openai/openai-go/v3 v3.17.0 h1:CfTkmQoItolSyW+bHOUF190KuX5+1Zv6MC0Gb4 github.com/openai/openai-go/v3 v3.17.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= +github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= diff --git a/proxy/servertcp.go b/proxy/servertcp.go index ad7e557d3..892f33533 100644 --- a/proxy/servertcp.go +++ b/proxy/servertcp.go @@ -16,12 +16,14 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/syncutil" "github.com/miekg/dns" + + proxyproto "github.com/pires/go-proxyproto" ) // initTCPListeners initializes TCP listeners with configured addresses. func (p *Proxy) initTCPListeners(ctx context.Context) (err error) { for _, addr := range p.TCPListenAddr { - var ln *net.TCPListener + var ln net.Listener ln, err = p.listenTCP(ctx, addr) if err != nil { return fmt.Errorf("listening on tcp addr %s: %w", addr, err) @@ -34,7 +36,7 @@ func (p *Proxy) initTCPListeners(ctx context.Context) (err error) { } // listenTCP returns a new TCP listener listening on addr. -func (p *Proxy) listenTCP(ctx context.Context, addr *net.TCPAddr) (ln *net.TCPListener, err error) { +func (p *Proxy) listenTCP(ctx context.Context, addr *net.TCPAddr) (ln net.Listener, err error) { addrStr := addr.String() p.logger.InfoContext(ctx, "creating tcp server socket", "addr", addrStr) @@ -60,7 +62,29 @@ func (p *Proxy) listenTCP(ctx context.Context, addr *net.TCPAddr) (ln *net.TCPLi p.logger.InfoContext(ctx, "listening to tcp", "addr", ln.Addr()) - return ln, nil + return p.wrapProxyListener(ln), nil +} + +// wrapProxyListener wraps a net.Listener with a proxyproto.Listener that +// implements the ConnPolicy callback. If the upstream address is in +// p.TrustedProxies, it returns proxyproto.USE; otherwise, it returns +// proxyproto.REJECT. +func (p *Proxy) wrapProxyListener(ln net.Listener) net.Listener { + return &proxyproto.Listener{ + Listener: ln, + ConnPolicy: func(options proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { + if p.TrustedProxies != nil && p.TrustedProxies.Contains(netutil.NetAddrToAddrPort(options.Upstream).Addr()) { + // If a proxyproto header is present, use it to determine the + // upstream address. + return proxyproto.USE, nil + } + + // Reject connections if the proxyproto header is present, + // with reason (will be logged): + // proxyproto: upstream connection sent PROXY header but isn't allowed to send one + return proxyproto.REJECT, nil + }, + } } // initTLSListeners initializes TLS listeners with configured addresses. @@ -78,7 +102,9 @@ func (p *Proxy) initTLSListeners(ctx context.Context) (err error) { return fmt.Errorf("listening on tls addr %s: %w", addr, err) } - l := tls.NewListener(tcpListen, p.TLSConfig) + proxyListen := p.wrapProxyListener(tcpListen) + + l := tls.NewListener(proxyListen, p.TLSConfig) p.tlsListen = append(p.tlsListen, l) p.logger.InfoContext(ctx, "listening to tls", "addr", l.Addr()) diff --git a/proxy/serverudp.go b/proxy/serverudp.go index 32ee9effe..d9ecbbec8 100644 --- a/proxy/serverudp.go +++ b/proxy/serverudp.go @@ -1,8 +1,11 @@ package proxy import ( + "bufio" + "bytes" "context" "fmt" + "io" "log/slog" "net" "net/netip" @@ -14,6 +17,7 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/syncutil" "github.com/miekg/dns" + proxyproto "github.com/pires/go-proxyproto" ) // initUDPListeners initializes UDP listeners with configured addresses. @@ -144,6 +148,8 @@ func (p *Proxy) udpHandlePacket( p.logger.DebugContext(ctx, "handling new udp packet", "raddr", remoteAddr) + packet, remoteAddr = p.parseUDPProxyHeader(packet, remoteAddr) + req := &dns.Msg{} err := req.Unpack(packet) if err != nil { @@ -162,6 +168,42 @@ func (p *Proxy) udpHandlePacket( } } +// parseUDPProxyHeader attempts to parse a proxy protocol header from a UDP +// packet. If the remote address is in p.TrustedProxies and a valid proxy +// protocol header is present, it returns the remaining packet data and the +// source address from the header. Otherwise, it returns the original packet +// and remote address unchanged. +func (p *Proxy) parseUDPProxyHeader(packet []byte, remoteAddr *net.UDPAddr) ([]byte, *net.UDPAddr) { + if p.TrustedProxies == nil || !p.TrustedProxies.Contains(netutil.NetAddrToAddrPort(remoteAddr).Addr()) { + return packet, remoteAddr + } + + reader := bufio.NewReader(bytes.NewReader(packet)) + header, err := proxyproto.Read(reader) + if err != nil { + // No proxy protocol header found; return packet as-is. + return packet, remoteAddr + } + + // Read the remaining bytes after the proxy protocol header; these are the + // actual DNS payload. + remaining, err := io.ReadAll(reader) + if err != nil { + p.logger.Error("reading remaining udp data after proxy header", slogutil.KeyError, err) + + return packet, remoteAddr + } + + srcUDPAddr, ok := header.SourceAddr.(*net.UDPAddr) + if ok { + return remaining, srcUDPAddr + } + + p.logger.Debug("proxy protocol header has unsupported source address type", "addr", header.SourceAddr) + + return remaining, remoteAddr +} + // Writes a response to the UDP client func (p *Proxy) respondUDP(d *DNSContext) error { resp := d.Res