diff --git a/docs/features/api-server/index.md b/docs/features/api-server/index.md index 79e3f03b0..63c2089e5 100644 --- a/docs/features/api-server/index.md +++ b/docs/features/api-server/index.md @@ -52,6 +52,9 @@ All endpoints are under the `/api` prefix. | `POST` | `/api/sessions/:id/elicitation` | Respond to an MCP tool elicitation request | | `POST` | `/api/sessions/:id/steer` | Inject messages into a running turn (pre-empts current) | | `POST` | `/api/sessions/:id/followup` | Enqueue messages to run after the current turn finishes | +| `GET` | `/api/sessions/:id/models` | List available models for the session's current agent | +| `PATCH` | `/api/sessions/:id/model` | Set or clear the agent's model override | +| `POST` | `/api/sessions/:id/model` | Set or clear the agent's model override (backward compat with RemoteRuntime) | ### Agent Execution diff --git a/pkg/api/types.go b/pkg/api/types.go index 0bfa5cd7a..f35143169 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -270,3 +270,15 @@ type SessionStatusResponse struct { OutputTokens int64 `json:"output_tokens"` NumMessages int `json:"num_messages"` } + +// SetSessionModelRequest is the body of PATCH /api/sessions/:id/model. +// An empty Model clears the override and reverts to the agent's default. +type SetSessionModelRequest struct { + Model string `json:"model"` +} + +// SetSessionModelResponse is the response from PATCH /api/sessions/:id/model. +type SetSessionModelResponse struct { + Agent string `json:"agent"` + Model string `json:"model,omitempty"` +} diff --git a/pkg/app/app.go b/pkg/app/app.go index ec8aba3ae..732ccef2a 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -777,76 +777,15 @@ func (a *App) AvailableModels(ctx context.Context) []runtime.ModelChoice { if !a.runtime.SupportsModelSwitching() { return nil } - models := a.runtime.AvailableModels(ctx) - // Determine the currently active model for this agent agentName := a.runtime.CurrentAgentName() - currentModelRef := "" - if a.session != nil && a.session.AgentModelOverrides != nil { - currentModelRef = a.session.AgentModelOverrides[agentName] - } - - // Build a set of model refs already in the list - existingRefs := make(map[string]bool) - for _, m := range models { - existingRefs[m.Ref] = true - } - - // Check if current model is in the list and mark it - currentFound := currentModelRef == "" - for i := range models { - if currentModelRef != "" { - // An override is set - mark the override as current - if models[i].Ref == currentModelRef { - models[i].IsCurrent = true - currentFound = true - } - } else { - // No override - the default model is current - models[i].IsCurrent = models[i].IsDefault - } - } - - // Add custom models from the session that aren't already in the list + currentRef := "" + var customRefs []string if a.session != nil { - for _, customRef := range a.session.CustomModelsUsed { - if existingRefs[customRef] { - continue // Already in the list - } - existingRefs[customRef] = true - - providerName, modelName, _ := strings.Cut(customRef, "/") - isCurrent := customRef == currentModelRef - if isCurrent { - currentFound = true - } - models = append(models, runtime.ModelChoice{ - Name: customRef, - Ref: customRef, - Provider: providerName, - Model: modelName, - IsDefault: false, - IsCurrent: isCurrent, - IsCustom: true, - }) - } + currentRef = a.session.AgentModelOverrides[agentName] + customRefs = a.session.CustomModelsUsed } - - // If current model is a custom model not in the list, add it - if !currentFound && strings.Contains(currentModelRef, "/") { - providerName, modelName, _ := strings.Cut(currentModelRef, "/") - models = append(models, runtime.ModelChoice{ - Name: currentModelRef, - Ref: currentModelRef, - Provider: providerName, - Model: modelName, - IsDefault: false, - IsCurrent: true, - IsCustom: true, - }) - } - - return models + return runtime.DecorateModelChoices(a.runtime.AvailableModels(ctx), currentRef, customRefs) } // trackCustomModel adds a custom model to the session's history if not already present. diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 50231dee4..81b6caea6 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -16,49 +16,136 @@ import ( "github.com/docker/docker-agent/pkg/modelsdev" ) -// ModelChoice represents a model available for selection in the TUI picker. +// ModelChoice represents a model available for selection in the model picker. +// +// JSON tags are part of the public wire format used by +// GET /api/sessions/:id/models; renaming a tag is a breaking change. type ModelChoice struct { // Name is the display name (config key) - Name string + Name string `json:"name"` // Ref is the model reference used internally (e.g., "my_model" or "openai/gpt-4o") - Ref string + Ref string `json:"ref"` // Provider is the provider name (e.g., "openai", "anthropic") - Provider string + Provider string `json:"provider,omitempty"` // Model is the specific model name (e.g., "gpt-4o", "claude-sonnet-4-0") - Model string + Model string `json:"model,omitempty"` // IsDefault indicates this is the agent's configured default model - IsDefault bool + IsDefault bool `json:"is_default,omitempty"` // IsCurrent indicates this is the currently active model for the agent - IsCurrent bool + IsCurrent bool `json:"is_current,omitempty"` // IsCustom indicates this is a custom model from the session history (not from config) - IsCustom bool + IsCustom bool `json:"is_custom,omitempty"` // IsCatalog indicates this is a model from the models.dev catalog - IsCatalog bool + IsCatalog bool `json:"is_catalog,omitempty"` // The fields below are populated (best-effort) from the models.dev // catalog. They are optional and may all be zero/empty when no // catalog entry is found for the model. // Family is the model family (e.g., "claude", "gpt"). - Family string + Family string `json:"family,omitempty"` // InputCost is the price (in USD) per 1M input tokens. - InputCost float64 + InputCost float64 `json:"input_cost,omitempty"` // OutputCost is the price (in USD) per 1M output tokens. - OutputCost float64 + OutputCost float64 `json:"output_cost,omitempty"` // CacheReadCost is the price (in USD) per 1M cached input tokens. - CacheReadCost float64 + CacheReadCost float64 `json:"cache_read_cost,omitempty"` // CacheWriteCost is the price (in USD) per 1M cache-write tokens. - CacheWriteCost float64 + CacheWriteCost float64 `json:"cache_write_cost,omitempty"` // ContextLimit is the maximum context window size in tokens. - ContextLimit int + ContextLimit int `json:"context_limit,omitempty"` // OutputLimit is the maximum number of tokens the model can produce // in a single response. - OutputLimit int64 + OutputLimit int64 `json:"output_limit,omitempty"` // InputModalities lists the input modalities supported by the model // (e.g., "text", "image", "audio"). - InputModalities []string + InputModalities []string `json:"input_modalities,omitempty"` // OutputModalities lists the output modalities the model can produce. - OutputModalities []string + OutputModalities []string `json:"output_modalities,omitempty"` +} + +// SessionModelsResponse is the response returned by +// GET /api/sessions/:id/models. CurrentModelRef is the active override for +// the named agent (empty when the agent is using its configured default). +type SessionModelsResponse struct { + Agent string `json:"agent"` + CurrentModelRef string `json:"current_model_ref,omitempty"` + Models []ModelChoice `json:"models"` +} + +// DecorateModelChoices marks the active selection with IsCurrent and +// appends any custom (provider/model) refs from the session history that +// the runtime does not already expose. It is used by every consumer that +// wants to render a model picker (the TUI App, the HTTP /sessions/:id/models +// endpoint, …) so they all agree on which entry is current and what the +// final list looks like. +// +// currentRef is the model override active for the agent ("" when none), +// and customRefs is the session's CustomModelsUsed history. +// +// The input slice is never mutated: callers can safely pass a slice that +// is shared with or backed by an internal cache. +func DecorateModelChoices(models []ModelChoice, currentRef string, customRefs []string) []ModelChoice { + // Defensive copy: AvailableModels implementations may return a slice + // backed by an internal cache. Mutating its IsCurrent flag in place + // would leak picker state across sessions/agents. + result := make([]ModelChoice, len(models), len(models)+len(customRefs)+1) + copy(result, models) + + existingRefs := make(map[string]bool, len(result)) + for _, m := range result { + existingRefs[m.Ref] = true + } + + currentFound := currentRef == "" + for i := range result { + if currentRef != "" { + if result[i].Ref == currentRef { + result[i].IsCurrent = true + currentFound = true + } + } else { + result[i].IsCurrent = result[i].IsDefault + } + } + + for _, ref := range customRefs { + if existingRefs[ref] { + continue + } + existingRefs[ref] = true + + prov, name, _ := strings.Cut(ref, "/") + isCurrent := ref == currentRef + if isCurrent { + currentFound = true + } + result = append(result, ModelChoice{ + Name: ref, + Ref: ref, + Provider: prov, + Model: name, + IsCurrent: isCurrent, + IsCustom: true, + }) + } + + // If the override points at an inline provider/model not in the + // runtime's list nor in the session's history, fabricate a synthetic + // choice so the picker can still highlight the active selection. + if !currentFound && strings.Contains(currentRef, "/") { + prov, name, _ := strings.Cut(currentRef, "/") + result = append(result, ModelChoice{ + Name: currentRef, + Ref: currentRef, + Provider: prov, + Model: name, + IsCurrent: true, + IsCustom: true, + }) + } + + return result } // ModelSwitcherConfig holds the configuration needed for model switching. diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index 935f9319d..b2f7eaba5 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -486,3 +486,124 @@ func TestResolveModelRef_InvalidFormat(t *testing.T) { }) } } + +func TestDecorateModelChoices(t *testing.T) { + t.Parallel() + + t.Run("default marked current when no override is set", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + }, + "", + nil, + ) + require.Len(t, got, 2) + assert.True(t, got[0].IsCurrent, "the IsDefault model must be marked IsCurrent when no override is set") + assert.False(t, got[1].IsCurrent) + }) + + t.Run("override matching a known choice marks it current", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + }, + "openai/gpt-4o", + nil, + ) + require.Len(t, got, 2) + assert.False(t, got[0].IsCurrent, "default must not be marked current when an override is active") + assert.True(t, got[1].IsCurrent) + }) + + t.Run("synthesizes choice for inline override not in list", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}}, + "anthropic/claude-sonnet-4-0", + nil, + ) + require.Len(t, got, 2) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got[1].Ref) + assert.Equal(t, "anthropic", got[1].Provider) + assert.Equal(t, "claude-sonnet-4-0", got[1].Model) + assert.True(t, got[1].IsCurrent) + assert.True(t, got[1].IsCustom) + }) + + t.Run("appends custom refs from session history", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}}, + "", + []string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0"}, + ) + require.Len(t, got, 3) + assert.Equal(t, "openai/gpt-4o-mini", got[0].Ref) + assert.Equal(t, "openai/gpt-4o", got[1].Ref) + assert.True(t, got[1].IsCustom) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got[2].Ref) + assert.True(t, got[2].IsCustom) + }) + + t.Run("does not duplicate custom ref already in list", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "openai/gpt-4o", IsDefault: true}}, + "", + []string{"openai/gpt-4o"}, + ) + require.Len(t, got, 1) + assert.Equal(t, "openai/gpt-4o", got[0].Ref) + }) + + t.Run("non-provider override does not synthesize a fabricated choice", func(t *testing.T) { + t.Parallel() + // "my_model" is a config key (no slash); when not in the runtime's + // list we should NOT fabricate a choice for it because we have no + // provider/model breakdown to display. + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "default", IsDefault: true}}, + "my_model", + nil, + ) + require.Len(t, got, 1) + assert.False(t, got[0].IsCurrent, "default must not be marked current when override is unknown") + }) + + t.Run("custom ref matching the active override is marked current", func(t *testing.T) { + t.Parallel() + got := DecorateModelChoices( + []ModelChoice{{Name: "default", Ref: "default", IsDefault: true}}, + "openai/gpt-4o", + []string{"openai/gpt-4o"}, + ) + require.Len(t, got, 2) + assert.False(t, got[0].IsCurrent) + assert.Equal(t, "openai/gpt-4o", got[1].Ref) + assert.True(t, got[1].IsCurrent) + assert.True(t, got[1].IsCustom) + }) + + // AvailableModels implementations may return a slice backed by an + // internal cache; mutating its IsCurrent flag in place would leak + // state across sessions. The function must therefore never modify + // the input slice or its underlying array. + t.Run("does not mutate the input slice", func(t *testing.T) { + t.Parallel() + input := []ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + } + orig := make([]ModelChoice, len(input)) + copy(orig, input) + + _ = DecorateModelChoices(input, "openai/gpt-4o", []string{"anthropic/claude-sonnet-4-0"}) + + assert.Equal(t, orig, input, "DecorateModelChoices must not modify the input slice") + }) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index b9a182302..9215586c1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -19,6 +19,7 @@ import ( "github.com/docker/docker-agent/pkg/api" "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/echolog" + "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/upstream" ) @@ -70,6 +71,13 @@ func (s *Server) registerRoutes() { group.PATCH("/sessions/:id/title", s.updateSessionTitle) group.PATCH("/sessions/:id/tokens", s.updateSessionTokens) group.PATCH("/sessions/:id/starred", s.setSessionStarred) + group.GET("/sessions/:id/models", s.getSessionModels) + // PATCH is the canonical verb for updating the agent's model. POST is + // also accepted because pkg/runtime Client.SetAgentModel (used by + // RemoteRuntime) was historically a POST; keep both so a client + // upgrade is not required. + group.PATCH("/sessions/:id/model", s.setSessionModel) + group.POST("/sessions/:id/model", s.setSessionModel) group.POST("/sessions", s.createSession) group.DELETE("/sessions/:id", s.deleteSession) group.POST("/sessions/:id/agent/:agent", s.runAgent) @@ -557,6 +565,66 @@ func (s *Server) setSessionStarred(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"status": "updated"}) } +// getSessionModels lists the models the user can pick from for the +// session's current agent. Returns 404 if the session has no active runtime +// (it must have been started at least once or be attached out-of-band) +// or 422 if the runtime does not support model switching. +func (s *Server) getSessionModels(c echo.Context) error { + sessionID := c.Param("id") + + agentName, current, choices, err := s.sm.AvailableSessionModels(c.Request().Context(), sessionID) + if err != nil { + switch { + case errors.Is(err, ErrModelSwitchingNotSupported): + return echo.NewHTTPError(http.StatusUnprocessableEntity, err.Error()) + case errors.Is(err, ErrSessionNotRunning): + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + default: + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + } + + return c.JSON(http.StatusOK, runtime.SessionModelsResponse{ + Agent: agentName, + CurrentModelRef: current, + Models: choices, + }) +} + +// setSessionModel applies a model override on the session's current agent +// and persists it. An empty `model` clears the override and reverts the +// agent to its configured default. +func (s *Server) setSessionModel(c echo.Context) error { + sessionID := c.Param("id") + + var req api.SetSessionModelRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err)) + } + + agentName, modelRef, err := s.sm.SetSessionAgentModel(c.Request().Context(), sessionID, req.Model) + if err != nil { + switch { + case errors.Is(err, ErrModelSwitchingNotSupported): + return echo.NewHTTPError(http.StatusUnprocessableEntity, err.Error()) + case errors.Is(err, ErrSessionNotRunning): + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + default: + // Unknown errors come from the runtime (e.g. an inline model + // ref that fails provider creation) or from the session store + // (e.g. a write failure). Both are server-side concerns, not + // client mistakes, so map to 500. Validation of the request + // body itself is handled above by Bind which returns 400. + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + } + + return c.JSON(http.StatusOK, api.SetSessionModelResponse{ + Agent: agentName, + Model: modelRef, + }) +} + func (s *Server) batchDeleteSessions(c echo.Context) error { var req api.BatchDeleteSessionsRequest if err := c.Bind(&req); err != nil { diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 57dd3025a..f8e60346c 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -5,8 +5,10 @@ import ( "errors" "fmt" "log/slog" + "maps" "os" "path/filepath" + "slices" "strings" "sync" "time" @@ -32,6 +34,13 @@ type activeRuntimes struct { titleGen *sessiontitle.Generator // Title generator (includes fallback models) streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests + + // modelSwitch serialises concurrent SetSessionAgentModel calls on the + // same session. It is held across the (potentially slow) runtime I/O + // so we never overlap a SetAgentModel + rollback pair with another + // switch on the same session, while still allowing other sessions to + // make progress. + modelSwitch sync.Mutex } // SessionManager manages sessions for HTTP and Connect-RPC servers. @@ -223,6 +232,19 @@ func (sm *SessionManager) CreateSession(ctx context.Context, sessionTemplate *se } sess := session.New(opts...) + + // Copy model-related fields from the template so callers can pin a + // specific model when creating a session over the API. The runtime + // will pick these up the first time it is built for the session + // (see runtimeForSession). Callers that want a model to also appear + // in the picker history should include it in CustomModelsUsed. + if len(sessionTemplate.AgentModelOverrides) > 0 { + sess.AgentModelOverrides = maps.Clone(sessionTemplate.AgentModelOverrides) + } + if len(sessionTemplate.CustomModelsUsed) > 0 { + sess.CustomModelsUsed = append([]string(nil), sessionTemplate.CustomModelsUsed...) + } + return sess, sm.sessionStore.AddSession(ctx, sess) } @@ -449,7 +471,7 @@ func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirma func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, messages []api.Message) error { rt, exists := sm.runtimeSessions.Load(sessionID) if !exists { - return errors.New("session not found or not running") + return ErrSessionNotRunning } for _, msg := range messages { @@ -474,7 +496,7 @@ func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, mess func (sm *SessionManager) FollowUpSession(_ context.Context, sessionID string, messages []api.Message) (streaming bool, err error) { rt, exists := sm.runtimeSessions.Load(sessionID) if !exists { - return false, errors.New("session not found or not running") + return false, ErrSessionNotRunning } for _, msg := range messages { @@ -603,10 +625,11 @@ func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.S // constructor: it must not touch sm.runtimeSessions, otherwise it would // briefly publish a half-initialised activeRuntimes (e.g. without the // cancel func) that other goroutines could observe. - t, err := sm.loadTeam(ctx, agentFilename, rc) + loadResult, err := sm.loadTeamWithConfig(ctx, agentFilename, rc) if err != nil { return nil, nil, err } + t := loadResult.Team // Resolve the team's default agent when no specific agent was requested. agt, err := t.AgentOrDefault(currentAgent) @@ -618,17 +641,31 @@ func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.S sess.MaxConsecutiveToolCalls = agt.MaxConsecutiveToolCalls() sess.MaxOldToolCallTokens = agt.MaxOldToolCallTokens() + modelSwitcherCfg := &runtime.ModelSwitcherConfig{ + Models: loadResult.Models, + Providers: loadResult.Providers, + ModelsGateway: rc.ModelsGateway, + EnvProvider: rc.EnvProvider(), + AgentDefaultModels: loadResult.AgentDefaultModels, + } + opts := []runtime.Opt{ runtime.WithCurrentAgent(currentAgent), runtime.WithManagedOAuth(false), runtime.WithSessionStore(sm.sessionStore), runtime.WithTracer(otel.Tracer("cagent")), + runtime.WithModelSwitcherConfig(modelSwitcherCfg), } run, err := runtime.New(t, opts...) if err != nil { return nil, nil, err } + // Apply any stored per-agent model overrides so that a session + // resumed (or freshly created with overrides via CreateSession) uses + // the requested models instead of the agent's defaults. + applyStoredOverrides(ctx, sess.ID, run, sess.AgentModelOverrides) + titleGen := sessiontitle.New(agt.Model(ctx), agt.FallbackModels()...) slog.DebugContext(ctx, "Runtime created for session", "session_id", sess.ID) @@ -645,6 +682,33 @@ func (sm *SessionManager) loadTeam(ctx context.Context, agentFilename string, ru return teamloader.Load(ctx, agentSource, runConfig) } +// loadTeamWithConfig is like loadTeam but also returns the loaded model and +// provider configuration so the runtime can be wired for model switching. +func (sm *SessionManager) loadTeamWithConfig(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*teamloader.LoadResult, error) { + agentSource, found := sm.Sources[agentFilename] + if !found { + return nil, fmt.Errorf("agent not found: %s", agentFilename) + } + + return teamloader.LoadWithConfig(ctx, agentSource, runConfig) +} + +// applyStoredOverrides applies the persisted per-agent model overrides on +// the freshly created runtime. Failures are logged at WARN and otherwise +// ignored: a stored override that no longer resolves (e.g. because the +// model was removed from the agent's config) must not prevent the +// session from being resumed with the agent's default model. +func applyStoredOverrides(ctx context.Context, sessionID string, run runtime.Runtime, overrides map[string]string) { + if len(overrides) == 0 || !run.SupportsModelSwitching() { + return + } + for agentName, modelRef := range overrides { + if err := run.SetAgentModel(ctx, agentName, modelRef); err != nil { + slog.WarnContext(ctx, "Failed to apply stored model override", "session_id", sessionID, "agent", agentName, "model", modelRef, "error", err) + } + } +} + // GetAgentToolCount loads the agent's team and returns the number of // tools available to the given agent. When agentName is empty, it // resolves to the team's default agent. @@ -729,6 +793,203 @@ func (sm *SessionManager) SetSessionStarred(ctx context.Context, sessionID strin return sm.sessionStore.SetSessionStarred(ctx, sessionID, starred) } +// ErrModelSwitchingNotSupported is returned when the runtime backing a +// session does not support runtime model switching (e.g. when the agent +// was created without a ModelSwitcherConfig). +var ErrModelSwitchingNotSupported = errors.New("model switching not supported by this runtime") + +// ErrSessionNotRunning is returned by methods that require an active +// runtime for the session (i.e. RunSession must have been called or +// AttachRuntime invoked) when none is found. HTTP handlers map this to +// 404 to distinguish from other runtime errors. +var ErrSessionNotRunning = errors.New("session not found or not running") + +// AvailableSessionModels returns the list of models available for the +// session's current agent. The agent's name and the active model override +// (if any) are returned alongside the choices so callers don't have to +// peek into the runtime registry. A session-scoped runtime is required, +// so the session must have been started at least once (RunSession called) +// or be attached out-of-band via AttachRuntime. +// +// Each returned ModelChoice has IsCurrent set so the picker can highlight +// the active selection without a second round-trip. When no override is +// active, the agent's configured default carries IsCurrent=true; if the +// override points at an inline provider/model not present in the agent +// config, a synthetic choice is appended (mirrors App.AvailableModels via +// the shared runtime.DecorateModelChoices helper). +func (sm *SessionManager) AvailableSessionModels(ctx context.Context, sessionID string) (string, string, []runtime.ModelChoice, error) { + rs, ok := sm.runtimeSessions.Load(sessionID) + if !ok { + return "", "", nil, ErrSessionNotRunning + } + + if !rs.runtime.SupportsModelSwitching() { + return "", "", nil, ErrModelSwitchingNotSupported + } + + agentName := rs.runtime.CurrentAgentName() + + // Snapshot the override and custom-model history under sm.mux so the + // read is atomic with respect to SetSessionAgentModel writes. The + // (potentially slow) runtime.AvailableModels call must NOT happen + // under sm.mux: it can perform network I/O (provider discovery, + // models.dev catalog lookup) and would block every other session + // operation in the manager. + sm.mux.Lock() + current := "" + var customRefs []string + if rs.session != nil { + current = rs.session.AgentModelOverrides[agentName] + if n := len(rs.session.CustomModelsUsed); n > 0 { + customRefs = make([]string, n) + copy(customRefs, rs.session.CustomModelsUsed) + } + } + sm.mux.Unlock() + + choices := runtime.DecorateModelChoices(rs.runtime.AvailableModels(ctx), current, customRefs) + return agentName, current, choices, nil +} + +// SetSessionAgentModel applies modelRef as the model override for the +// current agent of the session, persists it to the session store, and +// tracks custom models for later re-selection. Pass an empty modelRef +// to clear the override and revert to the agent's default model. +// +// On store-write failure the in-memory session state and the runtime +// override are rolled back so the next call observes a consistent state. +// +// Concurrent SetSessionAgentModel calls on the same session are +// serialised via the session-scoped modelSwitch lock so the runtime, +// session and store never observe interleaved updates. The manager-wide +// sm.mux is only held briefly while reading or mutating session fields, +// never while calling into the runtime or the store. +func (sm *SessionManager) SetSessionAgentModel(ctx context.Context, sessionID, modelRef string) (string, string, error) { + rs, ok := sm.runtimeSessions.Load(sessionID) + if !ok { + return "", "", ErrSessionNotRunning + } + + if !rs.runtime.SupportsModelSwitching() { + return "", "", ErrModelSwitchingNotSupported + } + + rs.modelSwitch.Lock() + defer rs.modelSwitch.Unlock() + + agentName := rs.runtime.CurrentAgentName() + sess := rs.session + + // Snapshot current state so we can roll back if persistence fails + // after we've already mutated the runtime. + var ( + hadOverride bool + prevOverride string + hadOverridesMap bool + ) + if sess != nil { + sm.mux.Lock() + hadOverridesMap = sess.AgentModelOverrides != nil + if hadOverridesMap { + prevOverride, hadOverride = sess.AgentModelOverrides[agentName] + } + sm.mux.Unlock() + } + + // Runtime mutation runs without sm.mux so it doesn't block other + // session operations during slow provider creation. The per-session + // modelSwitch lock above keeps SetAgentModel + UpdateSession + any + // rollback atomic with respect to other model-switch calls on this + // session. + if err := rs.runtime.SetAgentModel(ctx, agentName, modelRef); err != nil { + return "", "", err + } + + if sess == nil { + return agentName, modelRef, nil + } + + // Clone the session for the store write. We'll apply mutations to the + // clone, persist it, and only then update the live session. This ensures + // concurrent readers never observe a not-yet-persisted state. + updatedSess := &session.Session{ + ID: sess.ID, + Title: sess.Title, + CreatedAt: sess.CreatedAt, + WorkingDir: sess.WorkingDir, + ToolsApproved: sess.ToolsApproved, + Permissions: sess.Permissions, + MaxIterations: sess.MaxIterations, + MaxConsecutiveToolCalls: sess.MaxConsecutiveToolCalls, + MaxOldToolCallTokens: sess.MaxOldToolCallTokens, + InputTokens: sess.InputTokens, + OutputTokens: sess.OutputTokens, + Cost: sess.Cost, + Starred: sess.Starred, + } + + // Clone the maps/slices under sm.mux to avoid data races + sm.mux.Lock() + if sess.AgentModelOverrides != nil { + updatedSess.AgentModelOverrides = maps.Clone(sess.AgentModelOverrides) + } + if len(sess.CustomModelsUsed) > 0 { + updatedSess.CustomModelsUsed = append([]string(nil), sess.CustomModelsUsed...) + } + sm.mux.Unlock() + + // Apply the mutations to the cloned session + var appendedCustomUsed bool + if modelRef == "" { + delete(updatedSess.AgentModelOverrides, agentName) + } else { + if updatedSess.AgentModelOverrides == nil { + updatedSess.AgentModelOverrides = make(map[string]string) + } + updatedSess.AgentModelOverrides[agentName] = modelRef + + // Track inline provider/model references so they remain easy to + // re-select via the model picker (mirrors App.SetCurrentAgentModel). + if strings.Contains(modelRef, "/") && !slices.Contains(updatedSess.CustomModelsUsed, modelRef) { + updatedSess.CustomModelsUsed = append(updatedSess.CustomModelsUsed, modelRef) + appendedCustomUsed = true + } + } + + // Persist the cloned session. If this fails, the live session is + // unchanged and we only need to roll back the runtime. + if err := sm.sessionStore.UpdateSession(ctx, updatedSess); err != nil { + rollback := prevOverride + if !hadOverride { + rollback = "" + } + if rbErr := rs.runtime.SetAgentModel(ctx, agentName, rollback); rbErr != nil { + slog.ErrorContext(ctx, "Failed to roll back runtime model override", "session_id", sessionID, "agent", agentName, "error", rbErr) + } + return "", "", fmt.Errorf("failed to persist model override: %w", err) + } + + // Store write succeeded. Now apply the mutations to the live session + // under sm.mux so concurrent readers observe the change atomically. + sm.mux.Lock() + if modelRef == "" { + delete(sess.AgentModelOverrides, agentName) + } else { + if sess.AgentModelOverrides == nil { + sess.AgentModelOverrides = make(map[string]string) + } + sess.AgentModelOverrides[agentName] = modelRef + + if appendedCustomUsed { + sess.CustomModelsUsed = append(sess.CustomModelsUsed, modelRef) + } + } + sm.mux.Unlock() + + slog.DebugContext(ctx, "Updated session model override", "session_id", sessionID, "agent", agentName, "model", modelRef) + return agentName, modelRef, nil +} + // BatchDeleteSessions deletes multiple sessions in a single operation. func (sm *SessionManager) BatchDeleteSessions(ctx context.Context, sessionIDs []string) (int, []string) { sm.mux.Lock() diff --git a/pkg/server/session_manager_test.go b/pkg/server/session_manager_test.go index f30526afd..b5ad8e953 100644 --- a/pkg/server/session_manager_test.go +++ b/pkg/server/session_manager_test.go @@ -58,6 +58,10 @@ func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAc func (f *fakeRuntime) CurrentAgentName() string { return "root" } +// SupportsModelSwitching reports false by default. Tests that exercise +// the /models endpoints embed fakeRuntime and override this. +func (f *fakeRuntime) SupportsModelSwitching() bool { return false } + func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager { t.Helper() diff --git a/pkg/server/session_models_test.go b/pkg/server/session_models_test.go new file mode 100644 index 000000000..53bc09f16 --- /dev/null +++ b/pkg/server/session_models_test.go @@ -0,0 +1,681 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/api" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" +) + +// modelSwitchingRuntime is a fakeRuntime variant that supports model +// switching so the /models and /model endpoints can be exercised +// without spinning up a real LocalRuntime. +type modelSwitchingRuntime struct { + fakeRuntime + + mu sync.Mutex + currentAgent string + availableModels []runtime.ModelChoice + overrides map[string]string + setErr error + + // availableModelsCalled fires every time AvailableModels is invoked. + // Tests can use it to coordinate with a deliberately-slow runtime + // call (see availableModelsDelay). + availableModelsCalled chan struct{} + // availableModelsDelay, when set, makes AvailableModels block on + // this channel before returning. Used by lock-contention tests to + // hold a runtime call open while another goroutine probes the + // SessionManager. + availableModelsDelay <-chan struct{} + // setAgentModelCalled fires every time SetAgentModel is invoked. + setAgentModelCalled chan struct{} + // setAgentModelDelay, when set, makes SetAgentModel block on this + // channel before returning. + setAgentModelDelay <-chan struct{} +} + +func newModelSwitchingRuntime(models []runtime.ModelChoice) *modelSwitchingRuntime { + return &modelSwitchingRuntime{ + currentAgent: "root", + availableModels: models, + overrides: make(map[string]string), + } +} + +func (m *modelSwitchingRuntime) CurrentAgentName() string { return m.currentAgent } + +func (m *modelSwitchingRuntime) SupportsModelSwitching() bool { return true } + +func (m *modelSwitchingRuntime) AvailableModels(_ context.Context) []runtime.ModelChoice { + m.mu.Lock() + delay := m.availableModelsDelay + called := m.availableModelsCalled + out := make([]runtime.ModelChoice, len(m.availableModels)) + copy(out, m.availableModels) + m.mu.Unlock() + + if called != nil { + select { + case called <- struct{}{}: + default: + } + } + if delay != nil { + <-delay + } + return out +} + +func (m *modelSwitchingRuntime) SetAgentModel(_ context.Context, agentName, modelRef string) error { + m.mu.Lock() + setErr := m.setErr + delay := m.setAgentModelDelay + called := m.setAgentModelCalled + m.mu.Unlock() + + if called != nil { + select { + case called <- struct{}{}: + default: + } + } + if delay != nil { + <-delay + } + + m.mu.Lock() + defer m.mu.Unlock() + if setErr != nil { + return setErr + } + if modelRef == "" { + delete(m.overrides, agentName) + return nil + } + m.overrides[agentName] = modelRef + return nil +} + +// startAttachedServer wires a SessionManager + HTTP server backed by an +// in-process listener and registers a t.Cleanup that closes the listener +// (and unblocks the Serve goroutine) when the test finishes. +func startAttachedServer(t *testing.T, ctx context.Context, sm *SessionManager) string { + t.Helper() + srv := NewWithManager(sm, "") + ln, err := Listen(ctx, "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { _ = ln.Close() }) + go func() { _ = srv.Serve(ctx, ln) }() + return "http://" + ln.Addr().String() +} + +func TestSessionManager_CreateSession_KeepsModelOverrides(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + + template := &session.Session{ + AgentModelOverrides: map[string]string{ + "root": "openai/gpt-4o", + "researcher": "anthropic/claude-sonnet-4-0", + }, + CustomModelsUsed: []string{"openai/gpt-4o"}, + } + + created, err := sm.CreateSession(ctx, template) + require.NoError(t, err) + require.NotEmpty(t, created.ID) + + assert.Equal(t, "openai/gpt-4o", created.AgentModelOverrides["root"]) + assert.Equal(t, "anthropic/claude-sonnet-4-0", created.AgentModelOverrides["researcher"]) + assert.Equal(t, []string{"openai/gpt-4o"}, created.CustomModelsUsed, + "CreateSession is a passthrough: only refs explicitly listed in CustomModelsUsed should be tracked") + + // Mutating the template after creation must not affect the stored session. + template.AgentModelOverrides["root"] = "mutated" + assert.Equal(t, "openai/gpt-4o", created.AgentModelOverrides["root"]) + + stored, err := store.GetSession(ctx, created.ID) + require.NoError(t, err) + assert.Equal(t, "openai/gpt-4o", stored.AgentModelOverrides["root"]) +} + +func TestAttachedServer_GetSessionModels(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + sess.AgentModelOverrides = map[string]string{"root": "openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + choices := []runtime.ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", Provider: "openai", Model: "gpt-4o-mini", IsDefault: true}, + {Name: "custom", Ref: "openai/gpt-4o", Provider: "openai", Model: "gpt-4o"}, + } + fake := newModelSwitchingRuntime(choices) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) + + var got runtime.SessionModelsResponse + require.NoError(t, json.Unmarshal(resp, &got)) + + assert.Equal(t, "root", got.Agent) + assert.Equal(t, "openai/gpt-4o", got.CurrentModelRef) + require.Len(t, got.Models, 2) + assert.Equal(t, "openai/gpt-4o-mini", got.Models[0].Ref) + assert.True(t, got.Models[0].IsDefault) + assert.False(t, got.Models[0].IsCurrent, "default must not be marked current when an override is active") + assert.Equal(t, "openai/gpt-4o", got.Models[1].Ref) + assert.True(t, got.Models[1].IsCurrent, "the model matching the override must be marked current") +} + +// When no override is set, the agent's default model must be marked +// IsCurrent so the picker can highlight it without a second round trip. +func TestAttachedServer_GetSessionModels_DefaultMarkedCurrent(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + choices := []runtime.ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + {Name: "other", Ref: "openai/gpt-4o"}, + } + fake := newModelSwitchingRuntime(choices) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) + + var got runtime.SessionModelsResponse + require.NoError(t, json.Unmarshal(resp, &got)) + + assert.Empty(t, got.CurrentModelRef) + require.Len(t, got.Models, 2) + assert.True(t, got.Models[0].IsCurrent, "default model must be marked current when no override is set") + assert.False(t, got.Models[1].IsCurrent) +} + +// Custom (provider/model) refs from the session history must be appended +// to the picker so a user can pick a previously used model again. +func TestAttachedServer_GetSessionModels_AppendsCustomModels(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + sess.CustomModelsUsed = []string{"openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + choices := []runtime.ModelChoice{ + {Name: "default", Ref: "openai/gpt-4o-mini", IsDefault: true}, + } + fake := newModelSwitchingRuntime(choices) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + resp := httpDoTCP(t, ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", nil) + + var got runtime.SessionModelsResponse + require.NoError(t, json.Unmarshal(resp, &got)) + + require.Len(t, got.Models, 2) + assert.Equal(t, "openai/gpt-4o-mini", got.Models[0].Ref) + assert.Equal(t, "openai/gpt-4o", got.Models[1].Ref) + assert.True(t, got.Models[1].IsCustom) +} + +func TestAttachedServer_SetSessionModel_PersistsOverride(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + resp := httpDoTCP(t, ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", + api.SetSessionModelRequest{Model: "anthropic/claude-sonnet-4-0"}) + + var got api.SetSessionModelResponse + require.NoError(t, json.Unmarshal(resp, &got)) + assert.Equal(t, "root", got.Agent) + assert.Equal(t, "anthropic/claude-sonnet-4-0", got.Model) + + // The runtime must have received the override. + fake.mu.Lock() + assert.Equal(t, "anthropic/claude-sonnet-4-0", fake.overrides["root"]) + fake.mu.Unlock() + + // The session in the store must reflect the override and track the + // custom model for future picks. + stored, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, "anthropic/claude-sonnet-4-0", stored.AgentModelOverrides["root"]) + assert.Contains(t, stored.CustomModelsUsed, "anthropic/claude-sonnet-4-0") +} + +func TestAttachedServer_SetSessionModel_EmptyClearsOverride(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + sess.AgentModelOverrides = map[string]string{"root": "openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + _ = httpDoTCP(t, ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", + api.SetSessionModelRequest{Model: ""}) + + stored, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + _, exists := stored.AgentModelOverrides["root"] + assert.False(t, exists, "override should be cleared") +} + +func TestAttachedServer_SetSessionModel_PostVerbAlsoWorks(t *testing.T) { + // The pre-existing pkg/runtime Client.SetAgentModel POSTs to + // /api/sessions/:id/model. The server must accept POST as well as + // PATCH so RemoteRuntime keeps working without a coordinated bump. + t.Parallel() + + ctx := t.Context() + + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + addr := startAttachedServer(t, ctx, sm) + _ = httpDoTCP(t, ctx, http.MethodPost, addr+"/api/sessions/"+sess.ID+"/model", + api.SetSessionModelRequest{Model: "openai/gpt-4o"}) + + fake.mu.Lock() + assert.Equal(t, "openai/gpt-4o", fake.overrides["root"]) + fake.mu.Unlock() +} + +func TestAttachedServer_GetSessionModels_NotSupported(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, &fakeRuntime{}, sess) + + addr := startAttachedServer(t, ctx, sm) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", http.NoBody) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) +} + +// failingStore wraps an in-memory store so UpdateSession can be made to +// fail on demand to exercise the rollback path of SetSessionAgentModel. +type failingStore struct { + session.Store + + mu sync.Mutex + failUpdate bool +} + +func (s *failingStore) UpdateSession(ctx context.Context, sess *session.Session) error { + s.mu.Lock() + fail := s.failUpdate + s.mu.Unlock() + if fail { + return errors.New("synthetic store failure") + } + return s.Store.UpdateSession(ctx, sess) +} + +// When the session store rejects the persistence write, the in-memory +// session and the runtime override must both be rolled back so the next +// read does not surface state that was never persisted. +func TestSessionManager_SetSessionAgentModel_RollsBackOnStoreFailure(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := &failingStore{Store: session.NewInMemorySessionStore()} + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + store.mu.Lock() + store.failUpdate = true + store.mu.Unlock() + + _, _, err := sm.SetSessionAgentModel(ctx, sess.ID, "openai/gpt-4o") + require.Error(t, err) + + // In-memory session must not contain the override. + _, exists := sess.AgentModelOverrides["root"] + assert.False(t, exists, "in-memory override must be rolled back") + assert.NotContains(t, sess.CustomModelsUsed, "openai/gpt-4o", "CustomModelsUsed must be rolled back") + + // Runtime must not have the override either. + fake.mu.Lock() + _, runtimeHas := fake.overrides["root"] + fake.mu.Unlock() + assert.False(t, runtimeHas, "runtime override must be rolled back") +} + +// When the runtime rejects SetAgentModel, no in-memory or persisted +// state must be mutated; the error must propagate verbatim. +func TestSessionManager_SetSessionAgentModel_RuntimeFailureLeavesStateUntouched(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sess := session.New() + sess.AgentModelOverrides = map[string]string{"root": "openai/gpt-4o"} + sess.CustomModelsUsed = []string{"openai/gpt-4o"} + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + fake.setErr = errors.New("runtime says no") + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + _, _, err := sm.SetSessionAgentModel(ctx, sess.ID, "anthropic/claude-sonnet-4-0") + require.Error(t, err) + + // The original override must be intact. + assert.Equal(t, "openai/gpt-4o", sess.AgentModelOverrides["root"]) + assert.Equal(t, []string{"openai/gpt-4o"}, sess.CustomModelsUsed) +} + +// Server-side errors (store-write failures, runtime errors that aren't +// the well-known sentinels) must be reported as 500, not 400. 400 is +// reserved for client-side mistakes like an invalid request body. +func TestAttachedServer_SetSessionModel_StoreFailureReturns500(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := &failingStore{Store: session.NewInMemorySessionStore()} + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + fake := newModelSwitchingRuntime(nil) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(sess.ID, fake, sess) + + store.mu.Lock() + store.failUpdate = true + store.mu.Unlock() + + addr := startAttachedServer(t, ctx, sm) + body := bytes.NewReader([]byte(`{"model":"openai/gpt-4o"}`)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", body) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) +} + +// runtime.DecorateModelChoices is exercised end-to-end through the GET +// handler tests above; unit-level corner cases live in pkg/runtime +// (see model_switcher_test.go). + +// When no runtime is attached for the session, the endpoints must +// return 404 (not 400 or 500) so callers can tell apart a stale id +// from an actual server-side problem. +func TestAttachedServer_ModelEndpoints_404WhenNotRunning(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + sess := session.New() + require.NoError(t, store.AddSession(ctx, sess)) + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + // Note: no AttachRuntime call. + + addr := startAttachedServer(t, ctx, sm) + + t.Run("GET", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr+"/api/sessions/"+sess.ID+"/models", http.NoBody) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("PATCH", func(t *testing.T) { + body := bytes.NewReader([]byte(`{"model":"openai/gpt-4o"}`)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, addr+"/api/sessions/"+sess.ID+"/model", body) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) +} + +// AvailableSessionModels must NOT hold sm.mux while the runtime's +// AvailableModels call is in progress. If it did, an unrelated session +// operation that takes sm.mux (e.g. SetSessionStarred on a different +// session) would block for the duration of the runtime call. We verify +// this by holding the runtime call open and then making sure another +// sm.mux-acquiring method completes before we release the runtime call. +func TestSessionManager_AvailableSessionModels_DoesNotHoldMuxDuringRuntimeIO(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + slowSess := session.New() + unrelatedSess := session.New() + require.NoError(t, store.AddSession(ctx, slowSess)) + require.NoError(t, store.AddSession(ctx, unrelatedSess)) + + called := make(chan struct{}, 1) + release := make(chan struct{}) + slow := newModelSwitchingRuntime(nil) + slow.availableModelsCalled = called + slow.availableModelsDelay = release + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(slowSess.ID, slow, slowSess) + sm.AttachRuntime(unrelatedSess.ID, &fakeRuntime{}, unrelatedSess) + + // Start a slow AvailableSessionModels call on the first session. + done := make(chan struct{}) + go func() { + defer close(done) + _, _, _, _ = sm.AvailableSessionModels(ctx, slowSess.ID) + }() + + // Wait until the runtime is actually inside AvailableModels (and so + // stuck on `release`). + select { + case <-called: + case <-time.After(2 * time.Second): + t.Fatal("runtime AvailableModels was never called") + } + + // While the runtime call is parked, an unrelated method that + // acquires sm.mux must complete promptly. If sm.mux were held for + // the duration of the runtime call this would deadlock. + unrelatedDone := make(chan error, 1) + go func() { + unrelatedDone <- sm.SetSessionStarred(ctx, unrelatedSess.ID, true) + }() + + select { + case err := <-unrelatedDone: + require.NoError(t, err) + case <-time.After(2 * time.Second): + close(release) + <-done + t.Fatal("sm.mux is held across runtime I/O: unrelated session op blocked") + } + + // Let the slow call finish so the test cleanup can proceed. + close(release) + <-done +} + +// SetSessionAgentModel must also avoid holding sm.mux while the runtime's +// SetAgentModel call is in progress. +func TestSessionManager_SetSessionAgentModel_DoesNotHoldMuxDuringRuntimeIO(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + slowSess := session.New() + unrelatedSess := session.New() + require.NoError(t, store.AddSession(ctx, slowSess)) + require.NoError(t, store.AddSession(ctx, unrelatedSess)) + + called := make(chan struct{}, 1) + release := make(chan struct{}) + slow := newModelSwitchingRuntime(nil) + slow.setAgentModelCalled = called + slow.setAgentModelDelay = release + + sm := NewSessionManager(ctx, config.Sources{}, store, 0, &config.RuntimeConfig{}) + sm.AttachRuntime(slowSess.ID, slow, slowSess) + sm.AttachRuntime(unrelatedSess.ID, &fakeRuntime{}, unrelatedSess) + + done := make(chan struct{}) + go func() { + defer close(done) + _, _, _ = sm.SetSessionAgentModel(ctx, slowSess.ID, "openai/gpt-4o") + }() + + select { + case <-called: + case <-time.After(2 * time.Second): + t.Fatal("runtime SetAgentModel was never called") + } + + unrelatedDone := make(chan error, 1) + go func() { + unrelatedDone <- sm.SetSessionStarred(ctx, unrelatedSess.ID, true) + }() + + select { + case err := <-unrelatedDone: + require.NoError(t, err) + case <-time.After(2 * time.Second): + close(release) + <-done + t.Fatal("sm.mux is held across runtime I/O: unrelated session op blocked") + } + + close(release) + <-done +} + +// applyStoredOverrides is the helper called by runtimeForSession to +// re-apply persisted overrides on a freshly-built runtime. We can't +// drive runtimeForSession with a fake runtime (it constructs a real +// LocalRuntime), so we cover the helper's contract directly here. +func TestApplyStoredOverrides(t *testing.T) { + t.Parallel() + + t.Run("no-op when no overrides", func(t *testing.T) { + t.Parallel() + fake := newModelSwitchingRuntime(nil) + applyStoredOverrides(t.Context(), "sess", fake, nil) + + fake.mu.Lock() + assert.Empty(t, fake.overrides) + fake.mu.Unlock() + }) + + t.Run("no-op when runtime does not support switching", func(t *testing.T) { + t.Parallel() + // fakeRuntime has SupportsModelSwitching == false; SetAgentModel + // must NOT be called on it (it would panic since fakeRuntime + // does not implement the method either). + applyStoredOverrides(t.Context(), "sess", &fakeRuntime{}, map[string]string{"root": "openai/gpt-4o"}) + // Reaching this point without panic is the assertion. + }) + + t.Run("applies each override on the runtime", func(t *testing.T) { + t.Parallel() + fake := newModelSwitchingRuntime(nil) + applyStoredOverrides(t.Context(), "sess", fake, map[string]string{ + "root": "openai/gpt-4o", + "researcher": "anthropic/claude-sonnet-4-0", + }) + + fake.mu.Lock() + assert.Equal(t, "openai/gpt-4o", fake.overrides["root"]) + assert.Equal(t, "anthropic/claude-sonnet-4-0", fake.overrides["researcher"]) + fake.mu.Unlock() + }) + + t.Run("runtime errors are swallowed (logged) so the session still loads", func(t *testing.T) { + t.Parallel() + fake := newModelSwitchingRuntime(nil) + fake.setErr = errors.New("model not in config anymore") + + require.NotPanics(t, func() { + applyStoredOverrides(t.Context(), "sess", fake, map[string]string{"root": "gone/model"}) + }) + }) +}