Skip to content

Commit b2d350b

Browse files
authored
Ignore error if can't write back echoed protocol in negotiate (#87)
* Ignore error if can't write back echoed protocol * Add test to verify we negotiate even if peer closes the stream after sending data * Rework comment
1 parent 2a41ec3 commit b2d350b

2 files changed

Lines changed: 79 additions & 3 deletions

File tree

multistream.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,10 @@ loop:
235235
continue loop
236236
}
237237

238-
if err := delimWriteBuffered(rwc, []byte(tok)); err != nil {
239-
return "", nil, err
240-
}
238+
// Ignore the error here. We want the handshake to finish, even if the
239+
// other side has closed this rwc for writing. They may have sent us a
240+
// message and closed. Future writers will get an error anyways.
241+
_ = delimWriteBuffered(rwc, []byte(tok))
241242

242243
// hand off processing to the sub-protocol handler
243244
return tok, h.Handle, nil

multistream_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,81 @@ func TestNegotiateFail(t *testing.T) {
688688
}
689689
}
690690

691+
type mockStream struct {
692+
expectWrite [][]byte
693+
toRead [][]byte
694+
}
695+
696+
func (s *mockStream) Close() error {
697+
return nil
698+
}
699+
700+
func (s *mockStream) Write(p []byte) (n int, err error) {
701+
if len(s.expectWrite) == 0 {
702+
return 0, fmt.Errorf("no more writes expected")
703+
}
704+
705+
if !bytes.Equal(s.expectWrite[0], p) {
706+
return 0, fmt.Errorf("unexpected write")
707+
}
708+
709+
s.expectWrite = s.expectWrite[1:]
710+
return len(p), nil
711+
}
712+
713+
func (s *mockStream) Read(p []byte) (n int, err error) {
714+
if len(s.toRead) == 0 {
715+
return 0, fmt.Errorf("no more reads expected")
716+
}
717+
718+
if len(p) < len(s.toRead[0]) {
719+
copy(p, s.toRead[0])
720+
s.toRead[0] = s.toRead[0][len(p):]
721+
n = len(p)
722+
} else {
723+
copy(p, s.toRead[0])
724+
n = len(s.toRead[0])
725+
s.toRead = s.toRead[1:]
726+
}
727+
728+
return n, nil
729+
}
730+
731+
func TestNegotiatePeerSendsAndCloses(t *testing.T) {
732+
// Tests the case where a peer will negotiate a protocol, send data, then close the stream immediately
733+
var buf bytes.Buffer
734+
err := delimWrite(&buf, []byte(ProtocolID))
735+
if err != nil {
736+
t.Fatal(err)
737+
}
738+
delimtedProtocolID := make([]byte, buf.Len())
739+
copy(delimtedProtocolID, buf.Bytes())
740+
741+
err = delimWrite(&buf, []byte("foo"))
742+
if err != nil {
743+
t.Fatal(err)
744+
}
745+
err = delimWrite(&buf, []byte("somedata"))
746+
if err != nil {
747+
t.Fatal(err)
748+
}
749+
750+
s := &mockStream{
751+
// We mock the closed stream by only expecting a single write. The
752+
// mockstream will error on any more writes (same as writing to a closed
753+
// stream)
754+
expectWrite: [][]byte{delimtedProtocolID},
755+
toRead: [][]byte{buf.Bytes()},
756+
}
757+
758+
mux := NewMultistreamMuxer()
759+
mux.AddHandler("foo", nil)
760+
_, _, err = mux.Negotiate(s)
761+
if err != nil {
762+
t.Fatal("Negotiate should not fail here", err)
763+
}
764+
}
765+
691766
func TestSimopenClientServer(t *testing.T) {
692767
a, b := newPipe(t)
693768

0 commit comments

Comments
 (0)