Skip to content

Commit 32b86a0

Browse files
committed
feat(gateway): 切换 Provider/Model 时同步更新所有工作区会话元数据
在 SelectProviderModel / CreateProvider 调用链中,选择完成后同步 广播到当前所有已加载工作区,将其会话 Head 中的 Provider/Model 元数据更新为新值,避免非管理端口对应工作区的会话滞留旧值, 导致后续 listModels 解析到过期 provider/model。
1 parent ae49927 commit 32b86a0

3 files changed

Lines changed: 226 additions & 0 deletions

File tree

internal/cli/gateway_runtime_bridge.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,9 @@ func (b *gatewayRuntimePortBridge) SelectProviderModel(ctx context.Context, inpu
10401040
return gateway.ProviderSelectionResult{}, err
10411041
}
10421042
}
1043+
if err := b.SyncSessionsProviderModel(ctx, selection.ProviderID, selection.ModelID); err != nil {
1044+
return gateway.ProviderSelectionResult{}, err
1045+
}
10431046
return gateway.ProviderSelectionResult{ProviderID: selection.ProviderID, ModelID: selection.ModelID}, nil
10441047
}
10451048

@@ -1950,6 +1953,52 @@ func (b *gatewayRuntimePortBridge) loadStoredSession(ctx context.Context, sessio
19501953
return loader.LoadSession(ctx, strings.TrimSpace(sessionID))
19511954
}
19521955

1956+
// SyncSessionsProviderModel 将当前工作区已列出的会话统一切换到新的 provider/model,避免全局切换后会话元数据继续滞留旧值。
1957+
func (b *gatewayRuntimePortBridge) SyncSessionsProviderModel(
1958+
ctx context.Context,
1959+
providerID string,
1960+
modelID string,
1961+
) error {
1962+
if b == nil || b.sessionStore == nil || b.runtime == nil {
1963+
return nil
1964+
}
1965+
providerID = strings.TrimSpace(providerID)
1966+
modelID = strings.TrimSpace(modelID)
1967+
if providerID == "" || modelID == "" {
1968+
return nil
1969+
}
1970+
1971+
summaries, err := b.runtime.ListSessions(ctx)
1972+
if err != nil {
1973+
return err
1974+
}
1975+
for _, summary := range summaries {
1976+
sessionID := strings.TrimSpace(summary.ID)
1977+
if sessionID == "" {
1978+
continue
1979+
}
1980+
session, loadErr := b.loadStoredSession(ctx, sessionID)
1981+
if loadErr != nil {
1982+
if errors.Is(loadErr, agentsession.ErrSessionNotFound) {
1983+
continue
1984+
}
1985+
return loadErr
1986+
}
1987+
head := session.HeadSnapshot()
1988+
head.Provider = providerID
1989+
head.Model = modelID
1990+
if updateErr := b.sessionStore.UpdateSessionState(ctx, agentsession.UpdateSessionStateInput{
1991+
SessionID: session.ID,
1992+
Title: session.Title,
1993+
UpdatedAt: time.Now().UTC(),
1994+
Head: head,
1995+
}); updateErr != nil {
1996+
return updateErr
1997+
}
1998+
}
1999+
return nil
2000+
}
2001+
19532002
// resolveSafeListFilesPath 将前端传入的相对路径限制在根目录内。
19542003
func resolveSafeListFilesPath(root string, rawPath string) (string, string, error) {
19552004
rootAbs, err := filepath.Abs(filepath.Clean(root))

internal/cli/gateway_runtime_bridge_test.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,6 +2211,161 @@ func TestGatewayRuntimePortBridgeSelectProviderModelSelectError(t *testing.T) {
22112211
}
22122212
}
22132213

2214+
func TestGatewayRuntimePortBridgeSelectProviderModelSyncsWorkspaceSessions(t *testing.T) {
2215+
updated := make([]agentsession.UpdateSessionStateInput, 0)
2216+
store := &bridgeSessionStoreWithLoader{
2217+
bridgeSessionStoreStub: bridgeSessionStoreStub{
2218+
updateFn: func(_ context.Context, input agentsession.UpdateSessionStateInput) error {
2219+
updated = append(updated, input)
2220+
return nil
2221+
},
2222+
},
2223+
session: agentsession.Session{
2224+
ID: "session-1",
2225+
Title: "Session 1",
2226+
Provider: "openai",
2227+
Model: "gpt-4.1",
2228+
},
2229+
}
2230+
ps := &providerSelectionStub{
2231+
selectRes: configstate.Selection{ProviderID: "gemini", ModelID: "gemini-2.5-pro"},
2232+
}
2233+
stub := &runtimeStub{
2234+
eventsCh: make(chan agentruntime.RuntimeEvent, 1),
2235+
sessionList: []agentsession.Summary{
2236+
{ID: "session-1", Title: "Session 1"},
2237+
{ID: "session-2", Title: "Session 2"},
2238+
},
2239+
}
2240+
bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store, nil, ps)
2241+
defer bridge.Close()
2242+
2243+
result, err := bridge.SelectProviderModel(context.Background(), gateway.SelectProviderModelInput{
2244+
SubjectID: testBridgeSubjectID,
2245+
ProviderID: "gemini",
2246+
})
2247+
if err != nil {
2248+
t.Fatalf("SelectProviderModel() error = %v", err)
2249+
}
2250+
if result.ProviderID != "gemini" || result.ModelID != "gemini-2.5-pro" {
2251+
t.Fatalf("result = %+v, want gemini/gemini-2.5-pro", result)
2252+
}
2253+
if len(updated) != 2 {
2254+
t.Fatalf("updated len = %d, want 2", len(updated))
2255+
}
2256+
for _, input := range updated {
2257+
if input.Head.Provider != "gemini" || input.Head.Model != "gemini-2.5-pro" {
2258+
t.Fatalf("updated head = %+v, want gemini/gemini-2.5-pro", input.Head)
2259+
}
2260+
}
2261+
}
2262+
2263+
func TestGatewayRuntimePortBridgeSelectProviderModelWithExplicitModelSyncsWorkspaceSessions(t *testing.T) {
2264+
updated := make([]agentsession.UpdateSessionStateInput, 0)
2265+
store := &bridgeSessionStoreWithLoader{
2266+
bridgeSessionStoreStub: bridgeSessionStoreStub{
2267+
updateFn: func(_ context.Context, input agentsession.UpdateSessionStateInput) error {
2268+
updated = append(updated, input)
2269+
return nil
2270+
},
2271+
},
2272+
session: agentsession.Session{
2273+
ID: "session-1",
2274+
Title: "Session 1",
2275+
Provider: "openai",
2276+
Model: "gpt-4.1",
2277+
},
2278+
}
2279+
ps := &providerSelectionStub{
2280+
selectRes: configstate.Selection{ProviderID: "openai", ModelID: "gpt-4.1"},
2281+
setModelRes: configstate.Selection{ProviderID: "openai", ModelID: "gpt-4o"},
2282+
}
2283+
stub := &runtimeStub{
2284+
eventsCh: make(chan agentruntime.RuntimeEvent, 1),
2285+
sessionList: []agentsession.Summary{
2286+
{ID: "session-1", Title: "Session 1"},
2287+
},
2288+
}
2289+
bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store, nil, ps)
2290+
defer bridge.Close()
2291+
2292+
result, err := bridge.SelectProviderModel(context.Background(), gateway.SelectProviderModelInput{
2293+
SubjectID: testBridgeSubjectID,
2294+
ProviderID: "openai",
2295+
ModelID: "gpt-4o",
2296+
})
2297+
if err != nil {
2298+
t.Fatalf("SelectProviderModel() error = %v", err)
2299+
}
2300+
if result.ProviderID != "openai" || result.ModelID != "gpt-4o" {
2301+
t.Fatalf("result = %+v, want openai/gpt-4o", result)
2302+
}
2303+
if len(updated) != 1 {
2304+
t.Fatalf("updated len = %d, want 1", len(updated))
2305+
}
2306+
if updated[0].Head.Provider != "openai" || updated[0].Head.Model != "gpt-4o" {
2307+
t.Fatalf("updated head = %+v, want openai/gpt-4o", updated[0].Head)
2308+
}
2309+
}
2310+
2311+
func TestGatewayRuntimePortBridgeSelectProviderModelSyncWorkspaceLoadError(t *testing.T) {
2312+
store := &bridgeSessionStoreWithLoader{loadErr: errors.New("load failed")}
2313+
ps := &providerSelectionStub{
2314+
selectRes: configstate.Selection{ProviderID: "gemini", ModelID: "gemini-2.5-pro"},
2315+
}
2316+
stub := &runtimeStub{
2317+
eventsCh: make(chan agentruntime.RuntimeEvent, 1),
2318+
sessionList: []agentsession.Summary{
2319+
{ID: "session-1", Title: "Session 1"},
2320+
},
2321+
}
2322+
bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store, nil, ps)
2323+
defer bridge.Close()
2324+
2325+
_, err := bridge.SelectProviderModel(context.Background(), gateway.SelectProviderModelInput{
2326+
SubjectID: testBridgeSubjectID,
2327+
ProviderID: "gemini",
2328+
})
2329+
if err == nil || err.Error() != "load failed" {
2330+
t.Fatalf("expected load failed, got %v", err)
2331+
}
2332+
}
2333+
2334+
func TestGatewayRuntimePortBridgeSelectProviderModelSyncWorkspaceUpdateError(t *testing.T) {
2335+
store := &bridgeSessionStoreWithLoader{
2336+
bridgeSessionStoreStub: bridgeSessionStoreStub{
2337+
updateFn: func(_ context.Context, _ agentsession.UpdateSessionStateInput) error {
2338+
return errors.New("update failed")
2339+
},
2340+
},
2341+
session: agentsession.Session{
2342+
ID: "session-1",
2343+
Title: "Session 1",
2344+
Provider: "openai",
2345+
Model: "gpt-4.1",
2346+
},
2347+
}
2348+
ps := &providerSelectionStub{
2349+
selectRes: configstate.Selection{ProviderID: "gemini", ModelID: "gemini-2.5-pro"},
2350+
}
2351+
stub := &runtimeStub{
2352+
eventsCh: make(chan agentruntime.RuntimeEvent, 1),
2353+
sessionList: []agentsession.Summary{
2354+
{ID: "session-1", Title: "Session 1"},
2355+
},
2356+
}
2357+
bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store, nil, ps)
2358+
defer bridge.Close()
2359+
2360+
_, err := bridge.SelectProviderModel(context.Background(), gateway.SelectProviderModelInput{
2361+
SubjectID: testBridgeSubjectID,
2362+
ProviderID: "gemini",
2363+
})
2364+
if err == nil || err.Error() != "update failed" {
2365+
t.Fatalf("expected update failed, got %v", err)
2366+
}
2367+
}
2368+
22142369
// ---- UpsertMCPServer ----
22152370

22162371
func TestGatewayRuntimePortBridgeUpsertMCPServerNilConfigManager(t *testing.T) {

internal/gateway/multi_workspace_runtime.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,26 @@ func (m *MultiWorkspaceRuntime) syncAllWorkspaceMCP() {
560560
}
561561
}
562562

563+
// syncAllWorkspaceSessionsProviderModel 将全局 provider/model 选择同步到所有已加载工作区的会话元数据,
564+
// 避免非管理端口对应工作区的会话滞留旧值,导致 listModels 解析到过期 provider/model。
565+
func (m *MultiWorkspaceRuntime) syncAllWorkspaceSessionsProviderModel(ctx context.Context, providerID, modelID string) {
566+
m.mu.RLock()
567+
bundles := make([]*workspaceBundle, 0, len(m.bundles))
568+
for _, b := range m.bundles {
569+
bundles = append(bundles, b)
570+
}
571+
m.mu.RUnlock()
572+
573+
for _, b := range bundles {
574+
type sessionSyncer interface {
575+
SyncSessionsProviderModel(ctx context.Context, providerID, modelID string) error
576+
}
577+
if syncer, ok := b.port.(sessionSyncer); ok {
578+
_ = syncer.SyncSessionsProviderModel(ctx, providerID, modelID)
579+
}
580+
}
581+
}
582+
563583
// ---- ManagementRuntimePort implementation ----
564584

565585
func (m *MultiWorkspaceRuntime) ListProviders(ctx context.Context, input ListProvidersInput) ([]ProviderOption, error) {
@@ -580,6 +600,7 @@ func (m *MultiWorkspaceRuntime) CreateProvider(ctx context.Context, input Create
580600
return ProviderSelectionResult{}, err
581601
}
582602
m.syncAllWorkspaceConfigs(ctx)
603+
m.syncAllWorkspaceSessionsProviderModel(ctx, result.ProviderID, result.ModelID)
583604
return result, nil
584605
}
585606

@@ -605,6 +626,7 @@ func (m *MultiWorkspaceRuntime) SelectProviderModel(ctx context.Context, input S
605626
return ProviderSelectionResult{}, err
606627
}
607628
m.syncAllWorkspaceConfigs(ctx)
629+
m.syncAllWorkspaceSessionsProviderModel(ctx, result.ProviderID, result.ModelID)
608630
return result, nil
609631
}
610632

0 commit comments

Comments
 (0)