Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 74 additions & 2 deletions backend/internal/handler/gateway_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var gatewayCompatibilityMetricsLogCounter atomic.Uint64
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
openAIGatewayService *service.OpenAIGatewayService
geminiCompatService *service.GeminiMessagesCompatService
antigravityGatewayService *service.AntigravityGatewayService
userService *service.UserService
Expand All @@ -56,6 +57,7 @@ type GatewayHandler struct {
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(
gatewayService *service.GatewayService,
openAIGatewayService *service.OpenAIGatewayService,
geminiCompatService *service.GeminiMessagesCompatService,
antigravityGatewayService *service.AntigravityGatewayService,
userService *service.UserService,
Expand Down Expand Up @@ -90,6 +92,7 @@ func NewGatewayHandler(

return &GatewayHandler{
gatewayService: gatewayService,
openAIGatewayService: openAIGatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
Expand All @@ -107,6 +110,52 @@ func NewGatewayHandler(
}
}

type messagesForwardRoute int

const (
messagesForwardRouteAnthropicNative messagesForwardRoute = iota
messagesForwardRouteOpenAICompat
messagesForwardRouteAntigravity
)

func messagesForwardRouteKind(group *service.Group, account *service.Account) messagesForwardRoute {
if group != nil && group.Platform == service.PlatformAnthropic && account != nil && account.IsOpenAIOAuth() {
return messagesForwardRouteOpenAICompat
}
if account != nil && account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
return messagesForwardRouteAntigravity
}
return messagesForwardRouteAnthropicNative
}

func openAICompatForwardResult(result *service.OpenAIForwardResult) *service.ForwardResult {
if result == nil {
return nil
}
inputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens
if inputTokens < 0 {
inputTokens = 0
}
return &service.ForwardResult{
RequestID: result.RequestID,
Usage: service.ClaudeUsage{
InputTokens: inputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationInputTokens: result.Usage.CacheCreationInputTokens,
CacheReadInputTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
},
Model: result.Model,
BillingModel: result.BillingModel,
UpstreamModel: result.UpstreamModel,
ServiceTier: result.ServiceTier,
Stream: result.Stream,
Duration: result.Duration,
FirstTokenMs: result.FirstTokenMs,
ReasoningEffort: result.ReasoningEffort,
}
}

// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
Expand Down Expand Up @@ -684,9 +733,32 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
switch messagesForwardRouteKind(currentAPIKey.Group, account) {
case messagesForwardRouteAntigravity:
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
case messagesForwardRouteOpenAICompat:
if h.openAIGatewayService == nil {
err = errors.New("openai gateway service unavailable")
break
}
defaultMappedModel := ""
if currentAPIKey.Group != nil {
defaultMappedModel = strings.TrimSpace(currentAPIKey.Group.ResolveMessagesDispatchModel(reqModel))
if defaultMappedModel == "" {
defaultMappedModel = strings.TrimSpace(currentAPIKey.Group.ResolveMessagesDispatchModel(parsedReq.Model))
}
}
openAIResult, forwardErr := h.openAIGatewayService.ForwardAsAnthropic(
requestCtx,
c,
account,
body,
sessionKey,
defaultMappedModel,
)
err = forwardErr
result = openAICompatForwardResult(openAIResult)
default:
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
}

Expand Down
26 changes: 26 additions & 0 deletions backend/internal/handler/gateway_handler_codex_routing_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package handler

import (
"testing"

"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)

func TestMessagesForwardRouteKind(t *testing.T) {
t.Run("anthropic group routed to openai oauth uses openai compat forwarder", func(t *testing.T) {
kind := messagesForwardRouteKind(
&service.Group{Platform: service.PlatformAnthropic},
&service.Account{Platform: service.PlatformOpenAI, Type: service.AccountTypeOAuth},
)
require.Equal(t, messagesForwardRouteOpenAICompat, kind)
})

t.Run("anthropic native accounts keep anthropic forwarder", func(t *testing.T) {
kind := messagesForwardRouteKind(
&service.Group{Platform: service.PlatformAnthropic},
&service.Account{Platform: service.PlatformAnthropic, Type: service.AccountTypeOAuth},
)
require.Equal(t, messagesForwardRouteAnthropicNative, kind)
})
}
112 changes: 112 additions & 0 deletions backend/internal/service/gateway_multiplatform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2775,6 +2775,118 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Equal(t, int64(3), cache.sessionBindings["fallback"])
})

t.Run("模型路由-可显式路由到OpenAI OAuth账号", func(t *testing.T) {
groupID := int64(24)

repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 99, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}

cache := &mockGatewayCacheForPlatform{}

groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-sonnet-4-5-20250929": {99},
},
},
},
}

cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true

concurrencyCache := &mockConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
99: {AccountID: 99, LoadRate: 0},
},
}

svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}

result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "codex-route", "claude-sonnet-4-5-20250929", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(99), result.Account.ID)
})

t.Run("模型路由-未绑定分组的OpenAI OAuth账号会被跳过", func(t *testing.T) {
groupID := int64(25)

repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 99, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}

cache := &mockGatewayCacheForPlatform{}

groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-sonnet-4-5-20250929": {99},
},
},
},
}

cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true

concurrencyCache := &mockConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 0},
99: {AccountID: 99, LoadRate: 0},
},
}

svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}

result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "codex-unbound", "claude-sonnet-4-5-20250929", nil, "", int64(0))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
})

t.Run("负载批量失败且无法获取-兜底等待", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
Expand Down
Loading