Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion packages/cli/src/ui/AppContainer.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import {
UIActionsContext,
type UIActions,
} from './contexts/UIActionsContext.js';
import { useContext } from 'react';
import { useContext, act } from 'react';

// Mock useStdout to capture terminal title writes
let mockStdout: { write: ReturnType<typeof vi.fn> };
Expand Down Expand Up @@ -1395,5 +1395,41 @@ describe('AppContainer State Management', () => {
expect.any(Number),
);
});

it('updates currentModel when ModelChanged event is received', async () => {
// Arrange: Mock initial model
vi.spyOn(mockConfig, 'getModel').mockReturnValue('initial-model');

const { unmount } = render(
<AppContainer
config={mockConfig}
settings={mockSettings}
version="1.0.0"
initializationResult={mockInitResult}
/>,
);

// Verify initial model
await act(async () => {
await vi.waitFor(() => {
expect(capturedUIState?.currentModel).toBe('initial-model');
});
});

// Get the registered handler for ModelChanged
const handler = mockCoreEvents.on.mock.calls.find(
(call: unknown[]) => call[0] === CoreEvent.ModelChanged,
)?.[1];
expect(handler).toBeDefined();

// Act: Simulate ModelChanged event
act(() => {
handler({ model: 'new-model' });
});

// Assert: Verify model is updated
expect(capturedUIState.currentModel).toBe('new-model');
unmount();
});
});
});
9 changes: 8 additions & 1 deletion packages/cli/src/ui/AppContainer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import {
debugLogger,
coreEvents,
CoreEvent,
type ModelChangedPayload,
} from '@google/gemini-cli-core';
import { validateAuthMethod } from '../config/auth.js';
import { loadHierarchicalGeminiMemory } from '../config/config.js';
Expand Down Expand Up @@ -253,16 +254,22 @@ export const AppContainer = (props: AppContainerProps) => {
[historyManager.addItem],
);

// Subscribe to fallback mode changes from core
// Subscribe to fallback mode and model changes from core
useEffect(() => {
const handleFallbackModeChanged = () => {
const effectiveModel = getEffectiveModel();
setCurrentModel(effectiveModel);
};

const handleModelChanged = (payload: ModelChangedPayload) => {
setCurrentModel(payload.model);
};

coreEvents.on(CoreEvent.FallbackModeChanged, handleFallbackModeChanged);
coreEvents.on(CoreEvent.ModelChanged, handleModelChanged);
return () => {
coreEvents.off(CoreEvent.FallbackModeChanged, handleFallbackModeChanged);
coreEvents.off(CoreEvent.ModelChanged, handleModelChanged);
};
}, [getEffectiveModel]);

Expand Down
15 changes: 15 additions & 0 deletions packages/core/src/config/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ vi.mock('../agents/subagent-tool-wrapper.js', () => ({
SubagentToolWrapper: vi.fn(),
}));

const mockCoreEvents = vi.hoisted(() => ({
emitFeedback: vi.fn(),
emitModelChanged: vi.fn(),
}));

const mockSetGlobalProxy = vi.hoisted(() => vi.fn());

vi.mock('../utils/events.js', () => ({
coreEvents: mockCoreEvents,
}));

vi.mock('../utils/fetch.js', () => ({
setGlobalProxy: mockSetGlobalProxy,
}));

import { BaseLlmClient } from '../core/baseLlmClient.js';
import { tokenLimit } from '../core/tokenLimits.js';
import { uiTelemetryService } from '../telemetry/index.js';
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
DEFAULT_OTLP_ENDPOINT,
uiTelemetryService,
} from '../telemetry/index.js';
import { coreEvents } from '../utils/events.js';
import { tokenLimit } from '../core/tokenLimits.js';
import {
DEFAULT_GEMINI_EMBEDDING_MODEL,
Expand Down Expand Up @@ -638,7 +639,10 @@ export class Config {
return;
}

this.model = newModel;
if (this.model !== newModel) {
this.model = newModel;
coreEvents.emitModelChanged(newModel);
}
}

isInFallbackMode(): boolean {
Expand Down
13 changes: 13 additions & 0 deletions packages/core/src/utils/events.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,17 @@ describe('CoreEventEmitter', () => {
});
expect(listener.mock.calls[2][0]).toMatchObject({ message: 'Buffered 2' });
});

describe('ModelChanged Event', () => {
it('should emit ModelChanged event with correct payload', () => {
const listener = vi.fn();
events.on(CoreEvent.ModelChanged, listener);

const newModel = 'gemini-2.5-pro';
events.emitModelChanged(newModel);

expect(listener).toHaveBeenCalledTimes(1);
expect(listener).toHaveBeenCalledWith({ model: newModel });
});
});
});
31 changes: 31 additions & 0 deletions packages/core/src/utils/events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,20 @@ export interface FallbackModeChangedPayload {
isInFallbackMode: boolean;
}

/**
* Payload for the 'model-changed' event.
*/
export interface ModelChangedPayload {
/**
* The new model that was set.
*/
model: string;
}

export enum CoreEvent {
UserFeedback = 'user-feedback',
FallbackModeChanged = 'fallback-mode-changed',
ModelChanged = 'model-changed',
}

export class CoreEventEmitter extends EventEmitter {
Expand Down Expand Up @@ -86,6 +97,14 @@ export class CoreEventEmitter extends EventEmitter {
this.emit(CoreEvent.FallbackModeChanged, payload);
}

/**
* Notifies subscribers that the model has changed.
*/
emitModelChanged(model: string): void {
const payload: ModelChangedPayload = { model };
this.emit(CoreEvent.ModelChanged, payload);
}

/**
* Flushes buffered messages. Call this immediately after primary UI listener
* subscribes.
Expand All @@ -106,6 +125,10 @@ export class CoreEventEmitter extends EventEmitter {
event: CoreEvent.FallbackModeChanged,
listener: (payload: FallbackModeChangedPayload) => void,
): this;
override on(
event: CoreEvent.ModelChanged,
listener: (payload: ModelChangedPayload) => void,
): this;
override on(
event: string | symbol,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -122,6 +145,10 @@ export class CoreEventEmitter extends EventEmitter {
event: CoreEvent.FallbackModeChanged,
listener: (payload: FallbackModeChangedPayload) => void,
): this;
override off(
event: CoreEvent.ModelChanged,
listener: (payload: ModelChangedPayload) => void,
): this;
override off(
event: string | symbol,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -138,6 +165,10 @@ export class CoreEventEmitter extends EventEmitter {
event: CoreEvent.FallbackModeChanged,
payload: FallbackModeChangedPayload,
): boolean;
override emit(
event: CoreEvent.ModelChanged,
payload: ModelChangedPayload,
): boolean;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
override emit(event: string | symbol, ...args: any[]): boolean {
return super.emit(event, ...args);
Expand Down
Loading