Skip to content

Commit ad5b4d3

Browse files
committed
Add SessionConfig.OnEvent
1 parent 7310091 commit ad5b4d3

15 files changed

Lines changed: 135 additions & 443 deletions

File tree

dotnet/src/Client.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,10 @@ public async Task<CopilotSession> CreateSessionAsync(SessionConfig config, Cance
396396
{
397397
session.RegisterHooks(config.Hooks);
398398
}
399+
if (config.OnEvent != null)
400+
{
401+
session.On(config.OnEvent);
402+
}
399403
_sessions[sessionId] = session;
400404

401405
try
@@ -495,6 +499,10 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes
495499
{
496500
session.RegisterHooks(config.Hooks);
497501
}
502+
if (config.OnEvent != null)
503+
{
504+
session.On(config.OnEvent);
505+
}
498506
_sessions[sessionId] = session;
499507

500508
try

dotnet/src/Types.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ protected SessionConfig(SessionConfig? other)
766766
? new Dictionary<string, object>(other.McpServers, other.McpServers.Comparer)
767767
: null;
768768
Model = other.Model;
769+
OnEvent = other.OnEvent;
769770
OnPermissionRequest = other.OnPermissionRequest;
770771
OnUserInputRequest = other.OnUserInputRequest;
771772
Provider = other.Provider;
@@ -864,6 +865,18 @@ protected SessionConfig(SessionConfig? other)
864865
/// </summary>
865866
public InfiniteSessionConfig? InfiniteSessions { get; set; }
866867

868+
/// <summary>
869+
/// Optional event handler that is registered on the session before the
870+
/// session.create RPC is issued.
871+
/// </summary>
872+
/// </remarks>
873+
/// Equivalent to calling <see cref="CopilotSession.On"/> immediately
874+
/// after creation, but executes earlier in the lifecycle so no events are missed.
875+
/// Using this property rather than <see cref="CopilotSession.On"/> guarantees that early events emitted
876+
/// by the CLI during session creation (e.g. session.start) are delivered to the handler.
877+
/// <remarks>
878+
public SessionEventHandler? OnEvent { get; set; }
879+
867880
/// <summary>
868881
/// Creates a shallow clone of this <see cref="SessionConfig"/> instance.
869882
/// </summary>
@@ -905,6 +918,7 @@ protected ResumeSessionConfig(ResumeSessionConfig? other)
905918
? new Dictionary<string, object>(other.McpServers, other.McpServers.Comparer)
906919
: null;
907920
Model = other.Model;
921+
OnEvent = other.OnEvent;
908922
OnPermissionRequest = other.OnPermissionRequest;
909923
OnUserInputRequest = other.OnUserInputRequest;
910924
Provider = other.Provider;
@@ -1020,6 +1034,12 @@ protected ResumeSessionConfig(ResumeSessionConfig? other)
10201034
/// </summary>
10211035
public InfiniteSessionConfig? InfiniteSessions { get; set; }
10221036

1037+
/// <summary>
1038+
/// Optional event handler registered before the session.resume RPC is issued,
1039+
/// ensuring early events are delivered. See <see cref="SessionConfig.OnEvent"/>.
1040+
/// </summary>
1041+
public SessionEventHandler? OnEvent { get; set; }
1042+
10231043
/// <summary>
10241044
/// Creates a shallow clone of this <see cref="ResumeSessionConfig"/> instance.
10251045
/// </summary>

dotnet/test/SessionTests.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,17 @@ public async Task Should_Pass_Streaming_Option_To_Session_Creation()
296296
[Fact]
297297
public async Task Should_Receive_Session_Events()
298298
{
299-
var session = await CreateSessionAsync();
299+
// Use OnEvent to capture events dispatched during session creation.
300+
// session.start is emitted during the session.create RPC; if the session
301+
// weren't registered in the sessions map before the RPC, it would be dropped.
302+
var earlyEvents = new List<SessionEvent>();
303+
var session = await CreateSessionAsync(new SessionConfig
304+
{
305+
OnEvent = evt => earlyEvents.Add(evt),
306+
});
307+
308+
Assert.Contains(earlyEvents, evt => evt is SessionStartEvent);
309+
300310
var receivedEvents = new List<SessionEvent>();
301311
var idleReceived = new TaskCompletionSource<bool>();
302312

go/client.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,9 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
536536
if config.Hooks != nil {
537537
session.registerHooks(config.Hooks)
538538
}
539+
if config.OnEvent != nil {
540+
session.On(config.OnEvent)
541+
}
539542

540543
c.sessionsMux.Lock()
541544
c.sessions[sessionID] = session
@@ -645,6 +648,9 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
645648
if config.Hooks != nil {
646649
session.registerHooks(config.Hooks)
647650
}
651+
if config.OnEvent != nil {
652+
session.On(config.OnEvent)
653+
}
648654

649655
c.sessionsMux.Lock()
650656
c.sessions[sessionID] = session

go/client_test.go

Lines changed: 0 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
package copilot
22

33
import (
4-
"bufio"
54
"encoding/json"
6-
"fmt"
7-
"io"
85
"os"
96
"path/filepath"
107
"reflect"
118
"regexp"
129
"sync"
1310
"testing"
14-
15-
"github.com/github/copilot-sdk/go/internal/jsonrpc2"
1611
)
1712

1813
// This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead
@@ -574,230 +569,3 @@ func TestClient_StartStopRace(t *testing.T) {
574569
t.Fatal(err)
575570
}
576571
}
577-
578-
// fakeJSONRPCServer reads one JSON-RPC request from r and sends a response to w.
579-
// onRequest is called with the parsed method and params before the response is sent,
580-
// allowing the caller to inspect state (e.g. the sessions map) during the RPC.
581-
func fakeJSONRPCServer(t *testing.T, r io.Reader, w io.WriteCloser, onRequest func(method string, params json.RawMessage)) {
582-
t.Helper()
583-
reader := bufio.NewReader(r)
584-
585-
// Read Content-Length header
586-
var contentLength int
587-
for {
588-
line, err := reader.ReadString('\n')
589-
if err != nil {
590-
t.Errorf("failed to read header: %v", err)
591-
w.Close()
592-
return
593-
}
594-
if line == "\r\n" || line == "\n" {
595-
break
596-
}
597-
fmt.Sscanf(line, "Content-Length: %d", &contentLength)
598-
}
599-
600-
// Read body
601-
body := make([]byte, contentLength)
602-
if _, err := io.ReadFull(reader, body); err != nil {
603-
t.Errorf("failed to read body: %v", err)
604-
w.Close()
605-
return
606-
}
607-
608-
// Parse request
609-
var req struct {
610-
ID json.RawMessage `json:"id"`
611-
Method string `json:"method"`
612-
Params json.RawMessage `json:"params"`
613-
}
614-
if err := json.Unmarshal(body, &req); err != nil {
615-
t.Errorf("failed to unmarshal request: %v", err)
616-
w.Close()
617-
return
618-
}
619-
620-
onRequest(req.Method, req.Params)
621-
622-
// Echo sessionId from request params
623-
var params struct {
624-
SessionID string `json:"sessionId"`
625-
}
626-
json.Unmarshal(req.Params, &params)
627-
628-
result, _ := json.Marshal(map[string]any{"sessionId": params.SessionID, "workspacePath": "/tmp"})
629-
resp, _ := json.Marshal(map[string]any{
630-
"jsonrpc": "2.0",
631-
"id": req.ID,
632-
"result": json.RawMessage(result),
633-
})
634-
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp))
635-
w.Write([]byte(header))
636-
w.Write(resp)
637-
}
638-
639-
// fakeJSONRPCErrorServer reads one JSON-RPC request and returns an error response.
640-
func fakeJSONRPCErrorServer(t *testing.T, r io.Reader, w io.WriteCloser) {
641-
t.Helper()
642-
reader := bufio.NewReader(r)
643-
644-
var contentLength int
645-
for {
646-
line, err := reader.ReadString('\n')
647-
if err != nil {
648-
w.Close()
649-
return
650-
}
651-
if line == "\r\n" || line == "\n" {
652-
break
653-
}
654-
fmt.Sscanf(line, "Content-Length: %d", &contentLength)
655-
}
656-
657-
body := make([]byte, contentLength)
658-
if _, err := io.ReadFull(reader, body); err != nil {
659-
w.Close()
660-
return
661-
}
662-
663-
var req struct {
664-
ID json.RawMessage `json:"id"`
665-
}
666-
json.Unmarshal(body, &req)
667-
668-
resp, _ := json.Marshal(map[string]any{
669-
"jsonrpc": "2.0",
670-
"id": req.ID,
671-
"error": map[string]any{"code": -32000, "message": "test error"},
672-
})
673-
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(resp))
674-
w.Write([]byte(header))
675-
w.Write(resp)
676-
}
677-
678-
// newTestClientWithFakeServer creates a Client wired to a fake jsonrpc2.Client
679-
// backed by the provided io pipes. The caller must call jrpcClient.Stop() when done.
680-
func newTestClientWithFakeServer(clientWriter io.WriteCloser, clientReader io.ReadCloser) (*Client, *jsonrpc2.Client) {
681-
jrpcClient := jsonrpc2.NewClient(clientWriter, clientReader)
682-
jrpcClient.Start()
683-
684-
client := NewClient(nil)
685-
client.client = jrpcClient
686-
client.state = StateConnected
687-
client.sessions = make(map[string]*Session)
688-
return client, jrpcClient
689-
}
690-
691-
func TestClient_CreateSession_RegistersSessionBeforeRPC(t *testing.T) {
692-
// Create pipes: client writes to serverReader, server writes to clientReader
693-
serverReader, clientWriter := io.Pipe()
694-
clientReader, serverWriter := io.Pipe()
695-
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
696-
defer jrpcClient.Stop()
697-
698-
sessionInMap := false
699-
go fakeJSONRPCServer(t, serverReader, serverWriter, func(method string, params json.RawMessage) {
700-
if method != "session.create" {
701-
t.Errorf("expected session.create, got %s", method)
702-
}
703-
var p struct {
704-
SessionID string `json:"sessionId"`
705-
}
706-
json.Unmarshal(params, &p)
707-
client.sessionsMux.Lock()
708-
_, sessionInMap = client.sessions[p.SessionID]
709-
client.sessionsMux.Unlock()
710-
})
711-
712-
session, err := client.CreateSession(t.Context(), &SessionConfig{
713-
OnPermissionRequest: PermissionHandler.ApproveAll,
714-
})
715-
if err != nil {
716-
t.Fatalf("CreateSession failed: %v", err)
717-
}
718-
if session == nil {
719-
t.Fatal("expected non-nil session")
720-
}
721-
if !sessionInMap {
722-
t.Error("session was not in sessions map when session.create RPC was issued")
723-
}
724-
}
725-
726-
func TestClient_ResumeSession_RegistersSessionBeforeRPC(t *testing.T) {
727-
serverReader, clientWriter := io.Pipe()
728-
clientReader, serverWriter := io.Pipe()
729-
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
730-
defer jrpcClient.Stop()
731-
732-
sessionInMap := false
733-
go fakeJSONRPCServer(t, serverReader, serverWriter, func(method string, params json.RawMessage) {
734-
if method != "session.resume" {
735-
t.Errorf("expected session.resume, got %s", method)
736-
}
737-
var p struct {
738-
SessionID string `json:"sessionId"`
739-
}
740-
json.Unmarshal(params, &p)
741-
client.sessionsMux.Lock()
742-
_, sessionInMap = client.sessions[p.SessionID]
743-
client.sessionsMux.Unlock()
744-
})
745-
746-
session, err := client.ResumeSessionWithOptions(t.Context(), "test-session-id", &ResumeSessionConfig{
747-
OnPermissionRequest: PermissionHandler.ApproveAll,
748-
})
749-
if err != nil {
750-
t.Fatalf("ResumeSessionWithOptions failed: %v", err)
751-
}
752-
if session == nil {
753-
t.Fatal("expected non-nil session")
754-
}
755-
if !sessionInMap {
756-
t.Error("session was not in sessions map when session.resume RPC was issued")
757-
}
758-
}
759-
760-
func TestClient_CreateSession_CleansUpOnRPCFailure(t *testing.T) {
761-
serverReader, clientWriter := io.Pipe()
762-
clientReader, serverWriter := io.Pipe()
763-
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
764-
defer jrpcClient.Stop()
765-
766-
// Send a JSON-RPC error response to simulate failure
767-
go fakeJSONRPCErrorServer(t, serverReader, serverWriter)
768-
769-
_, err := client.CreateSession(t.Context(), &SessionConfig{
770-
OnPermissionRequest: PermissionHandler.ApproveAll,
771-
})
772-
if err == nil {
773-
t.Fatal("expected error from CreateSession")
774-
}
775-
client.sessionsMux.Lock()
776-
count := len(client.sessions)
777-
client.sessionsMux.Unlock()
778-
if count != 0 {
779-
t.Errorf("expected 0 sessions after failed create, got %d", count)
780-
}
781-
}
782-
783-
func TestClient_ResumeSession_CleansUpOnRPCFailure(t *testing.T) {
784-
serverReader, clientWriter := io.Pipe()
785-
clientReader, serverWriter := io.Pipe()
786-
client, jrpcClient := newTestClientWithFakeServer(clientWriter, clientReader)
787-
defer jrpcClient.Stop()
788-
789-
go fakeJSONRPCErrorServer(t, serverReader, serverWriter)
790-
791-
_, err := client.ResumeSessionWithOptions(t.Context(), "test-session-id", &ResumeSessionConfig{
792-
OnPermissionRequest: PermissionHandler.ApproveAll,
793-
})
794-
if err == nil {
795-
t.Fatal("expected error from ResumeSessionWithOptions")
796-
}
797-
client.sessionsMux.Lock()
798-
count := len(client.sessions)
799-
client.sessionsMux.Unlock()
800-
if count != 0 {
801-
t.Errorf("expected 0 sessions after failed resume, got %d", count)
802-
}
803-
}

go/internal/e2e/session_test.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,11 +661,31 @@ func TestSession(t *testing.T) {
661661
t.Run("should receive session events", func(t *testing.T) {
662662
ctx.ConfigureForTest(t)
663663

664-
session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{OnPermissionRequest: copilot.PermissionHandler.ApproveAll})
664+
// Use OnEvent to capture events dispatched during session creation.
665+
// session.start is emitted during the session.create RPC; if the session
666+
// weren't registered in the sessions map before the RPC, it would be dropped.
667+
var earlyEvents []copilot.SessionEvent
668+
session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{
669+
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
670+
OnEvent: func(event copilot.SessionEvent) {
671+
earlyEvents = append(earlyEvents, event)
672+
},
673+
})
665674
if err != nil {
666675
t.Fatalf("Failed to create session: %v", err)
667676
}
668677

678+
hasSessionStart := false
679+
for _, evt := range earlyEvents {
680+
if evt.Type == "session.start" {
681+
hasSessionStart = true
682+
break
683+
}
684+
}
685+
if !hasSessionStart {
686+
t.Error("Expected session.start event via OnEvent during creation")
687+
}
688+
669689
var receivedEvents []copilot.SessionEvent
670690
idle := make(chan bool)
671691

0 commit comments

Comments
 (0)