diff --git a/controller/relay.go b/controller/relay.go index c97ab45b4ac..e3fc7d6793c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -319,6 +319,9 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b if openaiErr == nil { return false } + if clientRequestDone(c) { + return false + } if service.ShouldSkipRetryAfterChannelAffinityFailure(c) { return false } @@ -347,6 +350,10 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b return operation_setting.ShouldRetryByStatusCode(code) } +func clientRequestDone(c *gin.Context) bool { + return c != nil && c.Request != nil && c.Request.Context().Err() != nil +} + func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 @@ -608,6 +615,9 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, if taskErr == nil { return false } + if clientRequestDone(c) { + return false + } if service.ShouldSkipRetryAfterChannelAffinityFailure(c) { return false } diff --git a/controller/relay_retry_test.go b/controller/relay_retry_test.go new file mode 100644 index 00000000000..f7b1516d79a --- /dev/null +++ b/controller/relay_retry_test.go @@ -0,0 +1,54 @@ +package controller + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestShouldRetryReturnsFalseWhenClientRequestDone(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + reqCtx, cancel := context.WithCancel(req.Context()) + cancel() + ctx.Request = req.WithContext(reqCtx) + + err := types.NewErrorWithStatusCode( + fmt.Errorf("upstream error"), + types.ErrorCodeBadResponse, + http.StatusInternalServerError, + ) + + require.False(t, shouldRetry(ctx, err, 1)) +} + +func TestShouldRetryTaskRelayReturnsFalseWhenClientRequestDone(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + reqCtx, cancel := context.WithCancel(req.Context()) + cancel() + ctx.Request = req.WithContext(reqCtx) + + taskErr := &dto.TaskError{ + Code: "upstream_error", + Message: "upstream error", + StatusCode: http.StatusInternalServerError, + } + + require.False(t, shouldRetryTaskRelay(ctx, 1, taskErr, 1)) +} diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8dfb61d4009..32fa3f5452a 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -295,7 +295,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if common2.DebugEnabled { println("fullRequestURL:", fullRequestURL) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } @@ -326,7 +326,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if common2.DebugEnabled { println("fullRequestURL:", fullRequestURL) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } @@ -534,7 +534,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req if err != nil { return nil, err } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index f697f855569..484e5211b31 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -1,15 +1,84 @@ package channel import ( + "context" + "io" "net/http" "net/http/httptest" + "strings" + "sync/atomic" "testing" + "github.com/QuantumNous/new-api/dto" relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +type contextTestAdaptor struct { + url string +} + +func (a *contextTestAdaptor) Init(info *relaycommon.RelayInfo) {} + +func (a *contextTestAdaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return a.url, nil +} + +func (a *contextTestAdaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + return nil +} + +func (a *contextTestAdaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, nil +} + +func (a *contextTestAdaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { + return nil, nil +} + +func (a *contextTestAdaptor) GetModelList() []string { + return nil +} + +func (a *contextTestAdaptor) GetChannelName() string { + return "context-test" +} + +func (a *contextTestAdaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return nil, nil +} + +func (a *contextTestAdaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + return nil, nil +} + func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) { t.Parallel() @@ -33,6 +102,36 @@ func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) { require.Empty(t, headers) } +func TestDoApiRequestUsesClientRequestContext(t *testing.T) { + service.InitHttpClient() + + var called atomic.Bool + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called.Store(true) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(upstream.Close) + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{}`)) + reqCtx, cancel := context.WithCancel(req.Context()) + cancel() + ctx.Request = req.WithContext(reqCtx) + + resp, err := DoApiRequest( + &contextTestAdaptor{url: upstream.URL}, + ctx, + &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}, + strings.NewReader(`{}`), + ) + + require.Error(t, err) + require.Nil(t, resp) + require.False(t, called.Load(), "upstream must not be called after downstream request context is cancelled") +} + func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) { t.Parallel()