Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 190 additions & 2 deletions backend/internal/service/openai_gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,12 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact)

if selected == nil {
if recovered := s.tryRecoverOpenAIRateLimitedAccountForProbe(ctx, groupID, requestedModel, excludedIDs, requireCompact); recovered != nil {
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, recovered.ID, openaiStickySessionTTL)
}
return recovered, nil
}
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
}

Expand Down Expand Up @@ -1593,6 +1599,21 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
return nil, err
}
if len(accounts) == 0 {
if recovered := s.tryRecoverOpenAIRateLimitedAccountForProbe(ctx, groupID, requestedModel, excludedIDs, requireCompact); recovered != nil {
result, err := s.tryAcquireAccountSlot(ctx, recovered.ID, recovered.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, recovered.ID, openaiStickySessionTTL)
}
return newOpenAISelectionResult(recovered, true, result.ReleaseFunc, nil), nil
}
return newOpenAISelectionResult(recovered, false, nil, &AccountWaitPlan{
AccountID: recovered.ID,
MaxConcurrency: recovered.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
}), nil
}
return nil, ErrNoAvailableAccounts
}

Expand Down Expand Up @@ -1667,6 +1688,21 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
}

if len(candidates) == 0 {
if recovered := s.tryRecoverOpenAIRateLimitedAccountForProbe(ctx, groupID, requestedModel, excludedIDs, requireCompact); recovered != nil {
result, err := s.tryAcquireAccountSlot(ctx, recovered.ID, recovered.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, recovered.ID, openaiStickySessionTTL)
}
return newOpenAISelectionResult(recovered, true, result.ReleaseFunc, nil), nil
}
return newOpenAISelectionResult(recovered, false, nil, &AccountWaitPlan{
AccountID: recovered.ID,
MaxConcurrency: recovered.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
}), nil
}
return nil, ErrNoAvailableAccounts
}

Expand Down Expand Up @@ -1812,9 +1848,157 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
if requireCompact && baseCandidateCount > 0 {
return nil, ErrNoAvailableCompactAccounts
}
if recovered := s.tryRecoverOpenAIRateLimitedAccountForProbe(ctx, groupID, requestedModel, excludedIDs, requireCompact); recovered != nil {
result, err := s.tryAcquireAccountSlot(ctx, recovered.ID, recovered.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, recovered.ID, openaiStickySessionTTL)
}
return newOpenAISelectionResult(recovered, true, result.ReleaseFunc, nil), nil
}
return newOpenAISelectionResult(recovered, false, nil, &AccountWaitPlan{
AccountID: recovered.ID,
MaxConcurrency: recovered.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
}), nil
}
return nil, ErrNoAvailableAccounts
}

func (s *OpenAIGatewayService) tryRecoverOpenAIRateLimitedAccountForProbe(ctx context.Context, groupID *int64, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) *Account {
if s == nil || s.accountRepo == nil || s.rateLimitService == nil {
return nil
}

accounts, err := s.listActiveOpenAIAccountsForRateLimitProbe(ctx, groupID)
if err != nil {
slog.Warn("openai_rate_limit_probe_list_failed", "group_id", derefGroupID(groupID), "error", err)
return nil
}

var candidate *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !isOpenAIAccountProbeRecoveryCandidate(acc, requestedModel, requireCompact) {
continue
}
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) {
continue
}
if candidate == nil || openAIRateLimitProbeCandidateLess(acc, candidate) {
candidate = acc
}
}
if candidate == nil {
return nil
}

if err := s.rateLimitService.ClearRateLimit(ctx, candidate.ID); err != nil {
slog.Warn("openai_rate_limit_probe_clear_failed", "account_id", candidate.ID, "error", err)
return nil
}

recovered, err := s.accountRepo.GetByID(ctx, candidate.ID)
if err != nil || recovered == nil {
slog.Warn("openai_rate_limit_probe_get_recovered_failed", "account_id", candidate.ID, "error", err)
return nil
}
if !isOpenAIAccountEligibleForRequest(recovered, requestedModel, requireCompact) {
return nil
}
slog.Info("openai_rate_limit_probe_account_recovered", "account_id", recovered.ID, "group_id", derefGroupID(groupID))
return recovered
}

func (s *OpenAIGatewayService) listActiveOpenAIAccountsForRateLimitProbe(ctx context.Context, groupID *int64) ([]Account, error) {
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
return s.accountRepo.ListByPlatform(ctx, PlatformOpenAI)
}
if groupID != nil {
accounts, err := s.accountRepo.ListByGroup(ctx, *groupID)
if err != nil {
return nil, err
}
filtered := accounts[:0]
for _, acc := range accounts {
if acc.Platform == PlatformOpenAI {
filtered = append(filtered, acc)
}
}
return filtered, nil
}
accounts, err := s.accountRepo.ListByPlatform(ctx, PlatformOpenAI)
if err != nil {
return nil, err
}
filtered := accounts[:0]
for _, acc := range accounts {
if len(acc.AccountGroups) == 0 && len(acc.GroupIDs) == 0 {
filtered = append(filtered, acc)
}
}
return filtered, nil
}

func isOpenAIAccountProbeRecoveryCandidate(account *Account, requestedModel string, requireCompact bool) bool {
if account == nil || account.Platform != PlatformOpenAI || account.Status != StatusActive || !account.Schedulable {
return false
}
now := time.Now()
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
return false
}
if account.IsAPIKeyOrBedrock() && account.IsQuotaExceeded() {
return false
}
if account.OverloadUntil != nil && now.Before(*account.OverloadUntil) {
return false
}
if account.TempUnschedulableUntil != nil && now.Before(*account.TempUnschedulableUntil) {
return false
}
if account.RateLimitResetAt == nil || !now.Before(*account.RateLimitResetAt) {
return false
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return false
}
if requireCompact && openAICompactSupportTier(account) == 0 {
return false
}
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return false
}
return true
}

func openAIRateLimitProbeCandidateLess(candidate, current *Account) bool {
if candidate.RateLimitResetAt != nil && current.RateLimitResetAt != nil && !candidate.RateLimitResetAt.Equal(*current.RateLimitResetAt) {
return candidate.RateLimitResetAt.Before(*current.RateLimitResetAt)
}
if candidate.Priority != current.Priority {
return candidate.Priority < current.Priority
}
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
return candidate.ID < current.ID
default:
if !candidate.LastUsedAt.Equal(*current.LastUsedAt) {
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
return candidate.ID < current.ID
}
}

func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
Expand Down Expand Up @@ -1918,12 +2102,16 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *
if err != nil {
return nil, err
}
return newOpenAISelectionResult(hydrated, acquired, release, waitPlan), nil
}

func newOpenAISelectionResult(account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) *AccountSelectionResult {
return &AccountSelectionResult{
Account: hydrated,
Account: account,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
}

func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
Expand Down
131 changes: 129 additions & 2 deletions backend/internal/service/openai_gateway_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account,
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, acc)
}
}
Expand All @@ -68,7 +68,7 @@ func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.C
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, acc)
}
}
Expand All @@ -79,6 +79,68 @@ func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Co
return r.ListSchedulableByPlatform(ctx, platform)
}

func (r stubOpenAIAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Status != StatusActive {
continue
}
matched := len(acc.AccountGroups) == 0 && len(acc.GroupIDs) == 0
for _, ag := range acc.AccountGroups {
if ag.GroupID == groupID {
matched = true
break
}
}
if !matched {
for _, gid := range acc.GroupIDs {
if gid == groupID {
matched = true
break
}
}
}
if matched {
result = append(result, acc)
}
}
return result, nil
}

func (r stubOpenAIAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform && acc.Status == StatusActive {
result = append(result, acc)
}
}
return result, nil
}

func (r stubOpenAIAccountRepo) ClearRateLimit(ctx context.Context, id int64) error {
for i := range r.accounts {
if r.accounts[i].ID == id {
r.accounts[i].RateLimitedAt = nil
r.accounts[i].RateLimitResetAt = nil
r.accounts[i].OverloadUntil = nil
return nil
}
}
return errors.New("account not found")
}

func (r stubOpenAIAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}

func (r stubOpenAIAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error {
return nil
}

func (r stubOpenAIAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}

type stubConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
Expand Down Expand Up @@ -807,6 +869,35 @@ func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
}
}

func TestOpenAISelectAccountForModelWithExclusions_RecoversRateLimitedProbeAccount(t *testing.T) {
groupID := int64(42)
resetAt := time.Now().Add(30 * time.Minute)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 9, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
}
cache := &stubGatewayCache{}
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)

svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
rateLimitService: rateLimitService,
}

acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "recover-session", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
}
if acc == nil || acc.ID != 9 {
t.Fatalf("expected recovered account 9, got %#v", acc)
}
if acc.RateLimitResetAt != nil {
t.Fatalf("expected rate limit to be cleared")
}
}

func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
groupID := int64(1)
resetAt := time.Now().Add(1 * time.Hour)
Expand All @@ -833,6 +924,42 @@ func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
}
}

func TestOpenAISelectAccountWithLoadAwareness_RecoversRateLimitedProbeAccount(t *testing.T) {
groupID := int64(1)
soon := time.Now().Add(5 * time.Minute)
later := time.Now().Add(1 * time.Hour)
repo := stubOpenAIAccountRepo{
accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2, RateLimitResetAt: &later, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &soon, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
}
cache := &stubGatewayCache{}
concurrencyCache := stubConcurrencyCache{}
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)

svc := &OpenAIGatewayService{
accountRepo: repo,
cache: cache,
concurrencyService: NewConcurrencyService(concurrencyCache),
rateLimitService: rateLimitService,
}

selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
t.Fatalf("expected recovered account 2, got %#v", selection)
}
if selection.Account.RateLimitResetAt != nil {
t.Fatalf("expected recovered account rate limit to be cleared")
}
if !selection.Acquired {
t.Fatalf("expected recovered account slot to be acquired")
}
}

func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
groupID := int64(1)
repo := stubOpenAIAccountRepo{
Expand Down
Loading
Loading