Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
if openaiErr == nil {
return false
}
if clientRequestDone(c) {
return false
}
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
return false
}
Expand Down Expand Up @@ -347,6 +350,10 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
return operation_setting.ShouldRetryByStatusCode(code)
}

func clientRequestDone(c *gin.Context) bool {
return c != nil && c.Request != nil && c.Request.Context().Err() != nil
}

func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
Expand Down Expand Up @@ -608,6 +615,9 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError,
if taskErr == nil {
return false
}
if clientRequestDone(c) {
return false
}
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
return false
}
Expand Down
54 changes: 54 additions & 0 deletions controller/relay_retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package controller

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)

func TestShouldRetryReturnsFalseWhenClientRequestDone(t *testing.T) {
t.Parallel()

gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
reqCtx, cancel := context.WithCancel(req.Context())
cancel()
ctx.Request = req.WithContext(reqCtx)

err := types.NewErrorWithStatusCode(
fmt.Errorf("upstream error"),
types.ErrorCodeBadResponse,
http.StatusInternalServerError,
)

require.False(t, shouldRetry(ctx, err, 1))
}

func TestShouldRetryTaskRelayReturnsFalseWhenClientRequestDone(t *testing.T) {
t.Parallel()

gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
reqCtx, cancel := context.WithCancel(req.Context())
cancel()
ctx.Request = req.WithContext(reqCtx)

taskErr := &dto.TaskError{
Code: "upstream_error",
Message: "upstream error",
StatusCode: http.StatusInternalServerError,
}

require.False(t, shouldRetryTaskRelay(ctx, 1, taskErr, 1))
}
6 changes: 3 additions & 3 deletions relay/channel/api_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
Expand Down Expand Up @@ -326,7 +326,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
Expand Down Expand Up @@ -534,7 +534,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
if err != nil {
return nil, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
Expand Down
99 changes: 99 additions & 0 deletions relay/channel/api_request_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,84 @@
package channel

import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"

"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)

type contextTestAdaptor struct {
url string
}

func (a *contextTestAdaptor) Init(info *relaycommon.RelayInfo) {}

func (a *contextTestAdaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return a.url, nil
}

func (a *contextTestAdaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
return nil
}

func (a *contextTestAdaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, nil
}

func (a *contextTestAdaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
return nil, nil
}

func (a *contextTestAdaptor) GetModelList() []string {
return nil
}

func (a *contextTestAdaptor) GetChannelName() string {
return "context-test"
}

func (a *contextTestAdaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return nil, nil
}

func (a *contextTestAdaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
return nil, nil
}

func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
t.Parallel()

Expand All @@ -33,6 +102,36 @@ func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
require.Empty(t, headers)
}

func TestDoApiRequestUsesClientRequestContext(t *testing.T) {
service.InitHttpClient()

var called atomic.Bool
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called.Store(true)
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(upstream.Close)

gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{}`))
reqCtx, cancel := context.WithCancel(req.Context())
cancel()
ctx.Request = req.WithContext(reqCtx)

resp, err := DoApiRequest(
&contextTestAdaptor{url: upstream.URL},
ctx,
&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}},
strings.NewReader(`{}`),
)

require.Error(t, err)
require.Nil(t, resp)
require.False(t, called.Load(), "upstream must not be called after downstream request context is cancelled")
}

func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
t.Parallel()

Expand Down