Skip to content

Commit 9f3d6d2

Browse files
authored
Merge pull request #2375 from dgageot/board/session-concurrency-issue-with-tool-call-492e1cda
fix: serialize concurrent RunSession calls to prevent tool_use/tool_result mismatch
2 parents d773763 + 8e0348b commit 9f3d6d2

File tree

3 files changed

+263
-15
lines changed

3 files changed

+263
-15
lines changed

pkg/server/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"cmp"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"log/slog"
910
"net"
@@ -285,6 +286,9 @@ func (s *Server) runAgent(c echo.Context) error {
285286

286287
streamChan, err := s.sm.RunSession(c.Request().Context(), sessionID, agentFilename, currentAgent, messages)
287288
if err != nil {
289+
if errors.Is(err, ErrSessionBusy) {
290+
return echo.NewHTTPError(http.StatusConflict, err.Error())
291+
}
288292
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to run session: %v", err))
289293
}
290294

pkg/server/session_manager.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type activeRuntimes struct {
2727
cancel context.CancelFunc
2828
session *session.Session // The actual session object used by the runtime
2929
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
30+
31+
streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests
3032
}
3133

3234
// SessionManager manages sessions for HTTP and Connect-RPC servers.
@@ -134,6 +136,9 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e
134136
return nil
135137
}
136138

139+
// ErrSessionBusy is returned when a session is already processing a request.
140+
var ErrSessionBusy = errors.New("session is already processing a request")
141+
137142
// RunSession runs a session with the given messages.
138143
func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) {
139144
sm.mux.Lock()
@@ -146,19 +151,6 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
146151
rc := sm.runConfig.Clone()
147152
rc.WorkingDir = sess.WorkingDir
148153

149-
// Collect user messages for potential title generation
150-
var userMessages []string
151-
for _, msg := range messages {
152-
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
153-
if msg.Content != "" {
154-
userMessages = append(userMessages, msg.Content)
155-
}
156-
}
157-
158-
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
159-
return nil, err
160-
}
161-
162154
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
163155
streamCtx, cancel := context.WithCancel(ctx)
164156
var titleGen *sessiontitle.Generator
@@ -177,17 +169,45 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
177169
}
178170
sm.runtimeSessions.Store(sessionID, runtimeSession)
179171
} else {
180-
// Update the session pointer in case it was reloaded
181-
runtimeSession.session = sess
182172
titleGen = runtimeSession.titleGen
183173
}
184174

175+
// Reject the request immediately if the session is already streaming.
176+
// This prevents interleaving user messages while a tool call is in
177+
// progress, which would produce a tool_use without a matching
178+
// tool_result and cause provider errors.
179+
if !runtimeSession.streaming.TryLock() {
180+
cancel()
181+
return nil, ErrSessionBusy
182+
}
183+
184+
// Now that we hold the streaming lock, it is safe to mutate the session.
185+
// Collect user messages for potential title generation
186+
var userMessages []string
187+
for _, msg := range messages {
188+
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
189+
if msg.Content != "" {
190+
userMessages = append(userMessages, msg.Content)
191+
}
192+
}
193+
194+
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
195+
runtimeSession.streaming.Unlock()
196+
cancel()
197+
return nil, err
198+
}
199+
200+
// Update the session pointer so the runtime sees the latest messages.
201+
runtimeSession.session = sess
202+
185203
streamChan := make(chan runtime.Event)
186204

187205
// Check if we need to generate a title
188206
needsTitle := sess.Title == "" && len(userMessages) > 0 && titleGen != nil
189207

190208
go func() {
209+
defer runtimeSession.streaming.Unlock()
210+
191211
// Start title generation in parallel if needed
192212
if needsTitle {
193213
go sm.generateTitle(ctx, sess, titleGen, userMessages, streamChan)

pkg/server/session_manager_test.go

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"sync"
6+
"sync/atomic"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/docker/docker-agent/pkg/api"
14+
"github.com/docker/docker-agent/pkg/concurrent"
15+
"github.com/docker/docker-agent/pkg/config"
16+
"github.com/docker/docker-agent/pkg/runtime"
17+
"github.com/docker/docker-agent/pkg/session"
18+
"github.com/docker/docker-agent/pkg/sessiontitle"
19+
"github.com/docker/docker-agent/pkg/tools"
20+
)
21+
22+
// fakeRuntime is a minimal Runtime that records concurrent RunStream calls.
23+
type fakeRuntime struct {
24+
runtime.Runtime
25+
26+
concurrentStreams atomic.Int32
27+
maxConcurrent atomic.Int32
28+
streamDelay time.Duration
29+
}
30+
31+
func (f *fakeRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
32+
cur := f.concurrentStreams.Add(1)
33+
for {
34+
old := f.maxConcurrent.Load()
35+
if cur <= old || f.maxConcurrent.CompareAndSwap(old, cur) {
36+
break
37+
}
38+
}
39+
40+
ch := make(chan runtime.Event)
41+
go func() {
42+
time.Sleep(f.streamDelay)
43+
f.concurrentStreams.Add(-1)
44+
close(ch)
45+
}()
46+
return ch
47+
}
48+
49+
func (f *fakeRuntime) Resume(_ context.Context, _ runtime.ResumeRequest) {}
50+
51+
func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAction, _ map[string]any) error {
52+
return nil
53+
}
54+
55+
func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager {
56+
t.Helper()
57+
58+
ctx := t.Context()
59+
store := session.NewInMemorySessionStore()
60+
require.NoError(t, store.AddSession(ctx, sess))
61+
62+
sm := &SessionManager{
63+
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
64+
sessionStore: store,
65+
Sources: config.Sources{},
66+
runConfig: &config.RuntimeConfig{},
67+
}
68+
69+
// Pre-register a runtime for this session so RunSession skips agent loading.
70+
sm.runtimeSessions.Store(sess.ID, &activeRuntimes{
71+
runtime: fake,
72+
session: sess,
73+
titleGen: (*sessiontitle.Generator)(nil),
74+
})
75+
76+
return sm
77+
}
78+
79+
// TestRunSession_ConcurrentRequestReturnsErrSessionBusy verifies that a
80+
// second RunSession call on a session that is already streaming returns
81+
// ErrSessionBusy instead of silently interleaving messages.
82+
func TestRunSession_ConcurrentRequestReturnsErrSessionBusy(t *testing.T) {
83+
t.Parallel()
84+
85+
ctx := t.Context()
86+
sess := session.New()
87+
fake := &fakeRuntime{streamDelay: 500 * time.Millisecond}
88+
sm := newTestSessionManager(t, sess, fake)
89+
90+
// Start the first stream.
91+
ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
92+
{Content: "first"},
93+
})
94+
require.NoError(t, err)
95+
96+
// Give the goroutine a moment to acquire the streaming lock.
97+
time.Sleep(50 * time.Millisecond)
98+
99+
// The second request should fail immediately with ErrSessionBusy.
100+
_, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
101+
{Content: "second"},
102+
})
103+
require.ErrorIs(t, err, ErrSessionBusy)
104+
105+
// Drain first stream to let it complete.
106+
for range ch1 {
107+
}
108+
109+
// After the first stream finishes, a new request should succeed.
110+
ch3, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
111+
{Content: "third"},
112+
})
113+
require.NoError(t, err)
114+
for range ch3 {
115+
}
116+
}
117+
118+
// TestRunSession_MessagesNotAddedWhenBusy verifies that when a session
119+
// is busy, the rejected request does not mutate the session's messages.
120+
func TestRunSession_MessagesNotAddedWhenBusy(t *testing.T) {
121+
t.Parallel()
122+
123+
ctx := t.Context()
124+
sess := session.New()
125+
fake := &fakeRuntime{streamDelay: 500 * time.Millisecond}
126+
sm := newTestSessionManager(t, sess, fake)
127+
128+
ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
129+
{Content: "first"},
130+
})
131+
require.NoError(t, err)
132+
133+
time.Sleep(50 * time.Millisecond)
134+
135+
msgCountBefore := len(sess.GetAllMessages())
136+
137+
_, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
138+
{Content: "should not be added"},
139+
})
140+
require.ErrorIs(t, err, ErrSessionBusy)
141+
142+
// Messages should not have been added.
143+
assert.Len(t, sess.GetAllMessages(), msgCountBefore)
144+
145+
for range ch1 {
146+
}
147+
}
148+
149+
// TestRunSession_SequentialRequestsSucceed verifies that sequential
150+
// (non-overlapping) requests on the same session work normally.
151+
func TestRunSession_SequentialRequestsSucceed(t *testing.T) {
152+
t.Parallel()
153+
154+
ctx := t.Context()
155+
sess := session.New()
156+
fake := &fakeRuntime{streamDelay: 10 * time.Millisecond}
157+
sm := newTestSessionManager(t, sess, fake)
158+
159+
for range 3 {
160+
ch, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
161+
{Content: "hello"},
162+
})
163+
require.NoError(t, err)
164+
for range ch {
165+
}
166+
}
167+
168+
assert.Equal(t, int32(1), fake.maxConcurrent.Load())
169+
}
170+
171+
// TestRunSession_DifferentSessionsConcurrently verifies that concurrent
172+
// requests on *different* sessions are not blocked by each other.
173+
func TestRunSession_DifferentSessionsConcurrently(t *testing.T) {
174+
t.Parallel()
175+
176+
ctx := t.Context()
177+
store := session.NewInMemorySessionStore()
178+
fake1 := &fakeRuntime{streamDelay: 200 * time.Millisecond}
179+
fake2 := &fakeRuntime{streamDelay: 200 * time.Millisecond}
180+
181+
sess1 := session.New()
182+
sess2 := session.New()
183+
require.NoError(t, store.AddSession(ctx, sess1))
184+
require.NoError(t, store.AddSession(ctx, sess2))
185+
186+
sm := &SessionManager{
187+
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
188+
sessionStore: store,
189+
Sources: config.Sources{},
190+
runConfig: &config.RuntimeConfig{},
191+
}
192+
193+
sm.runtimeSessions.Store(sess1.ID, &activeRuntimes{
194+
runtime: fake1, session: sess1, titleGen: (*sessiontitle.Generator)(nil),
195+
})
196+
sm.runtimeSessions.Store(sess2.ID, &activeRuntimes{
197+
runtime: fake2, session: sess2, titleGen: (*sessiontitle.Generator)(nil),
198+
})
199+
200+
var wg sync.WaitGroup
201+
wg.Add(2)
202+
203+
go func() {
204+
defer wg.Done()
205+
ch, err := sm.RunSession(ctx, sess1.ID, "agent", "root", []api.Message{{Content: "a"}})
206+
assert.NoError(t, err)
207+
for range ch {
208+
}
209+
}()
210+
211+
go func() {
212+
defer wg.Done()
213+
ch, err := sm.RunSession(ctx, sess2.ID, "agent", "root", []api.Message{{Content: "b"}})
214+
assert.NoError(t, err)
215+
for range ch {
216+
}
217+
}()
218+
219+
wg.Wait()
220+
221+
// Both sessions should have streamed (1 each).
222+
assert.Equal(t, int32(1), fake1.maxConcurrent.Load())
223+
assert.Equal(t, int32(1), fake2.maxConcurrent.Load())
224+
}

0 commit comments

Comments
 (0)