Skip to content

Commit 9c50ab8

Browse files
authored
Merge pull request #5 from pionxe/fork-pr-374-1776686256
fix: harden gateway runtime migration reliability
2 parents 6b68f48 + bc989c1 commit 9c50ab8

7 files changed

Lines changed: 177 additions & 33 deletions

File tree

internal/runtime/create_session_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package runtime
33
import (
44
"context"
55
"fmt"
6+
"os"
67
"testing"
78

89
agentsession "neo-code/internal/session"
@@ -160,7 +161,7 @@ func TestServiceCreateSessionDuplicateCreateFallsBackToLoad(t *testing.T) {
160161
memoryStore: newMemoryStore(),
161162
missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound),
162163
},
163-
createErr: fmt.Errorf("unique constraint failed"),
164+
createErr: fmt.Errorf("sqlite: %w", agentsession.ErrSessionAlreadyExists),
164165
loaded: agentsession.Session{ID: "session-dup", Title: "loaded"},
165166
}
166167
service := &Service{
@@ -190,9 +191,15 @@ func TestCreateSessionErrorPredicates(t *testing.T) {
190191
if isRuntimeSessionAlreadyExistsError(nil) {
191192
t.Fatalf("isRuntimeSessionAlreadyExistsError(nil) should be false")
192193
}
194+
if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", agentsession.ErrSessionAlreadyExists)) {
195+
t.Fatalf("wrapped ErrSessionAlreadyExists should be detected")
196+
}
197+
if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", os.ErrExist)) {
198+
t.Fatalf("wrapped os.ErrExist should be detected")
199+
}
193200
for _, text := range []string{"already exists", "UNIQUE CONSTRAINT", "duplicate key"} {
194-
if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) {
195-
t.Fatalf("expected %q to be treated as already exists", text)
201+
if isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) {
202+
t.Fatalf("plain text %q should not be treated as already exists", text)
196203
}
197204
}
198205
}

internal/runtime/runtime.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package runtime
33
import (
44
"context"
55
"errors"
6+
"os"
67
"strings"
78
"sync"
89
"time"
@@ -310,10 +311,7 @@ func isRuntimeSessionAlreadyExistsError(err error) bool {
310311
if err == nil {
311312
return false
312313
}
313-
normalized := strings.ToLower(strings.TrimSpace(err.Error()))
314-
return strings.Contains(normalized, "already exists") ||
315-
strings.Contains(normalized, "unique constraint") ||
316-
strings.Contains(normalized, "duplicate")
314+
return errors.Is(err, agentsession.ErrSessionAlreadyExists) || errors.Is(err, os.ErrExist)
317315
}
318316

319317
// SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。

internal/session/sqlite_store.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ INSERT INTO sessions (
177177
session.TokenOutputTotal,
178178
)
179179
if err != nil {
180+
if isSQLiteSessionUniqueConstraintError(err) {
181+
return Session{}, wrapSessionAlreadyExists(err)
182+
}
180183
return Session{}, fmt.Errorf("session: insert session %s: %w", session.ID, err)
181184
}
182185
if err := tx.Commit(); err != nil {
@@ -1306,6 +1309,14 @@ func wrapSessionNotFound(cause error) error {
13061309
return fmt.Errorf("%w: %w", ErrSessionNotFound, fmt.Errorf("%w: %w", os.ErrNotExist, cause))
13071310
}
13081311

1312+
// wrapSessionAlreadyExists 统一包装会话重复创建错误,确保上层可通过 ErrSessionAlreadyExists 做精确判断。
1313+
func wrapSessionAlreadyExists(cause error) error {
1314+
if cause == nil {
1315+
cause = os.ErrExist
1316+
}
1317+
return fmt.Errorf("%w: %w", ErrSessionAlreadyExists, fmt.Errorf("%w: %w", os.ErrExist, cause))
1318+
}
1319+
13091320
// cloneMessage 深拷贝消息,避免共享底层切片和映射。
13101321
// mapSessionAssetInsertError 统一收敛附件元数据插入阶段的缺失会话语义,避免向上泄漏底层 SQLite 错误。
13111322
func mapSessionAssetInsertError(assetID string, err error) error {
@@ -1324,6 +1335,18 @@ func isSQLiteForeignKeyConstraintError(err error) bool {
13241335
return false
13251336
}
13261337

1338+
// isSQLiteSessionUniqueConstraintError 判断底层错误是否为 SQLite 主键/唯一约束失败。
1339+
func isSQLiteSessionUniqueConstraintError(err error) bool {
1340+
var sqliteErr *sqlitedriver.Error
1341+
if !errors.As(err, &sqliteErr) {
1342+
return false
1343+
}
1344+
code := sqliteErr.Code()
1345+
return code == sqlite3.SQLITE_CONSTRAINT ||
1346+
code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
1347+
code == sqlite3.SQLITE_CONSTRAINT_UNIQUE
1348+
}
1349+
13271350
func cloneMessage(message providertypes.Message) providertypes.Message {
13281351
next := message
13291352
next.Parts = providertypes.CloneParts(message.Parts)

internal/session/store.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ var storageIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,127}$`)
2828
// ErrSessionNotFound 表示会话在存储层不存在,用于 runtime 做精确错误分流。
2929
var ErrSessionNotFound = errors.New("session: session not found")
3030

31+
// ErrSessionAlreadyExists 表示会话在存储层已存在,用于 runtime 处理并发创建冲突。
32+
var ErrSessionAlreadyExists = errors.New("session: session already exists")
33+
3134
// Session 表示单个会话的运行态与持久化聚合模型。
3235
type Session struct {
3336
ID string

internal/session/store_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,28 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) {
122122
}
123123
}
124124

125+
func TestSQLiteStoreCreateSessionDuplicateReturnsSentinel(t *testing.T) {
126+
t.Parallel()
127+
128+
ctx := context.Background()
129+
store := newTestStore(t)
130+
input := CreateSessionInput{ID: "dup_session", Title: "dup"}
131+
if _, err := store.CreateSession(ctx, input); err != nil {
132+
t.Fatalf("first CreateSession() error = %v", err)
133+
}
134+
135+
_, err := store.CreateSession(ctx, input)
136+
if err == nil {
137+
t.Fatalf("expected duplicate CreateSession() to fail")
138+
}
139+
if !errors.Is(err, ErrSessionAlreadyExists) {
140+
t.Fatalf("expected ErrSessionAlreadyExists, got %v", err)
141+
}
142+
if !errors.Is(err, os.ErrExist) {
143+
t.Fatalf("expected os.ErrExist chain, got %v", err)
144+
}
145+
}
146+
125147
func TestSQLiteStoreListSummariesSortedAndLegacyJSONIgnored(t *testing.T) {
126148
ctx := context.Background()
127149
baseDir, err := os.MkdirTemp("", "session-base-")

internal/tui/services/gateway_rpc_client.go

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"log"
89
"net"
910
"strings"
1011
"sync"
@@ -17,10 +18,11 @@ import (
1718
)
1819

1920
const (
20-
defaultGatewayRPCRequestTimeout = 8 * time.Second
21-
defaultGatewayRPCRetryCount = 1
22-
defaultGatewayNotificationBuffer = 64
23-
defaultGatewayNotificationQueue = 256
21+
defaultGatewayRPCRequestTimeout = 8 * time.Second
22+
defaultGatewayRPCRetryCount = 1
23+
defaultGatewayNotificationBuffer = 64
24+
defaultGatewayNotificationQueue = 256
25+
defaultGatewayNotificationEnqueueTimeout = 3 * time.Second
2426
)
2527

2628
// GatewayRPCClientOptions 描述网关 JSON-RPC 客户端的初始化参数。
@@ -107,11 +109,12 @@ type GatewayRPCClient struct {
107109
conn net.Conn
108110
pending map[string]chan gatewayRPCResponse
109111

110-
notifications chan gatewayRPCNotification
111-
notificationQueue chan gatewayRPCNotification
112-
notificationWG sync.WaitGroup
113-
notificationStart sync.Once
114-
sequence uint64
112+
notifications chan gatewayRPCNotification
113+
notificationQueue chan gatewayRPCNotification
114+
notificationEnqueueTimeout time.Duration
115+
notificationWG sync.WaitGroup
116+
notificationStart sync.Once
117+
sequence uint64
115118
}
116119

117120
// NewGatewayRPCClient 创建网关 RPC 客户端,并在启动时静默读取认证 Token。
@@ -146,15 +149,16 @@ func NewGatewayRPCClient(options GatewayRPCClientOptions) (*GatewayRPCClient, er
146149
}
147150

148151
return &GatewayRPCClient{
149-
listenAddress: listenAddress,
150-
token: token,
151-
requestTimeout: requestTimeout,
152-
retryCount: retryCount,
153-
dialFn: dialFn,
154-
closed: make(chan struct{}),
155-
pending: make(map[string]chan gatewayRPCResponse),
156-
notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer),
157-
notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue),
152+
listenAddress: listenAddress,
153+
token: token,
154+
requestTimeout: requestTimeout,
155+
retryCount: retryCount,
156+
dialFn: dialFn,
157+
closed: make(chan struct{}),
158+
pending: make(map[string]chan gatewayRPCResponse),
159+
notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer),
160+
notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue),
161+
notificationEnqueueTimeout: defaultGatewayNotificationEnqueueTimeout,
158162
}, nil
159163
}
160164

@@ -232,7 +236,6 @@ func (c *GatewayRPCClient) Close() error {
232236
c.closeOnce.Do(func() {
233237
close(c.closed)
234238
firstErr = c.forceCloseWithError(errors.New("gateway rpc client closed"))
235-
close(c.notificationQueue)
236239
c.notificationWG.Wait()
237240
close(c.notifications)
238241
})
@@ -372,7 +375,9 @@ func (c *GatewayRPCClient) readLoop(conn net.Conn) {
372375
if paramsRaw, hasParams := envelope["params"]; hasParams {
373376
notification.Params = cloneJSONRawMessage(paramsRaw)
374377
}
375-
c.enqueueNotification(notification)
378+
if !c.enqueueNotification(notification) {
379+
return
380+
}
376381
continue
377382
}
378383

@@ -387,7 +392,7 @@ func (c *GatewayRPCClient) readLoop(conn net.Conn) {
387392
}
388393
}
389394

390-
// startNotificationDispatcher 启动通知转发协程,确保 readLoop 不会被 UI 消费速度阻塞
395+
// startNotificationDispatcher 启动通知转发协程,配合 enqueue 超时保护避免 readLoop 长时间背压阻塞
391396
func (c *GatewayRPCClient) startNotificationDispatcher() {
392397
c.notificationStart.Do(func() {
393398
c.notificationWG.Add(1)
@@ -412,12 +417,25 @@ func (c *GatewayRPCClient) startNotificationDispatcher() {
412417
})
413418
}
414419

415-
// enqueueNotification 以阻塞方式投递通知,确保 gateway.event 不会因队列满被静默丢弃。
416-
func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) {
420+
// enqueueNotification 投递通知到内部队列;若背压持续超时则主动断开连接,避免 readLoop 无限阻塞。
421+
func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) bool {
422+
enqueueTimeout := c.notificationEnqueueTimeout
423+
if enqueueTimeout <= 0 {
424+
enqueueTimeout = defaultGatewayNotificationEnqueueTimeout
425+
}
426+
timer := time.NewTimer(enqueueTimeout)
427+
defer timer.Stop()
428+
417429
select {
418430
case <-c.closed:
419-
return
431+
return false
420432
case c.notificationQueue <- notification:
433+
return true
434+
case <-timer.C:
435+
err := fmt.Errorf("gateway rpc client: notification queue blocked for %s", enqueueTimeout)
436+
log.Printf("warning: gateway rpc client force close due to notification backpressure method=%s err=%v", notification.Method, err)
437+
_ = c.forceCloseWithError(err)
438+
return false
421439
}
422440
}
423441

internal/tui/services/gateway_rpc_client_additional_test.go

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ func TestGatewayRPCClientReadLoopAdditionalBranches(t *testing.T) {
443443
_ = client.Close()
444444
}
445445

446-
func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) {
446+
func TestGatewayRPCClientNotificationDispatcherStopsOnCloseSignal(t *testing.T) {
447447
t.Parallel()
448448

449449
client := &GatewayRPCClient{
@@ -453,7 +453,7 @@ func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) {
453453
notificationQueue: make(chan gatewayRPCNotification, 1),
454454
}
455455
client.startNotificationDispatcher()
456-
close(client.notificationQueue)
456+
close(client.closed)
457457
client.notificationWG.Wait()
458458
}
459459

@@ -510,6 +510,79 @@ func TestGatewayRPCClientEnqueueNotificationDoesNotDropUnderQueuePressure(t *tes
510510
}
511511
}
512512

513+
func TestGatewayRPCClientReadLoopFailsFastOnNotificationBackpressure(t *testing.T) {
514+
t.Parallel()
515+
516+
clientConn, serverConn := net.Pipe()
517+
t.Cleanup(func() {
518+
_ = clientConn.Close()
519+
_ = serverConn.Close()
520+
})
521+
522+
client := &GatewayRPCClient{
523+
closed: make(chan struct{}),
524+
pending: make(map[string]chan gatewayRPCResponse),
525+
notifications: make(chan gatewayRPCNotification),
526+
notificationQueue: make(chan gatewayRPCNotification, 1),
527+
notificationEnqueueTimeout: 50 * time.Millisecond,
528+
}
529+
client.startNotificationDispatcher()
530+
t.Cleanup(func() { _ = client.Close() })
531+
532+
readDone := make(chan struct{})
533+
go func() {
534+
defer close(readDone)
535+
client.readLoop(clientConn)
536+
}()
537+
encoder := json.NewEncoder(serverConn)
538+
if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 1}}); err != nil {
539+
t.Fatalf("encode first notification: %v", err)
540+
}
541+
if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 2}}); err != nil {
542+
t.Fatalf("encode second notification: %v", err)
543+
}
544+
if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 3}}); err != nil {
545+
t.Fatalf("encode third notification: %v", err)
546+
}
547+
548+
select {
549+
case <-readDone:
550+
case <-time.After(time.Second):
551+
t.Fatalf("expected readLoop to fail-fast on sustained notification backpressure")
552+
}
553+
}
554+
555+
func TestGatewayRPCClientEnqueueNotificationUnblocksOnClose(t *testing.T) {
556+
t.Parallel()
557+
558+
client := &GatewayRPCClient{
559+
closed: make(chan struct{}),
560+
pending: make(map[string]chan gatewayRPCResponse),
561+
notifications: make(chan gatewayRPCNotification),
562+
notificationQueue: make(chan gatewayRPCNotification, 1),
563+
notificationEnqueueTimeout: time.Second,
564+
}
565+
client.startNotificationDispatcher()
566+
567+
// 首条通知占满队列,第二条通知会阻塞在 enqueue,关闭客户端后应立即退出。
568+
client.notificationQueue <- gatewayRPCNotification{Method: protocol.MethodGatewayEvent}
569+
570+
done := make(chan struct{})
571+
go func() {
572+
defer close(done)
573+
client.enqueueNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent})
574+
}()
575+
576+
time.Sleep(20 * time.Millisecond)
577+
_ = client.Close()
578+
579+
select {
580+
case <-done:
581+
case <-time.After(time.Second):
582+
t.Fatalf("enqueueNotification should unblock when client closes")
583+
}
584+
}
585+
513586
func TestGatewayRPCClientWriteRequestFailure(t *testing.T) {
514587
t.Parallel()
515588

0 commit comments

Comments
 (0)