Skip to content

Commit dbc1113

Browse files
authored
fix(core): ensure correct flash model steering in plan mode implementation phase (#21871)
1 parent 3b0f730 commit dbc1113

10 files changed

Lines changed: 127 additions & 28 deletions

packages/core/src/availability/policyHelpers.test.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,21 @@ import {
2020
} from '../config/models.js';
2121
import { AuthType } from '../core/contentGenerator.js';
2222

23-
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
24-
({
23+
const createMockConfig = (overrides: Partial<Config> = {}): Config => {
24+
const config = {
2525
getUserTier: () => undefined,
2626
getModel: () => 'gemini-2.5-pro',
2727
getGemini31LaunchedSync: () => false,
28+
getUseCustomToolModelSync: () => {
29+
const useGemini31 = config.getGemini31LaunchedSync();
30+
const authType = config.getContentGeneratorConfig().authType;
31+
return useGemini31 && authType === AuthType.USE_GEMINI;
32+
},
2833
getContentGeneratorConfig: () => ({ authType: undefined }),
2934
...overrides,
30-
}) as unknown as Config;
35+
} as unknown as Config;
36+
return config;
37+
};
3138

3239
describe('policyHelpers', () => {
3340
describe('resolvePolicyChain', () => {

packages/core/src/availability/policyHelpers.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import type { GenerateContentConfig } from '@google/genai';
88
import type { Config } from '../config/config.js';
9-
import { AuthType } from '../core/contentGenerator.js';
109
import type {
1110
FailureKind,
1211
FallbackAction,
@@ -46,9 +45,7 @@ export function resolvePolicyChain(
4645

4746
let chain;
4847
const useGemini31 = config.getGemini31LaunchedSync?.() ?? false;
49-
const useCustomToolModel =
50-
useGemini31 &&
51-
config.getContentGeneratorConfig?.()?.authType === AuthType.USE_GEMINI;
48+
const useCustomToolModel = config.getUseCustomToolModelSync?.() ?? false;
5249
const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true;
5350

5451
const resolvedModel = resolveModel(

packages/core/src/config/config.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,6 +2529,26 @@ export class Config implements McpContext, AgentLoopContext {
25292529
return this.getGemini31LaunchedSync();
25302530
}
25312531

2532+
/**
2533+
* Returns whether the custom tool model should be used.
2534+
*/
2535+
async getUseCustomToolModel(): Promise<boolean> {
2536+
const useGemini3_1 = await this.getGemini31Launched();
2537+
const authType = this.contentGeneratorConfig?.authType;
2538+
return useGemini3_1 && authType === AuthType.USE_GEMINI;
2539+
}
2540+
2541+
/**
2542+
* Returns whether the custom tool model should be used.
2543+
*
2544+
* Note: This method should only be called after startup, once experiments have been loaded.
2545+
*/
2546+
getUseCustomToolModelSync(): boolean {
2547+
const useGemini3_1 = this.getGemini31LaunchedSync();
2548+
const authType = this.contentGeneratorConfig?.authType;
2549+
return useGemini3_1 && authType === AuthType.USE_GEMINI;
2550+
}
2551+
25322552
/**
25332553
* Returns whether Gemini 3.1 has been launched.
25342554
*

packages/core/src/config/models.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ export function isPreviewModel(model: string): boolean {
168168
model === PREVIEW_GEMINI_3_1_MODEL ||
169169
model === PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL ||
170170
model === PREVIEW_GEMINI_FLASH_MODEL ||
171-
model === PREVIEW_GEMINI_MODEL_AUTO
171+
model === PREVIEW_GEMINI_MODEL_AUTO ||
172+
model === GEMINI_MODEL_ALIAS_AUTO
172173
);
173174
}
174175

packages/core/src/routing/strategies/approvalModeStrategy.test.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ import {
1515
PREVIEW_GEMINI_FLASH_MODEL,
1616
DEFAULT_GEMINI_MODEL_AUTO,
1717
PREVIEW_GEMINI_MODEL_AUTO,
18+
GEMINI_MODEL_ALIAS_AUTO,
1819
} from '../../config/models.js';
20+
import { AuthType } from '../../core/contentGenerator.js';
1921
import { ApprovalMode } from '../../policy/types.js';
2022
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
2123

@@ -40,6 +42,15 @@ describe('ApprovalModeStrategy', () => {
4042
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
4143
getApprovedPlanPath: vi.fn().mockReturnValue(undefined),
4244
getPlanModeRoutingEnabled: vi.fn().mockResolvedValue(true),
45+
getGemini31Launched: vi.fn().mockResolvedValue(false),
46+
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
47+
const launched = await mockConfig.getGemini31Launched();
48+
const authType = mockConfig.getContentGeneratorConfig?.()?.authType;
49+
return launched && authType === AuthType.USE_GEMINI;
50+
}),
51+
getContentGeneratorConfig: vi.fn().mockReturnValue({
52+
authType: AuthType.LOGIN_WITH_GOOGLE,
53+
}),
4354
} as unknown as Config;
4455

4556
mockBaseLlmClient = {} as BaseLlmClient;
@@ -184,4 +195,50 @@ describe('ApprovalModeStrategy', () => {
184195

185196
expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
186197
});
198+
199+
it('should route to Preview models when using "auto" alias', async () => {
200+
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
201+
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
202+
203+
const decision = await strategy.route(
204+
mockContext,
205+
mockConfig,
206+
mockBaseLlmClient,
207+
);
208+
209+
expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
210+
211+
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
212+
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
213+
'/path/to/plan.md',
214+
);
215+
216+
const implementationDecision = await strategy.route(
217+
mockContext,
218+
mockConfig,
219+
mockBaseLlmClient,
220+
);
221+
222+
expect(implementationDecision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
223+
});
224+
225+
it('should route to Preview Flash model when an approved plan exists and Gemini 3.1 is launched', async () => {
226+
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
227+
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
228+
229+
// Exit plan mode with approved plan
230+
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
231+
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
232+
'/path/to/plan.md',
233+
);
234+
235+
const decision = await strategy.route(
236+
mockContext,
237+
mockConfig,
238+
mockBaseLlmClient,
239+
);
240+
241+
// Should resolve to Preview Flash (3.0) because resolveClassifierModel uses preview variants for Gemini 3
242+
expect(decision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
243+
});
187244
});

packages/core/src/routing/strategies/approvalModeStrategy.ts

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66

77
import type { Config } from '../../config/config.js';
88
import {
9-
DEFAULT_GEMINI_MODEL,
10-
DEFAULT_GEMINI_FLASH_MODEL,
11-
PREVIEW_GEMINI_MODEL,
12-
PREVIEW_GEMINI_FLASH_MODEL,
139
isAutoModel,
14-
isPreviewModel,
10+
resolveClassifierModel,
11+
GEMINI_MODEL_ALIAS_FLASH,
12+
GEMINI_MODEL_ALIAS_PRO,
1513
} from '../../config/models.js';
1614
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
1715
import { ApprovalMode } from '../../policy/types.js';
@@ -50,11 +48,19 @@ export class ApprovalModeStrategy implements RoutingStrategy {
5048
const approvalMode = config.getApprovalMode();
5149
const approvedPlanPath = config.getApprovedPlanPath();
5250

53-
const isPreview = isPreviewModel(model);
51+
const [useGemini3_1, useCustomToolModel] = await Promise.all([
52+
config.getGemini31Launched(),
53+
config.getUseCustomToolModel(),
54+
]);
5455

5556
// 1. Planning Phase: If ApprovalMode === PLAN, explicitly route to the Pro model.
5657
if (approvalMode === ApprovalMode.PLAN) {
57-
const proModel = isPreview ? PREVIEW_GEMINI_MODEL : DEFAULT_GEMINI_MODEL;
58+
const proModel = resolveClassifierModel(
59+
model,
60+
GEMINI_MODEL_ALIAS_PRO,
61+
useGemini3_1,
62+
useCustomToolModel,
63+
);
5864
return {
5965
model: proModel,
6066
metadata: {
@@ -65,9 +71,12 @@ export class ApprovalModeStrategy implements RoutingStrategy {
6571
};
6672
} else if (approvedPlanPath) {
6773
// 2. Implementation Phase: If ApprovalMode !== PLAN AND an approved plan path is set, prefer the Flash model.
68-
const flashModel = isPreview
69-
? PREVIEW_GEMINI_FLASH_MODEL
70-
: DEFAULT_GEMINI_FLASH_MODEL;
74+
const flashModel = resolveClassifierModel(
75+
model,
76+
GEMINI_MODEL_ALIAS_FLASH,
77+
useGemini3_1,
78+
useCustomToolModel,
79+
);
7180
return {
7281
model: flashModel,
7382
metadata: {

packages/core/src/routing/strategies/classifierStrategy.test.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ describe('ClassifierStrategy', () => {
5959
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
6060
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
6161
getGemini31Launched: vi.fn().mockResolvedValue(false),
62+
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
63+
const launched = await mockConfig.getGemini31Launched();
64+
const authType = mockConfig.getContentGeneratorConfig().authType;
65+
return launched && authType === AuthType.USE_GEMINI;
66+
}),
6267
getContentGeneratorConfig: vi.fn().mockReturnValue({
6368
authType: AuthType.LOGIN_WITH_GOOGLE,
6469
}),

packages/core/src/routing/strategies/classifierStrategy.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import {
2222
import { debugLogger } from '../../utils/debugLogger.js';
2323
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
2424
import { LlmRole } from '../../telemetry/types.js';
25-
import { AuthType } from '../../core/contentGenerator.js';
2625

2726
// The number of recent history turns to provide to the router for context.
2827
const HISTORY_TURNS_FOR_CONTEXT = 4;
@@ -172,10 +171,10 @@ export class ClassifierStrategy implements RoutingStrategy {
172171

173172
const reasoning = routerResponse.reasoning;
174173
const latencyMs = Date.now() - startTime;
175-
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
176-
const useCustomToolModel =
177-
useGemini3_1 &&
178-
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
174+
const [useGemini3_1, useCustomToolModel] = await Promise.all([
175+
config.getGemini31Launched(),
176+
config.getUseCustomToolModel(),
177+
]);
179178
const selectedModel = resolveClassifierModel(
180179
model,
181180
routerResponse.model_choice,

packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ describe('NumericalClassifierStrategy', () => {
5858
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
5959
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
6060
getGemini31Launched: vi.fn().mockResolvedValue(false),
61+
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
62+
const launched = await mockConfig.getGemini31Launched();
63+
const authType = mockConfig.getContentGeneratorConfig().authType;
64+
return launched && authType === AuthType.USE_GEMINI;
65+
}),
6166
getContentGeneratorConfig: vi.fn().mockReturnValue({
6267
authType: AuthType.LOGIN_WITH_GOOGLE,
6368
}),

packages/core/src/routing/strategies/numericalClassifierStrategy.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import type { Config } from '../../config/config.js';
1818
import { debugLogger } from '../../utils/debugLogger.js';
1919
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
2020
import { LlmRole } from '../../telemetry/types.js';
21-
import { AuthType } from '../../core/contentGenerator.js';
2221

2322
// The number of recent history turns to provide to the router for context.
2423
const HISTORY_TURNS_FOR_CONTEXT = 8;
@@ -185,10 +184,10 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
185184
config,
186185
config.getSessionId() || 'unknown-session',
187186
);
188-
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
189-
const useCustomToolModel =
190-
useGemini3_1 &&
191-
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
187+
const [useGemini3_1, useCustomToolModel] = await Promise.all([
188+
config.getGemini31Launched(),
189+
config.getUseCustomToolModel(),
190+
]);
192191
const selectedModel = resolveClassifierModel(
193192
model,
194193
modelAlias,

0 commit comments

Comments
 (0)