Skip to content

Commit be9e8ae

Browse files
committed
fix(server): release sm.mux during runtime I/O on model endpoints
1 parent ad1366a commit be9e8ae

1 file changed

Lines changed: 49 additions & 15 deletions

File tree

pkg/server/session_manager.go

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ type activeRuntimes struct {
3434
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
3535

3636
streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests
37+
38+
// modelSwitch serialises concurrent SetSessionAgentModel calls on the
39+
// same session. It is held across the (potentially slow) runtime I/O
40+
// so we never overlap a SetAgentModel + rollback pair with another
41+
// switch on the same session, while still allowing other sessions to
42+
// make progress.
43+
modelSwitch sync.Mutex
3744
}
3845

3946
// SessionManager manages sessions for HTTP and Connect-RPC servers.
@@ -811,9 +818,6 @@ var ErrSessionNotRunning = errors.New("session not found or not running")
811818
// config, a synthetic choice is appended (mirrors App.AvailableModels via
812819
// the shared runtime.DecorateModelChoices helper).
813820
func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID string) (string, string, []runtime.ModelChoice, error) {
814-
sm.mux.Lock()
815-
defer sm.mux.Unlock()
816-
817821
rs, ok := sm.runtimeSessions.Load(sessionID)
818822
if !ok {
819823
return "", "", nil, ErrSessionNotRunning
@@ -824,12 +828,24 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID
824828
}
825829

826830
agentName := rs.runtime.CurrentAgentName()
831+
832+
// Snapshot the override and custom-model history under sm.mux so the
833+
// read is atomic with respect to SetSessionAgentModel writes. The
834+
// (potentially slow) runtime.AvailableModels call must NOT happen
835+
// under sm.mux: it can perform network I/O (provider discovery,
836+
// models.dev catalog lookup) and would block every other session
837+
// operation in the manager.
838+
sm.mux.Lock()
827839
current := ""
828840
var customRefs []string
829841
if rs.session != nil {
830842
current = rs.session.AgentModelOverrides[agentName]
831-
customRefs = rs.session.CustomModelsUsed
843+
if n := len(rs.session.CustomModelsUsed); n > 0 {
844+
customRefs = make([]string, n)
845+
copy(customRefs, rs.session.CustomModelsUsed)
846+
}
832847
}
848+
sm.mux.Unlock()
833849

834850
choices := runtime.DecorateModelChoices(rs.runtime.AvailableModels(ctx), current, customRefs)
835851
return agentName, current, choices, nil
@@ -842,10 +858,13 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID
842858
//
843859
// On store-write failure the in-memory session state and the runtime
844860
// override are rolled back so the next call observes a consistent state.
861+
//
862+
// Concurrent SetSessionAgentModel calls on the same session are
863+
// serialised via the session-scoped modelSwitch lock so the runtime,
864+
// session and store never observe interleaved updates. The manager-wide
865+
// sm.mux is only held briefly while reading or mutating session fields,
866+
// never while calling into the runtime or the store.
845867
func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, modelRef string) (string, string, error) {
846-
sm.mux.Lock()
847-
defer sm.mux.Unlock()
848-
849868
rs, ok := sm.runtimeSessions.Load(sessionID)
850869
if !ok {
851870
return "", "", ErrSessionNotRunning
@@ -855,26 +874,35 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m
855874
return "", "", ErrModelSwitchingNotSupported
856875
}
857876

877+
rs.modelSwitch.Lock()
878+
defer rs.modelSwitch.Unlock()
879+
858880
agentName := rs.runtime.CurrentAgentName()
881+
sess := rs.session
859882

860883
// Snapshot current state so we can roll back if persistence fails
861884
// after we've already mutated the runtime + in-memory session.
862-
sess := rs.session
863885
var (
864-
hadOverride bool
865-
prevOverride string
866-
prevCustomLen int
867-
hadOverridesMap bool
868-
appendedCustomUsed bool
886+
hadOverride bool
887+
prevOverride string
888+
prevCustomLen int
889+
hadOverridesMap bool
869890
)
870891
if sess != nil {
892+
sm.mux.Lock()
871893
hadOverridesMap = sess.AgentModelOverrides != nil
872894
if hadOverridesMap {
873895
prevOverride, hadOverride = sess.AgentModelOverrides[agentName]
874896
}
875897
prevCustomLen = len(sess.CustomModelsUsed)
898+
sm.mux.Unlock()
876899
}
877900

901+
// Runtime mutation runs without sm.mux so it doesn't block other
902+
// session operations during slow provider creation. The per-session
903+
// modelSwitch lock above keeps SetAgentModel + UpdateSession + any
904+
// rollback atomic with respect to other model-switch calls on this
905+
// session.
878906
if err := rs.runtime.SetAgentModel(ctx, agentName, modelRef); err != nil {
879907
return "", "", err
880908
}
@@ -883,6 +911,8 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m
883911
return agentName, modelRef, nil
884912
}
885913

914+
var appendedCustomUsed bool
915+
sm.mux.Lock()
886916
if modelRef == "" {
887917
delete(sess.AgentModelOverrides, agentName)
888918
} else {
@@ -898,10 +928,12 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m
898928
appendedCustomUsed = true
899929
}
900930
}
931+
sm.mux.Unlock()
901932

902933
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
903-
// Roll back: restore in-memory map and runtime so callers don't see
904-
// a runtime/store mismatch on the next request.
934+
// Roll back in-memory under sm.mux first so concurrent readers
935+
// (e.g. AvailableSessionModels) never see the half-applied state.
936+
sm.mux.Lock()
905937
if hadOverride {
906938
sess.AgentModelOverrides[agentName] = prevOverride
907939
} else {
@@ -913,6 +945,8 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m
913945
if appendedCustomUsed {
914946
sess.CustomModelsUsed = sess.CustomModelsUsed[:prevCustomLen]
915947
}
948+
sm.mux.Unlock()
949+
916950
rollback := prevOverride
917951
if !hadOverride {
918952
rollback = ""

0 commit comments

Comments
 (0)