@@ -41,6 +41,7 @@ import (
4141 "strconv"
4242 "strings"
4343 "sync"
44+ "sync/atomic"
4445 "time"
4546
4647 "github.com/github/copilot-sdk/go/internal/embeddedcli"
@@ -86,8 +87,10 @@ type Client struct {
8687 lifecycleHandlers []SessionLifecycleHandler
8788 typedLifecycleHandlers map [SessionLifecycleEventType ][]SessionLifecycleHandler
8889 lifecycleHandlersMux sync.Mutex
89- processDone chan struct {} // closed when CLI process exits
90- processError error // set before processDone is closed
90+ startStopMux sync.RWMutex // protects process and state during start/[force]stop
91+ processDone chan struct {}
92+ processErrorPtr * error
93+ osProcess atomic.Pointer [os.Process ]
9194
9295 // RPC provides typed server-scoped RPC methods.
9396 // This field is nil until the client is connected via Start().
@@ -251,6 +254,9 @@ func parseCliUrl(url string) (string, int) {
251254// }
252255// // Now ready to create sessions
253256func (c * Client ) Start (ctx context.Context ) error {
257+ c .startStopMux .Lock ()
258+ defer c .startStopMux .Unlock ()
259+
254260 if c .state == StateConnected {
255261 return nil
256262 }
@@ -260,21 +266,24 @@ func (c *Client) Start(ctx context.Context) error {
260266 // Only start CLI server process if not connecting to external server
261267 if ! c .isExternalServer {
262268 if err := c .startCLIServer (ctx ); err != nil {
269+ c .process = nil
263270 c .state = StateError
264271 return err
265272 }
266273 }
267274
268275 // Connect to the server
269276 if err := c .connectToServer (ctx ); err != nil {
277+ killErr := c .killProcess ()
270278 c .state = StateError
271- return err
279+ return errors . Join ( err , killErr )
272280 }
273281
274282 // Verify protocol version compatibility
275283 if err := c .verifyProtocolVersion (ctx ); err != nil {
284+ killErr := c .killProcess ()
276285 c .state = StateError
277- return err
286+ return errors . Join ( err , killErr )
278287 }
279288
280289 c .state = StateConnected
@@ -316,13 +325,16 @@ func (c *Client) Stop() error {
316325 c .sessions = make (map [string ]* Session )
317326 c .sessionsMux .Unlock ()
318327
328+ c .startStopMux .Lock ()
329+ defer c .startStopMux .Unlock ()
330+
319331 // Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it
320332 if c .process != nil && ! c .isExternalServer {
321- if err := c .process . Process . Kill (); err != nil {
322- errs = append (errs , fmt . Errorf ( "failed to kill CLI process: %w" , err ) )
333+ if err := c .killProcess (); err != nil {
334+ errs = append (errs , err )
323335 }
324- c .process = nil
325336 }
337+ c .process = nil
326338
327339 // Close external TCP connection if exists
328340 if c .isExternalServer && c .conn != nil {
@@ -375,16 +387,27 @@ func (c *Client) Stop() error {
375387// client.ForceStop()
376388// }
377389func (c * Client ) ForceStop () {
390+ // Kill the process without waiting for startStopMux, which Start may hold.
391+ // This unblocks any I/O Start is doing (connect, version check).
392+ if p := c .osProcess .Swap (nil ); p != nil {
393+ p .Kill ()
394+ }
395+
378396 // Clear sessions immediately without trying to destroy them
379397 c .sessionsMux .Lock ()
380398 c .sessions = make (map [string ]* Session )
381399 c .sessionsMux .Unlock ()
382400
401+ c .startStopMux .Lock ()
402+ defer c .startStopMux .Unlock ()
403+
383404 // Kill CLI process (only if we spawned it)
405+ // This is a fallback in case the process wasn't killed above (e.g. if Start hadn't set
406+ // osProcess yet), or if the process was restarted and osProcess now points to a new process.
384407 if c .process != nil && ! c .isExternalServer {
385- c .process .Process .Kill () // Ignore errors
386- c .process = nil
408+ _ = c .killProcess () // Ignore errors since we're force stopping
387409 }
410+ c .process = nil
388411
389412 // Close external TCP connection if exists
390413 if c .isExternalServer && c .conn != nil {
@@ -886,6 +909,8 @@ func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) {
886909// })
887910// }
888911func (c * Client ) State () ConnectionState {
912+ c .startStopMux .RLock ()
913+ defer c .startStopMux .RUnlock ()
889914 return c .state
890915}
891916
@@ -1096,21 +1121,11 @@ func (c *Client) startCLIServer(ctx context.Context) error {
10961121 return fmt .Errorf ("failed to start CLI server: %w" , err )
10971122 }
10981123
1099- // Monitor process exit to signal pending requests
1100- c .processDone = make (chan struct {})
1101- go func () {
1102- waitErr := c .process .Wait ()
1103- if waitErr != nil {
1104- c .processError = fmt .Errorf ("CLI process exited: %v" , waitErr )
1105- } else {
1106- c .processError = fmt .Errorf ("CLI process exited unexpectedly" )
1107- }
1108- close (c .processDone )
1109- }()
1124+ c .monitorProcess ()
11101125
11111126 // Create JSON-RPC client immediately
11121127 c .client = jsonrpc2 .NewClient (stdin , stdout )
1113- c .client .SetProcessDone (c .processDone , & c . processError )
1128+ c .client .SetProcessDone (c .processDone , c . processErrorPtr )
11141129 c .RPC = rpc .NewServerRpc (c .client )
11151130 c .setupNotificationHandler ()
11161131 c .client .Start ()
@@ -1127,22 +1142,28 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11271142 return fmt .Errorf ("failed to start CLI server: %w" , err )
11281143 }
11291144
1130- // Wait for port announcement
1145+ c .monitorProcess ()
1146+
11311147 scanner := bufio .NewScanner (stdout )
11321148 timeout := time .After (10 * time .Second )
11331149 portRegex := regexp .MustCompile (`listening on port (\d+)` )
11341150
11351151 for {
11361152 select {
11371153 case <- timeout :
1138- return fmt .Errorf ("timeout waiting for CLI server to start" )
1154+ killErr := c .killProcess ()
1155+ return errors .Join (errors .New ("timeout waiting for CLI server to start" ), killErr )
1156+ case <- c .processDone :
1157+ killErr := c .killProcess ()
1158+ return errors .Join (errors .New ("CLI server process exited before reporting port" ), killErr )
11391159 default :
11401160 if scanner .Scan () {
11411161 line := scanner .Text ()
11421162 if matches := portRegex .FindStringSubmatch (line ); len (matches ) > 1 {
11431163 port , err := strconv .Atoi (matches [1 ])
11441164 if err != nil {
1145- return fmt .Errorf ("failed to parse port: %w" , err )
1165+ killErr := c .killProcess ()
1166+ return errors .Join (fmt .Errorf ("failed to parse port: %w" , err ), killErr )
11461167 }
11471168 c .actualPort = port
11481169 return nil
@@ -1153,6 +1174,39 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11531174 }
11541175}
11551176
1177+ func (c * Client ) killProcess () error {
1178+ if p := c .osProcess .Swap (nil ); p != nil {
1179+ if err := p .Kill (); err != nil {
1180+ return fmt .Errorf ("failed to kill CLI process: %w" , err )
1181+ }
1182+ }
1183+ c .process = nil
1184+ return nil
1185+ }
1186+
1187+ // monitorProcess signals when the CLI process exits and captures any exit error.
1188+ // processError is intentionally a local: each process lifecycle gets its own
1189+ // error value, so goroutines from previous processes can't overwrite the
1190+ // current one. Closing the channel synchronizes with readers, guaranteeing
1191+ // they see the final processError value.
1192+ func (c * Client ) monitorProcess () {
1193+ done := make (chan struct {})
1194+ c .processDone = done
1195+ proc := c .process
1196+ c .osProcess .Store (proc .Process )
1197+ var processError error
1198+ c .processErrorPtr = & processError
1199+ go func () {
1200+ waitErr := proc .Wait ()
1201+ if waitErr != nil {
1202+ processError = fmt .Errorf ("CLI process exited: %w" , waitErr )
1203+ } else {
1204+ processError = errors .New ("CLI process exited unexpectedly" )
1205+ }
1206+ close (done )
1207+ }()
1208+ }
1209+
11561210// connectToServer establishes a connection to the server.
11571211func (c * Client ) connectToServer (ctx context.Context ) error {
11581212 if c .useStdio {
@@ -1184,6 +1238,9 @@ func (c *Client) connectViaTcp(ctx context.Context) error {
11841238
11851239 // Create JSON-RPC client with the connection
11861240 c .client = jsonrpc2 .NewClient (conn , conn )
1241+ if c .processDone != nil {
1242+ c .client .SetProcessDone (c .processDone , c .processErrorPtr )
1243+ }
11871244 c .RPC = rpc .NewServerRpc (c .client )
11881245 c .setupNotificationHandler ()
11891246 c .client .Start ()
0 commit comments