Skip to content

Commit d11c57b

Browse files
committed
feat(claude): forward Claude count_tokens to AWS Bedrock
Signed-off-by: B1F030 <b1fzhang@gmail.com>
1 parent 6f41542 commit d11c57b

19 files changed

Lines changed: 784 additions & 46 deletions

constant/context_key.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
3838
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
3939
ContextKeyChannelKey ContextKey = "channel_key"
40+
ContextKeyAllowedChannelTypes ContextKey = "allowed_channel_types"
4041

4142
ContextKeyAutoGroup ContextKey = "auto_group"
4243
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"

controller/relay.go

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
123123
return
124124
}
125125

126+
if relayFormat == types.RelayFormatClaude && relayInfo.RelayMode == relayconstant.RelayModeClaudeCountTokens {
127+
newAPIError = relayClaudeCountTokens(c, relayInfo, request)
128+
return
129+
}
130+
126131
needSensitiveCheck := setting.ShouldCheckPromptSensitive()
127132
needCountToken := constant.CountToken
128133
// Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
@@ -179,10 +184,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
179184
}()
180185

181186
retryParam := &service.RetryParam{
182-
Ctx: c,
183-
TokenGroup: relayInfo.TokenGroup,
184-
ModelName: relayInfo.OriginModelName,
185-
Retry: common.GetPointer(0),
187+
Ctx: c,
188+
TokenGroup: relayInfo.TokenGroup,
189+
ModelName: relayInfo.OriginModelName,
190+
Retry: common.GetPointer(0),
191+
AllowedChannelTypes: service.GetAllowedChannelTypes(c),
186192
}
187193
relayInfo.RetryIndex = 0
188194
relayInfo.LastError = nil
@@ -247,6 +253,72 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
247253
}
248254
}
249255

256+
func relayClaudeCountTokens(c *gin.Context, relayInfo *relaycommon.RelayInfo, request dto.Request) *types.NewAPIError {
257+
if setting.ShouldCheckPromptSensitive() {
258+
meta := request.GetTokenCountMeta()
259+
if meta != nil {
260+
contains, words := service.CheckSensitiveText(meta.CombineText)
261+
if contains {
262+
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
263+
return types.NewError(errors.New("sensitive words detected"), types.ErrorCodeSensitiveWordsDetected)
264+
}
265+
}
266+
}
267+
268+
retryParam := &service.RetryParam{
269+
Ctx: c,
270+
TokenGroup: relayInfo.TokenGroup,
271+
ModelName: relayInfo.OriginModelName,
272+
Retry: common.GetPointer(0),
273+
AllowedChannelTypes: service.GetAllowedChannelTypes(c),
274+
}
275+
relayInfo.RetryIndex = 0
276+
relayInfo.LastError = nil
277+
278+
var newAPIError *types.NewAPIError
279+
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
280+
relayInfo.RetryIndex = retryParam.GetRetry()
281+
channel, channelErr := getChannel(c, relayInfo, retryParam)
282+
if channelErr != nil {
283+
logger.LogError(c, channelErr.Error())
284+
newAPIError = channelErr
285+
break
286+
}
287+
288+
addUsedChannel(c, channel.Id)
289+
bodyStorage, bodyErr := common.GetBodyStorage(c)
290+
if bodyErr != nil {
291+
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
292+
newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
293+
} else {
294+
newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
295+
}
296+
break
297+
}
298+
c.Request.Body = io.NopCloser(bodyStorage)
299+
300+
newAPIError = relay.ClaudeCountTokensHelper(c, relayInfo)
301+
if newAPIError == nil {
302+
relayInfo.LastError = nil
303+
return nil
304+
}
305+
306+
relayInfo.LastError = newAPIError
307+
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
308+
309+
if !shouldRetry(c, newAPIError, common.RetryTimes-retryParam.GetRetry()) {
310+
break
311+
}
312+
}
313+
314+
useChannel := c.GetStringSlice("use_channel")
315+
if len(useChannel) > 1 {
316+
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
317+
logger.LogInfo(c, retryLogStr)
318+
}
319+
return newAPIError
320+
}
321+
250322
var upgrader = websocket.Upgrader{
251323
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
252324
CheckOrigin: func(r *http.Request) bool {
@@ -296,9 +368,13 @@ func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service
296368
if !autoBan {
297369
autoBanInt = 0
298370
}
371+
channelType := c.GetInt("channel_type")
372+
if !service.IsChannelTypeAllowed(channelType, retryParam.AllowedChannelTypes) {
373+
return nil, types.NewErrorWithStatusCode(fmt.Errorf("channel type %d is not allowed for this route", channelType), types.ErrorCodeGetChannelFailed, http.StatusServiceUnavailable, types.ErrOptionWithSkipRetry())
374+
}
299375
return &model.Channel{
300376
Id: c.GetInt("channel_id"),
301-
Type: c.GetInt("channel_type"),
377+
Type: channelType,
302378
Name: c.GetString("channel_name"),
303379
AutoBan: &autoBanInt,
304380
}, nil

i18n/keys.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ const (
309309
MsgDistributorInvalidChannelId = "distributor.invalid_channel_id"
310310
MsgDistributorChannelDisabled = "distributor.channel_disabled"
311311
MsgDistributorAffinityChannelDisabled = "distributor.affinity_channel_disabled"
312+
MsgDistributorChannelTypeNotAllowed = "distributor.channel_type_not_allowed"
312313
MsgDistributorTokenNoModelAccess = "distributor.token_no_model_access"
313314
MsgDistributorTokenModelForbidden = "distributor.token_model_forbidden"
314315
MsgDistributorModelNameRequired = "distributor.model_name_required"

i18n/locales/en.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ distributor.invalid_request: "Invalid request: {{.Error}}"
259259
distributor.invalid_channel_id: "Invalid channel ID"
260260
distributor.channel_disabled: "This channel has been disabled"
261261
distributor.affinity_channel_disabled: "The channel selected by channel affinity has been disabled, and retry was stopped by rule. Please contact the administrator"
262+
distributor.channel_type_not_allowed: "Channel type is not allowed for this route"
262263
distributor.token_no_model_access: "This token has no access to any models"
263264
distributor.token_model_forbidden: "This token has no access to model {{.Model}}"
264265
distributor.model_name_required: "Model name not specified, model name cannot be empty"

i18n/locales/zh-CN.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ distributor.invalid_request: "无效的请求,{{.Error}}"
260260
distributor.invalid_channel_id: "无效的渠道 Id"
261261
distributor.channel_disabled: "该渠道已被禁用"
262262
distributor.affinity_channel_disabled: "渠道亲和性命中的渠道已被禁用,已按规则停止重试,请联系管理员处理"
263+
distributor.channel_type_not_allowed: "该渠道类型不允许用于此路由"
263264
distributor.token_no_model_access: "该令牌无权访问任何模型"
264265
distributor.token_model_forbidden: "该令牌无权访问模型 {{.Model}}"
265266
distributor.model_name_required: "未指定模型名称,模型名称不能为空"

i18n/locales/zh-TW.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ distributor.invalid_request: "無效的請求,{{.Error}}"
260260
distributor.invalid_channel_id: "無效的管道 Id"
261261
distributor.channel_disabled: "該管道已被禁用"
262262
distributor.affinity_channel_disabled: "管道親和性命中的管道已被禁用,已按規則停止重試,請聯絡管理員處理"
263+
distributor.channel_type_not_allowed: "該管道類型不允許用於此路由"
263264
distributor.token_no_model_access: "該令牌無權存取任何模型"
264265
distributor.token_model_forbidden: "該令牌無權存取模型 {{.Model}}"
265266
distributor.model_name_required: "未指定模型名稱,模型名稱不能為空"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/QuantumNous/new-api/constant"
9+
"github.com/QuantumNous/new-api/service"
10+
"github.com/gin-gonic/gin"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestLimitChannelTypesStoresAllowedTypes(t *testing.T) {
15+
gin.SetMode(gin.TestMode)
16+
17+
router := gin.New()
18+
router.Use(LimitChannelTypes(constant.ChannelTypeAws))
19+
router.GET("/test", func(c *gin.Context) {
20+
require.True(t, service.IsChannelTypeAllowed(constant.ChannelTypeAws, service.GetAllowedChannelTypes(c)))
21+
require.False(t, service.IsChannelTypeAllowed(constant.ChannelTypeOpenAI, service.GetAllowedChannelTypes(c)))
22+
c.Status(http.StatusNoContent)
23+
})
24+
25+
recorder := httptest.NewRecorder()
26+
request := httptest.NewRequest(http.MethodGet, "/test", nil)
27+
router.ServeHTTP(recorder, request)
28+
29+
require.Equal(t, http.StatusNoContent, recorder.Code)
30+
}

middleware/distributor.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,18 @@ type ModelRequest struct {
2929
Group string `json:"group,omitempty"`
3030
}
3131

32+
func LimitChannelTypes(channelTypes ...int) func(c *gin.Context) {
33+
return func(c *gin.Context) {
34+
common.SetContextKey(c, constant.ContextKeyAllowedChannelTypes, channelTypes)
35+
c.Next()
36+
}
37+
}
38+
3239
func Distribute() func(c *gin.Context) {
3340
return func(c *gin.Context) {
3441
var channel *model.Channel
3542
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
43+
allowedChannelTypes := service.GetAllowedChannelTypes(c)
3644
modelRequest, shouldSelectChannel, err := getModelRequest(c)
3745
if err != nil {
3846
abortWithOpenAiMessage(c, http.StatusBadRequest, i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
@@ -53,6 +61,10 @@ func Distribute() func(c *gin.Context) {
5361
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
5462
return
5563
}
64+
if !service.IsChannelTypeAllowed(channel.Type, allowedChannelTypes) {
65+
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelTypeNotAllowed))
66+
return
67+
}
5668
} else {
5769
// Select a channel for the user
5870
// check token model mapping
@@ -104,8 +116,16 @@ func Distribute() func(c *gin.Context) {
104116
if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
105117
affinityUsable := false
106118
preferred, err := model.CacheGetChannel(preferredChannelID)
107-
if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
108-
if usingGroup == "auto" {
119+
if err == nil && preferred != nil {
120+
if preferred.Status != common.ChannelStatusEnabled {
121+
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
122+
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
123+
return
124+
}
125+
} else if !service.IsChannelTypeAllowed(preferred.Type, allowedChannelTypes) {
126+
// Affinity is only a preference. If it points to a disallowed channel type,
127+
// keep selecting from the normal candidate pool.
128+
} else if usingGroup == "auto" {
109129
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
110130
autoGroups := service.GetUserAutoGroup(userGroup)
111131
for _, g := range autoGroups {
@@ -132,10 +152,11 @@ func Distribute() func(c *gin.Context) {
132152

133153
if channel == nil {
134154
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
135-
Ctx: c,
136-
ModelName: modelRequest.Model,
137-
TokenGroup: usingGroup,
138-
Retry: common.GetPointer(0),
155+
Ctx: c,
156+
ModelName: modelRequest.Model,
157+
TokenGroup: usingGroup,
158+
Retry: common.GetPointer(0),
159+
AllowedChannelTypes: allowedChannelTypes,
139160
})
140161
if err != nil {
141162
showGroup := usingGroup

model/ability.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package model
33
import (
44
"errors"
55
"fmt"
6+
"sort"
67
"strings"
78
"sync"
89

@@ -104,6 +105,14 @@ func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
104105
}
105106

106107
func GetChannel(group string, model string, retry int) (*Channel, error) {
108+
return GetChannelWithChannelTypes(group, model, retry, nil)
109+
}
110+
111+
func GetChannelWithChannelTypes(group string, model string, retry int, allowedChannelTypes []int) (*Channel, error) {
112+
if len(allowedChannelTypes) > 0 {
113+
return getChannelWithChannelTypes(group, model, retry, allowedChannelTypes)
114+
}
115+
107116
var abilities []Ability
108117

109118
var err error = nil
@@ -143,6 +152,74 @@ func GetChannel(group string, model string, retry int) (*Channel, error) {
143152
return &channel, err
144153
}
145154

155+
func getChannelWithChannelTypes(group string, model string, retry int, allowedChannelTypes []int) (*Channel, error) {
156+
var abilities []AbilityWithChannel
157+
err := DB.Table("abilities").
158+
Select("abilities.*, channels.type as channel_type").
159+
Joins("left join channels on abilities.channel_id = channels.id").
160+
Where("abilities."+commonGroupCol+" = ? and abilities.model = ? and abilities.enabled = ?", group, model, true).
161+
Where("channels.type IN ?", allowedChannelTypes).
162+
Scan(&abilities).Error
163+
if err != nil {
164+
return nil, err
165+
}
166+
if len(abilities) == 0 {
167+
return nil, nil
168+
}
169+
170+
uniquePriorities := make(map[int]bool)
171+
for _, ability := range abilities {
172+
priority := int64(0)
173+
if ability.Priority != nil {
174+
priority = *ability.Priority
175+
}
176+
uniquePriorities[int(priority)] = true
177+
}
178+
var sortedUniquePriorities []int
179+
for priority := range uniquePriorities {
180+
sortedUniquePriorities = append(sortedUniquePriorities, priority)
181+
}
182+
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
183+
184+
if retry >= len(uniquePriorities) {
185+
retry = len(uniquePriorities) - 1
186+
}
187+
targetPriority := int64(sortedUniquePriorities[retry])
188+
189+
targetAbilities := make([]AbilityWithChannel, 0)
190+
weightSum := uint(0)
191+
for _, ability := range abilities {
192+
priority := int64(0)
193+
if ability.Priority != nil {
194+
priority = *ability.Priority
195+
}
196+
if priority != targetPriority {
197+
continue
198+
}
199+
targetAbilities = append(targetAbilities, ability)
200+
weightSum += ability.Weight + 10
201+
}
202+
if len(targetAbilities) == 0 {
203+
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority))
204+
}
205+
206+
weight := common.GetRandomInt(int(weightSum))
207+
channelID := 0
208+
for _, ability := range targetAbilities {
209+
weight -= int(ability.Weight) + 10
210+
if weight <= 0 {
211+
channelID = ability.ChannelId
212+
break
213+
}
214+
}
215+
if channelID == 0 {
216+
return nil, errors.New("channel not found")
217+
}
218+
channel := Channel{}
219+
err = DB.First(&channel, "id = ?", channelID).Error
220+
return &channel, err
221+
}
222+
146223
func (channel *Channel) AddAbilities(tx *gorm.DB) error {
147224
models_ := strings.Split(channel.Models, ",")
148225
groups_ := strings.Split(channel.Group, ",")

0 commit comments

Comments
 (0)