Skip to content

Commit 7798dbc

Browse files
Merge pull request #18 from ai-action/refactor/select
refactor(components): improve select-based model picker flows
2 parents 7521c69 + f2449b8 commit 7798dbc

11 files changed

Lines changed: 325 additions & 83 deletions

src/components/App.test.tsx

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ const capturedCallbacks = vi.hoisted(() => ({
1818
onCommand: null as ((command: string) => void) | null,
1919
onModeChange: null as ((mode: string) => void) | null,
2020
onSelect: null as ((model: string) => void) | null,
21-
onCancel: null as (() => void) | null,
21+
onClose: null as (() => void) | null,
2222
onToggleMode: null as (() => void) | null,
2323
}));
2424

@@ -47,14 +47,14 @@ vi.mock('./Chat', () => ({
4747
vi.mock('./ModelPicker', () => ({
4848
ModelPicker: ({
4949
onSelect,
50-
onCancel,
50+
onClose,
5151
}: {
5252
currentModel: string;
5353
onSelect: (model: string) => void;
54-
onCancel: () => void;
54+
onClose: () => void;
5555
}) => {
5656
capturedCallbacks.onSelect = onSelect;
57-
capturedCallbacks.onCancel = onCancel;
57+
capturedCallbacks.onClose = onClose;
5858
return <Text>ModelPicker</Text>;
5959
},
6060
}));
@@ -80,7 +80,7 @@ describe('App', () => {
8080
capturedCallbacks.onCommand = null;
8181
capturedCallbacks.onModeChange = null;
8282
capturedCallbacks.onSelect = null;
83-
capturedCallbacks.onCancel = null;
83+
capturedCallbacks.onClose = null;
8484
capturedCallbacks.onToggleMode = null;
8585
});
8686

@@ -114,12 +114,12 @@ describe('App', () => {
114114
expect(lastFrame()).not.toContain('ModelPicker');
115115
});
116116

117-
it('returns to chat when onCancel is called', async () => {
117+
it('returns to chat when onClose is called', async () => {
118118
const { lastFrame, rerender } = render(<App />);
119119
capturedCallbacks.onCommand?.('/model');
120120
rerender(<App />);
121121
await test.tick();
122-
capturedCallbacks.onCancel?.();
122+
capturedCallbacks.onClose?.();
123123
rerender(<App />);
124124
await test.tick();
125125
expect(lastFrame()).not.toContain('ModelPicker');

src/components/App.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export function App() {
2525
setPicking(false);
2626
}, []);
2727

28-
const handleCancel = useCallback(() => {
28+
const handleClose = useCallback(() => {
2929
setPicking(false);
3030
}, []);
3131

@@ -37,7 +37,7 @@ export function App() {
3737
<ModelPicker
3838
currentModel={model}
3939
onSelect={handleSelect}
40-
onCancel={handleCancel}
40+
onClose={handleClose}
4141
/>
4242
) : (
4343
<Chat

src/components/ModelPicker.test.tsx

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,22 @@ import { ModelPicker } from './ModelPicker';
4343

4444
describe('ModelPicker', () => {
4545
beforeEach(() => {
46+
mockListModels.mockReset();
47+
mockOnChange.mockReset();
4648
mockListModels.mockResolvedValue(['gemma4', 'llama3', 'codellama']);
4749
});
4850

51+
afterEach(() => {
52+
vi.useRealTimers();
53+
});
54+
4955
it('shows loading state before models arrive', () => {
5056
mockListModels.mockReturnValue(new Promise(() => undefined));
5157
const { lastFrame } = render(
5258
<ModelPicker
5359
currentModel="gemma4"
5460
onSelect={vi.fn()}
55-
onCancel={vi.fn()}
61+
onClose={vi.fn()}
5662
/>,
5763
);
5864
expect(lastFrame()).toContain('Loading models');
@@ -63,7 +69,7 @@ describe('ModelPicker', () => {
6369
<ModelPicker
6470
currentModel="gemma4"
6571
onSelect={vi.fn()}
66-
onCancel={vi.fn()}
72+
onClose={vi.fn()}
6773
/>,
6874
);
6975
await test.tick(10);
@@ -78,40 +84,137 @@ describe('ModelPicker', () => {
7884
<ModelPicker
7985
currentModel="llama3"
8086
onSelect={vi.fn()}
81-
onCancel={vi.fn()}
87+
onClose={vi.fn()}
8288
/>,
8389
);
8490
await test.tick(10);
8591
expect(lastFrame()).toContain('llama3');
8692
});
8793

94+
it('renders current model first in the list', async () => {
95+
const { lastFrame } = render(
96+
<ModelPicker
97+
currentModel="llama3"
98+
onSelect={vi.fn()}
99+
onClose={vi.fn()}
100+
/>,
101+
);
102+
103+
await test.tick(10);
104+
105+
const frame = lastFrame() ?? '';
106+
expect(frame.indexOf('llama3')).toBeLessThan(frame.indexOf('gemma4'));
107+
expect(frame.indexOf('llama3')).toBeLessThan(frame.indexOf('codellama'));
108+
});
109+
110+
it('does not inject the current model when it is not in the fetched list', async () => {
111+
mockListModels.mockResolvedValue(['gemma4', 'codellama']);
112+
113+
const { lastFrame } = render(
114+
<ModelPicker
115+
currentModel="llama3"
116+
onSelect={vi.fn()}
117+
onClose={vi.fn()}
118+
/>,
119+
);
120+
121+
await test.tick(10);
122+
123+
const frame = lastFrame() ?? '';
124+
expect(frame).toContain('gemma4');
125+
expect(frame).toContain('codellama');
126+
expect(frame).not.toContain('llama3');
127+
});
128+
129+
it('reloads and reorders options when currentModel changes', async () => {
130+
const { lastFrame, rerender } = render(
131+
<ModelPicker
132+
currentModel="gemma4"
133+
onSelect={vi.fn()}
134+
onClose={vi.fn()}
135+
/>,
136+
);
137+
138+
await test.tick(10);
139+
140+
rerender(
141+
<ModelPicker
142+
currentModel="llama3"
143+
onSelect={vi.fn()}
144+
onClose={vi.fn()}
145+
/>,
146+
);
147+
148+
await test.tick(10);
149+
150+
const frame = lastFrame() ?? '';
151+
expect(mockListModels).toHaveBeenCalledTimes(2);
152+
expect(frame.indexOf('llama3')).toBeLessThan(frame.indexOf('gemma4'));
153+
});
154+
88155
it('calls onSelect when a model is chosen', async () => {
89156
const onSelect = vi.fn();
90157
render(
91158
<ModelPicker
92159
currentModel="gemma4"
93160
onSelect={onSelect}
94-
onCancel={vi.fn()}
161+
onClose={vi.fn()}
95162
/>,
96163
);
97164
await test.tick(10);
98165
mockOnChange('llama3');
99166
expect(onSelect).toHaveBeenCalledWith('llama3');
100167
});
101168

102-
it('calls onCancel on Escape', async () => {
103-
const onCancel = vi.fn();
169+
it('calls onClose on Escape', async () => {
170+
const onClose = vi.fn();
104171
const { stdin } = render(
105172
<ModelPicker
106173
currentModel="gemma4"
107174
onSelect={vi.fn()}
108-
onCancel={onCancel}
175+
onClose={onClose}
109176
/>,
110177
);
111178
await test.tick(10);
112179
stdin.write(KEY.ESCAPE);
113-
await test.tick(50);
114-
expect(onCancel).toHaveBeenCalled();
180+
await test.tick(20);
181+
expect(onClose).toHaveBeenCalled();
182+
});
183+
184+
it('does not call onClose on Enter while models are loading', async () => {
185+
vi.useFakeTimers();
186+
mockListModels.mockReturnValue(new Promise(() => undefined));
187+
188+
const onClose = vi.fn();
189+
const { stdin } = render(
190+
<ModelPicker
191+
currentModel="gemma4"
192+
onSelect={vi.fn()}
193+
onClose={onClose}
194+
/>,
195+
);
196+
197+
stdin.write(KEY.ENTER);
198+
await vi.runAllTimersAsync();
199+
200+
expect(onClose).not.toHaveBeenCalled();
201+
});
202+
203+
it('calls onClose on Enter after models load', async () => {
204+
const onClose = vi.fn();
205+
const { stdin } = render(
206+
<ModelPicker
207+
currentModel="gemma4"
208+
onSelect={vi.fn()}
209+
onClose={onClose}
210+
/>,
211+
);
212+
213+
await test.tick(10);
214+
stdin.write(KEY.ENTER);
215+
await test.tick(10);
216+
217+
expect(onClose).toHaveBeenCalledTimes(1);
115218
});
116219

117220
it('shows error when listModels fails', async () => {
@@ -120,7 +223,7 @@ describe('ModelPicker', () => {
120223
<ModelPicker
121224
currentModel="gemma4"
122225
onSelect={vi.fn()}
123-
onCancel={vi.fn()}
226+
onClose={vi.fn()}
124227
/>,
125228
);
126229
await test.tick(10);
@@ -133,25 +236,25 @@ describe('ModelPicker', () => {
133236
<ModelPicker
134237
currentModel="gemma4"
135238
onSelect={vi.fn()}
136-
onCancel={vi.fn()}
239+
onClose={vi.fn()}
137240
/>,
138241
);
139242
await test.tick(10);
140243
expect(lastFrame()).toContain('Error loading models: network timeout');
141244
});
142245

143-
it('does not call onCancel for non-escape keys', async () => {
144-
const onCancel = vi.fn();
246+
it('does not call onClose for non-enter keys', async () => {
247+
const onClose = vi.fn();
145248
const { stdin } = render(
146249
<ModelPicker
147250
currentModel="gemma4"
148251
onSelect={vi.fn()}
149-
onCancel={onCancel}
252+
onClose={onClose}
150253
/>,
151254
);
152255
await test.tick(10);
153256
stdin.write('a');
154257
await test.tick(10);
155-
expect(onCancel).not.toHaveBeenCalled();
258+
expect(onClose).not.toHaveBeenCalled();
156259
});
157260
});

src/components/ModelPicker.tsx

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,47 @@
1-
import { Select, Spinner } from '@inkjs/ui';
2-
import { Box, Text, useInput } from 'ink';
1+
import { Spinner } from '@inkjs/ui';
2+
import { Text, useInput } from 'ink';
33
import { useEffect, useState } from 'react';
44

55
import { ollama } from '../utils';
6+
import { SelectPrompt } from './SelectPrompt';
67

78
interface Props {
89
currentModel: string;
910
onSelect: (model: string) => void;
10-
onCancel: () => void;
11+
onClose: () => void;
1112
}
1213

13-
export function ModelPicker({ currentModel, onSelect, onCancel }: Props) {
14+
export function ModelPicker({ currentModel, onSelect, onClose }: Props) {
1415
const [options, setOptions] = useState<{ label: string; value: string }[]>(
1516
[],
1617
);
1718
const [error, setError] = useState<string | null>(null);
1819

20+
// close select prompt if current model is chosen
21+
useInput((_, key) => {
22+
if (options.length && key.return) {
23+
setTimeout(onClose);
24+
}
25+
});
26+
1927
useEffect(() => {
2028
async function load() {
2129
try {
22-
const list = await ollama.listModels();
23-
setOptions(list.map((name) => ({ label: name, value: name })));
30+
const models = await ollama.listModels();
31+
if (models.includes(currentModel)) {
32+
models.splice(models.indexOf(currentModel), 1);
33+
models.unshift(currentModel);
34+
}
35+
36+
const options = models.map((model) => ({ label: model, value: model }));
37+
setOptions(options);
2438
} catch (error: unknown) {
2539
setError(error instanceof Error ? error.message : String(error));
2640
}
2741
}
2842

2943
void load();
30-
}, []);
31-
32-
useInput((_, key) => {
33-
if (key.escape) {
34-
onCancel();
35-
}
36-
});
44+
}, [currentModel]);
3745

3846
if (error) {
3947
return <Text color="red">Error loading models: {error}</Text>;
@@ -44,16 +52,15 @@ export function ModelPicker({ currentModel, onSelect, onCancel }: Props) {
4452
}
4553

4654
return (
47-
<Box flexDirection="column">
55+
<SelectPrompt
56+
options={options}
57+
defaultValue={currentModel}
58+
onChange={onSelect}
59+
onEscape={onClose}
60+
>
4861
<Text dimColor>
4962
Select a model (↑↓ + Enter to confirm, Esc to cancel)
5063
</Text>
51-
52-
<Select
53-
options={options}
54-
defaultValue={currentModel}
55-
onChange={onSelect}
56-
/>
57-
</Box>
64+
</SelectPrompt>
5865
);
5966
}

src/components/PlanApproval.test.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ describe('PlanApproval', () => {
8484
);
8585

8686
stdin.write(KEY.ESCAPE);
87-
await test.tick(50);
87+
await test.tick(20);
8888

8989
expect(onModeChange).toHaveBeenCalledWith(MODE.NAME.PLAN);
9090
});
@@ -96,7 +96,7 @@ describe('PlanApproval', () => {
9696
);
9797

9898
stdin.write(KEY.ENTER);
99-
await test.tick(50);
99+
await test.tick(20);
100100

101101
expect(onModeChange).not.toHaveBeenCalled();
102102
});

0 commit comments

Comments
 (0)