-
Notifications
You must be signed in to change notification settings - Fork 334
Expand file tree
/
Copy pathopenai.ts
More file actions
327 lines (304 loc) · 11.3 KB
/
openai.ts
File metadata and controls
327 lines (304 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import type { ChatStreamEvent, ChatRequest, ContentBlock } from "../types";
import type { AgentModelConfig } from "../types";
import { isContentBlocks } from "../content_utils";
import {
generateAttachmentId,
convertTextBlock,
convertFileBlock,
imageBlockFallback,
audioBlockFallback,
readSSEStream,
} from "./content_utils";
// 将 ContentBlock[] 转换为 OpenAI content 数组格式
function convertContentBlocks(
blocks: ContentBlock[],
attachmentResolver?: (id: string) => string | null
): Array<Record<string, unknown>> {
const result: Array<Record<string, unknown>> = [];
for (const block of blocks) {
switch (block.type) {
case "text":
result.push(convertTextBlock(block));
break;
case "image": {
const data = attachmentResolver?.(block.attachmentId);
if (data) {
// OpenAI 格式:image_url
result.push({ type: "image_url", image_url: { url: data } });
} else {
result.push(imageBlockFallback(block));
}
break;
}
case "file":
result.push(convertFileBlock(block));
break;
case "audio": {
const data = attachmentResolver?.(block.attachmentId);
if (data) {
const match = data.match(/^data:([^;]+);base64,(.+)$/s);
if (match) {
// 从 mimeType 提取格式 (e.g. "audio/wav" → "wav")
const format = block.mimeType.split("/")[1] || "wav";
result.push({ type: "input_audio", input_audio: { data: match[2], format } });
} else {
result.push(audioBlockFallback(block));
}
} else {
result.push(audioBlockFallback(block));
}
break;
}
}
}
return result;
}
// 构造 OpenAI 兼容格式的请求
export function buildOpenAIRequest(
config: AgentModelConfig,
request: ChatRequest,
attachmentResolver?: (id: string) => string | null
): { url: string; init: RequestInit } {
const baseUrl = config.apiBaseUrl || "https://api.openai.com/v1";
const url = `${baseUrl}/chat/completions`;
const headers: Record<string, string> = {
"Content-Type": "application/json",
};
if (config.apiKey) {
headers["Authorization"] = `Bearer ${config.apiKey}`;
}
const messages = request.messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role };
// 处理 ContentBlock[] 格式的消息内容
if (isContentBlocks(m.content)) {
msg.content = convertContentBlocks(m.content, attachmentResolver);
} else {
msg.content = m.content;
}
if (m.toolCallId) {
msg.tool_call_id = m.toolCallId;
}
// assistant 消息带 tool_calls 时,转换为 OpenAI 格式
if (m.toolCalls && m.toolCalls.length > 0) {
msg.tool_calls = m.toolCalls.map((tc) => ({
id: tc.id,
type: "function",
function: { name: tc.name, arguments: tc.arguments },
}));
}
return msg;
});
const body: Record<string, unknown> = {
model: config.model,
messages,
stream: true,
stream_options: { include_usage: true },
};
if (config.maxTokens) {
body.max_tokens = config.maxTokens;
}
// 添加工具定义
if (request.tools && request.tools.length > 0) {
body.tools = request.tools.map((t) => ({
type: "function",
function: {
name: t.name,
description: t.description,
parameters: t.parameters,
},
}));
}
return {
url,
init: {
method: "POST",
headers,
body: JSON.stringify(body),
},
};
}
// 返回 input 末尾与 tag 前缀匹配的最长长度(用于跨 chunk 缓存被拆开的标签残片)
function longestTagPrefixSuffix(input: string, tag: string): number {
const max = Math.min(input.length, tag.length - 1);
for (let i = max; i > 0; i--) {
if (input.endsWith(tag.slice(0, i))) {
return i;
}
}
return 0;
}
// 解析 OpenAI SSE 流,生成 ChatStreamEvent
export function parseOpenAIStream(
reader: ReadableStreamDefaultReader<Uint8Array>,
onEvent: (event: ChatStreamEvent) => void,
signal: AbortSignal
): Promise<void> {
// 记录最新的 usage 数据(某些 API 如 Grok 在每个 chunk 都带 usage,而非仅最后一个)
let lastUsage:
| { inputTokens: number; outputTokens: number; cacheCreationInputTokens?: number; cacheReadInputTokens?: number }
| undefined;
// 标记是否已通过 [DONE] 信号发出了 done 事件,避免 .then() 再次发出
let doneSent = false;
// 跨 chunk 追踪 <think>...</think> 块状态(用于把思考混在 content 里的模型)
let inThinkBlock = false;
// 跨 chunk 保留可能属于标签前缀的残片(例如 chunk 末尾 "<th",等待下一个 chunk 的 "ink>")
let thinkTagCarry = "";
// 流结束时将未匹配到完整标签的残片原样输出,避免丢内容
const flushThinkCarry = () => {
if (thinkTagCarry.length > 0) {
onEvent({
type: inThinkBlock ? "thinking_delta" : "content_delta",
delta: thinkTagCarry,
});
thinkTagCarry = "";
}
};
return readSSEStream(
reader,
signal,
(sseEvent) => {
if (sseEvent.data === "[DONE]") {
flushThinkCarry();
doneSent = true;
onEvent({ type: "done", usage: lastUsage });
return true;
}
try {
const json = JSON.parse(sseEvent.data);
// 处理 API 错误响应
if (json.error) {
doneSent = true;
onEvent({
type: "error",
message: json.error.message || JSON.stringify(json.error),
});
return true;
}
const choice = json.choices?.[0];
if (choice) {
const delta = choice.delta;
if (delta) {
// 思考过程增量(reasoning_content 兼容 deepseek / openai o-series)
if (delta.reasoning_content) {
onEvent({ type: "thinking_delta", delta: delta.reasoning_content });
}
// 内容增量(可能是字符串或数组,GPT-4o 图片生成时为数组)
if (delta.content) {
if (Array.isArray(delta.content)) {
for (const part of delta.content) {
if (part.type === "text" && part.text) {
onEvent({ type: "content_delta", delta: part.text });
} else if (part.type === "image_url" && part.image_url?.url) {
// 模型生成的图片,通过 content_block_complete 事件传递 data URL
const dataUrl: string = part.image_url.url;
const mimeMatch = dataUrl.match(/^data:([^;]+);/);
const mimeType = mimeMatch ? mimeMatch[1] : "image/png";
const ext = mimeType.split("/")[1] || "png";
onEvent({
type: "content_block_complete",
block: {
type: "image",
attachmentId: generateAttachmentId(ext),
mimeType,
name: "generated-image",
},
data: dataUrl,
});
}
}
} else {
// 处理 <think>...</think> 内联标签(reasoning 模型)
// 思考内容路由为 thinking_delta,避免裸露标签出现在对话里
// 标签可能被 SSE chunk 拆开(如 "<th" + "ink>"),用 carry 保留末尾可能的标签前缀
let remaining: string = thinkTagCarry + delta.content;
thinkTagCarry = "";
while (remaining.length > 0) {
const tag = inThinkBlock ? "</think>" : "<think>";
const idx = remaining.indexOf(tag);
if (idx === -1) {
// 未找到完整标签,保留末尾可能匹配标签前缀的残片
const carryLen = longestTagPrefixSuffix(remaining, tag);
const emittable = remaining.slice(0, remaining.length - carryLen);
if (emittable.length > 0) {
onEvent({
type: inThinkBlock ? "thinking_delta" : "content_delta",
delta: emittable,
});
}
thinkTagCarry = remaining.slice(remaining.length - carryLen);
remaining = "";
} else {
// 找到标签:标签前的部分按当前状态输出,之后切换状态
if (idx > 0) {
onEvent({
type: inThinkBlock ? "thinking_delta" : "content_delta",
delta: remaining.slice(0, idx),
});
}
inThinkBlock = !inThinkBlock;
remaining = remaining.slice(idx + tag.length);
}
}
}
}
// 工具调用
if (delta.tool_calls) {
for (const tc of delta.tool_calls) {
// OpenAI 约定:第一个 chunk 带 id + function.name,后续 chunk 只带 index + function.arguments
if (tc.function?.name) {
onEvent({
type: "tool_call_start",
toolCall: {
id: tc.id || `tc_${Date.now()}_${tc.index ?? 0}`,
name: tc.function.name,
arguments: "", // 永远空启动,避免首 chunk 的 "{}" 作为 prefix 污染
},
});
}
// 首 chunk 带 arguments 也作为 delta 处理(不 else if!)
if (tc.function?.arguments !== undefined && tc.function.arguments !== "") {
onEvent({
type: "tool_call_delta",
id: tc.id || "", // 后续 chunk 大概率无 id,这里只保留接口兼容
index: tc.index, // 用于匹配的字段
delta: tc.function.arguments,
});
}
}
}
}
}
// 记录 usage(不作为结束信号,兼容每个 chunk 都带 usage 的 API)
if (json.usage) {
const cachedTokens = json.usage.prompt_tokens_details?.cached_tokens;
lastUsage = {
inputTokens: json.usage.prompt_tokens || 0,
outputTokens: json.usage.completion_tokens || 0,
...(cachedTokens ? { cacheReadInputTokens: cachedTokens } : {}),
};
}
} catch {
// 解析失败忽略
}
return false;
},
(message) => {
doneSent = true;
onEvent({ type: "error", message });
}
).then(() => {
// 流正常结束但没收到 [DONE](某些 API 可能如此)
if (!signal.aborted && !doneSent) {
flushThinkCarry();
onEvent({ type: "done", usage: lastUsage });
}
});
}
// ---- LLMProvider 接口适配 ----
import type { LLMProvider } from "./types";
/** OpenAI 兼容格式的 Provider 实现(注册在 providers/index.ts) */
export const openaiProvider: LLMProvider = {
name: "openai",
buildRequest: (input) => buildOpenAIRequest(input.model, input.request, input.resolver),
parseStream: (reader, onEvent, signal) => parseOpenAIStream(reader, onEvent, signal),
};