Skip to content

Commit 81c9902

Browse files
systemimeclaude
andcommitted
feat: improve streaming support and add two-phase database storage
- Add two-phase database storage: insert pending record on request, update on response - Fix responseWriter to implement http.Flusher interface for streaming support - Improve streaming detection: check both request 'stream' param and response Content-Type - Add async database updates to avoid blocking streaming responses - Add status field to track request state (pending/completed) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1d4278c commit 81c9902

3 files changed

Lines changed: 174 additions & 52 deletions

File tree

internal/proxy/proxy.go

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,15 @@ func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) {
9494
}
9595
defer r.Body.Close()
9696

97-
_, model, inputTokens := parseRequestMetadata(body)
97+
reqBody, model, inputTokens := parseRequestMetadata(body)
98+
99+
// 检查请求是否要求流式响应
100+
isStreamRequest := false
101+
if reqBody != nil {
102+
if stream, ok := reqBody["stream"].(bool); ok {
103+
isStreamRequest = stream
104+
}
105+
}
98106

99107
// 验证本地 API Key
100108
if !p.validateLocalAPIKey(r) {
@@ -129,9 +137,44 @@ func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) {
129137
// 日志记录
130138
p.logForwardRequest(model, inputTokens)
131139

140+
// 插入待处理记录到数据库
141+
pendingRecord := &storage.RequestRecord{
142+
Timestamp: startTime,
143+
Provider: p.cfg.Provider,
144+
Model: model,
145+
Method: r.Method,
146+
Path: r.URL.Path,
147+
ClientIP: clientIP,
148+
RequestBody: string(body),
149+
InputTokens: inputTokens,
150+
}
151+
recordID, err := p.storage.InsertPendingRequest(pendingRecord)
152+
if err != nil {
153+
p.logger.Error("插入待处理记录失败", zap.Error(err))
154+
}
155+
156+
// 辅助函数:异步更新记录为失败状态
157+
updateFailedRecord := func(statusCode int, errMsg string) {
158+
if recordID > 0 {
159+
go func() {
160+
duration := time.Since(startTime).Milliseconds()
161+
updateRecord := &storage.RequestRecord{
162+
StatusCode: statusCode,
163+
Duration: float64(duration),
164+
Success: false,
165+
ErrorMsg: errMsg,
166+
}
167+
if err := p.storage.UpdateRequestWithResponse(recordID, updateRecord); err != nil {
168+
p.logger.Error("更新失败记录失败", zap.Error(err))
169+
}
170+
}()
171+
}
172+
}
173+
132174
// 创建上游请求
133175
upstreamReq, err := http.NewRequestWithContext(r.Context(), "POST", targetURL, bytes.NewReader(body))
134176
if err != nil {
177+
updateFailedRecord(http.StatusInternalServerError, "创建请求失败")
135178
p.writeError(w, http.StatusInternalServerError, "创建请求失败")
136179
return
137180
}
@@ -145,20 +188,24 @@ func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) {
145188
resp, err := p.client.Do(upstreamReq)
146189
if err != nil {
147190
if strings.Contains(err.Error(), "context canceled") {
191+
updateFailedRecord(499, "请求被取消")
148192
p.logger.Info("请求被取消", zap.String("model", model))
149193
return
150194
}
195+
updateFailedRecord(http.StatusBadGateway, "上游服务不可用: "+err.Error())
151196
p.logger.Error("上游请求失败", zap.Error(err))
152197
p.writeError(w, http.StatusBadGateway, "上游服务不可用")
153198
return
154199
}
155200
defer resp.Body.Close()
156201

157-
// 处理响应
158-
if resp.StatusCode == http.StatusOK && isEventStream(resp.Header.Get("Content-Type")) {
159-
p.handleStreamResponseWithStats(w, resp, startTime, r.Method, r.URL.Path, targetURL, model, clientIP, inputTokens, string(body))
202+
// 处理响应 - 方案A: 同时检查请求中的 stream 参数和响应 Content-Type
203+
// 如果客户端请求 stream=true,或者上游返回 SSE 格式,都使用流式处理
204+
isStreamResponse := isStreamRequest || isEventStream(resp.Header.Get("Content-Type"))
205+
if resp.StatusCode == http.StatusOK && isStreamResponse {
206+
p.handleStreamResponseWithStats(w, resp, startTime, r.Method, r.URL.Path, targetURL, model, clientIP, inputTokens, string(body), recordID)
160207
} else {
161-
p.handleNormalResponseWithStats(w, resp, startTime, r.Method, r.URL.Path, targetURL, model, clientIP, inputTokens, string(body))
208+
p.handleNormalResponseWithStats(w, resp, startTime, r.Method, r.URL.Path, targetURL, model, clientIP, inputTokens, string(body), recordID)
162209
}
163210
}
164211

@@ -262,7 +309,7 @@ func (p *Proxy) buildHeaders(provider *config.ProviderConfig, apiKey string, req
262309
}
263310

264311
// handleStreamResponseWithStats 处理流式响应并统计
265-
func (p *Proxy) handleStreamResponseWithStats(w http.ResponseWriter, resp *http.Response, startTime time.Time, method, path, targetURL, model, clientIP string, inputTokens int, requestBody string) {
312+
func (p *Proxy) handleStreamResponseWithStats(w http.ResponseWriter, resp *http.Response, startTime time.Time, method, path, targetURL, model, clientIP string, inputTokens int, requestBody string, recordID int64) {
266313
copyHeaders(w.Header(), resp.Header)
267314

268315
// 设置 SSE 头
@@ -272,15 +319,16 @@ func (p *Proxy) handleStreamResponseWithStats(w http.ResponseWriter, resp *http.
272319
w.Header().Set("Cache-Control", "no-cache")
273320
w.Header().Set("Connection", "keep-alive")
274321
w.Header().Set("X-Accel-Buffering", "no")
275-
w.WriteHeader(resp.StatusCode)
276322

277-
// 获取 flusher
323+
// 获取 flusher(在 WriteHeader 之前检查)
278324
flusher, ok := w.(http.Flusher)
279325
if !ok {
280326
p.writeError(w, http.StatusInternalServerError, "不支持流式响应")
281327
return
282328
}
283329

330+
w.WriteHeader(resp.StatusCode)
331+
284332
// 读取并转发响应,同时收集数据
285333
var responseBuf bytes.Buffer
286334
var outputTokens int
@@ -339,30 +387,26 @@ func (p *Proxy) handleStreamResponseWithStats(w http.ResponseWriter, resp *http.
339387
// 打印响应日志
340388
p.logResponse(method, path, targetURL, resp.StatusCode, duration, clientIP, responseBuf.String())
341389

342-
// 保存记录
343-
record := &storage.RequestRecord{
344-
Timestamp: startTime,
345-
Provider: p.cfg.Provider,
346-
Model: model,
347-
Stream: true,
348-
Method: method,
349-
Path: path,
350-
ClientIP: clientIP,
351-
RequestBody: requestBody,
352-
ResponseBody: responseBuf.String(),
353-
StatusCode: resp.StatusCode,
354-
Duration: float64(duration),
355-
InputTokens: inputTokens,
356-
OutputTokens: outputTokens,
357-
TotalTokens: totalTokens,
358-
Success: resp.StatusCode == 200,
359-
}
360-
361-
go p.storage.SaveRequest(record)
390+
// 异步更新记录(不影响响应)
391+
if recordID > 0 {
392+
go func() {
393+
record := &storage.RequestRecord{
394+
ResponseBody: responseBuf.String(),
395+
StatusCode: resp.StatusCode,
396+
Duration: float64(duration),
397+
OutputTokens: outputTokens,
398+
TotalTokens: totalTokens,
399+
Success: resp.StatusCode == 200,
400+
}
401+
if err := p.storage.UpdateRequestWithResponse(recordID, record); err != nil {
402+
p.logger.Error("更新请求记录失败", zap.Error(err))
403+
}
404+
}()
405+
}
362406
}
363407

364408
// handleNormalResponseWithStats 处理普通响应并统计
365-
func (p *Proxy) handleNormalResponseWithStats(w http.ResponseWriter, resp *http.Response, startTime time.Time, method, path, targetURL, model, clientIP string, inputTokens int, requestBody string) {
409+
func (p *Proxy) handleNormalResponseWithStats(w http.ResponseWriter, resp *http.Response, startTime time.Time, method, path, targetURL, model, clientIP string, inputTokens int, requestBody string, recordID int64) {
366410
// 读取响应体
367411
respBody, err := io.ReadAll(resp.Body)
368412
if err != nil {
@@ -401,26 +445,22 @@ func (p *Proxy) handleNormalResponseWithStats(w http.ResponseWriter, resp *http.
401445
// 打印响应日志
402446
p.logResponse(method, path, targetURL, resp.StatusCode, duration, clientIP, string(respBody))
403447

404-
// 保存记录
405-
record := &storage.RequestRecord{
406-
Timestamp: startTime,
407-
Provider: p.cfg.Provider,
408-
Model: model,
409-
Stream: false,
410-
Method: method,
411-
Path: path,
412-
ClientIP: clientIP,
413-
RequestBody: requestBody,
414-
ResponseBody: string(respBody),
415-
StatusCode: resp.StatusCode,
416-
Duration: float64(duration),
417-
InputTokens: inputTokens,
418-
OutputTokens: outputTokens,
419-
TotalTokens: totalTokens,
420-
Success: resp.StatusCode == 200,
421-
}
422-
423-
go p.storage.SaveRequest(record)
448+
// 异步更新记录(不影响响应)
449+
if recordID > 0 {
450+
go func() {
451+
record := &storage.RequestRecord{
452+
ResponseBody: string(respBody),
453+
StatusCode: resp.StatusCode,
454+
Duration: float64(duration),
455+
OutputTokens: outputTokens,
456+
TotalTokens: totalTokens,
457+
Success: resp.StatusCode == 200,
458+
}
459+
if err := p.storage.UpdateRequestWithResponse(recordID, record); err != nil {
460+
p.logger.Error("更新请求记录失败", zap.Error(err))
461+
}
462+
}()
463+
}
424464
}
425465

426466
// writeError 写入错误响应

internal/server/server.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ func (rw *responseWriter) WriteHeader(code int) {
200200
rw.ResponseWriter.WriteHeader(code)
201201
}
202202

203+
// Flush 实现 http.Flusher 接口,支持流式响应
204+
func (rw *responseWriter) Flush() {
205+
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
206+
flusher.Flush()
207+
}
208+
}
209+
203210
// writeJSON 写入 JSON 响应
204211
func (s *Server) writeJSON(w http.ResponseWriter, code int, data interface{}) {
205212
w.Header().Set("Content-Type", "application/json")

internal/storage/storage.go

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,14 @@ CREATE TABLE IF NOT EXISTS requests (
139139
output_tokens INTEGER DEFAULT 0,
140140
total_tokens INTEGER DEFAULT 0,
141141
success INTEGER DEFAULT 1,
142-
error_msg TEXT
142+
error_msg TEXT,
143+
status TEXT DEFAULT 'pending'
143144
);
144145
145146
CREATE INDEX IF NOT EXISTS idx_requests_timestamp ON requests(timestamp);
146147
CREATE INDEX IF NOT EXISTS idx_requests_provider ON requests(provider);
147148
CREATE INDEX IF NOT EXISTS idx_requests_model ON requests(model);
149+
CREATE INDEX IF NOT EXISTS idx_requests_status ON requests(status);
148150
`)
149151
return err
150152
}
@@ -188,8 +190,8 @@ func (s *Storage) SaveRequest(record *RequestRecord) error {
188190
INSERT INTO requests (
189191
timestamp, provider, model, stream, method, path, client_ip,
190192
request_body, response_body, status_code, duration_ms,
191-
input_tokens, output_tokens, total_tokens, success, error_msg
192-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
193+
input_tokens, output_tokens, total_tokens, success, error_msg, status
194+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'completed')
193195
`,
194196
record.Timestamp,
195197
record.Provider,
@@ -220,6 +222,79 @@ func (s *Storage) SaveRequest(record *RequestRecord) error {
220222
return err
221223
}
222224

225+
// InsertPendingRequest 插入待处理请求记录(请求开始时调用)
226+
func (s *Storage) InsertPendingRequest(record *RequestRecord) (int64, error) {
227+
s.mu.Lock()
228+
defer s.mu.Unlock()
229+
230+
result, err := s.db.Exec(`
231+
INSERT INTO requests (
232+
timestamp, provider, model, stream, method, path, client_ip,
233+
request_body, input_tokens, status
234+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending')
235+
`,
236+
record.Timestamp,
237+
record.Provider,
238+
record.Model,
239+
record.Stream,
240+
record.Method,
241+
record.Path,
242+
record.ClientIP,
243+
record.RequestBody,
244+
record.InputTokens,
245+
)
246+
if err != nil {
247+
return 0, err
248+
}
249+
250+
id, err := result.LastInsertId()
251+
if err != nil {
252+
return 0, err
253+
}
254+
255+
// 更新缓存(先计入请求数)
256+
s.totalRequests++
257+
s.totalInputTokens += int64(record.InputTokens)
258+
259+
return id, nil
260+
}
261+
262+
// UpdateRequestWithResponse 更新请求记录(响应完成时调用)
263+
func (s *Storage) UpdateRequestWithResponse(id int64, record *RequestRecord) error {
264+
s.mu.Lock()
265+
defer s.mu.Unlock()
266+
267+
_, err := s.db.Exec(`
268+
UPDATE requests SET
269+
response_body = ?,
270+
status_code = ?,
271+
duration_ms = ?,
272+
output_tokens = ?,
273+
total_tokens = ?,
274+
success = ?,
275+
error_msg = ?,
276+
status = 'completed'
277+
WHERE id = ?
278+
`,
279+
record.ResponseBody,
280+
record.StatusCode,
281+
record.Duration,
282+
record.OutputTokens,
283+
record.TotalTokens,
284+
record.Success,
285+
record.ErrorMsg,
286+
id,
287+
)
288+
289+
if err == nil {
290+
// 更新缓存
291+
s.totalOutputTokens += int64(record.OutputTokens)
292+
s.totalTokens += int64(record.TotalTokens)
293+
}
294+
295+
return err
296+
}
297+
223298
// GetStats 获取统计信息
224299
func (s *Storage) GetStats() (*Stats, error) {
225300
s.mu.RLock()

0 commit comments

Comments
 (0)