diff --git a/_examples/ssh-proxy-protocol/proxy_protocol.go b/_examples/ssh-proxy-protocol/proxy_protocol.go new file mode 100644 index 0000000..d4d507f --- /dev/null +++ b/_examples/ssh-proxy-protocol/proxy_protocol.go @@ -0,0 +1,22 @@ +package main + +import ( + "fmt" + "io" + "log" + + "github.com/donovanhubbard/ssh" +) + +const ( + ADDR = "0.0.0.0:4444" +) + +func main() { + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, fmt.Sprintf("Your address is %s\n", s.RemoteAddr())) + }) + + log.Println("starting ssh server on " + ADDR) + log.Fatal(ssh.ListenAndServe(ADDR, nil, ssh.EnableProxyProtocol())) +} diff --git a/go.mod b/go.mod index 3be0917..585a724 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,15 @@ module github.com/charmbracelet/ssh -go 1.23.0 +go 1.25 -toolchain go1.24.1 +toolchain go1.25.0 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/charmbracelet/x/conpty v0.1.0 github.com/charmbracelet/x/termios v0.1.0 github.com/creack/pty v1.1.21 + github.com/pires/go-proxyproto v0.12.0 golang.org/x/crypto v0.37.0 golang.org/x/sys v0.32.0 ) diff --git a/go.sum b/go.sum index f379dfe..09b07fd 100644 --- a/go.sum +++ b/go.sum @@ -14,3 +14,5 @@ golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +github.com/pires/go-proxyproto v0.12.0 h1:TTCxD66dU898tahivkqc3hoceZp7P44FnorWyo9d5vM= +github.com/pires/go-proxyproto v0.12.0/go.mod h1:qUvfqUMEoX7T8g0q7TQLDnhMjdTrxnG0hvpMn+7ePNI= diff --git a/options.go b/options.go index fa87e19..7672480 100644 --- a/options.go +++ b/options.go @@ -110,3 +110,11 @@ func AllocatePty() Option { return nil } } + +// EnableProxyProtocol returns a functional option that sets EnableProxyProtocol on the server +func EnableProxyProtocol() Option { + return func(srv *Server) error { + srv.EnableProxyProtocol = true + return nil + } +} diff --git a/server.go b/server.go index b1335da..c387113 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/pires/go-proxyproto" gossh "golang.org/x/crypto/ssh" ) @@ -67,6 +68,8 @@ type Server struct { IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty + EnableProxyProtocol bool // Enable support for HA Proxy's and NGinx's PROXY protocol + // ChannelHandlers allow overriding the built-in session handlers or provide // extensions to the protocol, such as tcpip forwarding. By default only the // "session" handler is enabled. @@ -273,6 +276,13 @@ func (srv *Server) Shutdown(ctx context.Context) error { // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { + if srv.EnableProxyProtocol { + _, ok := l.(*proxyproto.Listener) + if !ok { + l = &proxyproto.Listener{Listener: l} + } + } + srv.ensureHandlers() defer l.Close() if err := srv.ensureHostSigner(); err != nil { diff --git a/server_test.go b/server_test.go index 8028a3a..9b258c1 100644 --- a/server_test.go +++ b/server_test.go @@ -3,9 +3,17 @@ package ssh import ( "bytes" "context" + "errors" + "fmt" "io" + "net" + "strconv" + "strings" "testing" "time" + + "github.com/pires/go-proxyproto" + gossh "golang.org/x/crypto/ssh" ) func TestAddHostKey(t *testing.T) { @@ -124,3 +132,203 @@ func TestServerClose(t *testing.T) { return } } + +func TestProxyProtocolEnabled(t *testing.T) { + const ( + CORRECT_IP = "1.1.1.1" + CORRECT_PORT = 55555 + ) + handlerDone := make(chan struct{}) + var testResult error + + handler := func(sess Session) { + defer close(handlerDone) + sourceAddress := sess.RemoteAddr() + + index := strings.Index(sourceAddress.String(), ":") + ip := sourceAddress.String()[:index] + portStr := sourceAddress.String()[index+1:] + + if ip != CORRECT_IP { + errorMsg := fmt.Sprintf("Expected source address '%s' but got '%s'", CORRECT_IP, ip) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + port, err := strconv.Atoi(portStr) + if err != nil { + testResult = errors.Join(testResult, fmt.Errorf("%s", err)) + } else if port != CORRECT_PORT { + errorMsg := fmt.Sprintf("Expected source port '%d' but got '%d'", CORRECT_PORT, port) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + } + + // Bind the port before starting the goroutine so net.Dial never races + // with the server not yet listening. + l := newLocalListener() + srv := &Server{Handler: handler} + srv.SetOption(EnableProxyProtocol()) + + serverDone := make(chan error, 1) + + go func() { + serverDone <- srv.Serve(l) + }() + + defer func() { + srv.Close() + if err := <-serverDone; err != nil && err != ErrServerClosed { + t.Error(err) + } + }() + + serverIP, serverPortStr, _ := net.SplitHostPort(l.Addr().String()) + serverPort, _ := strconv.Atoi(serverPortStr) + conn, err := net.Dial("tcp", l.Addr().String()) + + if err != nil { + t.Fatal(err) + } + + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(CORRECT_IP), + Port: CORRECT_PORT, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(serverIP), + Port: serverPort, + }, + } + + // Writes the PROXY header to the TCP stream before SSH begins + _, err = header.WriteTo(conn) + if err != nil { + t.Fatal(err) + } + + // Hand the same conn to the SSH stack — handshake starts from here. + clientConn, chans, reqs, err := gossh.NewClientConn(conn, l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + client := gossh.NewClient(clientConn, chans, reqs) + defer client.Close() + + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + session.Run("") // triggers the handler; ignore exec error + + <-handlerDone + + if testResult != nil { + t.Fatal(testResult) + } +} + +func TestProxyProtocolDisabled(t *testing.T) { + const ( + INCORRECT_IP = "1.1.1.1" + INCORRECT_PORT = 55555 + ) + handlerDone := make(chan struct{}) + var testResult error + + handler := func(sess Session) { + defer close(handlerDone) + sourceAddress := sess.RemoteAddr() + + index := strings.Index(sourceAddress.String(), ":") + ip := sourceAddress.String()[:index] + portStr := sourceAddress.String()[index+1:] + + if ip == INCORRECT_IP { + errorMsg := fmt.Sprintf("Expected source address to be anything but '%s' but got '%s'", INCORRECT_IP, ip) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + port, err := strconv.Atoi(portStr) + if err != nil { + testResult = errors.Join(testResult, fmt.Errorf("%s", err)) + } else if port == INCORRECT_PORT { + errorMsg := fmt.Sprintf("Expected source port anything but '%d' but got '%d'", INCORRECT_PORT, port) + testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg)) + } + } + + // Bind the port before starting the goroutine so net.Dial never races + // with the server not yet listening. + l := newLocalListener() + srv := &Server{Handler: handler} + + serverDone := make(chan error, 1) + + go func() { + serverDone <- srv.Serve(l) + }() + + defer func() { + srv.Close() + if err := <-serverDone; err != nil && err != ErrServerClosed { + t.Error(err) + } + }() + + serverIP, serverPortStr, _ := net.SplitHostPort(l.Addr().String()) + serverPort, _ := strconv.Atoi(serverPortStr) + conn, err := net.Dial("tcp", l.Addr().String()) + + if err != nil { + t.Fatal(err) + } + + //Set the PROXY header information. The server should not read it + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(INCORRECT_IP), + Port: INCORRECT_PORT, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(serverIP), + Port: serverPort, + }, + } + + // Writes the PROXY header to the TCP stream before SSH begins + _, err = header.WriteTo(conn) + if err != nil { + t.Fatal(err) + } + + // Hand the same conn to the SSH stack — handshake starts from here. + clientConn, chans, reqs, err := gossh.NewClientConn(conn, l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + client := gossh.NewClient(clientConn, chans, reqs) + defer client.Close() + + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + session.Run("") // triggers the handler; ignore exec error + + <-handlerDone + + if testResult != nil { + t.Fatal(testResult) + } +}