Skip to content

Commit 5a4f823

Browse files
authored
Fix Go client data races (#586)
* Fix Go JSON-RPC client data race * prevent race between startCLIServer and [Force]Stop * we require 1.24 * prevent races between Start and [Force]Stop * fail fast when CLI exits before reporting TCP port
1 parent 21a586d commit 5a4f823

File tree

4 files changed

+132
-32
lines changed

4 files changed

+132
-32
lines changed

go/client.go

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import (
4141
"strconv"
4242
"strings"
4343
"sync"
44+
"sync/atomic"
4445
"time"
4546

4647
"github.com/github/copilot-sdk/go/internal/embeddedcli"
@@ -86,8 +87,10 @@ type Client struct {
8687
lifecycleHandlers []SessionLifecycleHandler
8788
typedLifecycleHandlers map[SessionLifecycleEventType][]SessionLifecycleHandler
8889
lifecycleHandlersMux sync.Mutex
89-
processDone chan struct{} // closed when CLI process exits
90-
processError error // set before processDone is closed
90+
startStopMux sync.RWMutex // protects process and state during start/[force]stop
91+
processDone chan struct{}
92+
processErrorPtr *error
93+
osProcess atomic.Pointer[os.Process]
9194

9295
// RPC provides typed server-scoped RPC methods.
9396
// This field is nil until the client is connected via Start().
@@ -251,6 +254,9 @@ func parseCliUrl(url string) (string, int) {
251254
// }
252255
// // Now ready to create sessions
253256
func (c *Client) Start(ctx context.Context) error {
257+
c.startStopMux.Lock()
258+
defer c.startStopMux.Unlock()
259+
254260
if c.state == StateConnected {
255261
return nil
256262
}
@@ -260,21 +266,24 @@ func (c *Client) Start(ctx context.Context) error {
260266
// Only start CLI server process if not connecting to external server
261267
if !c.isExternalServer {
262268
if err := c.startCLIServer(ctx); err != nil {
269+
c.process = nil
263270
c.state = StateError
264271
return err
265272
}
266273
}
267274

268275
// Connect to the server
269276
if err := c.connectToServer(ctx); err != nil {
277+
killErr := c.killProcess()
270278
c.state = StateError
271-
return err
279+
return errors.Join(err, killErr)
272280
}
273281

274282
// Verify protocol version compatibility
275283
if err := c.verifyProtocolVersion(ctx); err != nil {
284+
killErr := c.killProcess()
276285
c.state = StateError
277-
return err
286+
return errors.Join(err, killErr)
278287
}
279288

280289
c.state = StateConnected
@@ -316,13 +325,16 @@ func (c *Client) Stop() error {
316325
c.sessions = make(map[string]*Session)
317326
c.sessionsMux.Unlock()
318327

328+
c.startStopMux.Lock()
329+
defer c.startStopMux.Unlock()
330+
319331
// Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it
320332
if c.process != nil && !c.isExternalServer {
321-
if err := c.process.Process.Kill(); err != nil {
322-
errs = append(errs, fmt.Errorf("failed to kill CLI process: %w", err))
333+
if err := c.killProcess(); err != nil {
334+
errs = append(errs, err)
323335
}
324-
c.process = nil
325336
}
337+
c.process = nil
326338

327339
// Close external TCP connection if exists
328340
if c.isExternalServer && c.conn != nil {
@@ -375,16 +387,27 @@ func (c *Client) Stop() error {
375387
// client.ForceStop()
376388
// }
377389
func (c *Client) ForceStop() {
390+
// Kill the process without waiting for startStopMux, which Start may hold.
391+
// This unblocks any I/O Start is doing (connect, version check).
392+
if p := c.osProcess.Swap(nil); p != nil {
393+
p.Kill()
394+
}
395+
378396
// Clear sessions immediately without trying to destroy them
379397
c.sessionsMux.Lock()
380398
c.sessions = make(map[string]*Session)
381399
c.sessionsMux.Unlock()
382400

401+
c.startStopMux.Lock()
402+
defer c.startStopMux.Unlock()
403+
383404
// Kill CLI process (only if we spawned it)
405+
// This is a fallback in case the process wasn't killed above (e.g. if Start hadn't set
406+
// osProcess yet), or if the process was restarted and osProcess now points to a new process.
384407
if c.process != nil && !c.isExternalServer {
385-
c.process.Process.Kill() // Ignore errors
386-
c.process = nil
408+
_ = c.killProcess() // Ignore errors since we're force stopping
387409
}
410+
c.process = nil
388411

389412
// Close external TCP connection if exists
390413
if c.isExternalServer && c.conn != nil {
@@ -886,6 +909,8 @@ func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) {
886909
// })
887910
// }
888911
func (c *Client) State() ConnectionState {
912+
c.startStopMux.RLock()
913+
defer c.startStopMux.RUnlock()
889914
return c.state
890915
}
891916

@@ -1096,21 +1121,11 @@ func (c *Client) startCLIServer(ctx context.Context) error {
10961121
return fmt.Errorf("failed to start CLI server: %w", err)
10971122
}
10981123

1099-
// Monitor process exit to signal pending requests
1100-
c.processDone = make(chan struct{})
1101-
go func() {
1102-
waitErr := c.process.Wait()
1103-
if waitErr != nil {
1104-
c.processError = fmt.Errorf("CLI process exited: %v", waitErr)
1105-
} else {
1106-
c.processError = fmt.Errorf("CLI process exited unexpectedly")
1107-
}
1108-
close(c.processDone)
1109-
}()
1124+
c.monitorProcess()
11101125

11111126
// Create JSON-RPC client immediately
11121127
c.client = jsonrpc2.NewClient(stdin, stdout)
1113-
c.client.SetProcessDone(c.processDone, &c.processError)
1128+
c.client.SetProcessDone(c.processDone, c.processErrorPtr)
11141129
c.RPC = rpc.NewServerRpc(c.client)
11151130
c.setupNotificationHandler()
11161131
c.client.Start()
@@ -1127,22 +1142,28 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11271142
return fmt.Errorf("failed to start CLI server: %w", err)
11281143
}
11291144

1130-
// Wait for port announcement
1145+
c.monitorProcess()
1146+
11311147
scanner := bufio.NewScanner(stdout)
11321148
timeout := time.After(10 * time.Second)
11331149
portRegex := regexp.MustCompile(`listening on port (\d+)`)
11341150

11351151
for {
11361152
select {
11371153
case <-timeout:
1138-
return fmt.Errorf("timeout waiting for CLI server to start")
1154+
killErr := c.killProcess()
1155+
return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr)
1156+
case <-c.processDone:
1157+
killErr := c.killProcess()
1158+
return errors.Join(errors.New("CLI server process exited before reporting port"), killErr)
11391159
default:
11401160
if scanner.Scan() {
11411161
line := scanner.Text()
11421162
if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 {
11431163
port, err := strconv.Atoi(matches[1])
11441164
if err != nil {
1145-
return fmt.Errorf("failed to parse port: %w", err)
1165+
killErr := c.killProcess()
1166+
return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr)
11461167
}
11471168
c.actualPort = port
11481169
return nil
@@ -1153,6 +1174,39 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11531174
}
11541175
}
11551176

1177+
func (c *Client) killProcess() error {
1178+
if p := c.osProcess.Swap(nil); p != nil {
1179+
if err := p.Kill(); err != nil {
1180+
return fmt.Errorf("failed to kill CLI process: %w", err)
1181+
}
1182+
}
1183+
c.process = nil
1184+
return nil
1185+
}
1186+
1187+
// monitorProcess signals when the CLI process exits and captures any exit error.
1188+
// processError is intentionally a local: each process lifecycle gets its own
1189+
// error value, so goroutines from previous processes can't overwrite the
1190+
// current one. Closing the channel synchronizes with readers, guaranteeing
1191+
// they see the final processError value.
1192+
func (c *Client) monitorProcess() {
1193+
done := make(chan struct{})
1194+
c.processDone = done
1195+
proc := c.process
1196+
c.osProcess.Store(proc.Process)
1197+
var processError error
1198+
c.processErrorPtr = &processError
1199+
go func() {
1200+
waitErr := proc.Wait()
1201+
if waitErr != nil {
1202+
processError = fmt.Errorf("CLI process exited: %w", waitErr)
1203+
} else {
1204+
processError = errors.New("CLI process exited unexpectedly")
1205+
}
1206+
close(done)
1207+
}()
1208+
}
1209+
11561210
// connectToServer establishes a connection to the server.
11571211
func (c *Client) connectToServer(ctx context.Context) error {
11581212
if c.useStdio {
@@ -1184,6 +1238,9 @@ func (c *Client) connectViaTcp(ctx context.Context) error {
11841238

11851239
// Create JSON-RPC client with the connection
11861240
c.client = jsonrpc2.NewClient(conn, conn)
1241+
if c.processDone != nil {
1242+
c.client.SetProcessDone(c.processDone, c.processErrorPtr)
1243+
}
11871244
c.RPC = rpc.NewServerRpc(c.client)
11881245
c.setupNotificationHandler()
11891246
c.client.Start()

go/client_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"path/filepath"
77
"reflect"
88
"regexp"
9+
"sync"
910
"testing"
1011
)
1112

@@ -486,3 +487,44 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) {
486487
}
487488
})
488489
}
490+
491+
func TestClient_StartStopRace(t *testing.T) {
492+
cliPath := findCLIPathForTest()
493+
if cliPath == "" {
494+
t.Skip("CLI not found")
495+
}
496+
client := NewClient(&ClientOptions{CLIPath: cliPath})
497+
defer client.ForceStop()
498+
errChan := make(chan error)
499+
wg := sync.WaitGroup{}
500+
for range 10 {
501+
wg.Add(3)
502+
go func() {
503+
defer wg.Done()
504+
if err := client.Start(t.Context()); err != nil {
505+
select {
506+
case errChan <- err:
507+
default:
508+
}
509+
}
510+
}()
511+
go func() {
512+
defer wg.Done()
513+
if err := client.Stop(); err != nil {
514+
select {
515+
case errChan <- err:
516+
default:
517+
}
518+
}
519+
}()
520+
go func() {
521+
defer wg.Done()
522+
client.ForceStop()
523+
}()
524+
}
525+
wg.Wait()
526+
close(errChan)
527+
if err := <-errChan; err != nil {
528+
t.Fatal(err)
529+
}
530+
}

go/internal/jsonrpc2/jsonrpc2.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"reflect"
1010
"sync"
11+
"sync/atomic"
1112
)
1213

1314
// Error represents a JSON-RPC error response
@@ -54,7 +55,7 @@ type Client struct {
5455
mu sync.Mutex
5556
pendingRequests map[string]chan *Response
5657
requestHandlers map[string]RequestHandler
57-
running bool
58+
running atomic.Bool
5859
stopChan chan struct{}
5960
wg sync.WaitGroup
6061
processDone chan struct{} // closed when the underlying process exits
@@ -97,17 +98,17 @@ func (c *Client) getProcessError() error {
9798

9899
// Start begins listening for messages in a background goroutine
99100
func (c *Client) Start() {
100-
c.running = true
101+
c.running.Store(true)
101102
c.wg.Add(1)
102103
go c.readLoop()
103104
}
104105

105106
// Stop stops the client and cleans up
106107
func (c *Client) Stop() {
107-
if !c.running {
108+
if !c.running.Load() {
108109
return
109110
}
110-
c.running = false
111+
c.running.Store(false)
111112
close(c.stopChan)
112113

113114
// Close stdout to unblock the readLoop
@@ -298,14 +299,14 @@ func (c *Client) readLoop() {
298299

299300
reader := bufio.NewReader(c.stdout)
300301

301-
for c.running {
302+
for c.running.Load() {
302303
// Read Content-Length header
303304
var contentLength int
304305
for {
305306
line, err := reader.ReadString('\n')
306307
if err != nil {
307308
// Only log unexpected errors (not EOF or closed pipe during shutdown)
308-
if err != io.EOF && c.running {
309+
if err != io.EOF && c.running.Load() {
309310
fmt.Printf("Error reading header: %v\n", err)
310311
}
311312
return

go/test.sh

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ echo
88

99
# Check prerequisites
1010
if ! command -v go &> /dev/null; then
11-
echo "❌ Go is not installed. Please install Go 1.21 or later."
11+
echo "❌ Go is not installed. Please install Go 1.24 or later."
1212
echo " Visit: https://golang.org/dl/"
1313
exit 1
1414
fi
@@ -43,7 +43,7 @@ cd "$(dirname "$0")"
4343
echo "=== Running Go SDK E2E Tests ==="
4444
echo
4545

46-
go test -v ./...
46+
go test -v ./... -race
4747

4848
echo
4949
echo "✅ All tests passed!"

0 commit comments

Comments
 (0)