Skip to content

Commit 0e3d6b8

Browse files
committed
fix(gateway): 隔离 Web 工作区会话绑定
1 parent f288568 commit 0e3d6b8

11 files changed

Lines changed: 294 additions & 61 deletions

internal/gateway/bootstrap.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func handleAuthenticateFrame(ctx context.Context, frame MessageFrame) MessageFra
108108
}
109109

110110
// handleBindStreamFrame 处理 gateway.bindStream 并注册连接订阅关系。
111-
func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame {
111+
func handleBindStreamFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame {
112112
params, err := decodeBindStreamParams(frame.Payload)
113113
if err != nil {
114114
return errorFrame(frame, err)
@@ -120,13 +120,18 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame
120120
return errorFrame(frame, NewFrameError(ErrorCodeInternalError, "stream relay context is unavailable"))
121121
}
122122

123+
if validationFrame := validateBindStreamSession(ctx, frame, runtimePort, params.SessionID); validationFrame != nil {
124+
return *validationFrame
125+
}
126+
123127
if bindErr := relay.BindConnection(connectionID, StreamBinding{
124-
SessionID: params.SessionID,
125-
RunID: params.RunID,
126-
Channel: params.Channel,
127-
Role: params.Role,
128-
State: cloneMapValue(params.State),
129-
Explicit: true,
128+
SessionID: params.SessionID,
129+
RunID: params.RunID,
130+
WorkspaceHash: WorkspaceHashFromContext(ctx),
131+
Channel: params.Channel,
132+
Role: params.Role,
133+
State: cloneMapValue(params.State),
134+
Explicit: true,
130135
}); bindErr != nil {
131136
return errorFrame(frame, bindErr)
132137
}
@@ -146,6 +151,34 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame
146151
}
147152
}
148153

154+
// validateBindStreamSession 确认事件流绑定的会话在当前工作区 runtime 中可见。
155+
func validateBindStreamSession(
156+
ctx context.Context,
157+
frame MessageFrame,
158+
runtimePort RuntimePort,
159+
sessionID string,
160+
) *MessageFrame {
161+
if runtimePort == nil {
162+
return nil
163+
}
164+
normalizedSessionID := strings.TrimSpace(sessionID)
165+
if normalizedSessionID == "" {
166+
return nil
167+
}
168+
169+
callCtx, cancel := withRuntimeOperationTimeout(ctx)
170+
defer cancel()
171+
_, err := runtimePort.LoadSession(callCtx, LoadSessionInput{
172+
SubjectID: AuthenticatedSubjectIDFromContext(ctx),
173+
SessionID: normalizedSessionID,
174+
})
175+
if err == nil {
176+
return nil
177+
}
178+
failedFrame := runtimeCallFailedFrame(callCtx, frame, err, "bind_stream")
179+
return &failedFrame
180+
}
181+
149182
// handleAskFrame 处理 gateway.ask 请求,并以异步方式转发到底层 Ask 编排能力。
150183
func handleAskFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame {
151184
if runtimePort == nil {

internal/gateway/bootstrap_test.go

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
12321232
Payload: protocol.BindStreamParams{
12331233
SessionID: "session-1",
12341234
},
1235-
})
1235+
}, nil)
12361236
if response.Type != FrameTypeError {
12371237
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
12381238
}
@@ -1271,7 +1271,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
12711271
SessionID: "session-1",
12721272
Channel: "ipc",
12731273
},
1274-
})
1274+
}, nil)
12751275
if response.Type != FrameTypeError {
12761276
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
12771277
}
@@ -1281,6 +1281,57 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
12811281
})
12821282
}
12831283

1284+
func TestHandleBindStreamFrameRejectsSessionOutsideCurrentWorkspace(t *testing.T) {
1285+
relay := NewStreamRelay(StreamRelayOptions{})
1286+
ctx, cancel := context.WithCancel(context.Background())
1287+
defer cancel()
1288+
1289+
connectionID := NewConnectionID()
1290+
workspaceState := NewConnectionWorkspaceState()
1291+
workspaceState.SetWorkspaceHash("workspace-b")
1292+
connectionCtx := WithConnectionID(ctx, connectionID)
1293+
connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState)
1294+
connectionCtx = WithStreamRelay(connectionCtx, relay)
1295+
if err := relay.RegisterConnection(ConnectionRegistration{
1296+
ConnectionID: connectionID,
1297+
Channel: StreamChannelIPC,
1298+
Context: connectionCtx,
1299+
Cancel: cancel,
1300+
Write: func(message RelayMessage) error {
1301+
_ = message
1302+
return nil
1303+
},
1304+
Close: func() {},
1305+
}); err != nil {
1306+
t.Fatalf("register connection: %v", err)
1307+
}
1308+
defer relay.dropConnection(connectionID)
1309+
1310+
runtimeStub := &bootstrapRuntimeStub{
1311+
loadSessionFn: func(context.Context, LoadSessionInput) (Session, error) {
1312+
return Session{}, ErrRuntimeResourceNotFound
1313+
},
1314+
}
1315+
response := handleBindStreamFrame(connectionCtx, MessageFrame{
1316+
Type: FrameTypeRequest,
1317+
Action: FrameActionBindStream,
1318+
RequestID: "bind-cross-workspace",
1319+
Payload: protocol.BindStreamParams{
1320+
SessionID: "session-from-workspace-a",
1321+
Channel: "all",
1322+
},
1323+
}, runtimeStub)
1324+
if response.Type != FrameTypeError {
1325+
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
1326+
}
1327+
if response.Error == nil || response.Error.Code != ErrorCodeResourceNotFound.String() {
1328+
t.Fatalf("response error = %#v, want resource_not_found", response.Error)
1329+
}
1330+
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-b"); got != "" {
1331+
t.Fatalf("binding should not be written after validation failure, got fallback %q", got)
1332+
}
1333+
}
1334+
12841335
func TestHandleTriggerActionFrame(t *testing.T) {
12851336
registerConnection := func(
12861337
t *testing.T,

internal/gateway/registry.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ func (r *ActionRegistry) initCore() {
3838
r.core[FrameActionPing] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame {
3939
return handlePingFrame(ctx, frame)
4040
}
41-
r.core[FrameActionBindStream] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame {
42-
return handleBindStreamFrame(ctx, frame)
43-
}
41+
r.core[FrameActionBindStream] = handleBindStreamFrame
4442
r.core[FrameActionAsk] = handleAskFrame
4543
r.core[FrameActionDeleteAskSession] = handleDeleteAskSessionFrame
4644
r.core[FrameActionTriggerAction] = handleTriggerActionFrame

internal/gateway/request_logging.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ import (
1212

1313
// RequestLogEntry 表示统一结构化请求日志字段。
1414
type RequestLogEntry struct {
15-
RequestID string `json:"request_id"`
16-
SessionID string `json:"session_id"`
17-
Method string `json:"method"`
18-
Source string `json:"source"`
19-
Status string `json:"status"`
20-
GatewayCode string `json:"gateway_code,omitempty"`
21-
LatencyMS int64 `json:"latency_ms"`
22-
ConnectionID string `json:"connection_id,omitempty"`
23-
AuthState string `json:"auth_state,omitempty"`
15+
RequestID string `json:"request_id"`
16+
SessionID string `json:"session_id"`
17+
Method string `json:"method"`
18+
Source string `json:"source"`
19+
Status string `json:"status"`
20+
WorkspaceHash string `json:"workspace_hash,omitempty"`
21+
GatewayCode string `json:"gateway_code,omitempty"`
22+
LatencyMS int64 `json:"latency_ms"`
23+
ConnectionID string `json:"connection_id,omitempty"`
24+
AuthState string `json:"auth_state,omitempty"`
2425
}
2526

2627
// emitRequestLog 输出网关结构化日志。
@@ -37,6 +38,9 @@ func emitRequestLog(ctx context.Context, logger *log.Logger, entry RequestLogEnt
3738
if connectionID, ok := ConnectionIDFromContext(ctx); ok {
3839
entry.ConnectionID = string(connectionID)
3940
}
41+
if entry.WorkspaceHash == "" {
42+
entry.WorkspaceHash = WorkspaceHashFromContext(ctx)
43+
}
4044
if authState, ok := ConnectionAuthStateFromContext(ctx); ok && authState.IsAuthenticated() {
4145
entry.AuthState = "authenticated"
4246
} else if _, ok := TokenAuthenticatorFromContext(ctx); ok {

internal/gateway/rpc_dispatch.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,10 @@ func hydrateFrameSessionFromConnection(ctx context.Context, frame MessageFrame)
277277
return frame
278278
}
279279

280-
frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionID(connectionID))
280+
frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionIDForWorkspace(
281+
connectionID,
282+
WorkspaceHashFromContext(ctx),
283+
))
281284
return frame
282285
}
283286

internal/gateway/stream_relay.go

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ type ConnectionRegistration struct {
5050

5151
// StreamBinding 描述连接绑定到会话路由表的一条订阅关系。
5252
type StreamBinding struct {
53-
SessionID string
54-
RunID string
55-
Channel StreamChannel
56-
Role StreamRole
57-
State map[string]any
58-
Explicit bool
53+
SessionID string
54+
RunID string
55+
WorkspaceHash string
56+
Channel StreamChannel
57+
Role StreamRole
58+
State map[string]any
59+
Explicit bool
5960
}
6061

6162
// StreamRelayOptions 描述会话路由与流式中继的可选配置。
@@ -87,14 +88,15 @@ type bindingKey struct {
8788
}
8889

8990
type bindingState struct {
90-
sessionID string
91-
runID string
92-
channel StreamChannel
93-
role StreamRole
94-
state map[string]any
95-
explicit bool
96-
expireAt time.Time
97-
lastSeen time.Time
91+
sessionID string
92+
runID string
93+
workspaceHash string
94+
channel StreamChannel
95+
role StreamRole
96+
state map[string]any
97+
explicit bool
98+
expireAt time.Time
99+
lastSeen time.Time
98100
}
99101

100102
// StreamRelay 维护连接-会话-运行态映射,并负责运行事件的精确中继。
@@ -405,6 +407,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
405407
return NewFrameError(ErrorCodeInvalidAction, "bind channel does not match connection channel")
406408
}
407409

410+
workspaceHash := strings.TrimSpace(binding.WorkspaceHash)
411+
if workspaceHash == "" {
412+
workspaceHash = WorkspaceHashFromContext(connection.ctx)
413+
}
414+
408415
key := bindingKey{sessionID: sessionID, runID: runID}
409416
connectionBindingMap := r.connectionBindings[normalizedConnectionID]
410417
if connectionBindingMap == nil {
@@ -424,14 +431,15 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
424431
return NewFrameError(ErrorCodeInvalidAction, "too many stream bindings for connection")
425432
}
426433
connectionBindingMap[key] = &bindingState{
427-
sessionID: sessionID,
428-
runID: runID,
429-
channel: channel,
430-
role: role,
431-
state: state,
432-
explicit: binding.Explicit,
433-
expireAt: now.Add(r.bindingTTL),
434-
lastSeen: now,
434+
sessionID: sessionID,
435+
runID: runID,
436+
workspaceHash: workspaceHash,
437+
channel: channel,
438+
role: role,
439+
state: state,
440+
explicit: binding.Explicit,
441+
expireAt: now.Add(r.bindingTTL),
442+
lastSeen: now,
435443
}
436444
r.addConnectionToSessionIndexLocked(sessionID, normalizedConnectionID)
437445
if runID != "" {
@@ -443,6 +451,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
443451

444452
// ResolveFallbackSessionID 返回连接当前可用绑定中的会话兜底值(取最近续期的绑定)。
445453
func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string {
454+
return r.ResolveFallbackSessionIDForWorkspace(connectionID, "")
455+
}
456+
457+
// ResolveFallbackSessionIDForWorkspace 返回指定工作区内最近续期的连接兜底会话。
458+
func (r *StreamRelay) ResolveFallbackSessionIDForWorkspace(connectionID ConnectionID, workspaceHash string) string {
446459
if r == nil {
447460
return ""
448461
}
@@ -453,6 +466,7 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string
453466
}
454467

455468
now := time.Now()
469+
normalizedWorkspaceHash := strings.TrimSpace(workspaceHash)
456470

457471
r.mu.RLock()
458472
connectionBindingMap := r.connectionBindings[normalizedConnectionID]
@@ -464,6 +478,10 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string
464478
if state == nil || state.expireAt.Before(now) {
465479
continue
466480
}
481+
if normalizedWorkspaceHash != "" &&
482+
!strings.EqualFold(strings.TrimSpace(state.workspaceHash), normalizedWorkspaceHash) {
483+
continue
484+
}
467485
if state.lastSeen.After(latestSeen) {
468486
latestSeen = state.lastSeen
469487
latestSessionID = state.sessionID

internal/gateway/stream_relay_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,49 @@ func TestStreamRelayBindAndFallbackSession(t *testing.T) {
4747
}
4848
}
4949

50+
func TestStreamRelayFallbackSessionIsWorkspaceScoped(t *testing.T) {
51+
relay := NewStreamRelay(StreamRelayOptions{})
52+
ctx, cancel := context.WithCancel(context.Background())
53+
defer cancel()
54+
55+
connectionID := NewConnectionID()
56+
workspaceState := NewConnectionWorkspaceState()
57+
workspaceState.SetWorkspaceHash("workspace-a")
58+
connectionCtx := WithConnectionID(ctx, connectionID)
59+
connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState)
60+
connectionCtx = WithStreamRelay(connectionCtx, relay)
61+
if err := relay.RegisterConnection(ConnectionRegistration{
62+
ConnectionID: connectionID,
63+
Channel: StreamChannelIPC,
64+
Context: connectionCtx,
65+
Cancel: cancel,
66+
Write: func(message RelayMessage) error {
67+
_ = message
68+
return nil
69+
},
70+
Close: func() {},
71+
}); err != nil {
72+
t.Fatalf("register connection: %v", err)
73+
}
74+
defer relay.dropConnection(connectionID)
75+
76+
if bindErr := relay.BindConnection(connectionID, StreamBinding{
77+
SessionID: "session-a",
78+
Channel: StreamChannelAll,
79+
Explicit: true,
80+
}); bindErr != nil {
81+
t.Fatalf("bind workspace-a: %v", bindErr)
82+
}
83+
84+
workspaceState.SetWorkspaceHash("workspace-b")
85+
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-b"); got != "" {
86+
t.Fatalf("workspace-b fallback session id = %q, want empty", got)
87+
}
88+
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-a"); got != "session-a" {
89+
t.Fatalf("workspace-a fallback session id = %q, want session-a", got)
90+
}
91+
}
92+
5093
func TestStreamRelayPublishRuntimeEventNoCrossSession(t *testing.T) {
5194
relay := NewStreamRelay(StreamRelayOptions{})
5295
ctx, cancel := context.WithCancel(context.Background())

web/src/components/chat/ChatInput.test.tsx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { useChatStore } from '@/stores/useChatStore'
55
import { useComposerStore } from '@/stores/useComposerStore'
66
import { useSessionStore } from '@/stores/useSessionStore'
77
import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore'
8+
import { useGatewayStore } from '@/stores/useGatewayStore'
89

910
const mockGatewayAPI = {
1011
listAvailableSkills: vi.fn(),
@@ -70,6 +71,7 @@ describe('ChatInput', () => {
7071

7172
useComposerStore.setState({ composerText: '' })
7273
useSessionStore.setState({ currentSessionId: '' } as never)
74+
useGatewayStore.setState({ currentRunId: '' } as never)
7375
useRuntimeInsightStore.getState().reset()
7476
useChatStore.setState({
7577
isGenerating: false,
@@ -335,4 +337,18 @@ describe('ChatInput', () => {
335337

336338
expect(ring).toHaveAttribute('stroke', 'var(--error)')
337339
})
340+
341+
it('sends session id when cancelling an active run', async () => {
342+
useSessionStore.setState({ currentSessionId: 'session-1' } as never)
343+
useGatewayStore.setState({ currentRunId: 'run-1' } as never)
344+
useChatStore.setState({ isGenerating: true } as never)
345+
mockGatewayAPI.cancel.mockResolvedValueOnce({ payload: { canceled: true, run_id: 'run-1' } })
346+
render(<ChatInput />)
347+
348+
fireEvent.click(screen.getByTitle('停止生成'))
349+
350+
await waitFor(() => {
351+
expect(mockGatewayAPI.cancel).toHaveBeenCalledWith({ session_id: 'session-1', run_id: 'run-1' })
352+
})
353+
})
338354
})

0 commit comments

Comments
 (0)