diff --git a/pkg/server/server.go b/pkg/server/server.go index b9cf3b626..c81ce54bd 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -4,6 +4,7 @@ import ( "cmp" "context" "encoding/json" + "errors" "fmt" "log/slog" "net" @@ -285,6 +286,9 @@ func (s *Server) runAgent(c echo.Context) error { streamChan, err := s.sm.RunSession(c.Request().Context(), sessionID, agentFilename, currentAgent, messages) if err != nil { + if errors.Is(err, ErrSessionBusy) { + return echo.NewHTTPError(http.StatusConflict, err.Error()) + } return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to run session: %v", err)) } diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 6c26d3b58..98eddfb8b 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -27,6 +27,8 @@ type activeRuntimes struct { cancel context.CancelFunc session *session.Session // The actual session object used by the runtime titleGen *sessiontitle.Generator // Title generator (includes fallback models) + + streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests } // SessionManager manages sessions for HTTP and Connect-RPC servers. @@ -134,6 +136,9 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e return nil } +// ErrSessionBusy is returned when a session is already processing a request. +var ErrSessionBusy = errors.New("session is already processing a request") + // RunSession runs a session with the given messages. func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) { sm.mux.Lock() @@ -146,19 +151,6 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena rc := sm.runConfig.Clone() rc.WorkingDir = sess.WorkingDir - // Collect user messages for potential title generation - var userMessages []string - for _, msg := range messages { - sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...)) - if msg.Content != "" { - userMessages = append(userMessages, msg.Content) - } - } - - if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { - return nil, err - } - runtimeSession, exists := sm.runtimeSessions.Load(sessionID) streamCtx, cancel := context.WithCancel(ctx) var titleGen *sessiontitle.Generator @@ -177,17 +169,45 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena } sm.runtimeSessions.Store(sessionID, runtimeSession) } else { - // Update the session pointer in case it was reloaded - runtimeSession.session = sess titleGen = runtimeSession.titleGen } + // Reject the request immediately if the session is already streaming. + // This prevents interleaving user messages while a tool call is in + // progress, which would produce a tool_use without a matching + // tool_result and cause provider errors. + if !runtimeSession.streaming.TryLock() { + cancel() + return nil, ErrSessionBusy + } + + // Now that we hold the streaming lock, it is safe to mutate the session. + // Collect user messages for potential title generation + var userMessages []string + for _, msg := range messages { + sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...)) + if msg.Content != "" { + userMessages = append(userMessages, msg.Content) + } + } + + if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { + runtimeSession.streaming.Unlock() + cancel() + return nil, err + } + + // Update the session pointer so the runtime sees the latest messages. + runtimeSession.session = sess + streamChan := make(chan runtime.Event) // Check if we need to generate a title needsTitle := sess.Title == "" && len(userMessages) > 0 && titleGen != nil go func() { + defer runtimeSession.streaming.Unlock() + // Start title generation in parallel if needed if needsTitle { go sm.generateTitle(ctx, sess, titleGen, userMessages, streamChan) diff --git a/pkg/server/session_manager_test.go b/pkg/server/session_manager_test.go new file mode 100644 index 000000000..39b93fd68 --- /dev/null +++ b/pkg/server/session_manager_test.go @@ -0,0 +1,224 @@ +package server + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/api" + "github.com/docker/docker-agent/pkg/concurrent" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/sessiontitle" + "github.com/docker/docker-agent/pkg/tools" +) + +// fakeRuntime is a minimal Runtime that records concurrent RunStream calls. +type fakeRuntime struct { + runtime.Runtime + + concurrentStreams atomic.Int32 + maxConcurrent atomic.Int32 + streamDelay time.Duration +} + +func (f *fakeRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event { + cur := f.concurrentStreams.Add(1) + for { + old := f.maxConcurrent.Load() + if cur <= old || f.maxConcurrent.CompareAndSwap(old, cur) { + break + } + } + + ch := make(chan runtime.Event) + go func() { + time.Sleep(f.streamDelay) + f.concurrentStreams.Add(-1) + close(ch) + }() + return ch +} + +func (f *fakeRuntime) Resume(_ context.Context, _ runtime.ResumeRequest) {} + +func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAction, _ map[string]any) error { + return nil +} + +func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager { + t.Helper() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + require.NoError(t, store.AddSession(ctx, sess)) + + sm := &SessionManager{ + runtimeSessions: concurrent.NewMap[string, *activeRuntimes](), + sessionStore: store, + Sources: config.Sources{}, + runConfig: &config.RuntimeConfig{}, + } + + // Pre-register a runtime for this session so RunSession skips agent loading. + sm.runtimeSessions.Store(sess.ID, &activeRuntimes{ + runtime: fake, + session: sess, + titleGen: (*sessiontitle.Generator)(nil), + }) + + return sm +} + +// TestRunSession_ConcurrentRequestReturnsErrSessionBusy verifies that a +// second RunSession call on a session that is already streaming returns +// ErrSessionBusy instead of silently interleaving messages. +func TestRunSession_ConcurrentRequestReturnsErrSessionBusy(t *testing.T) { + t.Parallel() + + ctx := t.Context() + sess := session.New() + fake := &fakeRuntime{streamDelay: 500 * time.Millisecond} + sm := newTestSessionManager(t, sess, fake) + + // Start the first stream. + ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "first"}, + }) + require.NoError(t, err) + + // Give the goroutine a moment to acquire the streaming lock. + time.Sleep(50 * time.Millisecond) + + // The second request should fail immediately with ErrSessionBusy. + _, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "second"}, + }) + require.ErrorIs(t, err, ErrSessionBusy) + + // Drain first stream to let it complete. + for range ch1 { + } + + // After the first stream finishes, a new request should succeed. + ch3, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "third"}, + }) + require.NoError(t, err) + for range ch3 { + } +} + +// TestRunSession_MessagesNotAddedWhenBusy verifies that when a session +// is busy, the rejected request does not mutate the session's messages. +func TestRunSession_MessagesNotAddedWhenBusy(t *testing.T) { + t.Parallel() + + ctx := t.Context() + sess := session.New() + fake := &fakeRuntime{streamDelay: 500 * time.Millisecond} + sm := newTestSessionManager(t, sess, fake) + + ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "first"}, + }) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + msgCountBefore := len(sess.GetAllMessages()) + + _, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "should not be added"}, + }) + require.ErrorIs(t, err, ErrSessionBusy) + + // Messages should not have been added. + assert.Len(t, sess.GetAllMessages(), msgCountBefore) + + for range ch1 { + } +} + +// TestRunSession_SequentialRequestsSucceed verifies that sequential +// (non-overlapping) requests on the same session work normally. +func TestRunSession_SequentialRequestsSucceed(t *testing.T) { + t.Parallel() + + ctx := t.Context() + sess := session.New() + fake := &fakeRuntime{streamDelay: 10 * time.Millisecond} + sm := newTestSessionManager(t, sess, fake) + + for range 3 { + ch, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "hello"}, + }) + require.NoError(t, err) + for range ch { + } + } + + assert.Equal(t, int32(1), fake.maxConcurrent.Load()) +} + +// TestRunSession_DifferentSessionsConcurrently verifies that concurrent +// requests on *different* sessions are not blocked by each other. +func TestRunSession_DifferentSessionsConcurrently(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + fake1 := &fakeRuntime{streamDelay: 200 * time.Millisecond} + fake2 := &fakeRuntime{streamDelay: 200 * time.Millisecond} + + sess1 := session.New() + sess2 := session.New() + require.NoError(t, store.AddSession(ctx, sess1)) + require.NoError(t, store.AddSession(ctx, sess2)) + + sm := &SessionManager{ + runtimeSessions: concurrent.NewMap[string, *activeRuntimes](), + sessionStore: store, + Sources: config.Sources{}, + runConfig: &config.RuntimeConfig{}, + } + + sm.runtimeSessions.Store(sess1.ID, &activeRuntimes{ + runtime: fake1, session: sess1, titleGen: (*sessiontitle.Generator)(nil), + }) + sm.runtimeSessions.Store(sess2.ID, &activeRuntimes{ + runtime: fake2, session: sess2, titleGen: (*sessiontitle.Generator)(nil), + }) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + ch, err := sm.RunSession(ctx, sess1.ID, "agent", "root", []api.Message{{Content: "a"}}) + assert.NoError(t, err) + for range ch { + } + }() + + go func() { + defer wg.Done() + ch, err := sm.RunSession(ctx, sess2.ID, "agent", "root", []api.Message{{Content: "b"}}) + assert.NoError(t, err) + for range ch { + } + }() + + wg.Wait() + + // Both sessions should have streamed (1 each). + assert.Equal(t, int32(1), fake1.maxConcurrent.Load()) + assert.Equal(t, int32(1), fake2.maxConcurrent.Load()) +}