From 7115f91c3a5c1454a10f2adab4dedb84f28212dc Mon Sep 17 00:00:00 2001 From: Donovan Hubbard <37090676+donovanhubbard@users.noreply.github.com> Date: Sat, 16 May 2026 15:39:02 -0700 Subject: [PATCH 1/3] Adding support for HA Proxy's PROXY protocol --- .../ssh-proxy-protocol/proxy_protocol.go | 22 ++++ go.mod | 5 +- go.sum | 2 + options.go | 8 ++ server.go | 10 ++ server_test.go | 108 ++++++++++++++++++ 6 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 _examples/ssh-proxy-protocol/proxy_protocol.go diff --git a/_examples/ssh-proxy-protocol/proxy_protocol.go b/_examples/ssh-proxy-protocol/proxy_protocol.go new file mode 100644 index 00000000..d4d507ff --- /dev/null +++ b/_examples/ssh-proxy-protocol/proxy_protocol.go @@ -0,0 +1,22 @@ +package main + +import ( + "fmt" + "io" + "log" + + "github.com/donovanhubbard/ssh" +) + +const ( + ADDR = "0.0.0.0:4444" +) + +func main() { + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, fmt.Sprintf("Your address is %s\n", s.RemoteAddr())) + }) + + log.Println("starting ssh server on " + ADDR) + log.Fatal(ssh.ListenAndServe(ADDR, nil, ssh.EnableProxyProtocol())) +} diff --git a/go.mod b/go.mod index 3be09170..585a7249 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,15 @@ module github.com/charmbracelet/ssh -go 1.23.0 +go 1.25 -toolchain go1.24.1 +toolchain go1.25.0 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/charmbracelet/x/conpty v0.1.0 github.com/charmbracelet/x/termios v0.1.0 github.com/creack/pty v1.1.21 + github.com/pires/go-proxyproto v0.12.0 golang.org/x/crypto v0.37.0 golang.org/x/sys v0.32.0 ) diff --git a/go.sum b/go.sum index f379dfed..09b07fda 100644 --- a/go.sum +++ b/go.sum @@ -14,3 +14,5 @@ golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +github.com/pires/go-proxyproto v0.12.0 h1:TTCxD66dU898tahivkqc3hoceZp7P44FnorWyo9d5vM= +github.com/pires/go-proxyproto v0.12.0/go.mod h1:qUvfqUMEoX7T8g0q7TQLDnhMjdTrxnG0hvpMn+7ePNI= diff --git a/options.go b/options.go index fa87e198..96f3b63f 100644 --- a/options.go +++ b/options.go @@ -110,3 +110,11 @@ func AllocatePty() Option { return nil } } + +// EnableProxyProtocol returns a functional option that sets EnableProxyProtocol on the server +func EnableProxyProtocol() Option { + return func(srv *Server) error { + srv.enableProxyProtocol = true + return nil + } +} diff --git a/server.go b/server.go index b1335da3..a6c7d6e0 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/pires/go-proxyproto" gossh "golang.org/x/crypto/ssh" ) @@ -87,6 +88,8 @@ type Server struct { conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup doneChan chan struct{} + + enableProxyProtocol bool // Enable support for HA Proxy's and NGinx's PROXY protocol } func (srv *Server) ensureHostSigner() error { @@ -273,6 +276,13 @@ func (srv *Server) Shutdown(ctx context.Context) error { // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { + if srv.enableProxyProtocol { + _, ok := l.(*proxyproto.Listener) + if !ok { + l = &proxyproto.Listener{Listener: l} + } + } + srv.ensureHandlers() defer l.Close() if err := srv.ensureHostSigner(); err != nil { diff --git a/server_test.go b/server_test.go index 8028a3aa..a23c2ae2 100644 --- a/server_test.go +++ b/server_test.go @@ -3,9 +3,17 @@ package ssh import ( "bytes" "context" + "errors" + "fmt" "io" + "net" + "strconv" + "strings" "testing" "time" + + "github.com/pires/go-proxyproto" + gossh "golang.org/x/crypto/ssh" ) func TestAddHostKey(t *testing.T) { @@ -124,3 +132,103 @@ func TestServerClose(t *testing.T) { return } } + +func TestProxyProtocol(t *testing.T) { + const ( + CORRECT_IP = "1.1.1.1" + CORRECT_PORT = 55555 + ) + handlerDone := make(chan struct{}) + var testResult error + + handler := func(sess Session) { + defer close(handlerDone) + sourceAddress := sess.RemoteAddr() + + index := strings.Index(sourceAddress.String(), ":") + ip := sourceAddress.String()[:index] + portStr := sourceAddress.String()[index+1:] + + if ip != CORRECT_IP { + errorMsg := fmt.Sprintf("Expected source address '%s' but got '%s'", CORRECT_IP, ip) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + port, err := strconv.Atoi(portStr) + if err != nil { + testResult = errors.Join(testResult, fmt.Errorf("%s", err)) + } else if port != CORRECT_PORT { + errorMsg := fmt.Sprintf("Expected source port '%d' but got '%d'", CORRECT_PORT, port) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + } + + // Bind the port before starting the goroutine so net.Dial never races + // with the server not yet listening. + l := newLocalListener() + srv := &Server{Handler: handler} + srv.SetOption(EnableProxyProtocol()) + + serverDone := make(chan error, 1) + + go func() { + serverDone <- srv.Serve(l) + }() + + defer func() { + srv.Close() + if err := <-serverDone; err != nil && err != ErrServerClosed { + t.Error(err) + } + }() + + serverIP, serverPortStr, _ := net.SplitHostPort(l.Addr().String()) + serverPort, _ := strconv.Atoi(serverPortStr) + conn, err := net.Dial("tcp", l.Addr().String()) + + if err != nil { + t.Fatal(err) + } + + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(CORRECT_IP), + Port: CORRECT_PORT, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(serverIP), + Port: serverPort, + }, + } + + // Writes the PROXY header to the TCP stream before SSH begins + _, err = header.WriteTo(conn) + if err != nil { + t.Fatal(err) + } + + // Hand the same conn to the SSH stack — handshake starts from here. + clientConn, chans, reqs, err := gossh.NewClientConn(conn, l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + client := gossh.NewClient(clientConn, chans, reqs) + defer client.Close() + + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + session.Run("") // triggers the handler; ignore exec error + + <-handlerDone + + if testResult != nil { + t.Fatal(testResult) + } +} From a949bbe45b304ce5d71b297972b88cbcd885844c Mon Sep 17 00:00:00 2001 From: Donovan Hubbard <37090676+donovanhubbard@users.noreply.github.com> Date: Sat, 16 May 2026 20:56:41 -0700 Subject: [PATCH 2/3] Making EnableProxyProtocol public like the other configuration items. --- options.go | 2 +- server.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/options.go b/options.go index 96f3b63f..76724806 100644 --- a/options.go +++ b/options.go @@ -114,7 +114,7 @@ func AllocatePty() Option { // EnableProxyProtocol returns a functional option that sets EnableProxyProtocol on the server func EnableProxyProtocol() Option { return func(srv *Server) error { - srv.enableProxyProtocol = true + srv.EnableProxyProtocol = true return nil } } diff --git a/server.go b/server.go index a6c7d6e0..c387113d 100644 --- a/server.go +++ b/server.go @@ -68,6 +68,8 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty + EnableProxyProtocol bool // Enable support for HA Proxy's and NGinx's PROXY protocol + // ChannelHandlers allow overriding the built-in session handlers or provide // extensions to the protocol, such as tcpip forwarding. By default only the // "session" handler is enabled. @@ -88,8 +90,6 @@ type Server struct { conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup doneChan chan struct{} - - enableProxyProtocol bool // Enable support for HA Proxy's and NGinx's PROXY protocol } func (srv *Server) ensureHostSigner() error { @@ -276,7 +276,7 @@ func (srv *Server) Shutdown(ctx context.Context) error { // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { - if srv.enableProxyProtocol { + if srv.EnableProxyProtocol { _, ok := l.(*proxyproto.Listener) if !ok { l = &proxyproto.Listener{Listener: l} From 1419e9a3cf20fe86d585973244c2eecea6311b9f Mon Sep 17 00:00:00 2001 From: Donovan Hubbard <37090676+donovanhubbard@users.noreply.github.com> Date: Mon, 1 Jun 2026 22:26:21 -0700 Subject: [PATCH 3/3] Added a negative test of PROXY enabled --- server_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/server_test.go b/server_test.go index a23c2ae2..9b258c1f 100644 --- a/server_test.go +++ b/server_test.go @@ -133,7 +133,7 @@ func TestServerClose(t *testing.T) { } } -func TestProxyProtocol(t *testing.T) { +func TestProxyProtocolEnabled(t *testing.T) { const ( CORRECT_IP = "1.1.1.1" CORRECT_PORT = 55555 @@ -232,3 +232,103 @@ func TestProxyProtocol(t *testing.T) { t.Fatal(testResult) } } + +func TestProxyProtocolDisabled(t *testing.T) { + const ( + INCORRECT_IP = "1.1.1.1" + INCORRECT_PORT = 55555 + ) + handlerDone := make(chan struct{}) + var testResult error + + handler := func(sess Session) { + defer close(handlerDone) + sourceAddress := sess.RemoteAddr() + + index := strings.Index(sourceAddress.String(), ":") + ip := sourceAddress.String()[:index] + portStr := sourceAddress.String()[index+1:] + + if ip == INCORRECT_IP { + errorMsg := fmt.Sprintf("Expected source address to be anything but '%s' but got '%s'", INCORRECT_IP, ip) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + port, err := strconv.Atoi(portStr) + if err != nil { + testResult = errors.Join(testResult, fmt.Errorf("%s", err)) + } else if port == INCORRECT_PORT { + errorMsg := fmt.Sprintf("Expected source port anything but '%d' but got '%d'", INCORRECT_PORT, port) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + } + + // Bind the port before starting the goroutine so net.Dial never races + // with the server not yet listening. + l := newLocalListener() + srv := &Server{Handler: handler} + + serverDone := make(chan error, 1) + + go func() { + serverDone <- srv.Serve(l) + }() + + defer func() { + srv.Close() + if err := <-serverDone; err != nil && err != ErrServerClosed { + t.Error(err) + } + }() + + serverIP, serverPortStr, _ := net.SplitHostPort(l.Addr().String()) + serverPort, _ := strconv.Atoi(serverPortStr) + conn, err := net.Dial("tcp", l.Addr().String()) + + if err != nil { + t.Fatal(err) + } + + //Set the PROXY header information. The server should not read it + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(INCORRECT_IP), + Port: INCORRECT_PORT, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(serverIP), + Port: serverPort, + }, + } + + // Writes the PROXY header to the TCP stream before SSH begins + _, err = header.WriteTo(conn) + if err != nil { + t.Fatal(err) + } + + // Hand the same conn to the SSH stack — handshake starts from here. + clientConn, chans, reqs, err := gossh.NewClientConn(conn, l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + client := gossh.NewClient(clientConn, chans, reqs) + defer client.Close() + + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + session.Run("") // triggers the handler; ignore exec error + + <-handlerDone + + if testResult != nil { + t.Fatal(testResult) + } +}