Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions packages/core/src/availability/policyHelpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ import {
} from '../config/models.js';
import { AuthType } from '../core/contentGenerator.js';

const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
const createMockConfig = (overrides: Partial<Config> = {}): Config => {
const config = {
getUserTier: () => undefined,
getModel: () => 'gemini-2.5-pro',
getGemini31LaunchedSync: () => false,
getUseCustomToolModelSync: () => {
const useGemini31 = config.getGemini31LaunchedSync();
const authType = config.getContentGeneratorConfig().authType;
return useGemini31 && authType === AuthType.USE_GEMINI;
},
getContentGeneratorConfig: () => ({ authType: undefined }),
...overrides,
}) as unknown as Config;
} as unknown as Config;
return config;
};

describe('policyHelpers', () => {
describe('resolvePolicyChain', () => {
Expand Down
5 changes: 1 addition & 4 deletions packages/core/src/availability/policyHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import type {
FailureKind,
FallbackAction,
Expand Down Expand Up @@ -46,9 +45,7 @@ export function resolvePolicyChain(

let chain;
const useGemini31 = config.getGemini31LaunchedSync?.() ?? false;
const useCustomToolModel =
useGemini31 &&
config.getContentGeneratorConfig?.()?.authType === AuthType.USE_GEMINI;
const useCustomToolModel = config.getUseCustomToolModelSync?.() ?? false;

const resolvedModel = resolveModel(
modelFromConfig,
Expand Down
20 changes: 20 additions & 0 deletions packages/core/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2476,6 +2476,26 @@ export class Config implements McpContext {
return this.getGemini31LaunchedSync();
}

/**
* Returns whether the custom tool model should be used.
*/
async getUseCustomToolModel(): Promise<boolean> {
const useGemini3_1 = await this.getGemini31Launched();
const authType = this.contentGeneratorConfig?.authType;
return useGemini3_1 && authType === AuthType.USE_GEMINI;
}

/**
* Returns whether the custom tool model should be used.
*
* Note: This method should only be called after startup, once experiments have been loaded.
*/
getUseCustomToolModelSync(): boolean {
const useGemini3_1 = this.getGemini31LaunchedSync();
const authType = this.contentGeneratorConfig?.authType;
return useGemini3_1 && authType === AuthType.USE_GEMINI;
}

/**
* Returns whether Gemini 3.1 has been launched.
*
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/config/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
*/

export const PREVIEW_GEMINI_MODEL = 'gemini-3-pro-preview';
export const PREVIEW_GEMINI_3_1_MODEL = 'gemini-3.1-pro-preview';

Check warning on line 8 in packages/core/src/config/models.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-3.1". Please make sure this change is appropriate to submit.
export const PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL =
'gemini-3.1-pro-preview-customtools';

Check warning on line 10 in packages/core/src/config/models.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-3.1". Please make sure this change is appropriate to submit.
export const PREVIEW_GEMINI_FLASH_MODEL = 'gemini-3-flash-preview';
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro';
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash';
Expand Down Expand Up @@ -136,7 +136,8 @@
model === PREVIEW_GEMINI_3_1_MODEL ||
model === PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL ||
model === PREVIEW_GEMINI_FLASH_MODEL ||
model === PREVIEW_GEMINI_MODEL_AUTO
model === PREVIEW_GEMINI_MODEL_AUTO ||
model === GEMINI_MODEL_ALIAS_AUTO
);
}

Expand Down Expand Up @@ -168,7 +169,7 @@
* @returns True if the model is a Gemini-2.x model.
*/
export function isGemini2Model(model: string): boolean {
return /^gemini-2(\.|$)/.test(model);

Check warning on line 172 in packages/core/src/config/models.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-2". Please make sure this change is appropriate to submit.
}

/**
Expand Down
57 changes: 57 additions & 0 deletions packages/core/src/routing/strategies/approvalModeStrategy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import {
PREVIEW_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
GEMINI_MODEL_ALIAS_AUTO,
} from '../../config/models.js';
import { AuthType } from '../../core/contentGenerator.js';
import { ApprovalMode } from '../../policy/types.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';

Expand All @@ -40,6 +42,15 @@ describe('ApprovalModeStrategy', () => {
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getApprovedPlanPath: vi.fn().mockReturnValue(undefined),
getPlanModeRoutingEnabled: vi.fn().mockResolvedValue(true),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
const launched = await mockConfig.getGemini31Launched();
const authType = mockConfig.getContentGeneratorConfig?.()?.authType;
return launched && authType === AuthType.USE_GEMINI;
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
} as unknown as Config;

mockBaseLlmClient = {} as BaseLlmClient;
Expand Down Expand Up @@ -184,4 +195,50 @@ describe('ApprovalModeStrategy', () => {

expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
});

it('should route to Preview models when using "auto" alias', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);

vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
'/path/to/plan.md',
);

const implementationDecision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(implementationDecision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
});

it('should route to Preview Flash model when an approved plan exists and Gemini 3.1 is launched', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);

// Exit plan mode with approved plan
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
'/path/to/plan.md',
);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

// Should resolve to Preview Flash (3.0) because resolveClassifierModel uses preview variants for Gemini 3
expect(decision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
});
});
29 changes: 19 additions & 10 deletions packages/core/src/routing/strategies/approvalModeStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

import type { Config } from '../../config/config.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
isAutoModel,
isPreviewModel,
resolveClassifierModel,
GEMINI_MODEL_ALIAS_FLASH,
GEMINI_MODEL_ALIAS_PRO,
} from '../../config/models.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import { ApprovalMode } from '../../policy/types.js';
Expand Down Expand Up @@ -50,11 +48,19 @@ export class ApprovalModeStrategy implements RoutingStrategy {
const approvalMode = config.getApprovalMode();
const approvedPlanPath = config.getApprovedPlanPath();

const isPreview = isPreviewModel(model);
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
]);

// 1. Planning Phase: If ApprovalMode === PLAN, explicitly route to the Pro model.
if (approvalMode === ApprovalMode.PLAN) {
const proModel = isPreview ? PREVIEW_GEMINI_MODEL : DEFAULT_GEMINI_MODEL;
const proModel = resolveClassifierModel(
model,
GEMINI_MODEL_ALIAS_PRO,
useGemini3_1,
useCustomToolModel,
);
return {
model: proModel,
metadata: {
Expand All @@ -65,9 +71,12 @@ export class ApprovalModeStrategy implements RoutingStrategy {
};
} else if (approvedPlanPath) {
// 2. Implementation Phase: If ApprovalMode !== PLAN AND an approved plan path is set, prefer the Flash model.
const flashModel = isPreview
? PREVIEW_GEMINI_FLASH_MODEL
: DEFAULT_GEMINI_FLASH_MODEL;
const flashModel = resolveClassifierModel(
model,
GEMINI_MODEL_ALIAS_FLASH,
useGemini3_1,
useCustomToolModel,
);
return {
model: flashModel,
metadata: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ describe('ClassifierStrategy', () => {
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
const launched = await mockConfig.getGemini31Launched();
const authType = mockConfig.getContentGeneratorConfig().authType;
return launched && authType === AuthType.USE_GEMINI;
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
Expand Down
9 changes: 4 additions & 5 deletions packages/core/src/routing/strategies/classifierStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import {
import { debugLogger } from '../../utils/debugLogger.js';
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
import { LlmRole } from '../../telemetry/types.js';
import { AuthType } from '../../core/contentGenerator.js';

// The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 4;
Expand Down Expand Up @@ -172,10 +171,10 @@ export class ClassifierStrategy implements RoutingStrategy {

const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
]);
const selectedModel = resolveClassifierModel(
model,
routerResponse.model_choice,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ describe('NumericalClassifierStrategy', () => {
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
const launched = await mockConfig.getGemini31Launched();
const authType = mockConfig.getContentGeneratorConfig().authType;
return launched && authType === AuthType.USE_GEMINI;
}),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js';
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
import { LlmRole } from '../../telemetry/types.js';
import { AuthType } from '../../core/contentGenerator.js';

// The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 8;
Expand Down Expand Up @@ -185,10 +184,10 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
config,
config.getSessionId() || 'unknown-session',
);
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
]);
const selectedModel = resolveClassifierModel(
model,
modelAlias,
Expand Down
Loading