|
| 1 | +import type { TextStreamPart } from 'ai'; |
| 2 | +import { expect } from 'chai'; |
| 3 | +import type { GenerateTextOptions, GenerationStats, LLM, LlmMessage } from '#shared/llm/llm.model'; |
| 4 | +import { type FlexMetricsSnapshot, OPENAI_FLEX_SERVICE, OpenAIFlex } from './openaiFlex'; |
| 5 | + |
| 6 | +type GenerateHandler = (messages: LlmMessage[], opts?: GenerateTextOptions) => Promise<string>; |
| 7 | +type StreamHandler = (messages: LlmMessage[], onChunk: (chunk: TextStreamPart<any>) => void, opts?: GenerateTextOptions) => Promise<GenerationStats>; |
| 8 | + |
| 9 | +const DEFAULT_STATS: GenerationStats = { |
| 10 | + llmId: 'test:model', |
| 11 | + cost: 0, |
| 12 | + inputTokens: 0, |
| 13 | + outputTokens: 0, |
| 14 | + totalTime: 0, |
| 15 | + timeToFirstToken: 0, |
| 16 | + requestTime: 0, |
| 17 | + finishReason: 'stop', |
| 18 | +}; |
| 19 | + |
| 20 | +class TestLLM implements LLM { |
| 21 | + constructor( |
| 22 | + private readonly name: string, |
| 23 | + private readonly model: string, |
| 24 | + private readonly generateHandler: GenerateHandler, |
| 25 | + private readonly streamHandler: StreamHandler, |
| 26 | + private readonly configured = true, |
| 27 | + ) {} |
| 28 | + |
| 29 | + async generateText( |
| 30 | + userOrSystemOrMessages: string | LlmMessage[] | ReadonlyArray<LlmMessage>, |
| 31 | + userOrOpts?: string | GenerateTextOptions, |
| 32 | + opts?: GenerateTextOptions, |
| 33 | + ): Promise<string> { |
| 34 | + const messages = this.toMessages(userOrSystemOrMessages, userOrOpts, opts); |
| 35 | + return this.generateHandler(messages, this.toOptions(userOrSystemOrMessages, userOrOpts, opts)); |
| 36 | + } |
| 37 | + |
| 38 | + async generateTextWithJson(): Promise<any> { |
| 39 | + throw new Error('Not implemented in TestLLM'); |
| 40 | + } |
| 41 | + |
| 42 | + async generateJson(): Promise<any> { |
| 43 | + throw new Error('Not implemented in TestLLM'); |
| 44 | + } |
| 45 | + |
| 46 | + async generateTextWithResult(): Promise<string> { |
| 47 | + throw new Error('Not implemented in TestLLM'); |
| 48 | + } |
| 49 | + |
| 50 | + async generateMessage(): Promise<LlmMessage> { |
| 51 | + throw new Error('Not implemented in TestLLM'); |
| 52 | + } |
| 53 | + |
| 54 | + streamText( |
| 55 | + messages: LlmMessage[] | ReadonlyArray<LlmMessage>, |
| 56 | + onChunk: (chunk: TextStreamPart<any>) => void, |
| 57 | + opts?: GenerateTextOptions, |
| 58 | + ): Promise<GenerationStats> { |
| 59 | + return this.streamHandler(messages as LlmMessage[], onChunk, opts); |
| 60 | + } |
| 61 | + |
| 62 | + getService(): string { |
| 63 | + return OPENAI_FLEX_SERVICE; |
| 64 | + } |
| 65 | + |
| 66 | + getModel(): string { |
| 67 | + return this.model; |
| 68 | + } |
| 69 | + |
| 70 | + getDisplayName(): string { |
| 71 | + return this.name; |
| 72 | + } |
| 73 | + |
| 74 | + getId(): string { |
| 75 | + return `${this.getService()}:${this.model}`; |
| 76 | + } |
| 77 | + |
| 78 | + getMaxInputTokens(): number { |
| 79 | + return 100_000; |
| 80 | + } |
| 81 | + |
| 82 | + getMaxOutputTokens(): number { |
| 83 | + return 100_000; |
| 84 | + } |
| 85 | + |
| 86 | + countTokens(): Promise<number> { |
| 87 | + return Promise.resolve(0); |
| 88 | + } |
| 89 | + |
| 90 | + isConfigured(): boolean { |
| 91 | + return this.configured; |
| 92 | + } |
| 93 | + |
| 94 | + getOldModels(): string[] { |
| 95 | + return []; |
| 96 | + } |
| 97 | + |
| 98 | + private toMessages( |
| 99 | + userOrSystemOrMessages: string | LlmMessage[] | ReadonlyArray<LlmMessage>, |
| 100 | + userOrOpts?: string | GenerateTextOptions, |
| 101 | + opts?: GenerateTextOptions, |
| 102 | + ): LlmMessage[] { |
| 103 | + if (Array.isArray(userOrSystemOrMessages)) return [...userOrSystemOrMessages]; |
| 104 | + if (typeof userOrOpts === 'string') { |
| 105 | + return [ |
| 106 | + { role: 'system', content: userOrSystemOrMessages as string }, |
| 107 | + { role: 'user', content: userOrOpts }, |
| 108 | + ]; |
| 109 | + } |
| 110 | + return [{ role: 'user', content: userOrSystemOrMessages as string }]; |
| 111 | + } |
| 112 | + |
| 113 | + private toOptions( |
| 114 | + userOrSystemOrMessages: string | LlmMessage[] | ReadonlyArray<LlmMessage>, |
| 115 | + userOrOpts?: string | GenerateTextOptions, |
| 116 | + opts?: GenerateTextOptions, |
| 117 | + ): GenerateTextOptions | undefined { |
| 118 | + if (Array.isArray(userOrSystemOrMessages)) return userOrOpts as GenerateTextOptions | undefined; |
| 119 | + if (typeof userOrOpts === 'string') return opts; |
| 120 | + return userOrOpts as GenerateTextOptions | undefined; |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +describe('OpenAIFlex', () => { |
| 125 | + const messages: LlmMessage[] = [{ role: 'user', content: 'hello' }]; |
| 126 | + |
| 127 | + it('uses flex response when first chunk arrives before timeout', async () => { |
| 128 | + let streamed = ''; |
| 129 | + const flexLLM = new TestLLM( |
| 130 | + 'flex', |
| 131 | + 'flex-model', |
| 132 | + async () => 'unused', |
| 133 | + async (_msgs, onChunk) => { |
| 134 | + onChunk({ type: 'text-delta', id: '1', text: 'flex-response' }); |
| 135 | + streamed += 'flex-response'; |
| 136 | + return DEFAULT_STATS; |
| 137 | + }, |
| 138 | + ); |
| 139 | + const standardLLM = new TestLLM( |
| 140 | + 'standard', |
| 141 | + 'std-model', |
| 142 | + async () => 'standard-response', |
| 143 | + async (_msgs, _onChunk) => DEFAULT_STATS, |
| 144 | + ); |
| 145 | + |
| 146 | + const flex = new OpenAIFlex('Flex Under Test', 'flex-test', standardLLM, flexLLM, 200); |
| 147 | + const response = await flex.generateTextFromMessages(messages); |
| 148 | + const metrics = flex.getMetrics(); |
| 149 | + |
| 150 | + expect(response).to.equal('flex-response'); |
| 151 | + expect(streamed).to.equal('flex-response'); |
| 152 | + expect(metrics.flexAttempts).to.equal(1); |
| 153 | + expect(metrics.flexFallbacks).to.equal(0); |
| 154 | + expect(metrics.flexResponses).to.equal(1); |
| 155 | + expect(metrics.lastFlexResponseMs).to.be.a('number'); |
| 156 | + }); |
| 157 | + |
| 158 | + it('falls back to standard when flex times out before first chunk', async () => { |
| 159 | + const flexLLM = new TestLLM( |
| 160 | + 'flex', |
| 161 | + 'flex-model', |
| 162 | + async () => 'unused', |
| 163 | + async (_msgs, _onChunk, opts) => |
| 164 | + await new Promise<GenerationStats>((_resolve, reject) => { |
| 165 | + opts?.abortSignal?.addEventListener('abort', () => reject(new Error('aborted'))); |
| 166 | + }), |
| 167 | + ); |
| 168 | + const standardLLM = new TestLLM( |
| 169 | + 'standard', |
| 170 | + 'std-model', |
| 171 | + async () => 'standard-response', |
| 172 | + async (_msgs, _onChunk) => DEFAULT_STATS, |
| 173 | + ); |
| 174 | + |
| 175 | + const flex = new OpenAIFlex('Flex Under Test', 'flex-test', standardLLM, flexLLM, 50); |
| 176 | + const response = await flex.generateTextFromMessages(messages); |
| 177 | + const metrics = flex.getMetrics(); |
| 178 | + |
| 179 | + expect(response).to.equal('standard-response'); |
| 180 | + expect(metrics.flexAttempts).to.equal(1); |
| 181 | + expect(metrics.flexFallbacks).to.equal(1); |
| 182 | + expect(metrics.flexResponses).to.equal(0); |
| 183 | + }); |
| 184 | + |
| 185 | + it('falls back if flex fails after first chunk', async () => { |
| 186 | + const flexLLM = new TestLLM( |
| 187 | + 'flex', |
| 188 | + 'flex-model', |
| 189 | + async () => 'unused', |
| 190 | + async (_msgs, onChunk) => |
| 191 | + await new Promise<GenerationStats>((_resolve, reject) => { |
| 192 | + onChunk({ type: 'text-delta', id: '1', text: 'partial' }); |
| 193 | + setTimeout(() => reject(new Error('boom')), 0); |
| 194 | + }), |
| 195 | + ); |
| 196 | + const standardLLM = new TestLLM( |
| 197 | + 'standard', |
| 198 | + 'std-model', |
| 199 | + async () => 'standard-response', |
| 200 | + async (_msgs, _onChunk) => DEFAULT_STATS, |
| 201 | + ); |
| 202 | + |
| 203 | + const flex = new OpenAIFlex('Flex Under Test', 'flex-test', standardLLM, flexLLM, 200); |
| 204 | + const response = await flex.generateTextFromMessages(messages); |
| 205 | + const metrics: FlexMetricsSnapshot = flex.getMetrics(); |
| 206 | + |
| 207 | + expect(response).to.equal('standard-response'); |
| 208 | + expect(metrics.flexFallbacks).to.equal(1); |
| 209 | + expect(metrics.flexResponses).to.equal(1); |
| 210 | + }); |
| 211 | + |
| 212 | + it('streams from standard when flex times out', async () => { |
| 213 | + const flexLLM = new TestLLM( |
| 214 | + 'flex', |
| 215 | + 'flex-model', |
| 216 | + async () => 'unused', |
| 217 | + async (_msgs, _onChunk, opts) => |
| 218 | + await new Promise<GenerationStats>((_resolve, reject) => { |
| 219 | + opts?.abortSignal?.addEventListener('abort', () => reject(new Error('aborted'))); |
| 220 | + }), |
| 221 | + ); |
| 222 | + |
| 223 | + const standardLLM = new TestLLM( |
| 224 | + 'standard', |
| 225 | + 'std-model', |
| 226 | + async () => 'standard-response', |
| 227 | + async (_msgs, onChunk) => { |
| 228 | + onChunk({ type: 'text-delta', id: '1', text: 'S' }); |
| 229 | + return DEFAULT_STATS; |
| 230 | + }, |
| 231 | + ); |
| 232 | + |
| 233 | + let streamed = ''; |
| 234 | + const flex = new OpenAIFlex('Flex Under Test', 'flex-test', standardLLM, flexLLM, 30); |
| 235 | + const stats = await flex.streamText(messages, (chunk) => { |
| 236 | + if (chunk.type === 'text-delta') streamed += chunk.text; |
| 237 | + }); |
| 238 | + |
| 239 | + expect(streamed).to.equal('S'); |
| 240 | + expect(stats.llmId).to.equal('test:model'); |
| 241 | + const metrics = flex.getMetrics(); |
| 242 | + expect(metrics.flexFallbacks).to.equal(1); |
| 243 | + }); |
| 244 | +}); |
0 commit comments