Skip to content

Commit 288024e

Browse files
committed
feat(ai-proxy): validate model before starting server
1 parent f881fd3 commit 288024e

2 files changed

Lines changed: 124 additions & 7 deletions

File tree

src/commands/ai.js

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,21 @@ const buildModelListResponse = (models, fallbackModel) => {
123123
return { object: 'list', data };
124124
};
125125

126+
const normalizeModelIds = (models) => {
127+
if (!Array.isArray(models)) return [];
128+
return models.map((model) => {
129+
if (typeof model === 'string') return model;
130+
if (model && model.id) return model.id;
131+
return null;
132+
}).filter(Boolean);
133+
};
134+
135+
const resolveAvailableModelsRaw = async (puter) => {
136+
if (!puter.ai || typeof puter.ai.listModels !== 'function') return [];
137+
const models = await puter.ai.listModels();
138+
return Array.isArray(models) ? models : [];
139+
};
140+
126141
export const createAIProxyServer = (options = {}) => {
127142
const defaults = {
128143
host: options.host || '127.0.0.1',
@@ -132,11 +147,18 @@ export const createAIProxyServer = (options = {}) => {
132147
maxTokens: normalizeNumber(options.maxTokens, 1024),
133148
temperature: normalizeNumber(options.temperature, 1)
134149
};
150+
const availableModelsRaw = options.availableModelsRaw;
151+
const availableModelsNormalized = Array.isArray(availableModelsRaw)
152+
? normalizeModelIds(availableModelsRaw)
153+
: null;
135154

136155
const modelsHandler = async ({ res }) => {
137156
try {
157+
if (Array.isArray(availableModelsRaw)) {
158+
return sendJson(res, 200, buildModelListResponse(availableModelsRaw, defaults.model));
159+
}
138160
const puter = getPuter();
139-
const models = typeof puter.ai?.listModels === 'function' ? await puter.ai.listModels() : [];
161+
const models = await resolveAvailableModelsRaw(puter);
140162
return sendJson(res, 200, buildModelListResponse(models, defaults.model));
141163
} catch (error) {
142164
return sendJson(res, 500, { error: { message: error.message || 'Failed to list models' } });
@@ -148,7 +170,7 @@ export const createAIProxyServer = (options = {}) => {
148170
method: 'GET',
149171
path: '/',
150172
handler: async ({ res }) => {
151-
return sendJson(res, 200, { status: 'ok', message: 'Puter AI proxy running on /v1' });
173+
return sendJson(res, 200, { status: 'ok', message: 'Puter AI running on /v1' });
152174
}
153175
},
154176
{
@@ -192,6 +214,17 @@ export const createAIProxyServer = (options = {}) => {
192214
return sendJson(res, 500, { error: { message: 'AI service not available', type: 'service_unavailable' } });
193215
}
194216

217+
if (availableModelsNormalized) {
218+
if (availableModelsNormalized.length > 0 && !availableModelsNormalized.includes(model)) {
219+
return sendJson(res, 400, { error: { message: `Unknown model: ${model}`, type: 'invalid_request_error' } });
220+
}
221+
} else if (typeof puter.ai.listModels === 'function') {
222+
const availableModels = normalizeModelIds(await puter.ai.listModels());
223+
if (availableModels.length > 0 && !availableModels.includes(model)) {
224+
return sendJson(res, 400, { error: { message: `Unknown model: ${model}`, type: 'invalid_request_error' } });
225+
}
226+
}
227+
195228
const result = await puter.ai.chat(prompt, {
196229
model,
197230
temperature,
@@ -224,15 +257,48 @@ export const createAIProxyServer = (options = {}) => {
224257
};
225258

226259
export const startAIProxyServer = async (options = {}) => {
260+
const requestedModel = typeof options.model === 'string'
261+
? options.model.trim()
262+
: (options.model ? String(options.model).trim() : '');
227263
const defaults = {
228264
host: options.host || '127.0.0.1',
229265
port: normalizeNumber(options.port, 8080),
230-
model: options.model || process.env.PUTER_AI_MODEL || 'gpt-5-nano',
266+
model: requestedModel || process.env.PUTER_AI_MODEL || 'gpt-5-nano',
231267
system: options.system ?? process.env.PUTER_AI_SYSTEM ?? '',
232268
maxTokens: normalizeNumber(options.maxTokens, 1024),
233269
temperature: normalizeNumber(options.temperature, 1)
234270
};
235-
const server = createAIProxyServer(defaults);
271+
const profileModule = getProfileModule();
272+
const authToken = profileModule.getAuthToken();
273+
if (!authToken) {
274+
throw new Error('Not authenticated. Run: puter login');
275+
}
276+
277+
const puter = getPuter();
278+
const availableModelsRaw = await resolveAvailableModelsRaw(puter);
279+
const availableModels = normalizeModelIds(availableModelsRaw);
280+
if (requestedModel && availableModels.length > 0 && !availableModels.includes(requestedModel)) {
281+
console.error(chalk.red(`Unknown model: ${requestedModel}`));
282+
const normalizedQuery = requestedModel.toLowerCase();
283+
const tokens = normalizedQuery.split(/[-_/]/).filter(Boolean);
284+
const primaryToken = tokens[0];
285+
const prefix = normalizedQuery.slice(0, 3);
286+
const suggestedModels = Array.from(new Set(availableModels.filter((model) => {
287+
const lower = model.toLowerCase();
288+
if (primaryToken && lower.includes(primaryToken)) return true;
289+
if (!primaryToken && normalizedQuery.length > 3 && lower.includes(prefix)) return true;
290+
return false;
291+
})));
292+
if (suggestedModels.length > 0) {
293+
console.log(chalk.cyan('Try one of the following:'));
294+
for (const suggestedModel of suggestedModels) {
295+
console.log(chalk.dim(` ${suggestedModel}`));
296+
}
297+
}
298+
return null;
299+
}
300+
301+
const server = createAIProxyServer({ ...defaults, availableModelsRaw });
236302
const { host, port } = await server.start();
237303
const trimmedSystem = String(defaults.system || '').trim();
238304
const systemPreview = trimmedSystem

tests/ai.test.js

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
2-
import { createAIProxyServer } from '../src/commands/ai.js';
2+
import { createAIProxyServer, startAIProxyServer } from '../src/commands/ai.js';
33
import { getPuter } from '../src/modules/PuterModule.js';
44
import { getProfileModule } from '../src/modules/ProfileModule.js';
55

@@ -65,7 +65,8 @@ describe('AI proxy server', () => {
6565
it('serves non-streaming chat completion', async () => {
6666
const puterMock = {
6767
ai: {
68-
chat: vi.fn().mockResolvedValue('Hello there')
68+
chat: vi.fn().mockResolvedValue('Hello there'),
69+
listModels: vi.fn().mockResolvedValue(['gpt-5-nano'])
6970
}
7071
};
7172
const { port } = await startServer(puterMock);
@@ -86,7 +87,8 @@ describe('AI proxy server', () => {
8687
it('serves streaming chat completion', async () => {
8788
const puterMock = {
8889
ai: {
89-
chat: vi.fn().mockResolvedValue('Hello world')
90+
chat: vi.fn().mockResolvedValue('Hello world'),
91+
listModels: vi.fn().mockResolvedValue(['gpt-5-nano'])
9092
}
9193
};
9294
const { port } = await startServer(puterMock);
@@ -104,4 +106,53 @@ describe('AI proxy server', () => {
104106
expect(text).toContain('data: ');
105107
expect(text).toContain('[DONE]');
106108
});
109+
110+
it('rejects unknown model', async () => {
111+
const puterMock = {
112+
ai: {
113+
chat: vi.fn().mockResolvedValue('Hello world'),
114+
listModels: vi.fn().mockResolvedValue(['gpt-5-nano'])
115+
}
116+
};
117+
const { port } = await startServer(puterMock);
118+
const response = await fetch(`http://127.0.0.1:${port}/v1/chat/completions`, {
119+
method: 'POST',
120+
headers: { 'content-type': 'application/json' },
121+
body: JSON.stringify({
122+
model: 'missing-model',
123+
messages: [{ role: 'user', content: 'Hi' }]
124+
})
125+
});
126+
const data = await response.json();
127+
expect(response.status).toBe(400);
128+
expect(data.error.message).toContain('Unknown model');
129+
});
130+
131+
it('rejects unknown model before startup', async () => {
132+
vi.mocked(getProfileModule).mockReturnValue({
133+
getAuthToken: vi.fn(() => 'test-token')
134+
});
135+
const listModels = vi.fn().mockResolvedValue(['gpt-5-nano']);
136+
vi.mocked(getPuter).mockReturnValue({
137+
ai: {
138+
listModels
139+
}
140+
});
141+
const server = await startAIProxyServer({ model: 'missing-model', port: 0 });
142+
expect(server).toBeNull();
143+
expect(listModels).toHaveBeenCalled();
144+
});
145+
146+
it('starts server when model exists', async () => {
147+
vi.mocked(getProfileModule).mockReturnValue({
148+
getAuthToken: vi.fn(() => 'test-token')
149+
});
150+
vi.mocked(getPuter).mockReturnValue({
151+
ai: {
152+
listModels: vi.fn().mockResolvedValue(['gpt-5-nano'])
153+
}
154+
});
155+
const server = await startAIProxyServer({ model: 'gpt-5-nano', port: 0 });
156+
await server.stop();
157+
});
107158
});

0 commit comments

Comments
 (0)