Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ func UseFstat(value bool) ClientOption {
}
}

// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written.
//
// The writer passed in will not be automatically closed.
// It is the responsibility of the caller to coordinate closure of any writers.
func CopyStderrTo(wr io.Writer) ClientOption {
return func(c *Client) error {
c.stderrTo = wr
return nil
}
}

// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
// Multiple Clients can be active on a single SSH connection, and a Client
// may be called concurrently from multiple Goroutines.
Expand All @@ -166,6 +177,8 @@ func UseFstat(value bool) ClientOption {
type Client struct {
clientConn

stderrTo io.Writer

ext map[string]string // Extensions (name -> data).

maxPacket int // max packet size read or written.
Expand All @@ -186,9 +199,7 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
if err != nil {
return nil, err
}
if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}

pw, err := s.StdinPipe()
if err != nil {
return nil, err
Expand All @@ -197,22 +208,35 @@ func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
if err != nil {
return nil, err
}
perr, err := s.StderrPipe()
if err != nil {
return nil, err
}

return NewClientPipe(pr, pw, opts...)
if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}

return newClientPipe(pr, perr, pw, s.Wait, opts...)
}

// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
// This can be used for connecting to an SFTP server over TCP/TLS or by using
// the system's ssh client program (e.g. via exec.Command).
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
sftp := &Client{
return newClientPipe(rd, nil, wr, nil, opts...)
}

func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) {
c := &Client{
clientConn: clientConn{
conn: conn{
Reader: rd,
WriteCloser: wr,
},
inflight: make(map[uint32]chan<- result),
closed: make(chan struct{}),
wait: wait,
},

ext: make(map[string]string),
Expand All @@ -222,32 +246,50 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
}

for _, opt := range opts {
if err := opt(sftp); err != nil {
if err := opt(c); err != nil {
wr.Close()
return nil, err
}
}

if err := sftp.sendInit(); err != nil {
if stderr != nil {
wr := io.Discard
if c.stderrTo != nil {
wr = c.stderrTo
}

go func() {
// DO NOT close the writer!
// Programs may pass in `os.Stderr` to write the remote stderr to,
// and the program may continue after disconnect by reconnecting.
// But if we've closed their stderr, then we just messed everything up.

if _, err := io.Copy(wr, stderr); err != nil {
debug("error copying stderr: %v", err)
}
}()
}

if err := c.sendInit(); err != nil {
wr.Close()
return nil, fmt.Errorf("error sending init packet to server: %w", err)
}

if err := sftp.recvVersion(); err != nil {
if err := c.recvVersion(); err != nil {
wr.Close()
return nil, fmt.Errorf("error receiving version packet from server: %w", err)
}

sftp.clientConn.wg.Add(1)
c.clientConn.wg.Add(1)
go func() {
defer sftp.clientConn.wg.Done()
defer c.clientConn.wg.Done()

if err := sftp.clientConn.recv(); err != nil {
sftp.clientConn.broadcastErr(err)
if err := c.clientConn.recv(); err != nil {
c.clientConn.broadcastErr(err)
}
}()

return sftp, nil
return c, nil
}

// Create creates the named file mode 0666 (before umask), truncating it if it
Expand Down
29 changes: 26 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type conn struct {
// For the client mode just pass 0.
// It returns io.EOF if the connection is closed and
// there are no more packets to read.
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) {
return recvPacket(c, c.alloc, orderID)
}

Expand All @@ -43,6 +43,8 @@ type clientConn struct {
conn
wg sync.WaitGroup

wait func() error // if non-nil, call this during Wait() to get a possible remote status error.

sync.Mutex // protects inflight
inflight map[uint32]chan<- result // outstanding requests

Expand All @@ -55,6 +57,27 @@ type clientConn struct {
// goroutines.
func (c *clientConn) Wait() error {
<-c.closed

if c.wait == nil {
// Only return this error if c.wait won't return something more useful.
return c.err
}

if err := c.wait(); err != nil {

// TODO: when https://github.com/golang/go/issues/35025 is fixed,
// we can remove this if block entirely.
// Right now, it’s always going to return this, so it is not useful.
// But we have this code here so that as soon as the ssh library is updated,
// we can return a possibly more useful error.
if err.Error() == "ssh: session not started" {
return c.err
}

return err
}

// c.wait returned no error; so, let's return something maybe more useful.
return c.err
}

Expand Down Expand Up @@ -119,7 +142,7 @@ func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {

// result captures the result of receiving the a packet from the server
type result struct {
typ byte
typ fxp
data []byte
err error
}
Expand All @@ -129,7 +152,7 @@ type idmarshaler interface {
encoding.BinaryMarshaler
}

func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) {
if cap(ch) < 1 {
ch = make(chan result, 1)
}
Expand Down
39 changes: 30 additions & 9 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,22 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
return nil
}

func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) {
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) {
var b []byte
if alloc != nil {
b = alloc.GetPage(orderID)
} else {
b = make([]byte, 4)
}
if _, err := io.ReadFull(r, b[:4]); err != nil {
return 0, nil, err

if n, err := io.ReadFull(r, b[:4]); err != nil {
if err == io.EOF {
return 0, nil, err
}

return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err)
}

length, _ := unmarshalUint32(b)
if length > maxMsgLength {
debug("recv packet %d bytes too long", length)
Expand All @@ -323,24 +329,39 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
debug("recv packet of 0 bytes too short")
return 0, nil, errShortPacket
}

if alloc == nil {
b = make([]byte, length)
}
if _, err := io.ReadFull(r, b[:length]); err != nil {

n, err := io.ReadFull(r, b[:length])
b = b[:n]

if err != nil {
debug("recv packet error: %d of %d bytes: %x", n, length, b)

// ReadFull only returns EOF if it has read no bytes.
// In this case, that means a partial packet, and thus unexpected.
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
debug("recv packet %d bytes: err %v", length, err)
return 0, nil, err

if n == 0 {
return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err)
}

return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err)
}

typ, payload := fxp(b[0]), b[1:n]

if debugDumpRxPacketBytes {
debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length])
debug("recv packet: %s %d bytes %x", typ, length, payload)
} else if debugDumpRxPacket {
debug("recv packet: %s %d bytes", fxp(b[0]), length)
debug("recv packet: %s %d bytes", typ, length)
}
return b[0], b[1:length], nil

return typ, payload, nil
}

type extensionPair struct {
Expand Down
2 changes: 1 addition & 1 deletion packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func TestRecvPacket(t *testing.T) {
var recvPacketTests = []struct {
b []byte

want uint8
want fxp
body []byte
wantErr error
}{
Expand Down
4 changes: 2 additions & 2 deletions request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {

var err error
var pkt requestPacket
var pktType uint8
var pktType fxp
var pktBytes []byte

for {
Expand All @@ -158,7 +158,7 @@ func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
return err
}

pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
pkt, err = makePacket(rxPacket{pktType, pktBytes})
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):
Expand Down
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ func (svr *Server) Serve() error {

var err error
var pkt requestPacket
var pktType uint8
var pktType fxp
var pktBytes []byte
for {
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
Expand All @@ -403,7 +403,7 @@ func (svr *Server) Serve() error {
break
}

pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
pkt, err = makePacket(rxPacket{pktType, pktBytes})
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):
Expand Down
8 changes: 4 additions & 4 deletions sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ func (f fx) String() string {
}

type unexpectedPacketErr struct {
want, got uint8
want, got fxp
}

func (u *unexpectedPacketErr) Error() string {
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got))
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got)
}

func unimplementedPacketErr(u uint8) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
func unimplementedPacketErr(u fxp) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", u)
}

type unexpectedIDErr struct{ want, got uint32 }
Expand Down