Skip to content

Commit f49bf26

Browse files
committed
fix: clean up fetch timeout abort listener
1 parent 6c11a74 commit f49bf26

4 files changed

Lines changed: 171 additions & 26 deletions

File tree

src/client.ts

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ import { isEmptyObj } from './internal/utils/values';
246246

247247
const WORKLOAD_IDENTITY_API_KEY_PLACEHOLDER = 'workload-identity-auth';
248248

249+
type FetchWithTimeoutResult = { response: Response; cleanup: () => void };
250+
249251
export type ApiKeySetter = () => Promise<string>;
250252

251253
export interface ClientOptions {
@@ -741,10 +743,11 @@ export class OpenAI {
741743

742744
const security = options.__security ?? { bearerAuth: true };
743745
const controller = new AbortController();
744-
const response = await this.fetchWithAuth(url, req, timeout, controller, security).catch(castToError);
746+
const fetchResult = await this.fetchWithAuth(url, req, timeout, controller, security).catch(castToError);
745747
const headersTime = Date.now();
746748

747-
if (response instanceof globalThis.Error) {
749+
if (fetchResult instanceof globalThis.Error) {
750+
const response = fetchResult;
748751
const retryMessage = `retrying, ${retriesRemaining} attempts remaining`;
749752
if (options.signal?.aborted) {
750753
throw new Errors.APIUserAbortError();
@@ -795,6 +798,8 @@ export class OpenAI {
795798
});
796799
}
797800

801+
const { response, cleanup } = fetchResult;
802+
798803
const specialHeaders = [...response.headers.entries()]
799804
.filter(([name]) => name === 'x-request-id')
800805
.map(([name, value]) => ', ' + name + ': ' + JSON.stringify(value))
@@ -812,6 +817,7 @@ export class OpenAI {
812817
!options.__metadata?.['workloadIdentityTokenRefreshed']
813818
) {
814819
await Shims.CancelReadableStream(response.body);
820+
cleanup();
815821
this._workloadIdentityAuth.invalidateToken();
816822

817823
return this.makeRequest(
@@ -833,6 +839,7 @@ export class OpenAI {
833839

834840
// We don't need the body of this response.
835841
await Shims.CancelReadableStream(response.body);
842+
cleanup();
836843
loggerFor(this).info(`${responseInfo} - ${retryMessage}`);
837844
loggerFor(this).debug(
838845
`[${requestLogID}] response error (${retryMessage})`,
@@ -856,7 +863,10 @@ export class OpenAI {
856863

857864
loggerFor(this).info(`${responseInfo} - ${retryMessage}`);
858865

859-
const errText = await response.text().catch((err: any) => castToError(err).message);
866+
const errText = await response
867+
.text()
868+
.catch((err: any) => castToError(err).message)
869+
.finally(cleanup);
860870
const errJSON = safeJSON(errText) as any;
861871
const errMessage = errJSON ? undefined : errText;
862872

@@ -888,7 +898,7 @@ export class OpenAI {
888898
}),
889899
);
890900

891-
return { response, options, controller, requestLogID, retryOfRequestLogID, startTime };
901+
return { response, options, controller, requestLogID, retryOfRequestLogID, startTime, cleanup };
892902
}
893903

894904
getAPIList<Item, PageClass extends Pagination.AbstractPage<Item> = Pagination.AbstractPage<Item>>(
@@ -924,7 +934,7 @@ export class OpenAI {
924934
bearerAuth: true,
925935
adminAPIKeyAuth: true,
926936
},
927-
): Promise<Response> {
937+
): Promise<FetchWithTimeoutResult> {
928938
if (this._workloadIdentityAuth && schemes.bearerAuth) {
929939
const headers = init.headers as Headers;
930940
const authHeader = headers.get('Authorization');
@@ -944,12 +954,19 @@ export class OpenAI {
944954
init: RequestInit | undefined,
945955
ms: number,
946956
controller: AbortController,
947-
): Promise<Response> {
957+
): Promise<FetchWithTimeoutResult> {
948958
const { signal, method, ...options } = init || {};
949959
const abort = this._makeAbort(controller);
950960
if (signal) signal.addEventListener('abort', abort, { once: true });
951961

952962
const timeout = setTimeout(abort, ms);
963+
let cleanedUp = false;
964+
const cleanup = () => {
965+
if (cleanedUp) return;
966+
cleanedUp = true;
967+
clearTimeout(timeout);
968+
if (signal) signal.removeEventListener('abort', abort);
969+
};
953970

954971
const isReadableBody =
955972
((globalThis as any).ReadableStream && options.body instanceof (globalThis as any).ReadableStream) ||
@@ -969,9 +986,13 @@ export class OpenAI {
969986

970987
try {
971988
// use undefined this binding; fetch errors if bound to something else in browser/cloudflare
972-
return await this.fetch.call(undefined, url, fetchOptions);
973-
} finally {
989+
const response = await this.fetch.call(undefined, url, fetchOptions);
974990
clearTimeout(timeout);
991+
if (!response.body) cleanup();
992+
return { response, cleanup };
993+
} catch (err) {
994+
cleanup();
995+
throw err;
975996
}
976997
}
977998

src/core/streaming.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ export class Stream<Item> implements AsyncIterable<Item> {
3636
controller: AbortController,
3737
client?: OpenAI,
3838
synthesizeEventData?: boolean,
39+
cleanup?: () => void,
3940
): Stream<Item> {
4041
let consumed = false;
4142
const logger = client ? loggerFor(client) : console;
@@ -95,6 +96,7 @@ export class Stream<Item> implements AsyncIterable<Item> {
9596
} finally {
9697
// If the user `break`s, abort the ongoing request.
9798
if (!done) controller.abort();
99+
cleanup?.();
98100
}
99101
}
100102

src/internal/parse.ts

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,40 @@ export type APIResponseProps = {
1313
requestLogID: string;
1414
retryOfRequestLogID: string | undefined;
1515
startTime: number;
16+
cleanup?: () => void;
1617
};
1718

1819
export async function defaultParseResponse<T>(
1920
client: OpenAI,
2021
props: APIResponseProps,
2122
): Promise<WithRequestID<T>> {
22-
const { response, requestLogID, retryOfRequestLogID, startTime } = props;
23-
const body = await (async () => {
24-
if (props.options.stream) {
25-
loggerFor(client).debug('response', response.status, response.url, response.headers, response.body);
26-
27-
// Note: there is an invariant here that isn't represented in the type system
28-
// that if you set `stream: true` the response type must also be `Stream<T>`
23+
const { response, requestLogID, retryOfRequestLogID, startTime, cleanup } = props;
24+
if (props.options.stream) {
25+
loggerFor(client).debug('response', response.status, response.url, response.headers, response.body);
2926

30-
if (props.options.__streamClass) {
31-
return props.options.__streamClass.fromSSEResponse(
32-
response,
33-
props.controller,
34-
client,
35-
props.options.__synthesizeEventData,
36-
) as any;
37-
}
27+
// Note: there is an invariant here that isn't represented in the type system
28+
// that if you set `stream: true` the response type must also be `Stream<T>`
3829

39-
return Stream.fromSSEResponse(
30+
if (props.options.__streamClass) {
31+
return props.options.__streamClass.fromSSEResponse(
4032
response,
4133
props.controller,
4234
client,
4335
props.options.__synthesizeEventData,
36+
cleanup,
4437
) as any;
4538
}
4639

40+
return Stream.fromSSEResponse(
41+
response,
42+
props.controller,
43+
client,
44+
props.options.__synthesizeEventData,
45+
cleanup,
46+
) as any;
47+
}
48+
49+
const body = await (async () => {
4750
// fetch refuses to read the body when the status code is 204.
4851
if (response.status === 204) {
4952
return null as T;
@@ -69,7 +72,7 @@ export async function defaultParseResponse<T>(
6972

7073
const text = await response.text();
7174
return text as unknown as T;
72-
})();
75+
})().finally(() => cleanup?.());
7376
loggerFor(client).debug(
7477
`[${requestLogID}] response parsed`,
7578
formatRequestDetails({
@@ -80,7 +83,7 @@ export async function defaultParseResponse<T>(
8083
durationMs: Date.now() - startTime,
8184
}),
8285
);
83-
return body;
86+
return body as WithRequestID<T>;
8487
}
8588

8689
export type WithRequestID<T> =

tests/index.test.ts

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,125 @@ describe('retries', () => {
685685
expect(count).toEqual(3);
686686
});
687687

688+
test('removes caller abort listener after successful response parsing', async () => {
689+
const callerController = new AbortController();
690+
const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener');
691+
const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener');
692+
const testFetch = async (): Promise<Response> =>
693+
new Response(JSON.stringify({ a: 1 }), { headers: { 'Content-Type': 'application/json' } });
694+
695+
const client = new OpenAI({
696+
apiKey: 'My API Key',
697+
adminAPIKey: 'My Admin API Key',
698+
timeout: 1000,
699+
fetch: testFetch,
700+
});
701+
702+
expect(
703+
await client.request({
704+
path: '/foo',
705+
method: 'get',
706+
signal: callerController.signal,
707+
}),
708+
).toEqual({ a: 1 });
709+
710+
const abortListener = addEventListenerSpy.mock.calls[0]?.[1];
711+
expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true });
712+
expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener);
713+
});
714+
715+
test('keeps caller abort forwarding until response body parsing settles', async () => {
716+
const callerController = new AbortController();
717+
const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener');
718+
const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener');
719+
const encoder = new TextEncoder();
720+
let fetchSignal: AbortSignal | undefined;
721+
722+
const testFetch = async (url: string | URL | Request, init: RequestInit = {}): Promise<Response> => {
723+
fetchSignal = init.signal as AbortSignal;
724+
return new Response(
725+
new ReadableStream({
726+
start(controller) {
727+
controller.enqueue(encoder.encode('{"a":'));
728+
fetchSignal?.addEventListener('abort', () => controller.error(new Error('body aborted')), {
729+
once: true,
730+
});
731+
},
732+
}),
733+
{ headers: { 'Content-Type': 'application/json' } },
734+
);
735+
};
736+
737+
const client = new OpenAI({
738+
apiKey: 'My API Key',
739+
adminAPIKey: 'My Admin API Key',
740+
timeout: 1000,
741+
maxRetries: 0,
742+
fetch: testFetch,
743+
});
744+
745+
const parsePromise = client
746+
.request({
747+
path: '/foo',
748+
method: 'get',
749+
signal: callerController.signal,
750+
})
751+
.then(
752+
() => undefined,
753+
(err) => err,
754+
);
755+
756+
for (let i = 0; i < 5 && !fetchSignal; i++) {
757+
await new Promise((resolve) => setTimeout(resolve, 0));
758+
}
759+
760+
expect(fetchSignal?.aborted).toBe(false);
761+
expect(removeEventListenerSpy).not.toHaveBeenCalled();
762+
763+
callerController.abort();
764+
765+
expect(fetchSignal?.aborted).toBe(true);
766+
await expect(parsePromise).resolves.toBeInstanceOf(Error);
767+
768+
const abortListener = addEventListenerSpy.mock.calls[0]?.[1];
769+
expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true });
770+
expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener);
771+
});
772+
773+
test('removes caller abort listener after streaming response is consumed', async () => {
774+
const callerController = new AbortController();
775+
const addEventListenerSpy = jest.spyOn(callerController.signal, 'addEventListener');
776+
const removeEventListenerSpy = jest.spyOn(callerController.signal, 'removeEventListener');
777+
const testFetch = async (): Promise<Response> =>
778+
new Response('data: {"a":1}\n\ndata: [DONE]\n\n', {
779+
headers: { 'Content-Type': 'text/event-stream' },
780+
});
781+
782+
const client = new OpenAI({
783+
apiKey: 'My API Key',
784+
adminAPIKey: 'My Admin API Key',
785+
timeout: 1000,
786+
fetch: testFetch,
787+
});
788+
789+
const stream = await client.request<any>({
790+
path: '/foo',
791+
method: 'get',
792+
stream: true,
793+
signal: callerController.signal,
794+
});
795+
const chunks: unknown[] = [];
796+
797+
for await (const chunk of stream as AsyncIterable<unknown>) {
798+
chunks.push(chunk);
799+
}
800+
801+
expect(chunks).toEqual([{ a: 1 }]);
802+
const abortListener = addEventListenerSpy.mock.calls[0]?.[1];
803+
expect(addEventListenerSpy).toHaveBeenCalledWith('abort', abortListener, { once: true });
804+
expect(removeEventListenerSpy).toHaveBeenCalledWith('abort', abortListener);
805+
});
806+
688807
test('retry count header', async () => {
689808
let count = 0;
690809
let capturedRequest: RequestInit | undefined;

0 commit comments

Comments
 (0)