Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ package ttrpc

import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -142,23 +145,70 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
return mh, p, nil
}

func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, flags uint8, p []byte) error {
if len(p) > messageLengthMax {
return OversizedMessageError(len(p))
}

if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
if ctx == nil {
ctx = context.Background()
}

if err := ctx.Err(); err != nil {
return err
}

if ctx.Done() != nil {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
ch.conn.SetWriteDeadline(time.Now())
case <-done:
}
}()
defer close(done)
}

defer ch.conn.SetWriteDeadline(time.Time{})

if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
return ch.failSend(ctx, err)
}

if len(p) > 0 {
_, err := ch.bw.Write(p)
if err != nil {
return err
if _, err := ch.bw.Write(p); err != nil {
return ch.failSend(ctx, err)
}
}

return ch.bw.Flush()
if err := ch.bw.Flush(); err != nil {
return ch.failSend(ctx, err)
}

return nil
}

func mapWriteTimeout(ctx context.Context, err error) error {
if err == nil {
return nil
}

var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
}

return err
}

func (ch *channel) failSend(ctx context.Context, err error) error {
// Any write-side failure may leave buffered bytes in an indeterminate state.
// Close the connection so later sends cannot corrupt the framing stream.
_ = ch.conn.Close()
return mapWriteTimeout(ctx, err)
}

func (ch *channel) getmbuf(size int) []byte {
Expand Down
5 changes: 3 additions & 2 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ttrpc

import (
"bytes"
"context"
"errors"
"io"
"net"
Expand All @@ -44,7 +45,7 @@ func TestReadWriteMessage(t *testing.T) {

go func() {
for i, msg := range messages {
if err := ch.send(uint32(i), 1, 0, msg); err != nil {
if err := ch.send(context.Background(), uint32(i), 1, 0, msg); err != nil {
errs <- err
return
}
Expand Down Expand Up @@ -96,7 +97,7 @@ func TestMessageOversize(t *testing.T) {
)

go func() {
errs <- wch.send(1, 1, 0, msg)
errs <- wch.send(context.Background(), 1, 1, 0, msg)
}()

err := <-errs
Expand Down
26 changes: 15 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
return c
}

func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
func (c *Client) send(ctx context.Context, sid uint32, mt messageType, flags uint8, b []byte) error {
c.sendLock.Lock()
defer c.sendLock.Unlock()
return c.channel.send(sid, mt, flags, b)
return c.channel.send(ctx, sid, mt, flags, b)
}

// Call makes a unary request and returns with response
Expand Down Expand Up @@ -214,7 +214,7 @@ func (cs *clientStream) CloseSend() error {
if cs.localClosed {
return ErrStreamClosed
}
err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
err := cs.s.send(cs.ctx, messageTypeData, flagRemoteClosed|flagNoData, nil)
if err != nil {
return filterCloseErr(err)
}
Expand All @@ -241,7 +241,7 @@ func (cs *clientStream) SendMsg(m interface{}) error {
}
}

err = cs.s.send(messageTypeData, 0, payload)
err = cs.s.send(cs.ctx, messageTypeData, 0, payload)
if err != nil {
return filterCloseErr(err)
}
Expand Down Expand Up @@ -384,9 +384,9 @@ func (c *Client) receiveLoop() error {
}
}

// createStream creates a new stream and registers it with the client
// Introduce stream types for multiple or single response
func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
// createStreamWithContext creates a new stream and registers it with the client.
// Introduce stream types for multiple or single response.
func (c *Client) createStreamWithContext(ctx context.Context, flags uint8, b []byte) (*stream, error) {
// sendLock must be held across both allocation of the stream ID and sending it across the wire.
// This ensures that new stream IDs sent on the wire are always increasing, which is a
// requirement of the TTRPC protocol.
Expand Down Expand Up @@ -426,8 +426,12 @@ func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
return nil, err
}

if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
return s, filterCloseErr(err)
if err := c.channel.send(ctx, uint32(s.id), messageTypeRequest, flags, b); err != nil {
c.streamLock.Lock()
delete(c.streams, s.id)
c.streamLock.Unlock()
s.closeWithError(err)
return nil, filterCloseErr(err)
}

return s, nil
Expand Down Expand Up @@ -517,7 +521,7 @@ func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, metho
} else {
flags = flagRemoteClosed
}
s, err := c.createStream(flags, p)
s, err := c.createStreamWithContext(ctx, flags, p)
if err != nil {
return nil, err
}
Expand All @@ -536,7 +540,7 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
return err
}

s, err := c.createStream(0, p)
s, err := c.createStreamWithContext(ctx, 0, p)
if err != nil {
return err
}
Expand Down
58 changes: 58 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package ttrpc

import (
"context"
"errors"
"net"
"testing"
"time"

Expand Down Expand Up @@ -70,3 +72,59 @@ func TestUserOnCloseWait(t *testing.T) {
t.Fatalf("expected error nil , but got %v", err)
}
}

func TestCallSendBlocked(t *testing.T) {
verifyCleanup := func(t *testing.T, client *Client) {
t.Helper()
client.streamLock.RLock()
streamsLen := len(client.streams)
client.streamLock.RUnlock()
if streamsLen != 0 {
t.Fatalf("expected no active streams after send failure, got %d", streamsLen)
}

waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second)
defer waitCancel()
if err := client.UserOnCloseWait(waitCtx); err != nil {
t.Fatalf("expected client to close after send failure, got %v", err)
}
}

t.Run("Timeout", func(t *testing.T) {
serverConn, clientConn := net.Pipe()
client := NewClient(clientConn)
defer serverConn.Close()
defer client.Close()

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

err := client.Call(ctx, "service", "method", &internal.TestPayload{}, &internal.TestPayload{})
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected error %v, got %v", context.DeadlineExceeded, err)
}

verifyCleanup(t, client)
})

t.Run("Cancel", func(t *testing.T) {
serverConn, clientConn := net.Pipe()
client := NewClient(clientConn)
defer serverConn.Close()
defer client.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

err := client.Call(ctx, "service", "method", &internal.TestPayload{}, &internal.TestPayload{})
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected error %v, got %v", context.Canceled, err)
}

verifyCleanup(t, client)
})
}
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func (c *serverConn) run(sctx context.Context) {
return
}

if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
if err := ch.send(ctx, response.id, messageTypeResponse, 0, p); err != nil {
log.G(ctx).WithError(err).Error("failed sending message on channel")
return
}
Expand All @@ -537,7 +537,7 @@ func (c *serverConn) run(sctx context.Context) {
if response.data == nil {
flags = flags | flagNoData
}
if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
if err := ch.send(ctx, response.id, messageTypeData, flags, response.data); err != nil {
log.G(ctx).WithError(err).Error("failed sending message on channel")
return
}
Expand Down
6 changes: 3 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func (s *stream) closeWithError(err error) error {
return nil
}

func (s *stream) send(mt messageType, flags uint8, b []byte) error {
return s.sender.send(uint32(s.id), mt, flags, b)
func (s *stream) send(ctx context.Context, mt messageType, flags uint8, b []byte) error {
return s.sender.send(ctx, uint32(s.id), mt, flags, b)
}

func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
Expand All @@ -80,5 +80,5 @@ func (s *stream) receive(ctx context.Context, msg *streamMessage) error {
}

type sender interface {
send(uint32, messageType, uint8, []byte) error
send(context.Context, uint32, messageType, uint8, []byte) error
}
Loading