diff --git a/src/client.ts b/src/client.ts index ac082a5e13..1839726c64 100644 --- a/src/client.ts +++ b/src/client.ts @@ -246,6 +246,8 @@ import { isEmptyObj } from './internal/utils/values'; const WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER = 'workload-identity-auth'; +type FetchWithTimeoutResult = { response: Response; cleanup: () => void }; + export type ApiKeySetter = () => Promise; export interface ClientOptions { @@ -741,10 +743,11 @@ export class OpenAI { const security = options.__security ?? { bearerAuth: true }; const controller = new AbortController(); - const response = await this.fetchWithAuth(url, req, timeout, controller, security).catch(castToError); + const fetchResult = await this.fetchWithAuth(url, req, timeout, controller, security).catch(castToError); const headersTime = Date.now(); - if (response instanceof globalThis.Error) { + if (fetchResult instanceof globalThis.Error) { + const response = fetchResult; const retryMessage = `retrying, ${retriesRemaining} attempts remaining`; if (options.signal?.aborted) { throw new Errors.APIUserAbortError(); @@ -795,6 +798,8 @@ export class OpenAI { }); } + const { response, cleanup } = fetchResult; + const specialHeaders = [...response.headers.entries()] .filter(([name]) => name === 'x-request-id') .map(([name, value]) => ', ' + name + ': ' + JSON.stringify(value)) @@ -812,6 +817,7 @@ export class OpenAI { !options.__metadata?.['workloadIdentityTokenRefreshed'] ) { await Shims.CancelReadableStream(response.body); + cleanup(); this._workloadIdentityAuth.invalidateToken(); return this.makeRequest( @@ -833,6 +839,7 @@ export class OpenAI { // We don't need the body of this response. await Shims.CancelReadableStream(response.body); + cleanup(); loggerFor(this).info(`${responseInfo} - ${retryMessage}`); loggerFor(this).debug( `[${requestLogID}] response error (${retryMessage})`, @@ -856,7 +863,10 @@ export class OpenAI { loggerFor(this).info(`${responseInfo} - ${retryMessage}`); - const errText = await response.text().catch((err: any) => castToError(err).message); + const errText = await response + .text() + .catch((err: any) => castToError(err).message) + .finally(cleanup); const errJSON = safeJSON(errText) as any; const errMessage = errJSON ? undefined : errText; @@ -888,7 +898,7 @@ export class OpenAI { }), ); - return { response, options, controller, requestLogID, retryOfRequestLogID, startTime }; + return { response, options, controller, requestLogID, retryOfRequestLogID, startTime, cleanup }; } getAPIList = Pagination.AbstractPage>( @@ -924,7 +934,7 @@ export class OpenAI { bearerAuth: true, adminAPIKeyAuth: true, }, - ): Promise { + ): Promise { if (this._workloadIdentityAuth && schemes.bearerAuth) { const headers = init.headers as Headers; const authHeader = headers.get('Authorization'); @@ -944,12 +954,19 @@ export class OpenAI { init: RequestInit | undefined, ms: number, controller: AbortController, - ): Promise { + ): Promise { const { signal, method, ...options } = init || {}; const abort = this._makeAbort(controller); if (signal) signal.addEventListener('abort', abort, { once: true }); const timeout = setTimeout(abort, ms); + let cleanedUp = false; + const cleanup = () => { + if (cleanedUp) return; + cleanedUp = true; + clearTimeout(timeout); + if (signal) signal.removeEventListener('abort', abort); + }; const isReadableBody = ((globalThis as any).ReadableStream && options.body instanceof (globalThis as any).ReadableStream) || @@ -969,9 +986,13 @@ export class OpenAI { try { // use undefined this binding; fetch errors if bound to something else in browser/cloudflare - return await this.fetch.call(undefined, url, fetchOptions); - } finally { + const response = await this.fetch.call(undefined, url, fetchOptions); clearTimeout(timeout); + if (!response.body) cleanup(); + return { response, cleanup }; + } catch (err) { + cleanup(); + throw err; } } diff --git a/src/core/streaming.ts b/src/core/streaming.ts index 3ea5f21524..086b65dc8b 100644 --- a/src/core/streaming.ts +++ b/src/core/streaming.ts @@ -36,6 +36,7 @@ export class Stream implements AsyncIterable { controller: AbortController, client?: OpenAI, synthesizeEventData?: boolean, + cleanup?: () => void, ): Stream { let consumed = false; const logger = client ? loggerFor(client) : console; @@ -95,6 +96,7 @@ export class Stream implements AsyncIterable { } finally { // If the user `break`s, abort the ongoing request. if (!done) controller.abort(); + cleanup?.(); } } diff --git a/src/internal/parse.ts b/src/internal/parse.ts index 39174399e8..4eed3a6c44 100644 --- a/src/internal/parse.ts +++ b/src/internal/parse.ts @@ -13,46 +13,60 @@ export type APIResponseProps = { requestLogID: string; retryOfRequestLogID: string | undefined; startTime: number; + cleanup?: () => void; }; export async function defaultParseResponse( client: OpenAI, props: APIResponseProps, ): Promise> { - const { response, requestLogID, retryOfRequestLogID, startTime } = props; - const body = await (async () => { - if (props.options.stream) { - loggerFor(client).debug('response', response.status, response.url, response.headers, response.body); - - // Note: there is an invariant here that isn't represented in the type system - // that if you set `stream: true` the response type must also be `Stream` - - if (props.options.__streamClass) { - return props.options.__streamClass.fromSSEResponse( - response, - props.controller, - client, - props.options.__synthesizeEventData, - ) as any; - } + const { response, requestLogID, retryOfRequestLogID, startTime, cleanup } = props; + if (props.options.stream) { + loggerFor(client).debug('response', response.status, response.url, response.headers, response.body); + + // Note: there is an invariant here that isn't represented in the type system + // that if you set `stream: true` the response type must also be `Stream` - return Stream.fromSSEResponse( + if (props.options.__streamClass) { + return props.options.__streamClass.fromSSEResponse( response, props.controller, client, props.options.__synthesizeEventData, + cleanup, ) as any; } + return Stream.fromSSEResponse( + response, + props.controller, + client, + props.options.__synthesizeEventData, + cleanup, + ) as any; + } + + if (props.options.__binaryResponse) { + const body = wrapResponseBodyWithCleanup(response, cleanup); + loggerFor(client).debug( + `[${requestLogID}] response parsed`, + formatRequestDetails({ + retryOfRequestLogID, + url: response.url, + status: response.status, + body, + durationMs: Date.now() - startTime, + }), + ); + return body as WithRequestID; + } + + const body = await (async () => { // fetch refuses to read the body when the status code is 204. if (response.status === 204) { return null as T; } - if (props.options.__binaryResponse) { - return response as unknown as T; - } - const contentType = response.headers.get('content-type'); const mediaType = contentType?.split(';')[0]?.trim(); const isJSON = mediaType?.includes('application/json') || mediaType?.endsWith('+json'); @@ -69,7 +83,7 @@ export async function defaultParseResponse( const text = await response.text(); return text as unknown as T; - })(); + })().finally(() => cleanup?.()); loggerFor(client).debug( `[${requestLogID}] response parsed`, formatRequestDetails({ @@ -80,7 +94,7 @@ export async function defaultParseResponse( durationMs: Date.now() - startTime, }), ); - return body; + return body as WithRequestID; } export type WithRequestID = @@ -88,6 +102,49 @@ export type WithRequestID = : T extends Record ? T & { _request_id?: string | null } : T; +function wrapResponseBodyWithCleanup(response: Response, cleanup: (() => void) | undefined): Response { + if (!cleanup) return response; + if (!response.body) { + cleanup(); + return response; + } + + const reader = response.body.getReader(); + const body = new ReadableStream({ + async pull(controller) { + try { + const { done, value } = await reader.read(); + if (done) { + cleanup(); + controller.close(); + return; + } + controller.enqueue(value); + } catch (err) { + cleanup(); + controller.error(err); + } + }, + async cancel(reason) { + cleanup(); + await reader.cancel(reason); + }, + }); + const wrapped = new Response(body, response); + + try { + Object.defineProperties(wrapped, { + redirected: { value: response.redirected }, + type: { value: response.type }, + url: { value: response.url }, + }); + } catch { + // Some fetch implementations may expose non-configurable Response fields. + } + + return wrapped; +} + export function addRequestID(value: T, response: Response): WithRequestID { if (!value || typeof value !== 'object' || Array.isArray(value)) { return value as WithRequestID; diff --git a/tests/index.test.ts b/tests/index.test.ts index 028eccb17a..ed4c74511f 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -685,6 +685,206 @@ describe('retries', () => { expect(count).toEqual(3); }); + test('removes caller abort listener after successful response parsing', async () => { + const callerController = new AbortController(); + const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener'); + const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener'); + const testFetch = async (): Promise => + new Response(JSON.stringify({ a: 1 }), { headers: { 'Content-Type': 'application/json' } }); + + const client = new OpenAI({ + apiKey: 'My API Key', + adminAPIKey: 'My Admin API Key', + timeout: 1000, + fetch: testFetch, + }); + + expect( + await client.request({ + path: '/foo', + method: 'get', + signal: callerController.signal, + }), + ).toEqual({ a: 1 }); + + const abortListener = addEventListenerSpy.mock.calls[0]?.[1]; + expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true }); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener); + }); + + test('keeps caller abort forwarding until response body parsing settles', async () => { + const callerController = new AbortController(); + const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener'); + const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener'); + const encoder = new TextEncoder(); + let fetchSignal: AbortSignal | undefined; + + const testFetch = async (url: string | URL | Request, init: RequestInit = {}): Promise => { + fetchSignal = init.signal as AbortSignal; + return new Response( + new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode('{"a":')); + fetchSignal?.addEventListener('abort', () => controller.error(new Error('body aborted')), { + once: true, + }); + }, + }), + { headers: { 'Content-Type': 'application/json' } }, + ); + }; + + const client = new OpenAI({ + apiKey: 'My API Key', + adminAPIKey: 'My Admin API Key', + timeout: 1000, + maxRetries: 0, + fetch: testFetch, + }); + + const parsePromise = client + .request({ + path: '/foo', + method: 'get', + signal: callerController.signal, + }) + .then( + () => undefined, + (err) => err, + ); + + for (let i = 0; i < 5 && !fetchSignal; i++) { + await new Promise((resolve) => setTimeout(resolve, 0)); + } + + expect(fetchSignal?.aborted).toBe(false); + expect(removeEventListenerSpy).not.toHaveBeenCalled(); + + callerController.abort(); + + expect(fetchSignal?.aborted).toBe(true); + await expect(parsePromise).resolves.toBeInstanceOf(Error); + + const abortListener = addEventListenerSpy.mock.calls[0]?.[1]; + expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true }); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener); + }); + + test('removes caller abort listener after streaming response is consumed', async () => { + const callerController = new AbortController(); + const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener'); + const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener'); + const testFetch = async (): Promise => + new Response('data: {"a":1}\n\ndata: [DONE]\n\n', { + headers: { 'Content-Type': 'text/event-stream' }, + }); + + const client = new OpenAI({ + apiKey: 'My API Key', + adminAPIKey: 'My Admin API Key', + timeout: 1000, + fetch: testFetch, + }); + + const stream = await client.request({ + path: '/foo', + method: 'get', + stream: true, + signal: callerController.signal, + }); + const chunks: unknown[] = []; + + for await (const chunk of stream as AsyncIterable) { + chunks.push(chunk); + } + + expect(chunks).toEqual([{ a: 1 }]); + const abortListener = addEventListenerSpy.mock.calls[0]?.[1]; + expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true }); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener); + }); + + test('removes caller abort listener after binary response body is consumed', async () => { + const callerController = new AbortController(); + const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener'); + const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener'); + const testFetch = async (): Promise => new Response('binary data'); + + const client = new OpenAI({ + apiKey: 'My API Key', + adminAPIKey: 'My Admin API Key', + timeout: 1000, + fetch: testFetch, + }); + + const response = await client.request({ + path: '/foo', + method: 'get', + signal: callerController.signal, + __binaryResponse: true, + }); + + expect(removeEventListenerSpy).not.toHaveBeenCalled(); + expect(await response.text()).toBe('binary data'); + + const abortListener = addEventListenerSpy.mock.calls[0]?.[1]; + expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true }); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener); + }); + + test('keeps caller abort forwarding until binary response body settles', async () => { + const callerController = new AbortController(); + const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener'); + const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener'); + const encoder = new TextEncoder(); + let fetchSignal: AbortSignal | undefined; + + const testFetch = async (url: string | URL | Request, init: RequestInit = {}): Promise => { + fetchSignal = init.signal as AbortSignal; + return new Response( + new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode('partial')); + fetchSignal?.addEventListener('abort', () => controller.error(new Error('body aborted')), { + once: true, + }); + }, + }), + ); + }; + + const client = new OpenAI({ + apiKey: 'My API Key', + adminAPIKey: 'My Admin API Key', + timeout: 1000, + maxRetries: 0, + fetch: testFetch, + }); + + const response = await client.request({ + path: '/foo', + method: 'get', + signal: callerController.signal, + __binaryResponse: true, + }); + const readPromise = response.text().then( + () => undefined, + (err) => err, + ); + + expect(fetchSignal?.aborted).toBe(false); + expect(removeEventListenerSpy).not.toHaveBeenCalled(); + + callerController.abort(); + + expect(fetchSignal?.aborted).toBe(true); + await expect(readPromise).resolves.toBeInstanceOf(Error); + + const abortListener = addEventListenerSpy.mock.calls[0]?.[1]; + expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true }); + expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener); + }); + test('retry count header', async () => { let count = 0; let capturedRequest: RequestInit | undefined;