@@ -123,6 +123,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
123123 return
124124 }
125125
126+ if relayFormat == types .RelayFormatClaude && relayInfo .RelayMode == relayconstant .RelayModeClaudeCountTokens {
127+ newAPIError = relayClaudeCountTokens (c , relayInfo , request )
128+ return
129+ }
130+
126131 needSensitiveCheck := setting .ShouldCheckPromptSensitive ()
127132 needCountToken := constant .CountToken
128133 // Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
@@ -179,10 +184,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
179184 }()
180185
181186 retryParam := & service.RetryParam {
182- Ctx : c ,
183- TokenGroup : relayInfo .TokenGroup ,
184- ModelName : relayInfo .OriginModelName ,
185- Retry : common .GetPointer (0 ),
187+ Ctx : c ,
188+ TokenGroup : relayInfo .TokenGroup ,
189+ ModelName : relayInfo .OriginModelName ,
190+ Retry : common .GetPointer (0 ),
191+ AllowedChannelTypes : service .GetAllowedChannelTypes (c ),
186192 }
187193 relayInfo .RetryIndex = 0
188194 relayInfo .LastError = nil
@@ -247,6 +253,72 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
247253 }
248254}
249255
256+ func relayClaudeCountTokens (c * gin.Context , relayInfo * relaycommon.RelayInfo , request dto.Request ) * types.NewAPIError {
257+ if setting .ShouldCheckPromptSensitive () {
258+ meta := request .GetTokenCountMeta ()
259+ if meta != nil {
260+ contains , words := service .CheckSensitiveText (meta .CombineText )
261+ if contains {
262+ logger .LogWarn (c , fmt .Sprintf ("user sensitive words detected: %s" , strings .Join (words , ", " )))
263+ return types .NewError (errors .New ("sensitive words detected" ), types .ErrorCodeSensitiveWordsDetected )
264+ }
265+ }
266+ }
267+
268+ retryParam := & service.RetryParam {
269+ Ctx : c ,
270+ TokenGroup : relayInfo .TokenGroup ,
271+ ModelName : relayInfo .OriginModelName ,
272+ Retry : common .GetPointer (0 ),
273+ AllowedChannelTypes : service .GetAllowedChannelTypes (c ),
274+ }
275+ relayInfo .RetryIndex = 0
276+ relayInfo .LastError = nil
277+
278+ var newAPIError * types.NewAPIError
279+ for ; retryParam .GetRetry () <= common .RetryTimes ; retryParam .IncreaseRetry () {
280+ relayInfo .RetryIndex = retryParam .GetRetry ()
281+ channel , channelErr := getChannel (c , relayInfo , retryParam )
282+ if channelErr != nil {
283+ logger .LogError (c , channelErr .Error ())
284+ newAPIError = channelErr
285+ break
286+ }
287+
288+ addUsedChannel (c , channel .Id )
289+ bodyStorage , bodyErr := common .GetBodyStorage (c )
290+ if bodyErr != nil {
291+ if common .IsRequestBodyTooLargeError (bodyErr ) || errors .Is (bodyErr , common .ErrRequestBodyTooLarge ) {
292+ newAPIError = types .NewErrorWithStatusCode (bodyErr , types .ErrorCodeReadRequestBodyFailed , http .StatusRequestEntityTooLarge , types .ErrOptionWithSkipRetry ())
293+ } else {
294+ newAPIError = types .NewErrorWithStatusCode (bodyErr , types .ErrorCodeReadRequestBodyFailed , http .StatusBadRequest , types .ErrOptionWithSkipRetry ())
295+ }
296+ break
297+ }
298+ c .Request .Body = io .NopCloser (bodyStorage )
299+
300+ newAPIError = relay .ClaudeCountTokensHelper (c , relayInfo )
301+ if newAPIError == nil {
302+ relayInfo .LastError = nil
303+ return nil
304+ }
305+
306+ relayInfo .LastError = newAPIError
307+ processChannelError (c , * types .NewChannelError (channel .Id , channel .Type , channel .Name , channel .ChannelInfo .IsMultiKey , common .GetContextKeyString (c , constant .ContextKeyChannelKey ), channel .GetAutoBan ()), newAPIError )
308+
309+ if ! shouldRetry (c , newAPIError , common .RetryTimes - retryParam .GetRetry ()) {
310+ break
311+ }
312+ }
313+
314+ useChannel := c .GetStringSlice ("use_channel" )
315+ if len (useChannel ) > 1 {
316+ retryLogStr := fmt .Sprintf ("重试:%s" , strings .Trim (strings .Join (strings .Fields (fmt .Sprint (useChannel )), "->" ), "[]" ))
317+ logger .LogInfo (c , retryLogStr )
318+ }
319+ return newAPIError
320+ }
321+
250322var upgrader = websocket.Upgrader {
251323 Subprotocols : []string {"realtime" }, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
252324 CheckOrigin : func (r * http.Request ) bool {
@@ -296,9 +368,13 @@ func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service
296368 if ! autoBan {
297369 autoBanInt = 0
298370 }
371+ channelType := c .GetInt ("channel_type" )
372+ if ! service .IsChannelTypeAllowed (channelType , retryParam .AllowedChannelTypes ) {
373+ return nil , types .NewErrorWithStatusCode (fmt .Errorf ("channel type %d is not allowed for this route" , channelType ), types .ErrorCodeGetChannelFailed , http .StatusServiceUnavailable , types .ErrOptionWithSkipRetry ())
374+ }
299375 return & model.Channel {
300376 Id : c .GetInt ("channel_id" ),
301- Type : c . GetInt ( "channel_type" ) ,
377+ Type : channelType ,
302378 Name : c .GetString ("channel_name" ),
303379 AutoBan : & autoBanInt ,
304380 }, nil
0 commit comments