Skip to content

Commit 36c0fcc

Browse files
committed
feat: Implement model fetching functionality in AIHandler and update UI for model selection
1 parent 0cd8763 commit 36c0fcc

4 files changed

Lines changed: 115 additions & 15 deletions

File tree

src/main/aiHandler.ts

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,38 @@ class AIHandler {
284284
}
285285

286286
public async testConnection(config: LLMConfig): Promise<{ success: boolean; error?: string }> {
287+
try {
288+
const response = await fetch(`${config.apiHost.replace(/\/$/, '')}/chat/completions`, {
289+
method: 'POST',
290+
headers: {
291+
'Content-Type': 'application/json',
292+
Authorization: `Bearer ${config.apiKey}`
293+
},
294+
body: JSON.stringify({
295+
model: config.modelName,
296+
messages: [{ role: 'user', content: 'Hi' }],
297+
max_tokens: 5
298+
})
299+
})
300+
301+
if (response.ok) {
302+
return { success: true }
303+
} else {
304+
const errorText = await response.text()
305+
return {
306+
success: false,
307+
error: `HTTP ${response.status}: ${errorText || response.statusText}`
308+
}
309+
}
310+
} catch (error) {
311+
return {
312+
success: false,
313+
error: error instanceof Error ? error.message : 'Connection failed'
314+
}
315+
}
316+
}
317+
318+
public async getModels(config: LLMConfig): Promise<{ success: boolean; models?: string[]; error?: string }> {
287319
try {
288320
const response = await fetch(`${config.apiHost.replace(/\/$/, '')}/models`, {
289321
method: 'GET',
@@ -293,7 +325,9 @@ class AIHandler {
293325
})
294326

295327
if (response.ok) {
296-
return { success: true }
328+
const data = await response.json()
329+
const models = data.data?.map((model: any) => model.id) || []
330+
return { success: true, models }
297331
} else {
298332
return {
299333
success: false,
@@ -303,7 +337,7 @@ class AIHandler {
303337
} catch (error) {
304338
return {
305339
success: false,
306-
error: error instanceof Error ? error.message : 'Connection failed'
340+
error: error instanceof Error ? error.message : 'Failed to fetch models'
307341
}
308342
}
309343
}
@@ -317,4 +351,5 @@ export function setupAIHandlers() {
317351
)
318352
ipcMain.handle('ai:stop-streaming', (event, requestId: string) => aiHandler.stopStreaming(requestId))
319353
ipcMain.handle('ai:test-connection', (event, config: LLMConfig) => aiHandler.testConnection(config))
354+
ipcMain.handle('ai:get-models', (event, config: LLMConfig) => aiHandler.getModels(config))
320355
}

src/preload/index.d.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ export interface TestConnectionResult {
3737
error?: string
3838
}
3939

40+
export interface GetModelsResult {
41+
success: boolean
42+
models?: string[]
43+
error?: string
44+
}
45+
4046
declare global {
4147
interface Window {
4248
electron: ElectronAPI
@@ -45,6 +51,7 @@ declare global {
4551
sendMessageStreaming: (request: AIRequest) => Promise<void>
4652
stopStreaming: (requestId: string) => Promise<void>
4753
testConnection: (config: LLMConfig) => Promise<TestConnectionResult>
54+
getModels: (config: LLMConfig) => Promise<GetModelsResult>
4855
onStreamData: (requestId: string, callback: (data: AIStreamChunk) => void) => void
4956
removeStreamListener: (requestId: string) => void
5057
}

src/preload/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import type {
44
AIRequest,
55
AIStreamChunk,
66
LLMConfig,
7-
TestConnectionResult
7+
TestConnectionResult,
8+
GetModelsResult
89
} from './index.d'
910

1011
// 管理多个流式监听器
@@ -17,6 +18,8 @@ const api = {
1718
ipcRenderer.invoke('ai:send-message-streaming', request),
1819
testConnection: (config: LLMConfig): Promise<TestConnectionResult> =>
1920
ipcRenderer.invoke('ai:test-connection', config),
21+
getModels: (config: LLMConfig): Promise<GetModelsResult> =>
22+
ipcRenderer.invoke('ai:get-models', config),
2023
stopStreaming: (requestId: string): Promise<void> =>
2124
ipcRenderer.invoke('ai:stop-streaming', requestId),
2225
onStreamData: (requestId: string, callback: (data: AIStreamChunk) => void): void => {

src/renderer/src/components/settings/LLMSettings.tsx

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ interface LLMConfigFormProps {
4242
function LLMConfigForm({ open, config, onSave, onCancel }: LLMConfigFormProps) {
4343
const [form] = Form.useForm()
4444
const [loading, setLoading] = useState(false)
45+
const [modelLoading, setModelLoading] = useState(false)
46+
const [models, setModels] = useState<string[]>([])
4547
const { message } = App.useApp()
4648
const { settings } = useSettingsStore()
4749

@@ -53,6 +55,7 @@ function LLMConfigForm({ open, config, onSave, onCancel }: LLMConfigFormProps) {
5355
} else {
5456
form.resetFields()
5557
}
58+
setModels([])
5659
}
5760
}, [open, config, form])
5861

@@ -79,31 +82,59 @@ function LLMConfigForm({ open, config, onSave, onCancel }: LLMConfigFormProps) {
7982
}
8083
}
8184

85+
const fetchModels = async () => {
86+
try {
87+
const apiHost = form.getFieldValue('apiHost')
88+
const apiKey = form.getFieldValue('apiKey')
89+
90+
if (!apiHost || !apiKey) {
91+
message.warning('请先输入 API Host 和 API Key')
92+
return
93+
}
94+
95+
setModelLoading(true)
96+
const result = await window.api.ai.getModels({
97+
id: 'temp',
98+
name: 'temp',
99+
apiHost,
100+
apiKey,
101+
modelName: '',
102+
createdAt: Date.now()
103+
})
104+
105+
if (result.success && result.models) {
106+
setModels(result.models)
107+
message.success(`成功获取 ${result.models.length} 个模型`)
108+
} else {
109+
message.error(result.error || '获取模型列表失败')
110+
setModels([])
111+
}
112+
} catch (error) {
113+
message.error('获取模型列表失败')
114+
setModels([])
115+
} finally {
116+
setModelLoading(false)
117+
}
118+
}
119+
82120
const testConnection = async () => {
83121
try {
84122
const values = await form.validateFields()
85123
setLoading(true)
86124

87-
const tempConfig: LLMConfig = {
125+
const result = await window.api.ai.testConnection({
88126
id: 'temp',
89127
name: values.name,
90128
apiHost: values.apiHost,
91129
apiKey: values.apiKey,
92130
modelName: values.modelName,
93131
createdAt: Date.now()
94-
}
95-
96-
const defaultModelConfig =
97-
settings.modelConfigs.find((c) => c.id === settings.defaultModelConfigId) ||
98-
settings.modelConfigs[0]
99-
const aiService = createAIService(tempConfig, defaultModelConfig)
100-
const isConnected = await aiService.testConnection()
132+
})
101133

102-
console.log('isConnected', isConnected)
103-
if (isConnected) {
134+
if (result.success) {
104135
message.success('连接测试成功')
105136
} else {
106-
message.error('连接测试失败,请检查配置')
137+
message.error(result.error || '连接测试失败,请检查配置')
107138
}
108139
} catch (error) {
109140
message.error('连接测试失败,请检查配置')
@@ -167,8 +198,32 @@ function LLMConfigForm({ open, config, onSave, onCancel }: LLMConfigFormProps) {
167198
name="modelName"
168199
label="模型名称"
169200
rules={[{ required: true, message: '请输入模型名称' }]}
201+
extra={
202+
<Button
203+
type="link"
204+
size="small"
205+
onClick={fetchModels}
206+
loading={modelLoading}
207+
style={{ padding: 0 }}
208+
>
209+
从服务器获取模型列表
210+
</Button>
211+
}
170212
>
171-
<Input placeholder="gpt-4" />
213+
{models.length > 0 ? (
214+
<Select
215+
placeholder="选择或输入模型名称"
216+
showSearch
217+
allowClear
218+
loading={modelLoading}
219+
options={models.map((model) => ({ label: model, value: model }))}
220+
filterOption={(input, option) =>
221+
(option?.label ?? '').toLowerCase().includes(input.toLowerCase())
222+
}
223+
/>
224+
) : (
225+
<Input placeholder="请输入模型名称,例如: gpt-4" />
226+
)}
172227
</Form.Item>
173228

174229
<Form.Item

0 commit comments

Comments
 (0)