|
| 1 | +import type { Context, NextFunction } from "grammy"; |
| 2 | +import type { FilePartInput, Model } from "@opencode-ai/sdk/v2"; |
| 3 | +import { config } from "../../config.js"; |
| 4 | +import { t } from "../../i18n/index.js"; |
| 5 | +import { getModelCapabilities, supportsInput } from "../../model/capabilities.js"; |
| 6 | +import { getStoredModel } from "../../model/manager.js"; |
| 7 | +import { logger } from "../../utils/logger.js"; |
| 8 | +import { |
| 9 | + downloadTelegramFile, |
| 10 | + isFileSizeAllowed, |
| 11 | + isTextMimeType, |
| 12 | + toDataUri, |
| 13 | +} from "../utils/file-download.js"; |
| 14 | +import { processUserPrompt, type ProcessPromptDeps } from "./prompt.js"; |
| 15 | + |
| 16 | +const DEFAULT_MEDIA_GROUP_DEBOUNCE_MS = 1_000; |
| 17 | + |
| 18 | +type TelegramDocument = NonNullable<NonNullable<Context["message"]>["document"]>; |
| 19 | +type TelegramPhoto = NonNullable<NonNullable<Context["message"]>["photo"]>; |
| 20 | + |
| 21 | +type PendingMediaGroupItem = { |
| 22 | + ctx: Context; |
| 23 | + messageId: number; |
| 24 | + caption: string; |
| 25 | +} & ( |
| 26 | + | { |
| 27 | + kind: "photo"; |
| 28 | + photos: TelegramPhoto; |
| 29 | + } |
| 30 | + | { |
| 31 | + kind: "document"; |
| 32 | + document: TelegramDocument; |
| 33 | + } |
| 34 | + | { |
| 35 | + kind: "unsupported"; |
| 36 | + } |
| 37 | +); |
| 38 | + |
| 39 | +type ValidMediaGroupItem = |
| 40 | + | { |
| 41 | + kind: "file"; |
| 42 | + ctx: Context; |
| 43 | + messageId: number; |
| 44 | + fileId: string; |
| 45 | + mime: string; |
| 46 | + filename: string; |
| 47 | + } |
| 48 | + | { |
| 49 | + kind: "text"; |
| 50 | + ctx: Context; |
| 51 | + fileId: string; |
| 52 | + filename: string; |
| 53 | + }; |
| 54 | + |
| 55 | +interface MediaGroupBatch { |
| 56 | + timer: ReturnType<typeof setTimeout>; |
| 57 | + items: PendingMediaGroupItem[]; |
| 58 | +} |
| 59 | + |
| 60 | +interface ValidatedMediaGroup { |
| 61 | + items: ValidMediaGroupItem[]; |
| 62 | +} |
| 63 | + |
| 64 | +interface MediaGroupValidationError { |
| 65 | + reason: string; |
| 66 | +} |
| 67 | + |
| 68 | +export interface MediaGroupHandlerDeps extends ProcessPromptDeps { |
| 69 | + downloadFile?: ( |
| 70 | + api: Context["api"], |
| 71 | + fileId: string, |
| 72 | + ) => Promise<{ buffer: Buffer; filePath: string }>; |
| 73 | + getModelCapabilities?: ( |
| 74 | + providerId: string, |
| 75 | + modelId: string, |
| 76 | + ) => Promise<Model["capabilities"] | null>; |
| 77 | + getStoredModel?: () => { providerID: string; modelID: string }; |
| 78 | + processPrompt?: ( |
| 79 | + ctx: Context, |
| 80 | + text: string, |
| 81 | + deps: ProcessPromptDeps, |
| 82 | + fileParts?: FilePartInput[], |
| 83 | + ) => Promise<boolean>; |
| 84 | +} |
| 85 | + |
| 86 | +export interface MediaGroupHandlerOptions { |
| 87 | + debounceMs?: number; |
| 88 | +} |
| 89 | + |
| 90 | +export class MediaGroupAttachmentHandler { |
| 91 | + private readonly deps: MediaGroupHandlerDeps; |
| 92 | + private readonly debounceMs: number; |
| 93 | + private readonly batches = new Map<string, MediaGroupBatch>(); |
| 94 | + |
| 95 | + constructor(deps: MediaGroupHandlerDeps, options: MediaGroupHandlerOptions = {}) { |
| 96 | + this.deps = deps; |
| 97 | + this.debounceMs = options.debounceMs ?? DEFAULT_MEDIA_GROUP_DEBOUNCE_MS; |
| 98 | + } |
| 99 | + |
| 100 | + async handle(ctx: Context, next: NextFunction): Promise<void> { |
| 101 | + const item = this.createPendingItem(ctx); |
| 102 | + |
| 103 | + if (!item) { |
| 104 | + await next(); |
| 105 | + return; |
| 106 | + } |
| 107 | + |
| 108 | + const mediaGroupId = ctx.message?.media_group_id; |
| 109 | + const chatId = ctx.chat?.id; |
| 110 | + if (!mediaGroupId || chatId === undefined) { |
| 111 | + await next(); |
| 112 | + return; |
| 113 | + } |
| 114 | + |
| 115 | + const key = this.getBatchKey(chatId, mediaGroupId); |
| 116 | + const existingBatch = this.batches.get(key); |
| 117 | + |
| 118 | + if (existingBatch) { |
| 119 | + clearTimeout(existingBatch.timer); |
| 120 | + existingBatch.items.push(item); |
| 121 | + existingBatch.timer = this.createFlushTimer(key); |
| 122 | + return; |
| 123 | + } |
| 124 | + |
| 125 | + this.batches.set(key, { |
| 126 | + items: [item], |
| 127 | + timer: this.createFlushTimer(key), |
| 128 | + }); |
| 129 | + } |
| 130 | + |
| 131 | + async flushAll(): Promise<void> { |
| 132 | + const keys = Array.from(this.batches.keys()); |
| 133 | + await Promise.all(keys.map((key) => this.flushBatch(key))); |
| 134 | + } |
| 135 | + |
| 136 | + private createPendingItem(ctx: Context): PendingMediaGroupItem | null { |
| 137 | + const message = ctx.message; |
| 138 | + const mediaGroupId = message?.media_group_id; |
| 139 | + |
| 140 | + if (!message || !mediaGroupId || !ctx.chat) { |
| 141 | + return null; |
| 142 | + } |
| 143 | + |
| 144 | + const baseItem = { |
| 145 | + ctx, |
| 146 | + messageId: message.message_id, |
| 147 | + caption: message.caption || "", |
| 148 | + }; |
| 149 | + |
| 150 | + if (message.photo && message.photo.length > 0) { |
| 151 | + return { |
| 152 | + ...baseItem, |
| 153 | + kind: "photo", |
| 154 | + photos: message.photo, |
| 155 | + }; |
| 156 | + } |
| 157 | + |
| 158 | + if (message.document) { |
| 159 | + return { |
| 160 | + ...baseItem, |
| 161 | + kind: "document", |
| 162 | + document: message.document, |
| 163 | + }; |
| 164 | + } |
| 165 | + |
| 166 | + return { |
| 167 | + ...baseItem, |
| 168 | + kind: "unsupported", |
| 169 | + }; |
| 170 | + } |
| 171 | + |
| 172 | + private getBatchKey(chatId: number | string, mediaGroupId: string): string { |
| 173 | + return `${chatId}:${mediaGroupId}`; |
| 174 | + } |
| 175 | + |
| 176 | + private createFlushTimer(key: string): ReturnType<typeof setTimeout> { |
| 177 | + return setTimeout(() => { |
| 178 | + void this.flushBatch(key); |
| 179 | + }, this.debounceMs); |
| 180 | + } |
| 181 | + |
| 182 | + private async flushBatch(key: string): Promise<void> { |
| 183 | + const batch = this.batches.get(key); |
| 184 | + if (!batch) { |
| 185 | + return; |
| 186 | + } |
| 187 | + |
| 188 | + clearTimeout(batch.timer); |
| 189 | + this.batches.delete(key); |
| 190 | + |
| 191 | + const items = [...batch.items].sort((left, right) => left.messageId - right.messageId); |
| 192 | + const replyCtx = items[0]?.ctx; |
| 193 | + |
| 194 | + if (!replyCtx) { |
| 195 | + return; |
| 196 | + } |
| 197 | + |
| 198 | + logger.info(`[MediaGroup] Processing Telegram media group: key=${key}, items=${items.length}`); |
| 199 | + |
| 200 | + try { |
| 201 | + const validationResult = await this.validateItems(items); |
| 202 | + if ("reason" in validationResult) { |
| 203 | + logger.warn( |
| 204 | + `[MediaGroup] Rejecting media group: key=${key}, reason=${validationResult.reason}`, |
| 205 | + ); |
| 206 | + await replyCtx.reply(t("bot.media_group_not_processed")); |
| 207 | + return; |
| 208 | + } |
| 209 | + |
| 210 | + await replyCtx.reply(t("bot.files_downloading")); |
| 211 | + |
| 212 | + const { promptText, fileParts } = await this.preparePrompt(validationResult.items, items); |
| 213 | + const processPrompt = this.deps.processPrompt ?? processUserPrompt; |
| 214 | + |
| 215 | + logger.info( |
| 216 | + `[MediaGroup] Sending media group as one prompt: key=${key}, files=${fileParts.length}, textLength=${promptText.length}`, |
| 217 | + ); |
| 218 | + |
| 219 | + await processPrompt(replyCtx, promptText, this.deps, fileParts); |
| 220 | + } catch (err) { |
| 221 | + logger.error(`[MediaGroup] Failed to process media group: key=${key}`, err); |
| 222 | + await replyCtx.reply(t("bot.media_group_download_error")); |
| 223 | + } |
| 224 | + } |
| 225 | + |
| 226 | + private async validateItems( |
| 227 | + items: PendingMediaGroupItem[], |
| 228 | + ): Promise<ValidatedMediaGroup | MediaGroupValidationError> { |
| 229 | + const storedModel = (this.deps.getStoredModel ?? getStoredModel)(); |
| 230 | + const validItems: ValidMediaGroupItem[] = []; |
| 231 | + let needsImageSupport = false; |
| 232 | + let needsPdfSupport = false; |
| 233 | + |
| 234 | + for (const item of items) { |
| 235 | + if (item.kind === "unsupported") { |
| 236 | + return { reason: "unsupported_media_kind" }; |
| 237 | + } |
| 238 | + |
| 239 | + if (item.kind === "photo") { |
| 240 | + needsImageSupport = true; |
| 241 | + const largestPhoto = item.photos[item.photos.length - 1]; |
| 242 | + validItems.push({ |
| 243 | + kind: "file", |
| 244 | + ctx: item.ctx, |
| 245 | + messageId: item.messageId, |
| 246 | + fileId: largestPhoto.file_id, |
| 247 | + mime: "image/jpeg", |
| 248 | + filename: `photo-${item.messageId}.jpg`, |
| 249 | + }); |
| 250 | + continue; |
| 251 | + } |
| 252 | + |
| 253 | + const document = item.document; |
| 254 | + const mimeType = document.mime_type || ""; |
| 255 | + const filename = document.file_name || "document"; |
| 256 | + |
| 257 | + if (isTextMimeType(mimeType)) { |
| 258 | + if (!isFileSizeAllowed(document.file_size, config.files.maxFileSizeKb)) { |
| 259 | + return { reason: "text_file_too_large" }; |
| 260 | + } |
| 261 | + |
| 262 | + validItems.push({ |
| 263 | + kind: "text", |
| 264 | + ctx: item.ctx, |
| 265 | + fileId: document.file_id, |
| 266 | + filename, |
| 267 | + }); |
| 268 | + continue; |
| 269 | + } |
| 270 | + |
| 271 | + if (mimeType.startsWith("image/")) { |
| 272 | + needsImageSupport = true; |
| 273 | + validItems.push({ |
| 274 | + kind: "file", |
| 275 | + ctx: item.ctx, |
| 276 | + messageId: item.messageId, |
| 277 | + fileId: document.file_id, |
| 278 | + mime: mimeType, |
| 279 | + filename, |
| 280 | + }); |
| 281 | + continue; |
| 282 | + } |
| 283 | + |
| 284 | + if (mimeType === "application/pdf") { |
| 285 | + needsPdfSupport = true; |
| 286 | + validItems.push({ |
| 287 | + kind: "file", |
| 288 | + ctx: item.ctx, |
| 289 | + messageId: item.messageId, |
| 290 | + fileId: document.file_id, |
| 291 | + mime: mimeType, |
| 292 | + filename, |
| 293 | + }); |
| 294 | + continue; |
| 295 | + } |
| 296 | + |
| 297 | + return { reason: `unsupported_document_mime:${mimeType || "unknown"}` }; |
| 298 | + } |
| 299 | + |
| 300 | + if (needsImageSupport || needsPdfSupport) { |
| 301 | + const getCapabilities = this.deps.getModelCapabilities ?? getModelCapabilities; |
| 302 | + const capabilities = await getCapabilities(storedModel.providerID, storedModel.modelID); |
| 303 | + |
| 304 | + if (needsImageSupport && !supportsInput(capabilities, "image")) { |
| 305 | + return { reason: `model_no_image:${storedModel.providerID}/${storedModel.modelID}` }; |
| 306 | + } |
| 307 | + |
| 308 | + if (needsPdfSupport && !supportsInput(capabilities, "pdf")) { |
| 309 | + return { reason: `model_no_pdf:${storedModel.providerID}/${storedModel.modelID}` }; |
| 310 | + } |
| 311 | + } |
| 312 | + |
| 313 | + return { items: validItems }; |
| 314 | + } |
| 315 | + |
| 316 | + private async preparePrompt( |
| 317 | + validItems: ValidMediaGroupItem[], |
| 318 | + originalItems: PendingMediaGroupItem[], |
| 319 | + ): Promise<{ promptText: string; fileParts: FilePartInput[] }> { |
| 320 | + const downloadFile = this.deps.downloadFile ?? downloadTelegramFile; |
| 321 | + const textSections: string[] = []; |
| 322 | + const fileParts: FilePartInput[] = []; |
| 323 | + |
| 324 | + for (const item of validItems) { |
| 325 | + const downloadedFile = await downloadFile(item.ctx.api, item.fileId); |
| 326 | + |
| 327 | + if (item.kind === "text") { |
| 328 | + const textContent = downloadedFile.buffer.toString("utf-8"); |
| 329 | + textSections.push( |
| 330 | + `--- Content of ${item.filename} ---\n${textContent}\n--- End of file ---`, |
| 331 | + ); |
| 332 | + continue; |
| 333 | + } |
| 334 | + |
| 335 | + fileParts.push({ |
| 336 | + type: "file", |
| 337 | + mime: item.mime, |
| 338 | + filename: item.filename, |
| 339 | + url: toDataUri(downloadedFile.buffer, item.mime), |
| 340 | + }); |
| 341 | + } |
| 342 | + |
| 343 | + const captions = originalItems |
| 344 | + .map((item) => item.caption.trim()) |
| 345 | + .filter((caption) => caption.length > 0); |
| 346 | + |
| 347 | + return { |
| 348 | + promptText: [...textSections, ...captions].join("\n\n"), |
| 349 | + fileParts, |
| 350 | + }; |
| 351 | + } |
| 352 | +} |
| 353 | + |
| 354 | +export function createMediaGroupAttachmentMiddleware( |
| 355 | + deps: MediaGroupHandlerDeps, |
| 356 | + options: MediaGroupHandlerOptions = {}, |
| 357 | +): (ctx: Context, next: NextFunction) => Promise<void> { |
| 358 | + const handler = new MediaGroupAttachmentHandler(deps, options); |
| 359 | + return (ctx, next) => handler.handle(ctx, next); |
| 360 | +} |
0 commit comments