Skip to content

Commit 80eb2e4

Browse files
authored
Merge pull request #2775 from dgageot/fix-tui-control-api-bugs
fix: two TUI control-plane bugs (SSE cancel, IPv6 listen)
2 parents 0a765ba + 1f57030 commit 80eb2e4

4 files changed

Lines changed: 164 additions & 3 deletions

File tree

pkg/runtime/client.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,19 +476,24 @@ func (c *Client) GetAgentToolCount(ctx context.Context, agentFilename, agentName
476476
}
477477

478478
// StreamSessionEvents streams events for a session as they occur via Server-Sent Events.
479+
// The returned channel is closed when ctx is cancelled, the stream's max
480+
// duration is reached, or the server closes the connection.
479481
func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string) (<-chan Event, error) {
480482
endpoint := fmt.Sprintf("/api/sessions/%s/events", sessionID)
481483

482484
u := *c.baseURL
483485
u.Path = path.Join(u.Path, endpoint)
484486

485-
// Use long timeout for streaming
487+
// Bound the maximum lifetime of a single SSE connection. The cancel
488+
// must be tied to the goroutine consuming the stream, not to this
489+
// function's return: cancelling streamCtx kills the in-flight HTTP
490+
// request, which would turn the stream into a one-shot read.
486491
timeout := c.timeoutFor("streaming")
487492
streamCtx, cancel := context.WithTimeout(ctx, timeout)
488-
defer cancel()
489493

490494
req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, u.String(), http.NoBody)
491495
if err != nil {
496+
cancel()
492497
return nil, fmt.Errorf("creating request: %w", err)
493498
}
494499

@@ -501,10 +506,12 @@ func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string) (<-c
501506

502507
resp, err := c.httpClient.Do(req) //nolint:bodyclose // body is closed in the goroutine below
503508
if err != nil {
509+
cancel()
504510
return nil, fmt.Errorf("performing request: %w", err)
505511
}
506512

507513
if resp.StatusCode >= 400 {
514+
defer cancel()
508515
defer resp.Body.Close()
509516
respBody, err := io.ReadAll(resp.Body)
510517
if err != nil {
@@ -521,6 +528,7 @@ func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string) (<-c
521528
eventChan := make(chan Event, defaultEventChannelCapacity)
522529

523530
go func() {
531+
defer cancel()
524532
defer close(eventChan)
525533
defer resp.Body.Close()
526534

pkg/runtime/client_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// TestClient_StreamSessionEvents_DeliversMultipleEvents verifies that the
16+
// SSE stream stays open across multiple events instead of being torn down
17+
// when StreamSessionEvents returns. This is a regression test for a bug
18+
// where a deferred cancel() on the streaming context killed the in-flight
19+
// HTTP request as soon as the function returned, turning the stream into
20+
// a one-shot read.
21+
func TestClient_StreamSessionEvents_DeliversMultipleEvents(t *testing.T) {
22+
t.Parallel()
23+
24+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
25+
w.Header().Set("Content-Type", "text/event-stream")
26+
w.WriteHeader(http.StatusOK)
27+
flusher, ok := w.(http.Flusher)
28+
if !ok {
29+
t.Errorf("ResponseWriter must support flushing")
30+
return
31+
}
32+
33+
for i := 1; i <= 3; i++ {
34+
fmt.Fprintf(w, "data: {\"type\":\"session_title\",\"session_id\":\"s\",\"title\":\"t%d\"}\n\n", i)
35+
flusher.Flush()
36+
time.Sleep(20 * time.Millisecond)
37+
}
38+
}))
39+
t.Cleanup(srv.Close)
40+
41+
c, err := NewClient(srv.URL)
42+
require.NoError(t, err)
43+
44+
ch, err := c.StreamSessionEvents(t.Context(), "s")
45+
require.NoError(t, err)
46+
47+
var titles []string
48+
for ev := range ch {
49+
titleEv, ok := ev.(*SessionTitleEvent)
50+
if !ok {
51+
continue
52+
}
53+
titles = append(titles, titleEv.Title)
54+
}
55+
56+
assert.Equal(t, []string{"t1", "t2", "t3"}, titles)
57+
}
58+
59+
// TestClient_StreamSessionEvents_StopsWhenContextCancelled verifies that
60+
// cancelling the caller's context tears down the stream and closes the
61+
// returned channel.
62+
func TestClient_StreamSessionEvents_StopsWhenContextCancelled(t *testing.T) {
63+
t.Parallel()
64+
65+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
66+
w.Header().Set("Content-Type", "text/event-stream")
67+
w.WriteHeader(http.StatusOK)
68+
flusher, _ := w.(http.Flusher)
69+
70+
ticker := time.NewTicker(20 * time.Millisecond)
71+
defer ticker.Stop()
72+
for {
73+
select {
74+
case <-r.Context().Done():
75+
return
76+
case <-ticker.C:
77+
fmt.Fprint(w, "data: {\"type\":\"session_title\",\"session_id\":\"s\",\"title\":\"x\"}\n\n")
78+
flusher.Flush()
79+
}
80+
}
81+
}))
82+
t.Cleanup(srv.Close)
83+
84+
c, err := NewClient(srv.URL)
85+
require.NoError(t, err)
86+
87+
ctx, cancel := context.WithCancel(t.Context())
88+
t.Cleanup(cancel)
89+
90+
ch, err := c.StreamSessionEvents(ctx, "s")
91+
require.NoError(t, err)
92+
93+
// Drain at least one event to confirm the stream is live.
94+
select {
95+
case <-ch:
96+
case <-time.After(2 * time.Second):
97+
t.Fatal("no events received before cancel")
98+
}
99+
100+
cancel()
101+
102+
// Channel must close in a bounded time after cancel.
103+
deadline := time.After(2 * time.Second)
104+
for {
105+
select {
106+
case _, ok := <-ch:
107+
if !ok {
108+
return
109+
}
110+
case <-deadline:
111+
t.Fatal("channel was not closed after context cancel")
112+
}
113+
}
114+
}

pkg/server/listen.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,5 @@ func listenUnix(ctx context.Context, path string) (net.Listener, error) {
6262

6363
func listenTCP(ctx context.Context, addr string) (net.Listener, error) {
6464
var lc net.ListenConfig
65-
return lc.Listen(ctx, "tcp4", addr)
65+
return lc.Listen(ctx, "tcp", addr)
6666
}

pkg/server/listen_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,42 @@ func TestListen_FD_InvalidDescriptor(t *testing.T) {
6262
require.Error(t, err)
6363
assert.Contains(t, err.Error(), "file descriptor 999999")
6464
}
65+
66+
// TestListen_TCP_IPv4 verifies that the default TCP listener binds to an
67+
// IPv4 loopback address. Regression test for the listener being hard-coded
68+
// to "tcp4" without that being the documented intent.
69+
func TestListen_TCP_IPv4(t *testing.T) {
70+
t.Parallel()
71+
72+
ln, err := Listen(t.Context(), "127.0.0.1:0")
73+
require.NoError(t, err)
74+
defer ln.Close()
75+
76+
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
77+
require.True(t, ok)
78+
assert.NotNil(t, tcpAddr.IP.To4())
79+
}
80+
81+
// TestListen_TCP_IPv6 verifies that an IPv6 loopback bind succeeds. The
82+
// listener used to force "tcp4" which made this fail with
83+
// "address ::1: non-IPv4 address" on dual-stack hosts.
84+
func TestListen_TCP_IPv6(t *testing.T) {
85+
t.Parallel()
86+
87+
// Probe whether the host actually has IPv6 before asserting.
88+
var probeLC net.ListenConfig
89+
probe, probeErr := probeLC.Listen(t.Context(), "tcp6", "[::1]:0")
90+
if probeErr != nil {
91+
t.Skipf("host does not support IPv6 loopback: %v", probeErr)
92+
}
93+
_ = probe.Close()
94+
95+
ln, err := Listen(t.Context(), "[::1]:0")
96+
require.NoError(t, err)
97+
defer ln.Close()
98+
99+
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
100+
require.True(t, ok)
101+
assert.True(t, tcpAddr.IP.IsLoopback())
102+
assert.Nil(t, tcpAddr.IP.To4(), "expected an IPv6-only address")
103+
}

0 commit comments

Comments
 (0)