Skip to content

Commit 4adc1a4

Browse files
committed
improve aiModel
1 parent 55c4c29 commit 4adc1a4

6 files changed

Lines changed: 33 additions & 21 deletions

File tree

web/src/core/adapters/ai/mock.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ export function createAi(params: { webUiUrl: string }): Ai {
77
webUiUrl,
88
apiBase: `${webUiUrl}/api`,
99
getToken: async () => ({ status: "success" as const, token: "mock-ai-token" }),
10-
listModels: async () => ["llama3.2", "mistral-7b", "codestral"]
10+
listModels: async () => [
11+
{ id: "llama3.2", name: "Llama 3.2" },
12+
{ id: "mistral-7b", name: "Mistral 7B" },
13+
{ id: "codestral", name: "Codestral" }
14+
]
1115
};
1216
}

web/src/core/adapters/ai/openWebUi.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import type { Ai, GetTokenResult } from "core/ports/Ai";
22
import { oidcTokenExchange, OidcTokenExchangeError } from "core/tools/oidcTokenExchange";
3+
import { z } from "zod";
34

45
export function createAi(params: {
56
webUiUrl: string;
@@ -37,9 +38,11 @@ export function createAi(params: {
3738
throw new Error(`Failed to list models (${response.status})`);
3839
}
3940

40-
const data = await response.json();
41+
const { data } = z
42+
.object({ data: z.array(z.object({ id: z.string(), name: z.string() })) })
43+
.parse(await response.json());
4144

42-
return (data.data as { id: string }[]).map(m => m.id);
45+
return data.map(({ id, name }) => ({ id, name }));
4346
}
4447
};
4548
}

web/src/core/ports/Ai.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export type Ai = {
22
webUiUrl: string;
33
apiBase: string;
44
getToken: () => Promise<GetTokenResult>;
5-
listModels: (token: string) => Promise<string[]>;
5+
listModels: (token: string) => Promise<{ id: string; name: string }[]>;
66
};
77

88
export type GetTokenResult =

web/src/core/usecases/ai/state.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ import { id } from "tsafe/id";
33

44
export const name = "ai";
55

6+
export type AiModel = { id: string; name: string };
7+
68
export type CustomAiProvider = {
79
id: string;
810
label: string;
911
apiBase: string;
1012
apiKey: string;
11-
availableModels: string[];
13+
availableModels: AiModel[];
1214
selectedModel: string | undefined;
1315
modelsFetchStatus: "fetching" | "success" | "error";
1416
};
@@ -26,7 +28,7 @@ export declare namespace State {
2628
webUiUrl: string;
2729
apiBase: string;
2830
token: string | undefined;
29-
availableModels: string[];
31+
availableModels: AiModel[];
3032
selectedModel: string | undefined;
3133
customProviders: CustomAiProvider[];
3234
};
@@ -59,7 +61,7 @@ export const { reducer, actions } = createUsecaseActions({
5961
webUiUrl: string;
6062
apiBase: string;
6163
token: string;
62-
availableModels: string[];
64+
availableModels: AiModel[];
6365
selectedModel: string | undefined;
6466
customProviders: CustomAiProvider[];
6567
};
@@ -80,7 +82,7 @@ export const { reducer, actions } = createUsecaseActions({
8082
apiBase,
8183
token,
8284
availableModels,
83-
selectedModel: selectedModel ?? availableModels[0],
85+
selectedModel: selectedModel ?? availableModels[0]?.id,
8486
customProviders
8587
});
8688
},
@@ -107,15 +109,15 @@ export const { reducer, actions } = createUsecaseActions({
107109
},
108110
customProviderModelsLoaded: (
109111
state,
110-
{ payload }: { payload: { id: string; models: string[] } }
112+
{ payload }: { payload: { id: string; models: AiModel[] } }
111113
) => {
112114
if (!state.isEnabled) return;
113115
const provider = state.customProviders.find(p => p.id === payload.id);
114116
if (provider === undefined) return;
115117
provider.availableModels = payload.models;
116118
provider.modelsFetchStatus = "success";
117119
if (provider.selectedModel === undefined && payload.models.length > 0) {
118-
provider.selectedModel = payload.models[0];
120+
provider.selectedModel = payload.models[0].id;
119121
}
120122
},
121123
customProviderModelsFetchFailed: (

web/src/core/usecases/ai/thunks.ts

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type { Thunks } from "core/bootstrap";
22
import { actions } from "./state";
3-
import type { CustomAiProvider } from "./state";
3+
import type { AiModel, CustomAiProvider } from "./state";
4+
import { z } from "zod";
45
import { getLocalStorage } from "core/tools/safeLocalStorage";
56
import * as deploymentRegionManagement from "core/usecases/deploymentRegionManagement";
67
import { assert } from "tsafe";
@@ -38,15 +39,17 @@ function writePersistedProviders(
3839
localStorage.setItem(CUSTOM_PROVIDERS_STORAGE_KEY, JSON.stringify(providers));
3940
}
4041

41-
async function fetchModels(apiBase: string, apiKey: string): Promise<string[]> {
42+
async function fetchModels(apiBase: string, apiKey: string): Promise<AiModel[]> {
4243
const response = await fetch(`${apiBase}/models`, {
4344
headers: { Authorization: `Bearer ${apiKey}` }
4445
});
4546
if (!response.ok) {
4647
throw new Error(`Failed to fetch models (${response.status})`);
4748
}
48-
const data = await response.json();
49-
return (data.data as { id: string }[]).map(m => m.id);
49+
const { data } = z
50+
.object({ data: z.array(z.object({ id: z.string(), name: z.string() })) })
51+
.parse(await response.json());
52+
return data.map(({ id, name }) => ({ id, name }));
5053
}
5154

5255
export const thunks = {
@@ -144,7 +147,7 @@ export const thunks = {
144147
},
145148
testCustomProvider:
146149
(params: { apiBase: string; apiKey: string }) =>
147-
async (..._args): Promise<string[]> => {
150+
async (..._args): Promise<AiModel[]> => {
148151
const { apiBase, apiKey } = params;
149152
return fetchModels(apiBase, apiKey);
150153
},

web/src/ui/pages/account/AccountAiTab.tsx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ const AccountAiGatewayTab = memo((props: Props) => {
180180
onChange={onModelChange}
181181
size="small"
182182
>
183-
{availableModels.map(model => (
184-
<MenuItem key={model} value={model}>
185-
{model}
183+
{availableModels.map(({ id, name }) => (
184+
<MenuItem key={id} value={id}>
185+
{name}
186186
</MenuItem>
187187
))}
188188
</Select>
@@ -343,9 +343,9 @@ const CustomProviderCard = memo((props: CustomProviderCardProps) => {
343343
onChange={onModelChange}
344344
size="small"
345345
>
346-
{provider.availableModels.map(model => (
347-
<MenuItem key={model} value={model}>
348-
{model}
346+
{provider.availableModels.map(({ id, name }) => (
347+
<MenuItem key={id} value={id}>
348+
{name}
349349
</MenuItem>
350350
))}
351351
</Select>

0 commit comments

Comments
 (0)