diff --git a/client.go b/client.go index db82b4182..5a3060a2f 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "net" + "runtime" "strings" "sync" "syscall" @@ -107,9 +108,18 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker } } -// NewClient creates a new ttrpc client using the given connection +// NewClient creates a new ttrpc client using the given connection. +// It is equivalent to [NewClientWithContext] with [context.Background] as +// the parent context. func NewClient(conn net.Conn, opts ...ClientOpts) *Client { - ctx, cancel := context.WithCancel(context.Background()) + return NewClientWithContext(context.Background(), conn, opts...) +} + +// NewClientWithContext creates a new ttrpc client using the given connection, +// deriving the client's internal context from ctx. Cancellation of ctx +// shuts the client down; the client's own Close does not cancel ctx. +func NewClientWithContext(ctx context.Context, conn net.Conn, opts ...ClientOpts) *Client { + ctx, cancel := context.WithCancel(ctx) channel := newChannel(conn) c := &Client{ codec: codec{}, @@ -434,10 +444,22 @@ func (c *Client) createStream(flags uint8, b []byte, recvBuf int) (*stream, erro } func (c *Client) deleteStream(s *stream) { + c.deleteStreamWithError(s, nil) +} + +// deleteStreamWithError removes the stream from the client and closes it, +// propagating the supplied error to anyone still observing the stream via +// receive (the connection read loop) or RecvMsg. A nil error closes the +// stream with the default ErrClosed. +// +// The stream is closed before being removed from the map so that any +// in-flight message dispatch in the read loop observes recvErr through the +// normal receive path rather than falling through to "inactive stream". +func (c *Client) deleteStreamWithError(s *stream, err error) { + s.closeWithError(err) c.streamLock.Lock() delete(c.streams, s.id) c.streamLock.Unlock() - s.closeWithError(nil) } func (c *Client) getStream(sid streamID) *stream { @@ -522,12 +544,46 @@ func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, metho return nil, err } - return &clientStream{ + cs := &clientStream{ ctx: ctx, s: s, c: c, desc: desc, - }, nil + } + // Attach a cleanup as a safety net for callers that drop the stream + // without consuming it to completion. In the common case the stream is + // already closed by the time GC reaches it and the cleanup is a no-op. + // If it is still open, the caller leaked the stream; force-close it + // with errStreamAbandoned so the connection's receive loop is not + // blocked by a buffer that will never drain, and the abandon surfaces + // in logs with a specific cause. + runtime.AddCleanup(cs, finalizeClientStream, clientStreamCleanupArgs{c: c, s: s}) + return cs, nil +} + +// clientStreamCleanupArgs carries the state needed by finalizeClientStream. +// It must not reference the *clientStream that the cleanup is attached to, +// otherwise the cleanup would never fire. +type clientStreamCleanupArgs struct { + c *Client + s *stream +} + +// finalizeClientStream is the runtime.AddCleanup callback registered for each +// clientStream returned by NewStream. The fast path is the common case: the +// stream has already been closed (recvClose is closed) and there is nothing +// to do. The slow path indicates the caller dropped the stream without +// closing it; force-close the stream with errStreamAbandoned so the abandon +// surfaces through the connection read loop's "failed to handle message" +// log when a frame is in flight, and so the receive loop is not blocked by +// a buffer that will never drain. +func finalizeClientStream(args clientStreamCleanupArgs) { + select { + case <-args.s.recvClose: + return + default: + } + args.c.deleteStreamWithError(args.s, errStreamAbandoned) } func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { diff --git a/errors.go b/errors.go index 1e6f6b9c9..d6930c1b8 100644 --- a/errors.go +++ b/errors.go @@ -44,6 +44,14 @@ var ( ErrStreamFull = errors.New("ttrpc: stream buffer full") ) +// errStreamAbandoned is set on a stream's recvErr by the runtime cleanup +// when the caller dropped a clientStream without closing it. It is not +// exported because it cannot reach external callers: by the time the +// cleanup runs every reference to the clientStream is gone, so no RecvMsg +// or dispatch is left to observe it. Its purpose is to differentiate the +// abandon case in the connection read loop's error log from a normal close. +var errStreamAbandoned = errors.New("ttrpc: stream abandoned by caller") + // OversizedMessageErr is used to indicate refusal to send an oversized message. // It wraps a ResourceExhausted grpc Status together with the offending message // length. diff --git a/go.mod b/go.mod index 413e01a7b..4aa0883ce 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/containerd/ttrpc -go 1.23 +go 1.24 require ( github.com/containerd/log v0.1.0 @@ -13,4 +13,4 @@ require ( google.golang.org/protobuf v1.36.0 ) -require github.com/sirupsen/logrus v1.9.3 // indirect +require github.com/sirupsen/logrus v1.9.3 diff --git a/services.go b/services.go index ac7c752f7..2ed2bc7ce 100644 --- a/services.go +++ b/services.go @@ -163,11 +163,17 @@ func (s *serviceSet) handle(ctx context.Context, req *Request, respond func(*sta return nil, status.Errorf(codes.Unimplemented, "method %v", req.Method) } -// streamRecvBufferSize is the buffer size for stream recv channels. It -// should be large enough to absorb normal bursts without hitting the -// 1-second timeout fallback in receive/data, but small enough that -// per-stream memory overhead stays trivial. -const streamRecvBufferSize = 64 +// streamRecvBufferSize is the buffer size for stream recv channels. +// +// Consumers are expected to either process incoming messages immediately +// or hand them off to a goroutine, so the buffer only needs to absorb +// brief scheduling jitter (a GC pause, a blocked syscall, the time to +// pass a value to a worker channel). 16 is comfortably above realistic +// jitter without inflating per-stream memory. Sustained slowness will +// hit the streamFullTimeout fallback in stream.receive / streamHandler.data, +// at which point either runtime.AddCleanup (for abandoned streams) or +// ErrStreamFull (for buggy held-but-unconsumed consumers) recovers. +const streamRecvBufferSize = 16 type streamHandler struct { ctx context.Context @@ -203,7 +209,7 @@ func (s *streamHandler) data(unmarshal Unmarshaler) error { return nil case <-s.ctx.Done(): return s.ctx.Err() - case <-time.After(time.Second): + case <-time.After(streamFullTimeout): return ErrStreamFull } } diff --git a/stream.go b/stream.go index a6a71def6..f07879401 100644 --- a/stream.go +++ b/stream.go @@ -29,6 +29,23 @@ type streamMessage struct { payload []byte } +// streamFullTimeout bounds how long the receive loop will wait for a stream's +// recv buffer to drain before giving up. The fallback prevents a single +// unconsumed stream from indefinitely blocking the connection-level receive +// loop. +// +// Most buffer fillups in practice are abandoned streams (the caller dropped +// the clientStream without consuming it), and those are handled faster by +// the runtime.AddCleanup attached in NewStream than by waiting out this +// timeout. Five seconds therefore primarily bounds head-of-line blocking +// for the rarer "held but not consumed" case (a goroutine leak in the +// caller), where neither the cleanup nor the contract that consumers +// process immediately or hand off to a goroutine can help. +// +// Exposed as a var (rather than const) so tests can extend it to observe +// the abandon-via-cleanup unblock path without racing the timeout. +var streamFullTimeout = 5 * time.Second + type stream struct { id streamID sender sender @@ -92,7 +109,7 @@ func (s *stream) receive(ctx context.Context, msg *streamMessage) error { return nil case <-ctx.Done(): return ctx.Err() - case <-time.After(time.Second): + case <-time.After(streamFullTimeout): s.closeWithError(ErrStreamFull) return ErrStreamFull } diff --git a/stream_cleanup_test.go b/stream_cleanup_test.go new file mode 100644 index 000000000..a754a229a --- /dev/null +++ b/stream_cleanup_test.go @@ -0,0 +1,286 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "runtime" + "sync" + "testing" + "time" + + "github.com/containerd/log" + "github.com/containerd/ttrpc/internal" + "github.com/sirupsen/logrus" +) + +// TestClientStreamCleanupOnAbandon verifies the runtime.AddCleanup safety +// net attached to clientStream in NewStream. It guarantees three properties: +// +// 1. When the caller drops a clientStream without consuming it, the +// underlying *stream is closed with errStreamAbandoned. +// 2. The cleanup unblocks the connection's receive loop, so other streams +// and unary calls on the same connection continue to make progress +// even when the buffer-full fallback timeout would otherwise apply. +// 3. errStreamAbandoned reaches the connection read loop's "failed to +// handle message" log, identifying the abandon as the cause. +// +// streamFullTimeout is extended for the duration of the test so the only +// mechanism that can unblock the receive loop is the cleanup itself. +// Without the cleanup, the test would deadlock until the original 1-second +// timeout fired. +func TestClientStreamCleanupOnAbandon(t *testing.T) { + prev := streamFullTimeout + streamFullTimeout = time.Hour + t.Cleanup(func() { streamFullTimeout = prev }) + + // Build a context whose attached logger captures the abandon error + // (forwarding all other entries to t.Log). The client is constructed + // with this context so its internal receive loop logs through the + // captured logger rather than the standard one. + abandonSeen := make(chan struct{}) + ctx := withCaptureLogger(t, errStreamAbandoned, abandonSeen) + + server := mustServer(t)(NewServer()) + addr, listener := newTestListener(t) + defer listener.Close() + + conn, err := net.Dial("unix", addr) + if err != nil { + t.Fatal(err) + } + client := NewClientWithContext(ctx, conn) + defer func() { + conn.Close() + client.Close() + }() + + const serviceName = "streamService" + desc := &ServiceDesc{ + Methods: map[string]Method{ + "Echo": func(_ context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req internal.EchoPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + req.Seq++ + return &req, nil + }, + }, + Streams: map[string]Stream{ + "EchoStream": { + Handler: func(_ context.Context, ss StreamServer) (interface{}, error) { + for { + var req internal.EchoPayload + if err := ss.RecvMsg(&req); err != nil { + return nil, err + } + req.Seq++ + if err := ss.SendMsg(&req); err != nil { + return nil, err + } + } + }, + StreamingClient: true, + StreamingServer: true, + }, + }, + } + server.RegisterService(serviceName, desc) + + go server.Serve(ctx, listener) + defer server.Close() + + // Open a stream, fill the receive buffer to capacity so the + // connection's read loop is reliably blocked in s.receive, and + // abandon the *clientStream. Filling the buffer before returning is + // what makes the next assertion deterministic: when the cleanup + // fires, the blocked receive call is the path that surfaces + // errStreamAbandoned to the log. + s := abandonClientStream(ctx, t, client, serviceName) + + waitForStreamCleanup(t, s, 10*time.Second) + + // With streamFullTimeout pinned at one hour, only the cleanup can + // have unblocked the receive loop. A unary call must therefore + // complete promptly. + callCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var req, resp internal.EchoPayload + req.Seq = 42 + req.Msg = "must not deadlock" + if err := client.Call(callCtx, serviceName, "Echo", &req, &resp); err != nil { + t.Fatalf("unary Call did not complete after abandoned stream cleanup: %v", err) + } + if resp.Seq != 43 { + t.Fatalf("unexpected sequence: got %d, want 43", resp.Seq) + } + + // Wait for the abandon error to reach the receive-path log. This + // is the path that surfaces the abandon to operators in production. + // The log should already have fired by the time waitForStreamCleanup + // returned (the cleanup itself is what unblocks the receive call + // that emits the log), but allow a generous window to avoid any + // scheduling-related flakiness. + select { + case <-abandonSeen: + case <-time.After(5 * time.Second): + t.Fatal("expected errStreamAbandoned to surface in the connection read loop's error log") + } +} + +// waitForStreamCleanup drives GC repeatedly until the runtime.AddCleanup +// callback registered for the stream's parent clientStream has run, or the +// supplied timeout elapses. AddCleanup callbacks execute on a separate +// goroutine after a GC cycle marks the object unreachable, so polling is +// required. +func waitForStreamCleanup(t *testing.T, s *stream, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for { + runtime.GC() + select { + case <-s.recvClose: + if s.recvErr != errStreamAbandoned { + t.Fatalf("expected recvErr to be errStreamAbandoned, got %v", s.recvErr) + } + return + default: + } + if time.Now().After(deadline) { + t.Fatal("clientStream cleanup did not run within deadline") + } + time.Sleep(10 * time.Millisecond) + } +} + +// abandonClientStream creates a streaming RPC, sends enough messages to +// fill the client-side recv buffer (and thereby block the connection's +// read loop in s.receive), then returns the underlying *stream. The +// *clientStream is local to this function so it becomes unreachable as +// soon as the function returns, allowing GC to reclaim it. +// +// Waiting for the buffer to be full before returning is what makes the +// abandon-error log assertion deterministic: when the cleanup fires, the +// blocked s.receive call is the one path that emits "failed to handle +// message" with errStreamAbandoned. If the buffer were not full, the read +// loop might be idle between iterations when the cleanup runs, in which +// case the abandon never surfaces in the log. +// +//go:noinline +func abandonClientStream(ctx context.Context, t *testing.T, c *Client, service string) *stream { + t.Helper() + cs, err := c.NewStream(ctx, &StreamDesc{StreamingClient: true, StreamingServer: true}, service, "EchoStream", nil) + if err != nil { + t.Fatal(err) + } + s := cs.(*clientStream).s + + // Send well above buffer capacity to overflow the channel once the + // server echoes back; SendMsg only writes to the wire and does not + // wait for the receive buffer to drain. + for i := 0; i < cap(s.recv)*4; i++ { + if err := cs.SendMsg(&internal.EchoPayload{Seq: int64(i), Msg: "fill"}); err != nil { + break + } + } + + // Wait for the buffer to be full (read loop blocked in s.receive). + bufferFull := time.Now().Add(5 * time.Second) + for time.Now().Before(bufferFull) { + if len(s.recv) == cap(s.recv) { + return s + } + time.Sleep(time.Millisecond) + } + t.Fatalf("recv buffer did not fill within deadline: have %d/%d", len(s.recv), cap(s.recv)) + return nil +} + +// withCaptureLogger attaches a logger to t.Context() that records whether +// any log entry's "error" field matches target (signaling via the supplied +// channel) and forwards every other entry to t.Log. This mirrors the +// pattern used by github.com/containerd/log/logtest, but adds an +// assertion hook so a test can wait until a specific error has been +// observed without scraping or replacing the standard logger. +// +// Repeated "ttrpc: received message on inactive stream" entries are +// dropped: the test intentionally creates the condition that produces +// them, and they would otherwise drown out other diagnostic output. +func withCaptureLogger(t *testing.T, target error, seen chan<- struct{}) context.Context { + t.Helper() + hook := &captureHook{ + t: t, + target: target, + seen: seen, + fmt: &logrus.TextFormatter{ + DisableColors: true, + TimestampFormat: log.RFC3339NanoFixed, + }, + } + + l := logrus.New() + l.SetLevel(logrus.DebugLevel) + l.SetOutput(io.Discard) + l.AddHook(hook) + + entry := logrus.NewEntry(l).WithField("testcase", t.Name()) + return log.WithLogger(t.Context(), entry) +} + +// captureHook is a logrus.Hook that signals (via seen) the first time it +// observes an "error" field matching target, and forwards every other +// log entry to t.Log. Noisy logs intentionally produced by the test are +// dropped so they do not bury other output. +type captureHook struct { + t testing.TB + target error + seen chan<- struct{} + once sync.Once + fmt logrus.Formatter + mu sync.Mutex +} + +func (*captureHook) Levels() []logrus.Level { return logrus.AllLevels } + +func (h *captureHook) Fire(e *logrus.Entry) error { + if raw, ok := e.Data["error"]; ok { + if err, ok := raw.(error); ok && errors.Is(err, h.target) { + h.once.Do(func() { close(h.seen) }) + } + } + + // Drop logs the test intentionally provokes; everything else is + // forwarded to t.Log so failures retain useful diagnostic context. + if e.Message == "ttrpc: received message on inactive stream" { + return nil + } + + formatted, err := h.fmt.Format(e) + if err != nil { + return err + } + h.mu.Lock() + defer h.mu.Unlock() + h.t.Log(string(bytes.TrimRight(formatted, "\n"))) + return nil +} diff --git a/stream_full_test.go b/stream_full_test.go index f454d3135..8e3ff8646 100644 --- a/stream_full_test.go +++ b/stream_full_test.go @@ -83,16 +83,18 @@ func TestStreamNotConsumedDoesNotBlockConnection(t *testing.T) { defer server.Close() // Create a bidirectional streaming RPC and send messages into it, - // but never call RecvMsg. This will fill up the stream's receive - // buffer (capacity 1) once the server echoes back. + // but never call RecvMsg. Once the server echoes back more than the + // client-side recv buffer can hold, the connection's receive loop + // will hit the buffer-full fallback in stream.receive. abandonedStream, err := client.NewStream(ctx, &StreamDesc{true, true}, serviceName, "EchoStream", nil) if err != nil { t.Fatal(err) } - // Send enough messages to guarantee the server has echoed back more - // than the client-side buffer (capacity 1) can hold. - for i := 0; i < 10; i++ { + // Send buffer+1 messages so exactly one echo will be left blocked + // in stream.receive, which keeps total wait bounded by a single + // streamFullTimeout regardless of buffer/timeout tuning. + for i := 0; i < streamRecvBufferSize+1; i++ { if err := abandonedStream.SendMsg(&internal.EchoPayload{ Seq: int64(i), Msg: "abandoned", @@ -103,15 +105,12 @@ func TestStreamNotConsumedDoesNotBlockConnection(t *testing.T) { } } - // Wait for the receive loop to detect the abandoned stream. The buffer - // fills immediately, then the 1-second timeout fires, closing the - // stream and unblocking the receive loop for other streams. - time.Sleep(2 * time.Second) - // A unary call on the same connection must succeed. Without the - // timeout in stream.receive, the receiveLoop would still be blocked - // trying to deliver to the abandoned stream. - callCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + // timeout in stream.receive (or the AddCleanup on dropped streams) + // the receiveLoop would still be blocked trying to deliver to the + // abandoned stream. The deadline must clear streamFullTimeout with + // margin for the unary round-trip. + callCtx, cancel := context.WithTimeout(ctx, streamFullTimeout+5*time.Second) defer cancel() var req, resp internal.EchoPayload @@ -202,13 +201,13 @@ func TestStreamFullOnServer(t *testing.T) { t.Fatal("timed out waiting for handler to start") } - // Send many messages to fill up the server's recv buffer (capacity 5). - // The server handler is not consuming, so these will pile up. + // Send buffer+1 messages so exactly one will be left blocked in the + // server's data(), bounding total wait by a single streamFullTimeout. // We send in a goroutine because sends may eventually block. sendDone := make(chan struct{}) go func() { defer close(sendDone) - for i := 0; i < 20; i++ { + for i := 0; i < streamRecvBufferSize+1; i++ { if err := slowStream.SendMsg(&internal.EchoPayload{ Seq: int64(i), Msg: "filling buffer", @@ -218,13 +217,10 @@ func TestStreamFullOnServer(t *testing.T) { } }() - // Wait for the server receive goroutine to detect the full buffer. - // The 1-second timeout in data() fires, after which the receive - // goroutine can process other streams again. - time.Sleep(2 * time.Second) - // Verify we can still make a unary call on the same connection. - callCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + // The deadline must clear streamFullTimeout with margin for the + // unary round-trip. + callCtx, cancel := context.WithTimeout(ctx, streamFullTimeout+5*time.Second) defer cancel() var req, resp internal.EchoPayload