Skip to content

Commit 48d83cb

Browse files
committed
drpcmanager: fix context cancellation error for unary RPCs
Fix a race condition where a unary RPC with a cancelled context could return io.EOF instead of codes.Canceled. Two changes, mirroring how gRPC handles this: 1. Early ctx.Err() check in NewClientStream before creating the stream. 2. Deferred stream.CheckCancelError in doInvoke to convert io.EOF to the cancel error if the stream was cancelled mid-operation. The problem: With multiplexing, each stream gets a manageStream goroutine that watches ctx.Done() and calls SendCancel + Cancel when the context is cancelled. This races with doInvoke, which writes invoke and message frames through the same stream. The race has three outcomes depending on who acquires the stream's write lock first: 1. doInvoke wins the lock, completes all writes, and the RPC succeeds even though it should have been cancelled. 2. SendCancel wins, sets send=io.EOF before doInvoke runs. rawWriteLocked sees send.IsSet() and returns io.EOF. Invoke's ToRPCErr passes io.EOF through unchanged, so the caller gets the wrong error code. 3. doInvoke finishes writes, then MsgRecv sees the cancellation and returns codes.Canceled. This is the correct outcome but only happens by luck of timing. Why this didn't happen before multiplexing: The old single-stream manager used a non-blocking SendCancel that returned (busy=true) when the write lock was held by an in-progress write. With SoftCancel=false (the default), the fallback path was: manageStream calls stream.Cancel(ctx.Err()). The stream is not finished because doInvoke holds the write lock, so the manager calls m.terminate(), which closes the entire transport. The in-flight Writer.Write() fails with an IO error, and checkCancelError sees cancel.IsSet() and returns context.Canceled. The correct error surfaced, but through connection termination. This was fine in single-stream mode where one stream is one connection. With multiplexing, we cannot terminate the entire connection for one stream's cancellation. The new SendCancel blocks on the write lock to guarantee the cancel frame is sent, and that introduced this race. How gRPC handles this (verified against grpc-go source): gRPC uses two mechanisms. First, newAttemptLocked (stream.go:408) checks cs.ctx.Err() before creating the transport stream. This catches the already-cancelled case without allocating resources. Second, for unary RPCs, csAttempt.sendMsg (stream.go:1092) swallows write errors and returns nil when !cs.desc.ClientStreams. The real error always surfaces from RecvMsg, which detects context cancellation via recvBufferReader.readClient (transport.go:239) and returns status.Error(codes.Canceled, ...). This means gRPC never returns io.EOF from a unary RPC because it never short-circuits on a send error. For streaming RPCs, gRPC returns io.EOF from Send() after cancel (the stream is done for writing) and codes.Canceled from Recv() (the actual reason). Our grpccompat tests confirm this by comparing gRPC and DRPC error results for identical cancel scenarios. Our fix: Rather than restructuring doInvoke to swallow send errors like gRPC, we use the stream's existing CheckCancelError mechanism. NewClientStream checks ctx.Err() before creating the stream. This mirrors gRPC's newAttemptLocked check and avoids wasting a stream ID, spawning a goroutine, and allocating stream resources. doInvoke defers stream.CheckCancelError on its return value. If any operation in doInvoke fails because SendCancel won the write lock race (returning io.EOF via the send signal), CheckCancelError replaces it with the cancel signal's error (context.Canceled). This is the same function the stream already uses internally for transport write failures. CheckCancelError is exported (was checkCancelError) so that doInvoke in the drpcconn package can call it. On TOCTOU: The NewClientStream check is technically TOCTOU: the context could be cancelled immediately after the check passes. This is acceptable because Go's context cancellation model is cooperative, not preemptive. The context package provides Done() "for use in select statements," and operations check at natural boundaries rather than continuously. The standard library follows this pattern: http.Client.Do checks between redirect hops, database/sql checks before query execution, and gRPC checks in newAttemptLocked before creating the transport stream. If the context is cancelled mid-operation, manageStream handles cleanup and the deferred CheckCancelError corrects the error code.
1 parent 68e31cb commit 48d83cb

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

drpcconn/conn.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou
8989
}
9090

9191
func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) {
92+
defer func() { err = stream.CheckCancelError(err) }()
9293
if err := stream.WriteInvoke(rpc, metadata); err != nil {
9394
return err
9495
}

drpcmanager/manager.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ func (m *Manager) Close() error {
328328

329329
// NewClientStream starts a stream on the managed transport for use by a client.
330330
func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) {
331+
if err := ctx.Err(); err != nil {
332+
return nil, err
333+
}
331334
return m.newStream(ctx, m.lastStreamID.Add(1), drpc.StreamKindClient, rpc)
332335
}
333336

drpcstream/stream.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,10 @@ func (s *Stream) checkFinished() {
297297
}
298298
}
299299

300-
// checkCancelError will replace the error with one from the cancel signal if it
300+
// CheckCancelError will replace the error with one from the cancel signal if it
301301
// is set. This is to prevent errors from reads/writes to a transport after it
302302
// has been asynchronously closed due to context cancelation.
303-
func (s *Stream) checkCancelError(err error) error {
303+
func (s *Stream) CheckCancelError(err error) error {
304304
if s.sigs.cancel.IsSet() {
305305
return s.sigs.cancel.Err()
306306
}
@@ -401,7 +401,7 @@ func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) {
401401
s.log("SEND", fr.String)
402402

403403
if err := s.wr.WriteFrame(fr); err != nil {
404-
return s.checkCancelError(errs.Wrap(err))
404+
return s.CheckCancelError(errs.Wrap(err))
405405
} else if fr.Done {
406406
return nil
407407
}
@@ -496,7 +496,7 @@ func (s *Stream) SendError(serr error) (err error) {
496496
s.terminate(termError)
497497
s.mu.Unlock()
498498

499-
return s.checkCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr)))
499+
return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindError, false, drpcwire.MarshalError(serr)))
500500
}
501501

502502
// SendCancel terminates the stream and sends a cancel to the remote side. It
@@ -519,7 +519,7 @@ func (s *Stream) SendCancel(err error) error {
519519
s.terminate(err)
520520
s.mu.Unlock()
521521

522-
return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil))
522+
return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCancel, true, nil))
523523
}
524524

525525
// Close terminates the stream and sends that the stream has been closed to the
@@ -540,7 +540,7 @@ func (s *Stream) Close() (err error) {
540540
s.terminate(termClosed)
541541
s.mu.Unlock()
542542

543-
return s.checkCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil))
543+
return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindClose, false, nil))
544544
}
545545

546546
// CloseSend informs the remote that no more messages will be sent. If the remote has
@@ -563,7 +563,7 @@ func (s *Stream) CloseSend() (err error) {
563563
s.terminateIfBothClosed()
564564
s.mu.Unlock()
565565

566-
return s.checkCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil))
566+
return s.CheckCancelError(s.sendPacketLocked(drpcwire.KindCloseSend, false, nil))
567567
}
568568

569569
// Cancel transitions the stream into a state where all writes to the transport will return

0 commit comments

Comments
 (0)