Skip to content

Commit b88fc2f

Browse files
authored
Merge pull request #663 from Yumiue/codex/fix-workspace-session-isolation
修复 Web 工作区会话串线问题
2 parents f288568 + 82a77f7 commit b88fc2f

12 files changed

Lines changed: 464 additions & 63 deletions

internal/gateway/bootstrap.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func handleAuthenticateFrame(ctx context.Context, frame MessageFrame) MessageFra
108108
}
109109

110110
// handleBindStreamFrame 处理 gateway.bindStream 并注册连接订阅关系。
111-
func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame {
111+
func handleBindStreamFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame {
112112
params, err := decodeBindStreamParams(frame.Payload)
113113
if err != nil {
114114
return errorFrame(frame, err)
@@ -120,13 +120,18 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame
120120
return errorFrame(frame, NewFrameError(ErrorCodeInternalError, "stream relay context is unavailable"))
121121
}
122122

123+
if validationFrame := validateBindStreamSession(ctx, frame, runtimePort, params.SessionID); validationFrame != nil {
124+
return *validationFrame
125+
}
126+
123127
if bindErr := relay.BindConnection(connectionID, StreamBinding{
124-
SessionID: params.SessionID,
125-
RunID: params.RunID,
126-
Channel: params.Channel,
127-
Role: params.Role,
128-
State: cloneMapValue(params.State),
129-
Explicit: true,
128+
SessionID: params.SessionID,
129+
RunID: params.RunID,
130+
WorkspaceHash: WorkspaceHashFromContext(ctx),
131+
Channel: params.Channel,
132+
Role: params.Role,
133+
State: cloneMapValue(params.State),
134+
Explicit: true,
130135
}); bindErr != nil {
131136
return errorFrame(frame, bindErr)
132137
}
@@ -146,6 +151,34 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame
146151
}
147152
}
148153

154+
// validateBindStreamSession 确认事件流绑定的会话在当前工作区 runtime 中可见。
155+
func validateBindStreamSession(
156+
ctx context.Context,
157+
frame MessageFrame,
158+
runtimePort RuntimePort,
159+
sessionID string,
160+
) *MessageFrame {
161+
if runtimePort == nil {
162+
return nil
163+
}
164+
normalizedSessionID := strings.TrimSpace(sessionID)
165+
if normalizedSessionID == "" {
166+
return nil
167+
}
168+
169+
callCtx, cancel := withRuntimeOperationTimeout(ctx)
170+
defer cancel()
171+
_, err := runtimePort.LoadSession(callCtx, LoadSessionInput{
172+
SubjectID: AuthenticatedSubjectIDFromContext(ctx),
173+
SessionID: normalizedSessionID,
174+
})
175+
if err == nil {
176+
return nil
177+
}
178+
failedFrame := runtimeCallFailedFrame(callCtx, frame, err, "bind_stream")
179+
return &failedFrame
180+
}
181+
149182
// handleAskFrame 处理 gateway.ask 请求,并以异步方式转发到底层 Ask 编排能力。
150183
func handleAskFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame {
151184
if runtimePort == nil {

internal/gateway/bootstrap_test.go

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
12321232
Payload: protocol.BindStreamParams{
12331233
SessionID: "session-1",
12341234
},
1235-
})
1235+
}, nil)
12361236
if response.Type != FrameTypeError {
12371237
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
12381238
}
@@ -1271,7 +1271,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
12711271
SessionID: "session-1",
12721272
Channel: "ipc",
12731273
},
1274-
})
1274+
}, nil)
12751275
if response.Type != FrameTypeError {
12761276
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
12771277
}
@@ -1281,6 +1281,110 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
12811281
})
12821282
}
12831283

1284+
func TestHandleBindStreamFrameRejectsSessionOutsideCurrentWorkspace(t *testing.T) {
1285+
relay := NewStreamRelay(StreamRelayOptions{})
1286+
ctx, cancel := context.WithCancel(context.Background())
1287+
defer cancel()
1288+
1289+
connectionID := NewConnectionID()
1290+
workspaceState := NewConnectionWorkspaceState()
1291+
workspaceState.SetWorkspaceHash("workspace-b")
1292+
connectionCtx := WithConnectionID(ctx, connectionID)
1293+
connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState)
1294+
connectionCtx = WithStreamRelay(connectionCtx, relay)
1295+
if err := relay.RegisterConnection(ConnectionRegistration{
1296+
ConnectionID: connectionID,
1297+
Channel: StreamChannelIPC,
1298+
Context: connectionCtx,
1299+
Cancel: cancel,
1300+
Write: func(message RelayMessage) error {
1301+
_ = message
1302+
return nil
1303+
},
1304+
Close: func() {},
1305+
}); err != nil {
1306+
t.Fatalf("register connection: %v", err)
1307+
}
1308+
defer relay.dropConnection(connectionID)
1309+
1310+
runtimeStub := &bootstrapRuntimeStub{
1311+
loadSessionFn: func(context.Context, LoadSessionInput) (Session, error) {
1312+
return Session{}, ErrRuntimeResourceNotFound
1313+
},
1314+
}
1315+
response := handleBindStreamFrame(connectionCtx, MessageFrame{
1316+
Type: FrameTypeRequest,
1317+
Action: FrameActionBindStream,
1318+
RequestID: "bind-cross-workspace",
1319+
Payload: protocol.BindStreamParams{
1320+
SessionID: "session-from-workspace-a",
1321+
Channel: "all",
1322+
},
1323+
}, runtimeStub)
1324+
if response.Type != FrameTypeError {
1325+
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
1326+
}
1327+
if response.Error == nil || response.Error.Code != ErrorCodeResourceNotFound.String() {
1328+
t.Fatalf("response error = %#v, want resource_not_found", response.Error)
1329+
}
1330+
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-b"); got != "" {
1331+
t.Fatalf("binding should not be written after validation failure, got fallback %q", got)
1332+
}
1333+
}
1334+
1335+
func TestHandleBindStreamFrameValidatesVisibleSessionBeforeBinding(t *testing.T) {
1336+
relay := NewStreamRelay(StreamRelayOptions{})
1337+
ctx, cancel := context.WithCancel(context.Background())
1338+
defer cancel()
1339+
1340+
connectionID := NewConnectionID()
1341+
workspaceState := NewConnectionWorkspaceState()
1342+
workspaceState.SetWorkspaceHash("workspace-a")
1343+
connectionCtx := WithConnectionID(ctx, connectionID)
1344+
connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState)
1345+
connectionCtx = WithStreamRelay(connectionCtx, relay)
1346+
if err := relay.RegisterConnection(ConnectionRegistration{
1347+
ConnectionID: connectionID,
1348+
Channel: StreamChannelIPC,
1349+
Context: connectionCtx,
1350+
Cancel: cancel,
1351+
Write: func(message RelayMessage) error {
1352+
_ = message
1353+
return nil
1354+
},
1355+
Close: func() {},
1356+
}); err != nil {
1357+
t.Fatalf("register connection: %v", err)
1358+
}
1359+
defer relay.dropConnection(connectionID)
1360+
1361+
var loaded LoadSessionInput
1362+
runtimeStub := &bootstrapRuntimeStub{
1363+
loadSessionFn: func(_ context.Context, input LoadSessionInput) (Session, error) {
1364+
loaded = input
1365+
return Session{ID: input.SessionID}, nil
1366+
},
1367+
}
1368+
response := handleBindStreamFrame(connectionCtx, MessageFrame{
1369+
Type: FrameTypeRequest,
1370+
Action: FrameActionBindStream,
1371+
RequestID: "bind-visible-session",
1372+
Payload: protocol.BindStreamParams{
1373+
SessionID: "session-visible",
1374+
Channel: "all",
1375+
},
1376+
}, runtimeStub)
1377+
if response.Type != FrameTypeAck {
1378+
t.Fatalf("response type = %q, want %q: %#v", response.Type, FrameTypeAck, response.Error)
1379+
}
1380+
if loaded.SessionID != "session-visible" {
1381+
t.Fatalf("validated session_id = %q, want %q", loaded.SessionID, "session-visible")
1382+
}
1383+
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-a"); got != "session-visible" {
1384+
t.Fatalf("fallback session = %q, want %q", got, "session-visible")
1385+
}
1386+
}
1387+
12841388
func TestHandleTriggerActionFrame(t *testing.T) {
12851389
registerConnection := func(
12861390
t *testing.T,

internal/gateway/registry.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ func (r *ActionRegistry) initCore() {
3838
r.core[FrameActionPing] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame {
3939
return handlePingFrame(ctx, frame)
4040
}
41-
r.core[FrameActionBindStream] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame {
42-
return handleBindStreamFrame(ctx, frame)
43-
}
41+
r.core[FrameActionBindStream] = handleBindStreamFrame
4442
r.core[FrameActionAsk] = handleAskFrame
4543
r.core[FrameActionDeleteAskSession] = handleDeleteAskSessionFrame
4644
r.core[FrameActionTriggerAction] = handleTriggerActionFrame

internal/gateway/request_logging.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ import (
1212

1313
// RequestLogEntry 表示统一结构化请求日志字段。
1414
type RequestLogEntry struct {
15-
RequestID string `json:"request_id"`
16-
SessionID string `json:"session_id"`
17-
Method string `json:"method"`
18-
Source string `json:"source"`
19-
Status string `json:"status"`
20-
GatewayCode string `json:"gateway_code,omitempty"`
21-
LatencyMS int64 `json:"latency_ms"`
22-
ConnectionID string `json:"connection_id,omitempty"`
23-
AuthState string `json:"auth_state,omitempty"`
15+
RequestID string `json:"request_id"`
16+
SessionID string `json:"session_id"`
17+
Method string `json:"method"`
18+
Source string `json:"source"`
19+
Status string `json:"status"`
20+
WorkspaceHash string `json:"workspace_hash,omitempty"`
21+
GatewayCode string `json:"gateway_code,omitempty"`
22+
LatencyMS int64 `json:"latency_ms"`
23+
ConnectionID string `json:"connection_id,omitempty"`
24+
AuthState string `json:"auth_state,omitempty"`
2425
}
2526

2627
// emitRequestLog 输出网关结构化日志。
@@ -37,6 +38,9 @@ func emitRequestLog(ctx context.Context, logger *log.Logger, entry RequestLogEnt
3738
if connectionID, ok := ConnectionIDFromContext(ctx); ok {
3839
entry.ConnectionID = string(connectionID)
3940
}
41+
if entry.WorkspaceHash == "" {
42+
entry.WorkspaceHash = WorkspaceHashFromContext(ctx)
43+
}
4044
if authState, ok := ConnectionAuthStateFromContext(ctx); ok && authState.IsAuthenticated() {
4145
entry.AuthState = "authenticated"
4246
} else if _, ok := TokenAuthenticatorFromContext(ctx); ok {

internal/gateway/rpc_dispatch.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,10 @@ func hydrateFrameSessionFromConnection(ctx context.Context, frame MessageFrame)
277277
return frame
278278
}
279279

280-
frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionID(connectionID))
280+
frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionIDForWorkspace(
281+
connectionID,
282+
WorkspaceHashFromContext(ctx),
283+
))
281284
return frame
282285
}
283286

internal/gateway/stream_relay.go

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ type ConnectionRegistration struct {
5050

5151
// StreamBinding 描述连接绑定到会话路由表的一条订阅关系。
5252
type StreamBinding struct {
53-
SessionID string
54-
RunID string
55-
Channel StreamChannel
56-
Role StreamRole
57-
State map[string]any
58-
Explicit bool
53+
SessionID string
54+
RunID string
55+
WorkspaceHash string
56+
Channel StreamChannel
57+
Role StreamRole
58+
State map[string]any
59+
Explicit bool
5960
}
6061

6162
// StreamRelayOptions 描述会话路由与流式中继的可选配置。
@@ -87,14 +88,15 @@ type bindingKey struct {
8788
}
8889

8990
type bindingState struct {
90-
sessionID string
91-
runID string
92-
channel StreamChannel
93-
role StreamRole
94-
state map[string]any
95-
explicit bool
96-
expireAt time.Time
97-
lastSeen time.Time
91+
sessionID string
92+
runID string
93+
workspaceHash string
94+
channel StreamChannel
95+
role StreamRole
96+
state map[string]any
97+
explicit bool
98+
expireAt time.Time
99+
lastSeen time.Time
98100
}
99101

100102
// StreamRelay 维护连接-会话-运行态映射,并负责运行事件的精确中继。
@@ -405,6 +407,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
405407
return NewFrameError(ErrorCodeInvalidAction, "bind channel does not match connection channel")
406408
}
407409

410+
workspaceHash := strings.TrimSpace(binding.WorkspaceHash)
411+
if workspaceHash == "" {
412+
workspaceHash = WorkspaceHashFromContext(connection.ctx)
413+
}
414+
408415
key := bindingKey{sessionID: sessionID, runID: runID}
409416
connectionBindingMap := r.connectionBindings[normalizedConnectionID]
410417
if connectionBindingMap == nil {
@@ -424,14 +431,15 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
424431
return NewFrameError(ErrorCodeInvalidAction, "too many stream bindings for connection")
425432
}
426433
connectionBindingMap[key] = &bindingState{
427-
sessionID: sessionID,
428-
runID: runID,
429-
channel: channel,
430-
role: role,
431-
state: state,
432-
explicit: binding.Explicit,
433-
expireAt: now.Add(r.bindingTTL),
434-
lastSeen: now,
434+
sessionID: sessionID,
435+
runID: runID,
436+
workspaceHash: workspaceHash,
437+
channel: channel,
438+
role: role,
439+
state: state,
440+
explicit: binding.Explicit,
441+
expireAt: now.Add(r.bindingTTL),
442+
lastSeen: now,
435443
}
436444
r.addConnectionToSessionIndexLocked(sessionID, normalizedConnectionID)
437445
if runID != "" {
@@ -443,6 +451,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
443451

444452
// ResolveFallbackSessionID 返回连接当前可用绑定中的会话兜底值(取最近续期的绑定)。
445453
func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string {
454+
return r.ResolveFallbackSessionIDForWorkspace(connectionID, "")
455+
}
456+
457+
// ResolveFallbackSessionIDForWorkspace 返回指定工作区内最近续期的连接兜底会话。
458+
func (r *StreamRelay) ResolveFallbackSessionIDForWorkspace(connectionID ConnectionID, workspaceHash string) string {
446459
if r == nil {
447460
return ""
448461
}
@@ -453,6 +466,7 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string
453466
}
454467

455468
now := time.Now()
469+
normalizedWorkspaceHash := strings.TrimSpace(workspaceHash)
456470

457471
r.mu.RLock()
458472
connectionBindingMap := r.connectionBindings[normalizedConnectionID]
@@ -464,6 +478,10 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string
464478
if state == nil || state.expireAt.Before(now) {
465479
continue
466480
}
481+
if normalizedWorkspaceHash != "" &&
482+
!strings.EqualFold(strings.TrimSpace(state.workspaceHash), normalizedWorkspaceHash) {
483+
continue
484+
}
467485
if state.lastSeen.After(latestSeen) {
468486
latestSeen = state.lastSeen
469487
latestSessionID = state.sessionID

0 commit comments

Comments
 (0)