From 23090072d844edc5f74a802f9bbb5db625832bd8 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 14:37:20 +0200 Subject: [PATCH 01/12] feat: support model overrides when creating sessions Allow callers to specify agent_model_overrides and custom_models_used when creating sessions via POST /api/sessions, and wire the runtime with model switcher configuration to apply those overrides when the session is first run. --- pkg/server/session_manager.go | 128 ++++++++++++++++++++++++++++- pkg/server/session_manager_test.go | 4 + 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 57dd3025a..ec9b3e9be 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -5,8 +5,10 @@ import ( "errors" "fmt" "log/slog" + "maps" "os" "path/filepath" + "slices" "strings" "sync" "time" @@ -223,6 +225,18 @@ func (sm *SessionManager) CreateSession(ctx context.Context, sessionTemplate *se } sess := session.New(opts...) + + // Copy model-related fields from the template so callers can pin a + // specific model when creating a session over the API. The runtime + // will pick these up the first time it is built for the session + // (see runtimeForSession). + if len(sessionTemplate.AgentModelOverrides) > 0 { + sess.AgentModelOverrides = maps.Clone(sessionTemplate.AgentModelOverrides) + } + if len(sessionTemplate.CustomModelsUsed) > 0 { + sess.CustomModelsUsed = append([]string(nil), sessionTemplate.CustomModelsUsed...) + } + return sess, sm.sessionStore.AddSession(ctx, sess) } @@ -603,10 +617,11 @@ func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.S // constructor: it must not touch sm.runtimeSessions, otherwise it would // briefly publish a half-initialised activeRuntimes (e.g. without the // cancel func) that other goroutines could observe. - t, err := sm.loadTeam(ctx, agentFilename, rc) + loadResult, err := sm.loadTeamWithConfig(ctx, agentFilename, rc) if err != nil { return nil, nil, err } + t := loadResult.Team // Resolve the team's default agent when no specific agent was requested. agt, err := t.AgentOrDefault(currentAgent) @@ -618,17 +633,37 @@ func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.S sess.MaxConsecutiveToolCalls = agt.MaxConsecutiveToolCalls() sess.MaxOldToolCallTokens = agt.MaxOldToolCallTokens() + modelSwitcherCfg := &runtime.ModelSwitcherConfig{ + Models: loadResult.Models, + Providers: loadResult.Providers, + ModelsGateway: rc.ModelsGateway, + EnvProvider: rc.EnvProvider(), + AgentDefaultModels: loadResult.AgentDefaultModels, + } + opts := []runtime.Opt{ runtime.WithCurrentAgent(currentAgent), runtime.WithManagedOAuth(false), runtime.WithSessionStore(sm.sessionStore), runtime.WithTracer(otel.Tracer("cagent")), + runtime.WithModelSwitcherConfig(modelSwitcherCfg), } run, err := runtime.New(t, opts...) if err != nil { return nil, nil, err } + // Apply any stored per-agent model overrides so that a session + // resumed (or freshly created with overrides via CreateSession) uses + // the requested models instead of the agent's defaults. + if len(sess.AgentModelOverrides) > 0 && run.SupportsModelSwitching() { + for agentName, modelRef := range sess.AgentModelOverrides { + if err := run.SetAgentModel(ctx, agentName, modelRef); err != nil { + slog.WarnContext(ctx, "Failed to apply stored model override", "session_id", sess.ID, "agent", agentName, "model", modelRef, "error", err) + } + } + } + titleGen := sessiontitle.New(agt.Model(ctx), agt.FallbackModels()...) slog.DebugContext(ctx, "Runtime created for session", "session_id", sess.ID) @@ -645,6 +680,17 @@ func (sm *SessionManager) loadTeam(ctx context.Context, agentFilename string, ru return teamloader.Load(ctx, agentSource, runConfig) } +// loadTeamWithConfig is like loadTeam but also returns the loaded model and +// provider configuration so the runtime can be wired for model switching. +func (sm *SessionManager) loadTeamWithConfig(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*teamloader.LoadResult, error) { + agentSource, found := sm.Sources[agentFilename] + if !found { + return nil, fmt.Errorf("agent not found: %s", agentFilename) + } + + return teamloader.LoadWithConfig(ctx, agentSource, runConfig) +} + // GetAgentToolCount loads the agent's team and returns the number of // tools available to the given agent. When agentName is empty, it // resolves to the team's default agent. @@ -729,6 +775,86 @@ func (sm *SessionManager) SetSessionStarred(ctx context.Context, sessionID strin return sm.sessionStore.SetSessionStarred(ctx, sessionID, starred) } +// ErrModelSwitchingNotSupported is returned when the runtime backing a +// session does not support runtime model switching (e.g. when the agent +// was created without a ModelSwitcherConfig). +var ErrModelSwitchingNotSupported = errors.New("model switching not supported by this runtime") + +// AvailableSessionModels returns the list of models available for the +// session's current agent. The agent's name and the active model override +// (if any) are returned alongside the choices so callers don't have to +// peek into the runtime registry. A session-scoped runtime is required, +// so the session must have been started at least once (RunSession called) +// or be attached out-of-band via AttachRuntime. +func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID string) (string, string, []runtime.ModelChoice, error) { + rs, ok := sm.runtimeSessions.Load(sessionID) + if !ok { + return "", "", nil, errors.New("session not found or not running") + } + + if !rs.runtime.SupportsModelSwitching() { + return "", "", nil, ErrModelSwitchingNotSupported + } + + agentName := rs.runtime.CurrentAgentName() + current := "" + if rs.session != nil { + current = rs.session.AgentModelOverrides[agentName] + } + + return agentName, current, rs.runtime.AvailableModels(ctx), nil +} + +// SetSessionAgentModel applies modelRef as the model override for the +// current agent of the session, persists it to the session store, and +// tracks custom models for later re-selection. Pass an empty modelRef +// to clear the override and revert to the agent's default model. +func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, modelRef string) (string, string, error) { + sm.mux.Lock() + defer sm.mux.Unlock() + + rs, ok := sm.runtimeSessions.Load(sessionID) + if !ok { + return "", "", errors.New("session not found or not running") + } + + if !rs.runtime.SupportsModelSwitching() { + return "", "", ErrModelSwitchingNotSupported + } + + agentName := rs.runtime.CurrentAgentName() + if err := rs.runtime.SetAgentModel(ctx, agentName, modelRef); err != nil { + return "", "", err + } + + sess := rs.session + if sess == nil { + return agentName, modelRef, nil + } + + if modelRef == "" { + delete(sess.AgentModelOverrides, agentName) + } else { + if sess.AgentModelOverrides == nil { + sess.AgentModelOverrides = make(map[string]string) + } + sess.AgentModelOverrides[agentName] = modelRef + + // Track inline provider/model references so they remain easy to + // re-select via the model picker (mirrors App.SetCurrentAgentModel). + if strings.Contains(modelRef, "/") && !slices.Contains(sess.CustomModelsUsed, modelRef) { + sess.CustomModelsUsed = append(sess.CustomModelsUsed, modelRef) + } + } + + if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { + return "", "", fmt.Errorf("failed to persist model override: %w", err) + } + + slog.DebugContext(ctx, "Updated session model override", "session_id", sessionID, "agent", agentName, "model", modelRef) + return agentName, modelRef, nil +} + // BatchDeleteSessions deletes multiple sessions in a single operation. func (sm *SessionManager) BatchDeleteSessions(ctx context.Context, sessionIDs []string) (int, []string) { sm.mux.Lock() diff --git a/pkg/server/session_manager_test.go b/pkg/server/session_manager_test.go index f30526afd..b5ad8e953 100644 --- a/pkg/server/session_manager_test.go +++ b/pkg/server/session_manager_test.go @@ -58,6 +58,10 @@ func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAc func (f *fakeRuntime) CurrentAgentName() string { return "root" } +// SupportsModelSwitching reports false by default. Tests that exercise +// the /models endpoints embed fakeRuntime and override this. +func (f *fakeRuntime) SupportsModelSwitching() bool { return false } + func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager { t.Helper() From e324eadf8d10ada7385318fc02913293841cef0b Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 14:37:26 +0200 Subject: [PATCH 02/12] feat: add GET/PATCH/POST endpoints for runtime model switching Add GET /api/sessions/:id/models to list available models and the current override (if any), and PATCH/POST /api/sessions/:id/model to apply a model override for the session's current agent. Extend ModelChoice with JSON tags for wire format and add SessionModelsResponse type. Both endpoints return 422 if the runtime doesn't support model switching. --- pkg/api/types.go | 12 ++ pkg/runtime/model_switcher.go | 48 ++++-- pkg/server/server.go | 51 ++++++ pkg/server/session_models_test.go | 263 ++++++++++++++++++++++++++++++ 4 files changed, 356 insertions(+), 18 deletions(-) create mode 100644 pkg/server/session_models_test.go diff --git a/pkg/api/types.go b/pkg/api/types.go index 0bfa5cd7a..f35143169 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -270,3 +270,15 @@ type SessionStatusResponse struct { OutputTokens int64 `json:"output_tokens"` NumMessages int `json:"num_messages"` } + +// SetSessionModelRequest is the body of PATCH /api/sessions/:id/model. +// An empty Model clears the override and reverts to the agent's default. +type SetSessionModelRequest struct { + Model string `json:"model"` +} + +// SetSessionModelResponse is the response from PATCH /api/sessions/:id/model. +type SetSessionModelResponse struct { + Agent string `json:"agent"` + Model string `json:"model,omitempty"` +} diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 50231dee4..84b55b265 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -16,49 +16,61 @@ import ( "github.com/docker/docker-agent/pkg/modelsdev" ) -// ModelChoice represents a model available for selection in the TUI picker. +// ModelChoice represents a model available for selection in the model picker. +// +// JSON tags are part of the public wire format used by +// GET /api/sessions/:id/models; renaming a tag is a breaking change. type ModelChoice struct { // Name is the display name (config key) - Name string + Name string `json:"name"` // Ref is the model reference used internally (e.g., "my_model" or "openai/gpt-4o") - Ref string + Ref string `json:"ref"` // Provider is the provider name (e.g., "openai", "anthropic") - Provider string + Provider string `json:"provider,omitempty"` // Model is the specific model name (e.g., "gpt-4o", "claude-sonnet-4-0") - Model string + Model string `json:"model,omitempty"` // IsDefault indicates this is the agent's configured default model - IsDefault bool + IsDefault bool `json:"is_default,omitempty"` // IsCurrent indicates this is the currently active model for the agent - IsCurrent bool + IsCurrent bool `json:"is_current,omitempty"` // IsCustom indicates this is a custom model from the session history (not from config) - IsCustom bool + IsCustom bool `json:"is_custom,omitempty"` // IsCatalog indicates this is a model from the models.dev catalog - IsCatalog bool + IsCatalog bool `json:"is_catalog,omitempty"` // The fields below are populated (best-effort) from the models.dev // catalog. They are optional and may all be zero/empty when no // catalog entry is found for the model. // Family is the model family (e.g., "claude", "gpt"). - Family string + Family string `json:"family,omitempty"` // InputCost is the price (in USD) per 1M input tokens. - InputCost float64 + InputCost float64 `json:"input_cost,omitempty"` // OutputCost is the price (in USD) per 1M output tokens. - OutputCost float64 + OutputCost float64 `json:"output_cost,omitempty"` // CacheReadCost is the price (in USD) per 1M cached input tokens. - CacheReadCost float64 + CacheReadCost float64 `json:"cache_read_cost,omitempty"` // CacheWriteCost is the price (in USD) per 1M cache-write tokens. - CacheWriteCost float64 + CacheWriteCost float64 `json:"cache_write_cost,omitempty"` // ContextLimit is the maximum context window size in tokens. - ContextLimit int + ContextLimit int `json:"context_limit,omitempty"` // OutputLimit is the maximum number of tokens the model can produce // in a single response. - OutputLimit int64 + OutputLimit int64 `json:"output_limit,omitempty"` // InputModalities lists the input modalities supported by the model // (e.g., "text", "image", "audio"). - InputModalities []string + InputModalities []string `json:"input_modalities,omitempty"` // OutputModalities lists the output modalities the model can produce. - OutputModalities []string + OutputModalities []string `json:"output_modalities,omitempty"` +} + +// SessionModelsResponse is the response returned by +// GET /api/sessions/:id/models. CurrentModelRef is the active override for +// the named agent (empty when the agent is using its configured default). +type SessionModelsResponse struct { + Agent string `json:"agent"` + CurrentModelRef string `json:"current_model_ref,omitempty"` + Models []ModelChoice `json:"models"` } // ModelSwitcherConfig holds the configuration needed for model switching. diff --git a/pkg/server/server.go b/pkg/server/server.go index b9a182302..e6858068d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -19,6 +19,7 @@ import ( "github.com/docker/docker-agent/pkg/api" "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/echolog" + "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/upstream" ) @@ -70,6 +71,9 @@ func (s *Server) registerRoutes() { group.PATCH("/sessions/:id/title", s.updateSessionTitle) group.PATCH("/sessions/:id/tokens", s.updateSessionTokens) group.PATCH("/sessions/:id/starred", s.setSessionStarred) + group.GET("/sessions/:id/models", s.getSessionModels) + group.PATCH("/sessions/:id/model", s.setSessionModel) + group.POST("/sessions/:id/model", s.setSessionModel) group.POST("/sessions", s.createSession) group.DELETE("/sessions/:id", s.deleteSession) group.POST("/sessions/:id/agent/:agent", s.runAgent) @@ -557,6 +561,53 @@ func (s *Server) setSessionStarred(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "updated"}) } +// getSessionModels lists the models the user can pick from for the +// session's current agent. Returns 404 if the session has no active runtime +// (it must have been started at least once or be attached out-of-band) +// or 422 if the runtime does not support model switching. +func (s *Server) getSessionModels(c echo.Context) error { + sessionID := c.Param("id") + + agentName, current, choices, err := s.sm.AvailableSessionModels(c.Request().Context(), sessionID) + if err != nil { + if errors.Is(err, ErrModelSwitchingNotSupported) { + return echo.NewHTTPError(http.StatusUnprocessableEntity, err.Error()) + } + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + + return c.JSON(http.StatusOK, runtime.SessionModelsResponse{ + Agent: agentName, + CurrentModelRef: current, + Models: choices, + }) +} + +// setSessionModel applies a model override on the session's current agent +// and persists it. An empty `model` clears the override and reverts the +// agent to its configured default. +func (s *Server) setSessionModel(c echo.Context) error { + sessionID := c.Param("id") + + var req api.SetSessionModelRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err)) + } + + agentName, modelRef, err := s.sm.SetSessionAgentModel(c.Request().Context(), sessionID, req.Model) + if err != nil { + if errors.Is(err, ErrModelSwitchingNotSupported) { + return echo.NewHTTPError(http.StatusUnprocessableEntity, err.Error()) + } + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + return c.JSON(http.StatusOK, api.SetSessionModelResponse{ + Agent: agentName, + Model: modelRef, + }) +} + func (s *Server) batchDeleteSessions(c echo.Context) error { var req api.BatchDeleteSessionsRequest if err := c.Bind(&req); err != nil { diff --git a/pkg/server/session_models_test.go b/pkg/server/session_models_test.go new file mode 100644 index 000000000..6f5b52123 --- /dev/null +++ b/pkg/server/session_models_test.go @@ -0,0 +1,263 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/api" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" +) + +// modelSwitchingRuntime is a fakeRuntime variant that supports model +// switching so the /models and /model endpoints can be exercised +// without spinning up a real LocalRuntime. +type modelSwitchingRuntime struct { + fakeRuntime + + mu sync.Mutex + currentAgent string + availableModels []runtime.ModelChoice + overrides map[string]string + setErr error +} + +func newModelSwitchingRuntime(models []runtime.ModelChoice) *modelSwitchingRuntime { + return &modelSwitchingRuntime{ + currentAgent: "root", + availableModels: models, + overrides: make(map[string]string), + } +} + +func (m *modelSwitchingRuntime) CurrentAgentName() string { return m.currentAgent } + +func (m *modelSwitchingRuntime) SupportsModelSwitching() bool { return true } + +func (m *modelSwitchingRuntime) AvailableModels(_ context.Context) []runtime.ModelChoice { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]runtime.ModelChoice, len(m.availableModels)) + copy(out, m.availableModels) + return out +} + +func (m *modelSwitchingRuntime) SetAgentModel(_ context.Context, agentName, modelRef string) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.setErr != nil { + return m.setErr + } + if modelRef == "" { + delete(m.overrides, agentName) + return nil + } + m.overrides[agentName] = modelRef + return nil +} + +func TestSessionManager_CreateSession_KeepsModelOverrides(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + + template := &session.Session{ + AgentModelOverrides: map[string]string{ + "root": "openai/gpt-4o", + "researcher": "anthropic/claude-sonnet-4-0", + }, + CustomModelsUsed: []string{"openai/gpt-4o"}, + } + + created, err := sm.CreateSession(ctx, template) + require.NoError(t, err) + require.NotEmpty(t, created.ID) + + assert.Equal(t, "openai/gpt-4o", created.AgentModelOverrides["root"]) + assert.Equal(t, "anthropic/claude-sonnet-4-0", created.AgentModelOverrides["researcher"]) + assert.Equal(t, []string{"openai/gpt-4o"}, created.CustomModelsUsed) + + // Mutating the template after creation must not affect the stored session. + template.AgentModelOverrides["root"] = "mutated" + assert.Equal(t, "openai/gpt-4o", created.AgentModelOverrides["root"]) + + stored, err := store.GetSession(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, "openai/gpt-4o", stored.AgentModelOverrides["root"]) +} + +func TestAttachedServer_GetSessionModels(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + sess.AgentModelOverrides = map[string]string{"root": "openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + choices := []runtime.ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", Provider: "openai", Model: "gpt-4o-mini", IsDefault: true}, + {Name: "custom", Ref: "openai/gpt-4o", Provider: "openai", Model: "gpt-4o", IsCurrent: true}, + } + fake := newModelSwitchingRuntime(choices) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = srv.Serve(ctx, ln) }() + + addr := "http://" + ln.Addr().String() + resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) + + var got runtime.SessionModelsResponse + require.NoError(t, json.Unmarshal(resp, &got)) + + assert.Equal(t, "root", got.Agent) + assert.Equal(t, "openai/gpt-4o", got.CurrentModelRef) + require.Len(t, got.Models, 2) + assert.Equal(t, "openai/gpt-4o-mini", got.Models[0].Ref) + assert.True(t, got.Models[0].IsDefault) + assert.Equal(t, "openai/gpt-4o", got.Models[1].Ref) + assert.True(t, got.Models[1].IsCurrent) +} + +func TestAttachedServer_SetSessionModel_PersistsOverride(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = srv.Serve(ctx, ln) }() + + addr := "http://" + ln.Addr().String() + resp := httpDoTCP(t, ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", + api.SetSessionModelRequest{Model: "anthropic/claude-sonnet-4-0"}) + + var got api.SetSessionModelResponse + require.NoError(t, json.Unmarshal(resp, &got)) + assert.Equal(t, "root", got.Agent) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got.Model) + + // The runtime must have received the override. + fake.mu.Lock() + assert.Equal(t, "anthropic/claude-sonnet-4-0", fake.overrides["root"]) + fake.mu.Unlock() + + // The session in the store must reflect the override and track the + // custom model for future picks. + stored, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, "anthropic/claude-sonnet-4-0", stored.AgentModelOverrides["root"]) + assert.Contains(t, stored.CustomModelsUsed, "anthropic/claude-sonnet-4-0") +} + +func TestAttachedServer_SetSessionModel_EmptyClearsOverride(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + sess.AgentModelOverrides = map[string]string{"root": "openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = srv.Serve(ctx, ln) }() + + addr := "http://" + ln.Addr().String() + _ = httpDoTCP(t, ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", + api.SetSessionModelRequest{Model: ""}) + + stored, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + _, exists := stored.AgentModelOverrides["root"] + assert.False(t, exists, "override should be cleared") +} + +func TestAttachedServer_SetSessionModel_PostVerbAlsoWorks(t *testing.T) { + // The pre-existing pkg/runtime Client.SetAgentModel POSTs to + // /api/sessions/:id/model. The server must accept POST as well as + // PATCH so RemoteRuntime keeps working without a coordinated bump. + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = srv.Serve(ctx, ln) }() + + addr := "http://" + ln.Addr().String() + _ = httpDoTCP(t, ctx, http.MethodPost, addr+"/api/sessions/"+sess.ID+"/model", + api.SetSessionModelRequest{Model: "openai/gpt-4o"}) + + fake.mu.Lock() + assert.Equal(t, "openai/gpt-4o", fake.overrides["root"]) + fake.mu.Unlock() +} + +func TestAttachedServer_GetSessionModels_NotSupported(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, &fakeRuntime{}, sess) + + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + go func() { _ = srv.Serve(ctx, ln) }() + + addr := "http://" + ln.Addr().String() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", http.NoBody) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) +} From 3ed8b94c0c6c4944edb31cc13b2b85efa3a5afb0 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 14:49:22 +0200 Subject: [PATCH 03/12] fix: add backward compat note for POST /sessions/:id/model PATCH is the canonical verb for updating the agent's model. POST is also accepted because pkg/runtime Client.SetAgentModel (used by RemoteRuntime) was historically a POST; keep both so a client upgrade is not required. --- pkg/server/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/server/server.go b/pkg/server/server.go index e6858068d..653ba95fa 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -72,6 +72,10 @@ func (s *Server) registerRoutes() { group.PATCH("/sessions/:id/tokens", s.updateSessionTokens) group.PATCH("/sessions/:id/starred", s.setSessionStarred) group.GET("/sessions/:id/models", s.getSessionModels) + // PATCH is the canonical verb for updating the agent's model. POST is + // also accepted because pkg/runtime Client.SetAgentModel (used by + // RemoteRuntime) was historically a POST; keep both so a client + // upgrade is not required. group.PATCH("/sessions/:id/model", s.setSessionModel) group.POST("/sessions/:id/model", s.setSessionModel) group.POST("/sessions", s.createSession) From 80d75e5ab3b84ccf516e372f81f641d7fa08d6a3 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 14:49:50 +0200 Subject: [PATCH 04/12] fix: thread-safe model reads and persistence rollback Fix a data race in AvailableSessionModels where rs.session is read without holding sm.mux (the lock protecting reads/writes of AgentModelOverrides and CustomModelsUsed by SetSessionAgentModel). In SetSessionAgentModel, snapshot the in-memory override state before mutating the runtime and session. If sessionStore.UpdateSession fails after the runtime mutation, roll back both the in-memory session state and the runtime override so callers don't observe a runtime/store mismatch on the next request. Fixes issues identified in review and validated with go test -race. --- pkg/server/session_manager.go | 116 +++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index ec9b3e9be..439a42601 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -786,7 +786,16 @@ var ErrModelSwitchingNotSupported = errors.New("model switching not supported by // peek into the runtime registry. A session-scoped runtime is required, // so the session must have been started at least once (RunSession called) // or be attached out-of-band via AttachRuntime. +// +// Each returned ModelChoice has IsCurrent set so the picker can highlight +// the active selection without a second round-trip. When no override is +// active, the agent's configured default carries IsCurrent=true; if the +// override points at an inline provider/model not present in the agent +// config, a synthetic choice is appended (mirrors App.AvailableModels). func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID string) (string, string, []runtime.ModelChoice, error) { + sm.mux.Lock() + defer sm.mux.Unlock() + rs, ok := sm.runtimeSessions.Load(sessionID) if !ok { return "", "", nil, errors.New("session not found or not running") @@ -798,17 +807,81 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID agentName := rs.runtime.CurrentAgentName() current := "" + var customRefs []string if rs.session != nil { current = rs.session.AgentModelOverrides[agentName] + customRefs = rs.session.CustomModelsUsed + } + + choices := decorateModelChoices(rs.runtime.AvailableModels(ctx), current, customRefs) + return agentName, current, choices, nil +} + +// decorateModelChoices marks the active selection with IsCurrent and +// appends any custom (provider/model) refs from the session history that +// the runtime doesn't already expose. Mirrors the post-processing in +// App.AvailableModels so HTTP and TUI clients see the same picker state. +func decorateModelChoices(models []runtime.ModelChoice, current string, customRefs []string) []runtime.ModelChoice { + existingRefs := make(map[string]bool, len(models)) + for _, m := range models { + existingRefs[m.Ref] = true + } + + currentFound := current == "" + for i := range models { + if current != "" { + if models[i].Ref == current { + models[i].IsCurrent = true + currentFound = true + } + } else { + models[i].IsCurrent = models[i].IsDefault + } } - return agentName, current, rs.runtime.AvailableModels(ctx), nil + for _, ref := range customRefs { + if existingRefs[ref] { + continue + } + existingRefs[ref] = true + + prov, name, _ := strings.Cut(ref, "/") + isCurrent := ref == current + if isCurrent { + currentFound = true + } + models = append(models, runtime.ModelChoice{ + Name: ref, + Ref: ref, + Provider: prov, + Model: name, + IsCurrent: isCurrent, + IsCustom: true, + }) + } + + if !currentFound && strings.Contains(current, "/") { + prov, name, _ := strings.Cut(current, "/") + models = append(models, runtime.ModelChoice{ + Name: current, + Ref: current, + Provider: prov, + Model: name, + IsCurrent: true, + IsCustom: true, + }) + } + + return models } // SetSessionAgentModel applies modelRef as the model override for the // current agent of the session, persists it to the session store, and // tracks custom models for later re-selection. Pass an empty modelRef // to clear the override and revert to the agent's default model. +// +// On store-write failure the in-memory session state and the runtime +// override are rolled back so the next call observes a consistent state. func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, modelRef string) (string, string, error) { sm.mux.Lock() defer sm.mux.Unlock() @@ -823,11 +896,29 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m } agentName := rs.runtime.CurrentAgentName() + + // Snapshot current state so we can roll back if persistence fails + // after we've already mutated the runtime + in-memory session. + sess := rs.session + var ( + hadOverride bool + prevOverride string + prevCustomLen int + hadOverridesMap bool + appendedCustomUsed bool + ) + if sess != nil { + hadOverridesMap = sess.AgentModelOverrides != nil + if hadOverridesMap { + prevOverride, hadOverride = sess.AgentModelOverrides[agentName] + } + prevCustomLen = len(sess.CustomModelsUsed) + } + if err := rs.runtime.SetAgentModel(ctx, agentName, modelRef); err != nil { return "", "", err } - sess := rs.session if sess == nil { return agentName, modelRef, nil } @@ -844,10 +935,31 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m // re-select via the model picker (mirrors App.SetCurrentAgentModel). if strings.Contains(modelRef, "/") && !slices.Contains(sess.CustomModelsUsed, modelRef) { sess.CustomModelsUsed = append(sess.CustomModelsUsed, modelRef) + appendedCustomUsed = true } } if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { + // Roll back: restore in-memory map and runtime so callers don't see + // a runtime/store mismatch on the next request. + if hadOverride { + sess.AgentModelOverrides[agentName] = prevOverride + } else { + delete(sess.AgentModelOverrides, agentName) + if !hadOverridesMap && len(sess.AgentModelOverrides) == 0 { + sess.AgentModelOverrides = nil + } + } + if appendedCustomUsed { + sess.CustomModelsUsed = sess.CustomModelsUsed[:prevCustomLen] + } + rollback := prevOverride + if !hadOverride { + rollback = "" + } + if rbErr := rs.runtime.SetAgentModel(ctx, agentName, rollback); rbErr != nil { + slog.ErrorContext(ctx, "Failed to roll back runtime model override", "session_id", sessionID, "agent", agentName, "error", rbErr) + } return "", "", fmt.Errorf("failed to persist model override: %w", err) } From c0e6fb97b239d1a44a6dc569e0cb68bafbedb413 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 14:49:58 +0200 Subject: [PATCH 05/12] test: enhance model selection test coverage and fix goroutine leaks Extract a startAttachedServer helper that wires up a SessionManager + HTTP server with t.Cleanup(ln.Close) so the Serve goroutine exits cleanly when the test finishes. Replaces boilerplate in all model-switching tests and prevents goroutine leaks. Add tests for key picker scenarios: - TestAttachedServer_GetSessionModels_DefaultMarkedCurrent: default marked current when no override is active - TestAttachedServer_GetSessionModels_AppendsCustomModels: session's custom model history appears in the picker - TestSessionManager_SetSessionAgentModel_RollsBackOnStoreFailure: store failure rolls back both in-memory session and runtime state - TestDecorateModelChoices (table): corner cases for the decorate helper Update TestAttachedServer_GetSessionModels to verify the IsCurrent decoration flow end-to-end and to remove the manual IsCurrent: true from the fixture (now correctly set by the manager). --- pkg/server/session_models_test.go | 219 +++++++++++++++++++++++++----- 1 file changed, 187 insertions(+), 32 deletions(-) diff --git a/pkg/server/session_models_test.go b/pkg/server/session_models_test.go index 6f5b52123..016c18473 100644 --- a/pkg/server/session_models_test.go +++ b/pkg/server/session_models_test.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "net/http" "sync" "testing" @@ -63,6 +64,19 @@ func (m *modelSwitchingRuntime) SetAgentModel(_ context.Context, agentName, mode return nil } +// startAttachedServer wires a SessionManager + HTTP server backed by an +// in-process listener and registers a t.Cleanup that closes the listener +// (and unblocks the Serve goroutine) when the test finishes. +func startAttachedServer(t *testing.T, ctx context.Context, sm *SessionManager) string { + t.Helper() + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + go func() { _ = srv.Serve(ctx, ln) }() + return "http://" + ln.Addr().String() +} + func TestSessionManager_CreateSession_KeepsModelOverrides(t *testing.T) { t.Parallel() @@ -107,19 +121,14 @@ func TestAttachedServer_GetSessionModels(t *testing.T) { choices := []runtime.ModelChoice{ {Name: "default", Ref: "openai/gpt-4o-mini", Provider: "openai", Model: "gpt-4o-mini", IsDefault: true}, - {Name: "custom", Ref: "openai/gpt-4o", Provider: "openai", Model: "gpt-4o", IsCurrent: true}, + {Name: "custom", Ref: "openai/gpt-4o", Provider: "openai", Model: "gpt-4o"}, } fake := newModelSwitchingRuntime(choices) sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) sm.AttachRuntime(sess.ID, fake, sess) - srv := NewWithManager(sm, "") - ln, err := Listen(ctx, "127.0.0.1:0") - require.NoError(t, err) - go func() { _ = srv.Serve(ctx, ln) }() - - addr := "http://" + ln.Addr().String() + addr := startAttachedServer(t, ctx, sm) resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) var got runtime.SessionModelsResponse @@ -130,8 +139,73 @@ func TestAttachedServer_GetSessionModels(t *testing.T) { require.Len(t, got.Models, 2) assert.Equal(t, "openai/gpt-4o-mini", got.Models[0].Ref) assert.True(t, got.Models[0].IsDefault) + assert.False(t, got.Models[0].IsCurrent, "default must not be marked current when an override is active") assert.Equal(t, "openai/gpt-4o", got.Models[1].Ref) - assert.True(t, got.Models[1].IsCurrent) + assert.True(t, got.Models[1].IsCurrent, "the model matching the override must be marked current") +} + +// When no override is set, the agent's default model must be marked +// IsCurrent so the picker can highlight it without a second round trip. +func TestAttachedServer_GetSessionModels_DefaultMarkedCurrent(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + choices := []runtime.ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + } + fake := newModelSwitchingRuntime(choices) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) + + var got runtime.SessionModelsResponse + require.NoError(t, json.Unmarshal(resp, &got)) + + assert.Empty(t, got.CurrentModelRef) + require.Len(t, got.Models, 2) + assert.True(t, got.Models[0].IsCurrent, "default model must be marked current when no override is set") + assert.False(t, got.Models[1].IsCurrent) +} + +// Custom (provider/model) refs from the session history must be appended +// to the picker so a user can pick a previously used model again. +func TestAttachedServer_GetSessionModels_AppendsCustomModels(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + sess.CustomModelsUsed = []string{"openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + choices := []runtime.ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + } + fake := newModelSwitchingRuntime(choices) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) + + var got runtime.SessionModelsResponse + require.NoError(t, json.Unmarshal(resp, &got)) + + require.Len(t, got.Models, 2) + assert.Equal(t, "openai/gpt-4o-mini", got.Models[0].Ref) + assert.Equal(t, "openai/gpt-4o", got.Models[1].Ref) + assert.True(t, got.Models[1].IsCustom) } func TestAttachedServer_SetSessionModel_PersistsOverride(t *testing.T) { @@ -148,12 +222,7 @@ func TestAttachedServer_SetSessionModel_PersistsOverride(t *testing.T) { sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) sm.AttachRuntime(sess.ID, fake, sess) - srv := NewWithManager(sm, "") - ln, err := Listen(ctx, "127.0.0.1:0") - require.NoError(t, err) - go func() { _ = srv.Serve(ctx, ln) }() - - addr := "http://" + ln.Addr().String() + addr := startAttachedServer(t, ctx, sm) resp := httpDoTCP(t, ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", api.SetSessionModelRequest{Model: "anthropic/claude-sonnet-4-0"}) @@ -190,12 +259,7 @@ func TestAttachedServer_SetSessionModel_EmptyClearsOverride(t *testing.T) { sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) sm.AttachRuntime(sess.ID, fake, sess) - srv := NewWithManager(sm, "") - ln, err := Listen(ctx, "127.0.0.1:0") - require.NoError(t, err) - go func() { _ = srv.Serve(ctx, ln) }() - - addr := "http://" + ln.Addr().String() + addr := startAttachedServer(t, ctx, sm) _ = httpDoTCP(t, ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", api.SetSessionModelRequest{Model: ""}) @@ -222,12 +286,7 @@ func TestAttachedServer_SetSessionModel_PostVerbAlsoWorks(t *testing.T) { sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) sm.AttachRuntime(sess.ID, fake, sess) - srv := NewWithManager(sm, "") - ln, err := Listen(ctx, "127.0.0.1:0") - require.NoError(t, err) - go func() { _ = srv.Serve(ctx, ln) }() - - addr := "http://" + ln.Addr().String() + addr := startAttachedServer(t, ctx, sm) _ = httpDoTCP(t, ctx, http.MethodPost, addr+"/api/sessions/"+sess.ID+"/model", api.SetSessionModelRequest{Model: "openai/gpt-4o"}) @@ -247,12 +306,7 @@ func TestAttachedServer_GetSessionModels_NotSupported(t *testing.T) { sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) sm.AttachRuntime(sess.ID, &fakeRuntime{}, sess) - srv := NewWithManager(sm, "") - ln, err := Listen(ctx, "127.0.0.1:0") - require.NoError(t, err) - go func() { _ = srv.Serve(ctx, ln) }() - - addr := "http://" + ln.Addr().String() + addr := startAttachedServer(t, ctx, sm) req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", http.NoBody) require.NoError(t, err) @@ -261,3 +315,104 @@ func TestAttachedServer_GetSessionModels_NotSupported(t *testing.T) { defer resp.Body.Close() assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) } + +// failingStore wraps an in-memory store so UpdateSession can be made to +// fail on demand to exercise the rollback path of SetSessionAgentModel. +type failingStore struct { + session.Store + + mu sync.Mutex + failUpdate bool +} + +func (s *failingStore) UpdateSession(ctx context.Context, sess *session.Session) error { + s.mu.Lock() + fail := s.failUpdate + s.mu.Unlock() + if fail { + return errors.New("synthetic store failure") + } + return s.Store.UpdateSession(ctx, sess) +} + +// When the session store rejects the persistence write, the in-memory +// session and the runtime override must both be rolled back so the next +// read does not surface state that was never persisted. +func TestSessionManager_SetSessionAgentModel_RollsBackOnStoreFailure(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := &failingStore{Store: session.NewInMemorySessionStore()} + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + store.mu.Lock() + store.failUpdate = true + store.mu.Unlock() + + _, _, err := sm.SetSessionAgentModel(ctx, sess.ID, "openai/gpt-4o") + require.Error(t, err) + + // In-memory session must not contain the override. + _, exists := sess.AgentModelOverrides["root"] + assert.False(t, exists, "in-memory override must be rolled back") + assert.NotContains(t, sess.CustomModelsUsed, "openai/gpt-4o", "CustomModelsUsed must be rolled back") + + // Runtime must not have the override either. + fake.mu.Lock() + _, runtimeHas := fake.overrides["root"] + fake.mu.Unlock() + assert.False(t, runtimeHas, "runtime override must be rolled back") +} + +// decorateModelChoices is exercised through the GET handler tests above; +// this unit test pins a few important corner cases that are too tedious +// to reach over HTTP. +func TestDecorateModelChoices(t *testing.T) { + t.Parallel() + + t.Run("synthesizes choice for inline override not in list", func(t *testing.T) { + t.Parallel() + got := decorateModelChoices( + []runtime.ModelChoice{{Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}}, + "anthropic/claude-sonnet-4-0", + nil, + ) + require.Len(t, got, 2) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got[1].Ref) + assert.Equal(t, "anthropic", got[1].Provider) + assert.Equal(t, "claude-sonnet-4-0", got[1].Model) + assert.True(t, got[1].IsCurrent) + assert.True(t, got[1].IsCustom) + }) + + t.Run("does not duplicate custom ref already in list", func(t *testing.T) { + t.Parallel() + got := decorateModelChoices( + []runtime.ModelChoice{{Name: "default", Ref: "openai/gpt-4o", IsDefault: true}}, + "", + []string{"openai/gpt-4o"}, + ) + require.Len(t, got, 1) + assert.Equal(t, "openai/gpt-4o", got[0].Ref) + }) + + t.Run("non-provider override (config key) does not synthesize choice", func(t *testing.T) { + t.Parallel() + // "my_model" is a config key (no slash); when not in the runtime's + // list we should NOT fabricate a choice for it because we have no + // provider/model breakdown to display. + got := decorateModelChoices( + []runtime.ModelChoice{{Name: "default", Ref: "default", IsDefault: true}}, + "my_model", + nil, + ) + require.Len(t, got, 1) + assert.False(t, got[0].IsCurrent, "default must not be marked current when override is unknown") + }) +} From e38635f722ccc12af791d7e0b71e92f87de8cddf Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 15:19:44 +0200 Subject: [PATCH 06/12] refactor(runtime): extract DecorateModelChoices and remove duplication --- pkg/app/app.go | 71 ++------------------ pkg/runtime/model_switcher.go | 66 ++++++++++++++++++ pkg/runtime/model_switcher_test.go | 103 +++++++++++++++++++++++++++++ pkg/server/session_manager.go | 63 +----------------- pkg/server/session_models_test.go | 69 ++++++++----------- 5 files changed, 203 insertions(+), 169 deletions(-) diff --git a/pkg/app/app.go b/pkg/app/app.go index ec8aba3ae..732ccef2a 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -777,76 +777,15 @@ func (a *App) AvailableModels(ctx context.Context) []runtime.ModelChoice { if !a.runtime.SupportsModelSwitching() { return nil } - models := a.runtime.AvailableModels(ctx) - // Determine the currently active model for this agent agentName := a.runtime.CurrentAgentName() - currentModelRef := "" - if a.session != nil && a.session.AgentModelOverrides != nil { - currentModelRef = a.session.AgentModelOverrides[agentName] - } - - // Build a set of model refs already in the list - existingRefs := make(map[string]bool) - for _, m := range models { - existingRefs[m.Ref] = true - } - - // Check if current model is in the list and mark it - currentFound := currentModelRef == "" - for i := range models { - if currentModelRef != "" { - // An override is set - mark the override as current - if models[i].Ref == currentModelRef { - models[i].IsCurrent = true - currentFound = true - } - } else { - // No override - the default model is current - models[i].IsCurrent = models[i].IsDefault - } - } - - // Add custom models from the session that aren't already in the list + currentRef := "" + var customRefs []string if a.session != nil { - for _, customRef := range a.session.CustomModelsUsed { - if existingRefs[customRef] { - continue // Already in the list - } - existingRefs[customRef] = true - - providerName, modelName, _ := strings.Cut(customRef, "/") - isCurrent := customRef == currentModelRef - if isCurrent { - currentFound = true - } - models = append(models, runtime.ModelChoice{ - Name: customRef, - Ref: customRef, - Provider: providerName, - Model: modelName, - IsDefault: false, - IsCurrent: isCurrent, - IsCustom: true, - }) - } + currentRef = a.session.AgentModelOverrides[agentName] + customRefs = a.session.CustomModelsUsed } - - // If current model is a custom model not in the list, add it - if !currentFound && strings.Contains(currentModelRef, "/") { - providerName, modelName, _ := strings.Cut(currentModelRef, "/") - models = append(models, runtime.ModelChoice{ - Name: currentModelRef, - Ref: currentModelRef, - Provider: providerName, - Model: modelName, - IsDefault: false, - IsCurrent: true, - IsCustom: true, - }) - } - - return models + return runtime.DecorateModelChoices(a.runtime.AvailableModels(ctx), currentRef, customRefs) } // trackCustomModel adds a custom model to the session's history if not already present. diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 84b55b265..2d45d0f31 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -73,6 +73,72 @@ type SessionModelsResponse struct { Models []ModelChoice `json:"models"` } +// DecorateModelChoices marks the active selection with IsCurrent and +// appends any custom (provider/model) refs from the session history that +// the runtime does not already expose. It is used by every consumer that +// wants to render a model picker (the TUI App, the HTTP /sessions/:id/models +// endpoint, …) so they all agree on which entry is current and what the +// final list looks like. +// +// currentRef is the model override active for the agent ("" when none), +// and customRefs is the session's CustomModelsUsed history. +func DecorateModelChoices(models []ModelChoice, currentRef string, customRefs []string) []ModelChoice { + existingRefs := make(map[string]bool, len(models)) + for _, m := range models { + existingRefs[m.Ref] = true + } + + currentFound := currentRef == "" + for i := range models { + if currentRef != "" { + if models[i].Ref == currentRef { + models[i].IsCurrent = true + currentFound = true + } + } else { + models[i].IsCurrent = models[i].IsDefault + } + } + + for _, ref := range customRefs { + if existingRefs[ref] { + continue + } + existingRefs[ref] = true + + prov, name, _ := strings.Cut(ref, "/") + isCurrent := ref == currentRef + if isCurrent { + currentFound = true + } + models = append(models, ModelChoice{ + Name: ref, + Ref: ref, + Provider: prov, + Model: name, + IsCurrent: isCurrent, + IsCustom: true, + }) + } + + // If the override points at an inline provider/model not in the + // runtime's list nor in the session's history, fabricate a synthetic + // choice so the picker can still highlight the active selection. + if !currentFound && strings.Contains(currentRef, "/") { + prov, name, _ := strings.Cut(currentRef, "/") + models = append(models, ModelChoice{ + Name: currentRef, + Ref: currentRef, + Provider: prov, + Model: name, + IsCurrent: true, + IsCustom: true, + }) + } + + return models +} + // ModelSwitcherConfig holds the configuration needed for model switching. // This is populated by the app layer when creating the runtime. type ModelSwitcherConfig struct { diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index 935f9319d..cc899189e 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -486,3 +486,106 @@ func TestResolveModelRef_InvalidFormat(t *testing.T) { }) } } + +func TestDecorateModelChoices(t *testing.T) { + t.Parallel() + + t.Run("default marked current when no override is set", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + }, + "", + nil, + ) + require.Len(t, got, 2) + assert.True(t, got[0].IsCurrent, "the IsDefault model must be marked IsCurrent when no override is set") + assert.False(t, got[1].IsCurrent) + }) + + t.Run("override matching a known choice marks it current", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + }, + "openai/gpt-4o", + nil, + ) + require.Len(t, got, 2) + assert.False(t, got[0].IsCurrent, "default must not be marked current when an override is active") + assert.True(t, got[1].IsCurrent) + }) + + t.Run("synthesizes choice for inline override not in list", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}}, + "anthropic/claude-sonnet-4-0", + nil, + ) + require.Len(t, got, 2) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got[1].Ref) + assert.Equal(t, "anthropic", got[1].Provider) + assert.Equal(t, "claude-sonnet-4-0", got[1].Model) + assert.True(t, got[1].IsCurrent) + assert.True(t, got[1].IsCustom) + }) + + t.Run("appends custom refs from session history", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}}, + "", + []string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0"}, + ) + require.Len(t, got, 3) + assert.Equal(t, "openai/gpt-4o-mini", got[0].Ref) + assert.Equal(t, "openai/gpt-4o", got[1].Ref) + assert.True(t, got[1].IsCustom) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got[2].Ref) + assert.True(t, got[2].IsCustom) + }) + + t.Run("does not duplicate custom ref already in list", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "openai/gpt-4o", IsDefault: true}}, + "", + []string{"openai/gpt-4o"}, + ) + require.Len(t, got, 1) + assert.Equal(t, "openai/gpt-4o", got[0].Ref) + }) + + t.Run("non-provider override does not synthesize a fabricated choice", func(t *testing.T) { + t.Parallel() + // "my_model" is a config key (no slash); when not in the runtime's + // list we should NOT fabricate a choice for it because we have no + // provider/model breakdown to display. + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "default", IsDefault: true}}, + "my_model", + nil, + ) + require.Len(t, got, 1) + assert.False(t, got[0].IsCurrent, "default must not be marked current when override is unknown") + }) + + t.Run("custom ref matching the active override is marked current", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "default", IsDefault: true}}, + "openai/gpt-4o", + []string{"openai/gpt-4o"}, + ) + require.Len(t, got, 2) + assert.False(t, got[0].IsCurrent) + assert.Equal(t, "openai/gpt-4o", got[1].Ref) + assert.True(t, got[1].IsCurrent) + assert.True(t, got[1].IsCustom) + }) +} diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 439a42601..26e73576d 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -791,7 +791,8 @@ var ErrModelSwitchingNotSupported = errors.New("model switching not supported by // the active selection without a second round-trip. When no override is // active, the agent's configured default carries IsCurrent=true; if the // override points at an inline provider/model not present in the agent -// config, a synthetic choice is appended (mirrors App.AvailableModels). +// config, a synthetic choice is appended (mirrors App.AvailableModels via +// the shared runtime.DecorateModelChoices helper). func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID string) (string, string, []runtime.ModelChoice, error) { sm.mux.Lock() defer sm.mux.Unlock() @@ -813,68 +814,10 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID customRefs = rs.session.CustomModelsUsed } - choices := decorateModelChoices(rs.runtime.AvailableModels(ctx), current, customRefs) + choices := runtime.DecorateModelChoices(rs.runtime.AvailableModels(ctx), current, customRefs) return agentName, current, choices, nil } -// decorateModelChoices marks the active selection with IsCurrent and -// appends any custom (provider/model) refs from the session history that -// the runtime doesn't already expose. Mirrors the post-processing in -// App.AvailableModels so HTTP and TUI clients see the same picker state. -func decorateModelChoices(models []runtime.ModelChoice, current string, customRefs []string) []runtime.ModelChoice { - existingRefs := make(map[string]bool, len(models)) - for _, m := range models { - existingRefs[m.Ref] = true - } - - currentFound := current == "" - for i := range models { - if current != "" { - if models[i].Ref == current { - models[i].IsCurrent = true - currentFound = true - } - } else { - models[i].IsCurrent = models[i].IsDefault - } - } - - for _, ref := range customRefs { - if existingRefs[ref] { - continue - } - existingRefs[ref] = true - - prov, name, _ := strings.Cut(ref, "/") - isCurrent := ref == current - if isCurrent { - currentFound = true - } - models = append(models, runtime.ModelChoice{ - Name: ref, - Ref: ref, - Provider: prov, - Model: name, - IsCurrent: isCurrent, - IsCustom: true, - }) - } - - if !currentFound && strings.Contains(current, "/") { - prov, name, _ := strings.Cut(current, "/") - models = append(models, runtime.ModelChoice{ - Name: current, - Ref: current, - Provider: prov, - Model: name, - IsCurrent: true, - IsCustom: true, - }) - } - - return models -} - // SetSessionAgentModel applies modelRef as the model override for the // current agent of the session, persists it to the session store, and // tracks custom models for later re-selection. Pass an empty modelRef diff --git a/pkg/server/session_models_test.go b/pkg/server/session_models_test.go index 016c18473..3a2acde22 100644 --- a/pkg/server/session_models_test.go +++ b/pkg/server/session_models_test.go @@ -370,49 +370,32 @@ func TestSessionManager_SetSessionAgentModel_RollsBackOnStoreFailure(t *testing. assert.False(t, runtimeHas, "runtime override must be rolled back") } -// decorateModelChoices is exercised through the GET handler tests above; -// this unit test pins a few important corner cases that are too tedious -// to reach over HTTP. -func TestDecorateModelChoices(t *testing.T) { +// When the runtime rejects SetAgentModel, no in-memory or persisted +// state must be mutated; the error must propagate verbatim. +func TestSessionManager_SetSessionAgentModel_RuntimeFailureLeavesStateUntouched(t *testing.T) { t.Parallel() - t.Run("synthesizes choice for inline override not in list", func(t *testing.T) { - t.Parallel() - got := decorateModelChoices( - []runtime.ModelChoice{{Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}}, - "anthropic/claude-sonnet-4-0", - nil, - ) - require.Len(t, got, 2) - assert.Equal(t, "anthropic/claude-sonnet-4-0", got[1].Ref) - assert.Equal(t, "anthropic", got[1].Provider) - assert.Equal(t, "claude-sonnet-4-0", got[1].Model) - assert.True(t, got[1].IsCurrent) - assert.True(t, got[1].IsCustom) - }) - - t.Run("does not duplicate custom ref already in list", func(t *testing.T) { - t.Parallel() - got := decorateModelChoices( - []runtime.ModelChoice{{Name: "default", Ref: "openai/gpt-4o", IsDefault: true}}, - "", - []string{"openai/gpt-4o"}, - ) - require.Len(t, got, 1) - assert.Equal(t, "openai/gpt-4o", got[0].Ref) - }) - - t.Run("non-provider override (config key) does not synthesize choice", func(t *testing.T) { - t.Parallel() - // "my_model" is a config key (no slash); when not in the runtime's - // list we should NOT fabricate a choice for it because we have no - // provider/model breakdown to display. - got := decorateModelChoices( - []runtime.ModelChoice{{Name: "default", Ref: "default", IsDefault: true}}, - "my_model", - nil, - ) - require.Len(t, got, 1) - assert.False(t, got[0].IsCurrent, "default must not be marked current when override is unknown") - }) + ctx := t.Context() + store := session.NewInMemorySessionStore() + sess := session.New() + sess.AgentModelOverrides = map[string]string{"root": "openai/gpt-4o"} + sess.CustomModelsUsed = []string{"openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + fake.setErr = errors.New("runtime says no") + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + _, _, err := sm.SetSessionAgentModel(ctx, sess.ID, "anthropic/claude-sonnet-4-0") + require.Error(t, err) + + // The original override must be intact. + assert.Equal(t, "openai/gpt-4o", sess.AgentModelOverrides["root"]) + assert.Equal(t, []string{"openai/gpt-4o"}, sess.CustomModelsUsed) } + +// runtime.DecorateModelChoices is exercised end-to-end through the GET +// handler tests above; unit-level corner cases live in pkg/runtime +// (see model_switcher_test.go). From d90661b3f4884f99a44ef5596e524ec971e40912 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 15:47:46 +0200 Subject: [PATCH 07/12] fix: distinguish "session not running" from other errors Map ErrSessionNotRunning to 404 so callers can tell apart a stale session ID from a server-side problem. Introduce the sentinel error and replace four ad-hoc errors.New() calls with it: SteerSession, FollowUpSession, AvailableSessionModels, SetSessionAgentModel. Update the HTTP handlers to check for this sentinel explicitly and return 404 accordingly. Also add TestAttachedServer_ModelEndpoints_404WhenNotRunning to verify both GET /sessions/:id/models and PATCH /sessions/:id/model return 404 when no runtime is attached. --- pkg/server/server.go | 16 ++++-- pkg/server/session_manager.go | 41 ++++++++++---- pkg/server/session_models_test.go | 92 ++++++++++++++++++++++++++++++- 3 files changed, 132 insertions(+), 17 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index 653ba95fa..3bf0270ea 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -574,10 +574,14 @@ func (s *Server) getSessionModels(c echo.Context) error { agentName, current, choices, err := s.sm.AvailableSessionModels(c.Request().Context(), sessionID) if err != nil { - if errors.Is(err, ErrModelSwitchingNotSupported) { + switch { + case errors.Is(err, ErrModelSwitchingNotSupported): return echo.NewHTTPError(http.StatusUnprocessableEntity, err.Error()) + case errors.Is(err, ErrSessionNotRunning): + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + default: + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return echo.NewHTTPError(http.StatusNotFound, err.Error()) } return c.JSON(http.StatusOK, runtime.SessionModelsResponse{ @@ -600,10 +604,14 @@ func (s *Server) setSessionModel(c echo.Context) error { agentName, modelRef, err := s.sm.SetSessionAgentModel(c.Request().Context(), sessionID, req.Model) if err != nil { - if errors.Is(err, ErrModelSwitchingNotSupported) { + switch { + case errors.Is(err, ErrModelSwitchingNotSupported): return echo.NewHTTPError(http.StatusUnprocessableEntity, err.Error()) + case errors.Is(err, ErrSessionNotRunning): + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + default: + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } return c.JSON(http.StatusOK, api.SetSessionModelResponse{ diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 26e73576d..f943e7175 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -229,7 +229,8 @@ func (sm *SessionManager) CreateSession(ctx context.Context, sessionTemplate *se // Copy model-related fields from the template so callers can pin a // specific model when creating a session over the API. The runtime // will pick these up the first time it is built for the session - // (see runtimeForSession). + // (see runtimeForSession). Callers that want a model to also appear + // in the picker history should include it in CustomModelsUsed. if len(sessionTemplate.AgentModelOverrides) > 0 { sess.AgentModelOverrides = maps.Clone(sessionTemplate.AgentModelOverrides) } @@ -463,7 +464,7 @@ func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirma func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, messages []api.Message) error { rt, exists := sm.runtimeSessions.Load(sessionID) if !exists { - return errors.New("session not found or not running") + return ErrSessionNotRunning } for _, msg := range messages { @@ -488,7 +489,7 @@ func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, mess func (sm *SessionManager) FollowUpSession(_ context.Context, sessionID string, messages []api.Message) (streaming bool, err error) { rt, exists := sm.runtimeSessions.Load(sessionID) if !exists { - return false, errors.New("session not found or not running") + return false, ErrSessionNotRunning } for _, msg := range messages { @@ -656,13 +657,7 @@ func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.S // Apply any stored per-agent model overrides so that a session // resumed (or freshly created with overrides via CreateSession) uses // the requested models instead of the agent's defaults. - if len(sess.AgentModelOverrides) > 0 && run.SupportsModelSwitching() { - for agentName, modelRef := range sess.AgentModelOverrides { - if err := run.SetAgentModel(ctx, agentName, modelRef); err != nil { - slog.WarnContext(ctx, "Failed to apply stored model override", "session_id", sess.ID, "agent", agentName, "model", modelRef, "error", err) - } - } - } + applyStoredOverrides(ctx, sess.ID, run, sess.AgentModelOverrides) titleGen := sessiontitle.New(agt.Model(ctx), agt.FallbackModels()...) @@ -691,6 +686,22 @@ func (sm *SessionManager) loadTeamWithConfig(ctx context.Context, agentFilename return teamloader.LoadWithConfig(ctx, agentSource, runConfig) } +// applyStoredOverrides applies the persisted per-agent model overrides on +// the freshly created runtime. Failures are logged at WARN and otherwise +// ignored: a stored override that no longer resolves (e.g. because the +// model was removed from the agent's config) must not prevent the +// session from being resumed with the agent's default model. +func applyStoredOverrides(ctx context.Context, sessionID string, run runtime.Runtime, overrides map[string]string) { + if len(overrides) == 0 || !run.SupportsModelSwitching() { + return + } + for agentName, modelRef := range overrides { + if err := run.SetAgentModel(ctx, agentName, modelRef); err != nil { + slog.WarnContext(ctx, "Failed to apply stored model override", "session_id", sessionID, "agent", agentName, "model", modelRef, "error", err) + } + } +} + // GetAgentToolCount loads the agent's team and returns the number of // tools available to the given agent. When agentName is empty, it // resolves to the team's default agent. @@ -780,6 +791,12 @@ func (sm *SessionManager) SetSessionStarred(ctx context.Context, sessionID strin // was created without a ModelSwitcherConfig). var ErrModelSwitchingNotSupported = errors.New("model switching not supported by this runtime") +// ErrSessionNotRunning is returned by methods that require an active +// runtime for the session (i.e. RunSession must have been called or +// AttachRuntime invoked) when none is found. HTTP handlers map this to +// 404 to distinguish from other runtime errors. +var ErrSessionNotRunning = errors.New("session not found or not running") + // AvailableSessionModels returns the list of models available for the // session's current agent. The agent's name and the active model override // (if any) are returned alongside the choices so callers don't have to @@ -799,7 +816,7 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID rs, ok := sm.runtimeSessions.Load(sessionID) if !ok { - return "", "", nil, errors.New("session not found or not running") + return "", "", nil, ErrSessionNotRunning } if !rs.runtime.SupportsModelSwitching() { @@ -831,7 +848,7 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m rs, ok := sm.runtimeSessions.Load(sessionID) if !ok { - return "", "", errors.New("session not found or not running") + return "", "", ErrSessionNotRunning } if !rs.runtime.SupportsModelSwitching() { diff --git a/pkg/server/session_models_test.go b/pkg/server/session_models_test.go index 3a2acde22..6cf029575 100644 --- a/pkg/server/session_models_test.go +++ b/pkg/server/session_models_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "encoding/json" "errors" @@ -98,7 +99,8 @@ func TestSessionManager_CreateSession_KeepsModelOverrides(t *testing.T) { assert.Equal(t, "openai/gpt-4o", created.AgentModelOverrides["root"]) assert.Equal(t, "anthropic/claude-sonnet-4-0", created.AgentModelOverrides["researcher"]) - assert.Equal(t, []string{"openai/gpt-4o"}, created.CustomModelsUsed) + assert.Equal(t, []string{"openai/gpt-4o"}, created.CustomModelsUsed, + "CreateSession is a passthrough: only refs explicitly listed in CustomModelsUsed should be tracked") // Mutating the template after creation must not affect the stored session. template.AgentModelOverrides["root"] = "mutated" @@ -399,3 +401,91 @@ func TestSessionManager_SetSessionAgentModel_RuntimeFailureLeavesStateUntouched( // runtime.DecorateModelChoices is exercised end-to-end through the GET // handler tests above; unit-level corner cases live in pkg/runtime // (see model_switcher_test.go). + +// When no runtime is attached for the session, the endpoints must +// return 404 (not 400 or 500) so callers can tell apart a stale id +// from an actual server-side problem. +func TestAttachedServer_ModelEndpoints_404WhenNotRunning(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + // Note: no AttachRuntime call. + + addr := startAttachedServer(t, ctx, sm) + + t.Run("GET", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", http.NoBody) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("PATCH", func(t *testing.T) { + body := bytes.NewReader([]byte(`{"model":"openai/gpt-4o"}`)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", body) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) +} + +// applyStoredOverrides is the helper called by runtimeForSession to +// re-apply persisted overrides on a freshly-built runtime. We can't +// drive runtimeForSession with a fake runtime (it constructs a real +// LocalRuntime), so we cover the helper's contract directly here. +func TestApplyStoredOverrides(t *testing.T) { + t.Parallel() + + t.Run("no-op when no overrides", func(t *testing.T) { + t.Parallel() + fake := newModelSwitchingRuntime(nil) + applyStoredOverrides(t.Context(), "sess", fake, nil) + + fake.mu.Lock() + assert.Empty(t, fake.overrides) + fake.mu.Unlock() + }) + + t.Run("no-op when runtime does not support switching", func(t *testing.T) { + t.Parallel() + // fakeRuntime has SupportsModelSwitching == false; SetAgentModel + // must NOT be called on it (it would panic since fakeRuntime + // does not implement the method either). + applyStoredOverrides(t.Context(), "sess", &fakeRuntime{}, map[string]string{"root": "openai/gpt-4o"}) + // Reaching this point without panic is the assertion. + }) + + t.Run("applies each override on the runtime", func(t *testing.T) { + t.Parallel() + fake := newModelSwitchingRuntime(nil) + applyStoredOverrides(t.Context(), "sess", fake, map[string]string{ + "root": "openai/gpt-4o", + "researcher": "anthropic/claude-sonnet-4-0", + }) + + fake.mu.Lock() + assert.Equal(t, "openai/gpt-4o", fake.overrides["root"]) + assert.Equal(t, "anthropic/claude-sonnet-4-0", fake.overrides["researcher"]) + fake.mu.Unlock() + }) + + t.Run("runtime errors are swallowed (logged) so the session still loads", func(t *testing.T) { + t.Parallel() + fake := newModelSwitchingRuntime(nil) + fake.setErr = errors.New("model not in config anymore") + + require.NotPanics(t, func() { + applyStoredOverrides(t.Context(), "sess", fake, map[string]string{"root": "gone/model"}) + }) + }) +} From 90c7b4cfe73d55b765d8673bcd4f75139051e430 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 16:18:21 +0200 Subject: [PATCH 08/12] fix(runtime): make DecorateModelChoices defensively copy input slice --- pkg/runtime/model_switcher.go | 27 ++++++++++++++++++--------- pkg/runtime/model_switcher_test.go | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 2d45d0f31..81b6caea6 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -82,21 +82,30 @@ type SessionModelsResponse struct { // // currentRef is the model override active for the agent ("" when none), // and customRefs is the session's CustomModelsUsed history. +// +// The input slice is never mutated: callers can safely pass a slice that +// is shared with or backed by an internal cache. func DecorateModelChoices(models []ModelChoice, currentRef string, customRefs []string) []ModelChoice { - existingRefs := make(map[string]bool, len(models)) - for _, m := range models { + // Defensive copy: AvailableModels implementations may return a slice + // backed by an internal cache. Mutating its IsCurrent flag in place + // would leak picker state across sessions/agents. + result := make([]ModelChoice, len(models), len(models)+len(customRefs)+1) + copy(result, models) + + existingRefs := make(map[string]bool, len(result)) + for _, m := range result { existingRefs[m.Ref] = true } currentFound := currentRef == "" - for i := range models { + for i := range result { if currentRef != "" { - if models[i].Ref == currentRef { - models[i].IsCurrent = true + if result[i].Ref == currentRef { + result[i].IsCurrent = true currentFound = true } } else { - models[i].IsCurrent = models[i].IsDefault + result[i].IsCurrent = result[i].IsDefault } } @@ -111,7 +120,7 @@ func DecorateModelChoices(models []ModelChoice, currentRef string, customRefs [] if isCurrent { currentFound = true } - models = append(models, ModelChoice{ + result = append(result, ModelChoice{ Name: ref, Ref: ref, Provider: prov, @@ -126,7 +135,7 @@ func DecorateModelChoices(models []ModelChoice, currentRef string, customRefs [] // choice so the picker can still highlight the active selection. if !currentFound && strings.Contains(currentRef, "/") { prov, name, _ := strings.Cut(currentRef, "/") - models = append(models, ModelChoice{ + result = append(result, ModelChoice{ Name: currentRef, Ref: currentRef, Provider: prov, @@ -136,7 +145,7 @@ func DecorateModelChoices(models []ModelChoice, currentRef string, customRefs [] }) } - return models + return result } // ModelSwitcherConfig holds the configuration needed for model switching. diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index cc899189e..b2f7eaba5 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -588,4 +588,22 @@ func TestDecorateModelChoices(t *testing.T) { assert.True(t, got[1].IsCurrent) assert.True(t, got[1].IsCustom) }) + + // AvailableModels implementations may return a slice backed by an + // internal cache; mutating its IsCurrent flag in place would leak + // state across sessions. The function must therefore never modify + // the input slice or its underlying array. + t.Run("does not mutate the input slice", func(t *testing.T) { + t.Parallel() + input := []ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + } + orig := make([]ModelChoice, len(input)) + copy(orig, input) + + _ = DecorateModelChoices(input, "openai/gpt-4o", []string{"anthropic/claude-sonnet-4-0"}) + + assert.Equal(t, orig, input, "DecorateModelChoices must not modify the input slice") + }) } From baf326c888b8eb972671e44f78678be43c0eee8f Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 16:18:24 +0200 Subject: [PATCH 09/12] fix(server): map unknown errors from setSessionModel to 500, not 400 --- pkg/server/server.go | 7 +- pkg/server/session_models_test.go | 196 +++++++++++++++++++++++++++++- 2 files changed, 199 insertions(+), 4 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index 3bf0270ea..9215586c1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -610,7 +610,12 @@ func (s *Server) setSessionModel(c echo.Context) error { case errors.Is(err, ErrSessionNotRunning): return echo.NewHTTPError(http.StatusNotFound, err.Error()) default: - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + // Unknown errors come from the runtime (e.g. an inline model + // ref that fails provider creation) or from the session store + // (e.g. a write failure). Both are server-side concerns, not + // client mistakes, so map to 500. Validation of the request + // body itself is handled above by Bind which returns 400. + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } } diff --git a/pkg/server/session_models_test.go b/pkg/server/session_models_test.go index 6cf029575..53bc09f16 100644 --- a/pkg/server/session_models_test.go +++ b/pkg/server/session_models_test.go @@ -8,6 +8,7 @@ import ( "net/http" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -29,6 +30,21 @@ type modelSwitchingRuntime struct { availableModels []runtime.ModelChoice overrides map[string]string setErr error + + // availableModelsCalled fires every time AvailableModels is invoked. + // Tests can use it to coordinate with a deliberately-slow runtime + // call (see availableModelsDelay). + availableModelsCalled chan struct{} + // availableModelsDelay, when set, makes AvailableModels block on + // this channel before returning. Used by lock-contention tests to + // hold a runtime call open while another goroutine probes the + // SessionManager. + availableModelsDelay <-chan struct{} + // setAgentModelCalled fires every time SetAgentModel is invoked. + setAgentModelCalled chan struct{} + // setAgentModelDelay, when set, makes SetAgentModel block on this + // channel before returning. + setAgentModelDelay <-chan struct{} } func newModelSwitchingRuntime(models []runtime.ModelChoice) *modelSwitchingRuntime { @@ -45,17 +61,45 @@ func (m *modelSwitchingRuntime) SupportsModelSwitching() bool { return true } func (m *modelSwitchingRuntime) AvailableModels(_ context.Context) []runtime.ModelChoice { m.mu.Lock() - defer m.mu.Unlock() + delay := m.availableModelsDelay + called := m.availableModelsCalled out := make([]runtime.ModelChoice, len(m.availableModels)) copy(out, m.availableModels) + m.mu.Unlock() + + if called != nil { + select { + case called <- struct{}{}: + default: + } + } + if delay != nil { + <-delay + } return out } func (m *modelSwitchingRuntime) SetAgentModel(_ context.Context, agentName, modelRef string) error { + m.mu.Lock() + setErr := m.setErr + delay := m.setAgentModelDelay + called := m.setAgentModelCalled + m.mu.Unlock() + + if called != nil { + select { + case called <- struct{}{}: + default: + } + } + if delay != nil { + <-delay + } + m.mu.Lock() defer m.mu.Unlock() - if m.setErr != nil { - return m.setErr + if setErr != nil { + return setErr } if modelRef == "" { delete(m.overrides, agentName) @@ -398,6 +442,37 @@ func TestSessionManager_SetSessionAgentModel_RuntimeFailureLeavesStateUntouched( assert.Equal(t, []string{"openai/gpt-4o"}, sess.CustomModelsUsed) } +// Server-side errors (store-write failures, runtime errors that aren't +// the well-known sentinels) must be reported as 500, not 400. 400 is +// reserved for client-side mistakes like an invalid request body. +func TestAttachedServer_SetSessionModel_StoreFailureReturns500(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := &failingStore{Store: session.NewInMemorySessionStore()} + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + store.mu.Lock() + store.failUpdate = true + store.mu.Unlock() + + addr := startAttachedServer(t, ctx, sm) + body := bytes.NewReader([]byte(`{"model":"openai/gpt-4o"}`)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", body) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) +} + // runtime.DecorateModelChoices is exercised end-to-end through the GET // handler tests above; unit-level corner cases live in pkg/runtime // (see model_switcher_test.go). @@ -439,6 +514,121 @@ func TestAttachedServer_ModelEndpoints_404WhenNotRunning(t *testing.T) { }) } +// AvailableSessionModels must NOT hold sm.mux while the runtime's +// AvailableModels call is in progress. If it did, an unrelated session +// operation that takes sm.mux (e.g. SetSessionStarred on a different +// session) would block for the duration of the runtime call. We verify +// this by holding the runtime call open and then making sure another +// sm.mux-acquiring method completes before we release the runtime call. +func TestSessionManager_AvailableSessionModels_DoesNotHoldMuxDuringRuntimeIO(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + slowSess := session.New() + unrelatedSess := session.New() + require.NoError(t, store.AddSession(ctx, slowSess)) + require.NoError(t, store.AddSession(ctx, unrelatedSess)) + + called := make(chan struct{}, 1) + release := make(chan struct{}) + slow := newModelSwitchingRuntime(nil) + slow.availableModelsCalled = called + slow.availableModelsDelay = release + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(slowSess.ID, slow, slowSess) + sm.AttachRuntime(unrelatedSess.ID, &fakeRuntime{}, unrelatedSess) + + // Start a slow AvailableSessionModels call on the first session. + done := make(chan struct{}) + go func() { + defer close(done) + _, _, _, _ = sm.AvailableSessionModels(ctx, slowSess.ID) + }() + + // Wait until the runtime is actually inside AvailableModels (and so + // stuck on `release`). + select { + case <-called: + case <-time.After(2 * time.Second): + t.Fatal("runtime AvailableModels was never called") + } + + // While the runtime call is parked, an unrelated method that + // acquires sm.mux must complete promptly. If sm.mux were held for + // the duration of the runtime call this would deadlock. + unrelatedDone := make(chan error, 1) + go func() { + unrelatedDone <- sm.SetSessionStarred(ctx, unrelatedSess.ID, true) + }() + + select { + case err := <-unrelatedDone: + require.NoError(t, err) + case <-time.After(2 * time.Second): + close(release) + <-done + t.Fatal("sm.mux is held across runtime I/O: unrelated session op blocked") + } + + // Let the slow call finish so the test cleanup can proceed. + close(release) + <-done +} + +// SetSessionAgentModel must also avoid holding sm.mux while the runtime's +// SetAgentModel call is in progress. +func TestSessionManager_SetSessionAgentModel_DoesNotHoldMuxDuringRuntimeIO(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + slowSess := session.New() + unrelatedSess := session.New() + require.NoError(t, store.AddSession(ctx, slowSess)) + require.NoError(t, store.AddSession(ctx, unrelatedSess)) + + called := make(chan struct{}, 1) + release := make(chan struct{}) + slow := newModelSwitchingRuntime(nil) + slow.setAgentModelCalled = called + slow.setAgentModelDelay = release + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(slowSess.ID, slow, slowSess) + sm.AttachRuntime(unrelatedSess.ID, &fakeRuntime{}, unrelatedSess) + + done := make(chan struct{}) + go func() { + defer close(done) + _, _, _ = sm.SetSessionAgentModel(ctx, slowSess.ID, "openai/gpt-4o") + }() + + select { + case <-called: + case <-time.After(2 * time.Second): + t.Fatal("runtime SetAgentModel was never called") + } + + unrelatedDone := make(chan error, 1) + go func() { + unrelatedDone <- sm.SetSessionStarred(ctx, unrelatedSess.ID, true) + }() + + select { + case err := <-unrelatedDone: + require.NoError(t, err) + case <-time.After(2 * time.Second): + close(release) + <-done + t.Fatal("sm.mux is held across runtime I/O: unrelated session op blocked") + } + + close(release) + <-done +} + // applyStoredOverrides is the helper called by runtimeForSession to // re-apply persisted overrides on a freshly-built runtime. We can't // drive runtimeForSession with a fake runtime (it constructs a real From 49d1b09b7626641c13f2efb4f3f37ae98a8dc41b Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 13 May 2026 16:18:27 +0200 Subject: [PATCH 10/12] fix(server): release sm.mux during runtime I/O on model endpoints --- pkg/server/session_manager.go | 64 +++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index f943e7175..260cea08b 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -34,6 +34,13 @@ type activeRuntimes struct { titleGen *sessiontitle.Generator // Title generator (includes fallback models) streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests + + // modelSwitch serialises concurrent SetSessionAgentModel calls on the + // same session. It is held across the (potentially slow) runtime I/O + // so we never overlap a SetAgentModel + rollback pair with another + // switch on the same session, while still allowing other sessions to + // make progress. + modelSwitch sync.Mutex } // SessionManager manages sessions for HTTP and Connect-RPC servers. @@ -811,9 +818,6 @@ var ErrSessionNotRunning = errors.New("session not found or not running") // config, a synthetic choice is appended (mirrors App.AvailableModels via // the shared runtime.DecorateModelChoices helper). func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID string) (string, string, []runtime.ModelChoice, error) { - sm.mux.Lock() - defer sm.mux.Unlock() - rs, ok := sm.runtimeSessions.Load(sessionID) if !ok { return "", "", nil, ErrSessionNotRunning @@ -824,12 +828,24 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID } agentName := rs.runtime.CurrentAgentName() + + // Snapshot the override and custom-model history under sm.mux so the + // read is atomic with respect to SetSessionAgentModel writes. The + // (potentially slow) runtime.AvailableModels call must NOT happen + // under sm.mux: it can perform network I/O (provider discovery, + // models.dev catalog lookup) and would block every other session + // operation in the manager. + sm.mux.Lock() current := "" var customRefs []string if rs.session != nil { current = rs.session.AgentModelOverrides[agentName] - customRefs = rs.session.CustomModelsUsed + if n := len(rs.session.CustomModelsUsed); n > 0 { + customRefs = make([]string, n) + copy(customRefs, rs.session.CustomModelsUsed) + } } + sm.mux.Unlock() choices := runtime.DecorateModelChoices(rs.runtime.AvailableModels(ctx), current, customRefs) return agentName, current, choices, nil @@ -842,10 +858,13 @@ func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID // // On store-write failure the in-memory session state and the runtime // override are rolled back so the next call observes a consistent state. +// +// Concurrent SetSessionAgentModel calls on the same session are +// serialised via the session-scoped modelSwitch lock so the runtime, +// session and store never observe interleaved updates. The manager-wide +// sm.mux is only held briefly while reading or mutating session fields, +// never while calling into the runtime or the store. func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, modelRef string) (string, string, error) { - sm.mux.Lock() - defer sm.mux.Unlock() - rs, ok := sm.runtimeSessions.Load(sessionID) if !ok { return "", "", ErrSessionNotRunning @@ -855,26 +874,35 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m return "", "", ErrModelSwitchingNotSupported } + rs.modelSwitch.Lock() + defer rs.modelSwitch.Unlock() + agentName := rs.runtime.CurrentAgentName() + sess := rs.session // Snapshot current state so we can roll back if persistence fails // after we've already mutated the runtime + in-memory session. - sess := rs.session var ( - hadOverride bool - prevOverride string - prevCustomLen int - hadOverridesMap bool - appendedCustomUsed bool + hadOverride bool + prevOverride string + prevCustomLen int + hadOverridesMap bool ) if sess != nil { + sm.mux.Lock() hadOverridesMap = sess.AgentModelOverrides != nil if hadOverridesMap { prevOverride, hadOverride = sess.AgentModelOverrides[agentName] } prevCustomLen = len(sess.CustomModelsUsed) + sm.mux.Unlock() } + // Runtime mutation runs without sm.mux so it doesn't block other + // session operations during slow provider creation. The per-session + // modelSwitch lock above keeps SetAgentModel + UpdateSession + any + // rollback atomic with respect to other model-switch calls on this + // session. if err := rs.runtime.SetAgentModel(ctx, agentName, modelRef); err != nil { return "", "", err } @@ -883,6 +911,8 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m return agentName, modelRef, nil } + var appendedCustomUsed bool + sm.mux.Lock() if modelRef == "" { delete(sess.AgentModelOverrides, agentName) } else { @@ -898,10 +928,12 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m appendedCustomUsed = true } } + sm.mux.Unlock() if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { - // Roll back: restore in-memory map and runtime so callers don't see - // a runtime/store mismatch on the next request. + // Roll back in-memory under sm.mux first so concurrent readers + // (e.g. AvailableSessionModels) never see the half-applied state. + sm.mux.Lock() if hadOverride { sess.AgentModelOverrides[agentName] = prevOverride } else { @@ -913,6 +945,8 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m if appendedCustomUsed { sess.CustomModelsUsed = sess.CustomModelsUsed[:prevCustomLen] } + sm.mux.Unlock() + rollback := prevOverride if !hadOverride { rollback = "" From da48ca80ff650ab5674b95e8c98e901bc989bb08 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 18 May 2026 10:25:22 +0200 Subject: [PATCH 11/12] docs(api-server): document model switching endpoints Add GET/PATCH/POST /api/sessions/:id/model(s) to the endpoint table. POST is accepted for backward compatibility with RemoteRuntime's historical SetAgentModel implementation. Assisted-By: docker-agent --- docs/features/api-server/index.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/features/api-server/index.md b/docs/features/api-server/index.md index 79e3f03b0..63c2089e5 100644 --- a/docs/features/api-server/index.md +++ b/docs/features/api-server/index.md @@ -52,6 +52,9 @@ All endpoints are under the `/api` prefix. | `POST` | `/api/sessions/:id/elicitation` | Respond to an MCP tool elicitation request | | `POST` | `/api/sessions/:id/steer` | Inject messages into a running turn (pre-empts current) | | `POST` | `/api/sessions/:id/followup` | Enqueue messages to run after the current turn finishes | +| `GET` | `/api/sessions/:id/models` | List available models for the session's current agent | +| `PATCH` | `/api/sessions/:id/model` | Set or clear the agent's model override | +| `POST` | `/api/sessions/:id/model` | Set or clear the agent's model override (backward compat with RemoteRuntime) | ### Agent Execution From 6d1457c69bc75077dbae860215c5eca6044486be Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 18 May 2026 10:54:19 +0200 Subject: [PATCH 12/12] fix(server): eliminate stale-read window in SetSessionAgentModel The previous implementation mutated the in-memory session before persisting it to the store. Between the mutation and the store write, concurrent calls to AvailableSessionModels could acquire sm.mux and observe the not-yet- persisted override. If the store write then failed, the rollback would correct the in-memory state, but any concurrent reader that polled during the window would have seen an incorrect IsCurrent marker. This commit restructures SetSessionAgentModel to: 1. Clone the session 2. Apply mutations to the clone 3. Persist the clone to the store 4. Only after persistence succeeds, apply the mutations to the live session under sm.mux Now concurrent readers never observe a state that hasn't been successfully persisted. The rollback path is simplified: if the store write fails, the live session is unchanged and we only need to roll back the runtime. Addresses the concurrency consistency issue identified in the latest docker-agent review (2026-05-18). Assisted-By: docker-agent --- pkg/server/session_manager.go | 85 +++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 260cea08b..f8e60346c 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -881,11 +881,10 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m sess := rs.session // Snapshot current state so we can roll back if persistence fails - // after we've already mutated the runtime + in-memory session. + // after we've already mutated the runtime. var ( hadOverride bool prevOverride string - prevCustomLen int hadOverridesMap bool ) if sess != nil { @@ -894,7 +893,6 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m if hadOverridesMap { prevOverride, hadOverride = sess.AgentModelOverrides[agentName] } - prevCustomLen = len(sess.CustomModelsUsed) sm.mux.Unlock() } @@ -911,42 +909,56 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m return agentName, modelRef, nil } - var appendedCustomUsed bool + // Clone the session for the store write. We'll apply mutations to the + // clone, persist it, and only then update the live session. This ensures + // concurrent readers never observe a not-yet-persisted state. + updatedSess := &session.Session{ + ID: sess.ID, + Title: sess.Title, + CreatedAt: sess.CreatedAt, + WorkingDir: sess.WorkingDir, + ToolsApproved: sess.ToolsApproved, + Permissions: sess.Permissions, + MaxIterations: sess.MaxIterations, + MaxConsecutiveToolCalls: sess.MaxConsecutiveToolCalls, + MaxOldToolCallTokens: sess.MaxOldToolCallTokens, + InputTokens: sess.InputTokens, + OutputTokens: sess.OutputTokens, + Cost: sess.Cost, + Starred: sess.Starred, + } + + // Clone the maps/slices under sm.mux to avoid data races sm.mux.Lock() + if sess.AgentModelOverrides != nil { + updatedSess.AgentModelOverrides = maps.Clone(sess.AgentModelOverrides) + } + if len(sess.CustomModelsUsed) > 0 { + updatedSess.CustomModelsUsed = append([]string(nil), sess.CustomModelsUsed...) + } + sm.mux.Unlock() + + // Apply the mutations to the cloned session + var appendedCustomUsed bool if modelRef == "" { - delete(sess.AgentModelOverrides, agentName) + delete(updatedSess.AgentModelOverrides, agentName) } else { - if sess.AgentModelOverrides == nil { - sess.AgentModelOverrides = make(map[string]string) + if updatedSess.AgentModelOverrides == nil { + updatedSess.AgentModelOverrides = make(map[string]string) } - sess.AgentModelOverrides[agentName] = modelRef + updatedSess.AgentModelOverrides[agentName] = modelRef // Track inline provider/model references so they remain easy to // re-select via the model picker (mirrors App.SetCurrentAgentModel). - if strings.Contains(modelRef, "/") && !slices.Contains(sess.CustomModelsUsed, modelRef) { - sess.CustomModelsUsed = append(sess.CustomModelsUsed, modelRef) + if strings.Contains(modelRef, "/") && !slices.Contains(updatedSess.CustomModelsUsed, modelRef) { + updatedSess.CustomModelsUsed = append(updatedSess.CustomModelsUsed, modelRef) appendedCustomUsed = true } } - sm.mux.Unlock() - - if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { - // Roll back in-memory under sm.mux first so concurrent readers - // (e.g. AvailableSessionModels) never see the half-applied state. - sm.mux.Lock() - if hadOverride { - sess.AgentModelOverrides[agentName] = prevOverride - } else { - delete(sess.AgentModelOverrides, agentName) - if !hadOverridesMap && len(sess.AgentModelOverrides) == 0 { - sess.AgentModelOverrides = nil - } - } - if appendedCustomUsed { - sess.CustomModelsUsed = sess.CustomModelsUsed[:prevCustomLen] - } - sm.mux.Unlock() + // Persist the cloned session. If this fails, the live session is + // unchanged and we only need to roll back the runtime. + if err := sm.sessionStore.UpdateSession(ctx, updatedSess); err != nil { rollback := prevOverride if !hadOverride { rollback = "" @@ -957,6 +969,23 @@ func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, m return "", "", fmt.Errorf("failed to persist model override: %w", err) } + // Store write succeeded. Now apply the mutations to the live session + // under sm.mux so concurrent readers observe the change atomically. + sm.mux.Lock() + if modelRef == "" { + delete(sess.AgentModelOverrides, agentName) + } else { + if sess.AgentModelOverrides == nil { + sess.AgentModelOverrides = make(map[string]string) + } + sess.AgentModelOverrides[agentName] = modelRef + + if appendedCustomUsed { + sess.CustomModelsUsed = append(sess.CustomModelsUsed, modelRef) + } + } + sm.mux.Unlock() + slog.DebugContext(ctx, "Updated session model override", "session_id", sessionID, "agent", agentName, "model", modelRef) return agentName, modelRef, nil }