@@ -94,7 +94,15 @@ func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) {
9494 }
9595 defer r .Body .Close ()
9696
97- _ , model , inputTokens := parseRequestMetadata (body )
97+ reqBody , model , inputTokens := parseRequestMetadata (body )
98+
99+ // 检查请求是否要求流式响应
100+ isStreamRequest := false
101+ if reqBody != nil {
102+ if stream , ok := reqBody ["stream" ].(bool ); ok {
103+ isStreamRequest = stream
104+ }
105+ }
98106
99107 // 验证本地 API Key
100108 if ! p .validateLocalAPIKey (r ) {
@@ -129,9 +137,44 @@ func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) {
129137 // 日志记录
130138 p .logForwardRequest (model , inputTokens )
131139
140+ // 插入待处理记录到数据库
141+ pendingRecord := & storage.RequestRecord {
142+ Timestamp : startTime ,
143+ Provider : p .cfg .Provider ,
144+ Model : model ,
145+ Method : r .Method ,
146+ Path : r .URL .Path ,
147+ ClientIP : clientIP ,
148+ RequestBody : string (body ),
149+ InputTokens : inputTokens ,
150+ }
151+ recordID , err := p .storage .InsertPendingRequest (pendingRecord )
152+ if err != nil {
153+ p .logger .Error ("插入待处理记录失败" , zap .Error (err ))
154+ }
155+
156+ // 辅助函数:异步更新记录为失败状态
157+ updateFailedRecord := func (statusCode int , errMsg string ) {
158+ if recordID > 0 {
159+ go func () {
160+ duration := time .Since (startTime ).Milliseconds ()
161+ updateRecord := & storage.RequestRecord {
162+ StatusCode : statusCode ,
163+ Duration : float64 (duration ),
164+ Success : false ,
165+ ErrorMsg : errMsg ,
166+ }
167+ if err := p .storage .UpdateRequestWithResponse (recordID , updateRecord ); err != nil {
168+ p .logger .Error ("更新失败记录失败" , zap .Error (err ))
169+ }
170+ }()
171+ }
172+ }
173+
132174 // 创建上游请求
133175 upstreamReq , err := http .NewRequestWithContext (r .Context (), "POST" , targetURL , bytes .NewReader (body ))
134176 if err != nil {
177+ updateFailedRecord (http .StatusInternalServerError , "创建请求失败" )
135178 p .writeError (w , http .StatusInternalServerError , "创建请求失败" )
136179 return
137180 }
@@ -145,20 +188,24 @@ func (p *Proxy) Forward(w http.ResponseWriter, r *http.Request) {
145188 resp , err := p .client .Do (upstreamReq )
146189 if err != nil {
147190 if strings .Contains (err .Error (), "context canceled" ) {
191+ updateFailedRecord (499 , "请求被取消" )
148192 p .logger .Info ("请求被取消" , zap .String ("model" , model ))
149193 return
150194 }
195+ updateFailedRecord (http .StatusBadGateway , "上游服务不可用: " + err .Error ())
151196 p .logger .Error ("上游请求失败" , zap .Error (err ))
152197 p .writeError (w , http .StatusBadGateway , "上游服务不可用" )
153198 return
154199 }
155200 defer resp .Body .Close ()
156201
157- // 处理响应
158- if resp .StatusCode == http .StatusOK && isEventStream (resp .Header .Get ("Content-Type" )) {
159- p .handleStreamResponseWithStats (w , resp , startTime , r .Method , r .URL .Path , targetURL , model , clientIP , inputTokens , string (body ))
202+ // 处理响应 - 方案A: 同时检查请求中的 stream 参数和响应 Content-Type
203+ // 如果客户端请求 stream=true,或者上游返回 SSE 格式,都使用流式处理
204+ isStreamResponse := isStreamRequest || isEventStream (resp .Header .Get ("Content-Type" ))
205+ if resp .StatusCode == http .StatusOK && isStreamResponse {
206+ p .handleStreamResponseWithStats (w , resp , startTime , r .Method , r .URL .Path , targetURL , model , clientIP , inputTokens , string (body ), recordID )
160207 } else {
161- p .handleNormalResponseWithStats (w , resp , startTime , r .Method , r .URL .Path , targetURL , model , clientIP , inputTokens , string (body ))
208+ p .handleNormalResponseWithStats (w , resp , startTime , r .Method , r .URL .Path , targetURL , model , clientIP , inputTokens , string (body ), recordID )
162209 }
163210}
164211
@@ -262,7 +309,7 @@ func (p *Proxy) buildHeaders(provider *config.ProviderConfig, apiKey string, req
262309}
263310
264311// handleStreamResponseWithStats 处理流式响应并统计
265- func (p * Proxy ) handleStreamResponseWithStats (w http.ResponseWriter , resp * http.Response , startTime time.Time , method , path , targetURL , model , clientIP string , inputTokens int , requestBody string ) {
312+ func (p * Proxy ) handleStreamResponseWithStats (w http.ResponseWriter , resp * http.Response , startTime time.Time , method , path , targetURL , model , clientIP string , inputTokens int , requestBody string , recordID int64 ) {
266313 copyHeaders (w .Header (), resp .Header )
267314
268315 // 设置 SSE 头
@@ -272,15 +319,16 @@ func (p *Proxy) handleStreamResponseWithStats(w http.ResponseWriter, resp *http.
272319 w .Header ().Set ("Cache-Control" , "no-cache" )
273320 w .Header ().Set ("Connection" , "keep-alive" )
274321 w .Header ().Set ("X-Accel-Buffering" , "no" )
275- w .WriteHeader (resp .StatusCode )
276322
277- // 获取 flusher
323+ // 获取 flusher(在 WriteHeader 之前检查)
278324 flusher , ok := w .(http.Flusher )
279325 if ! ok {
280326 p .writeError (w , http .StatusInternalServerError , "不支持流式响应" )
281327 return
282328 }
283329
330+ w .WriteHeader (resp .StatusCode )
331+
284332 // 读取并转发响应,同时收集数据
285333 var responseBuf bytes.Buffer
286334 var outputTokens int
@@ -339,30 +387,26 @@ func (p *Proxy) handleStreamResponseWithStats(w http.ResponseWriter, resp *http.
339387 // 打印响应日志
340388 p .logResponse (method , path , targetURL , resp .StatusCode , duration , clientIP , responseBuf .String ())
341389
342- // 保存记录
343- record := & storage.RequestRecord {
344- Timestamp : startTime ,
345- Provider : p .cfg .Provider ,
346- Model : model ,
347- Stream : true ,
348- Method : method ,
349- Path : path ,
350- ClientIP : clientIP ,
351- RequestBody : requestBody ,
352- ResponseBody : responseBuf .String (),
353- StatusCode : resp .StatusCode ,
354- Duration : float64 (duration ),
355- InputTokens : inputTokens ,
356- OutputTokens : outputTokens ,
357- TotalTokens : totalTokens ,
358- Success : resp .StatusCode == 200 ,
359- }
360-
361- go p .storage .SaveRequest (record )
390+ // 异步更新记录(不影响响应)
391+ if recordID > 0 {
392+ go func () {
393+ record := & storage.RequestRecord {
394+ ResponseBody : responseBuf .String (),
395+ StatusCode : resp .StatusCode ,
396+ Duration : float64 (duration ),
397+ OutputTokens : outputTokens ,
398+ TotalTokens : totalTokens ,
399+ Success : resp .StatusCode == 200 ,
400+ }
401+ if err := p .storage .UpdateRequestWithResponse (recordID , record ); err != nil {
402+ p .logger .Error ("更新请求记录失败" , zap .Error (err ))
403+ }
404+ }()
405+ }
362406}
363407
364408// handleNormalResponseWithStats 处理普通响应并统计
365- func (p * Proxy ) handleNormalResponseWithStats (w http.ResponseWriter , resp * http.Response , startTime time.Time , method , path , targetURL , model , clientIP string , inputTokens int , requestBody string ) {
409+ func (p * Proxy ) handleNormalResponseWithStats (w http.ResponseWriter , resp * http.Response , startTime time.Time , method , path , targetURL , model , clientIP string , inputTokens int , requestBody string , recordID int64 ) {
366410 // 读取响应体
367411 respBody , err := io .ReadAll (resp .Body )
368412 if err != nil {
@@ -401,26 +445,22 @@ func (p *Proxy) handleNormalResponseWithStats(w http.ResponseWriter, resp *http.
401445 // 打印响应日志
402446 p .logResponse (method , path , targetURL , resp .StatusCode , duration , clientIP , string (respBody ))
403447
404- // 保存记录
405- record := & storage.RequestRecord {
406- Timestamp : startTime ,
407- Provider : p .cfg .Provider ,
408- Model : model ,
409- Stream : false ,
410- Method : method ,
411- Path : path ,
412- ClientIP : clientIP ,
413- RequestBody : requestBody ,
414- ResponseBody : string (respBody ),
415- StatusCode : resp .StatusCode ,
416- Duration : float64 (duration ),
417- InputTokens : inputTokens ,
418- OutputTokens : outputTokens ,
419- TotalTokens : totalTokens ,
420- Success : resp .StatusCode == 200 ,
421- }
422-
423- go p .storage .SaveRequest (record )
448+ // 异步更新记录(不影响响应)
449+ if recordID > 0 {
450+ go func () {
451+ record := & storage.RequestRecord {
452+ ResponseBody : string (respBody ),
453+ StatusCode : resp .StatusCode ,
454+ Duration : float64 (duration ),
455+ OutputTokens : outputTokens ,
456+ TotalTokens : totalTokens ,
457+ Success : resp .StatusCode == 200 ,
458+ }
459+ if err := p .storage .UpdateRequestWithResponse (recordID , record ); err != nil {
460+ p .logger .Error ("更新请求记录失败" , zap .Error (err ))
461+ }
462+ }()
463+ }
424464}
425465
426466// writeError 写入错误响应
0 commit comments