diff --git a/src/client.ts b/src/client.ts index 678ebdc..2c4e80e 100644 --- a/src/client.ts +++ b/src/client.ts @@ -36,12 +36,14 @@ import type { ClientOptions, RetryConfig, ServerInfo } from "./types.js"; import { ConfigWatcher } from "./watcher.js"; /** - * Options for get() with nullable and per-call timeout support. + * Options for get() with nullable, per-call timeout, and cancellation support. */ interface GetOptions { readonly nullable?: boolean; /** Per-call timeout in ms. Overrides the client default. */ readonly timeout?: number; + /** Cancels the in-flight RPC when aborted. */ + readonly signal?: AbortSignal; } /** @@ -133,19 +135,19 @@ export class ConfigClient { tenantId: string, fieldPath: string, type: typeof Number, - options?: { timeout?: number }, + options?: { timeout?: number; signal?: AbortSignal }, ): Promise; get( tenantId: string, fieldPath: string, type: typeof Boolean, - options?: { timeout?: number }, + options?: { timeout?: number; signal?: AbortSignal }, ): Promise; get( tenantId: string, fieldPath: string, type: typeof String, - options?: { timeout?: number }, + options?: { timeout?: number; signal?: AbortSignal }, ): Promise; /** * Get a config value with nullable support. @@ -155,19 +157,19 @@ export class ConfigClient { tenantId: string, fieldPath: string, type: typeof Number, - options: { nullable: true; timeout?: number }, + options: { nullable: true; timeout?: number; signal?: AbortSignal }, ): Promise; get( tenantId: string, fieldPath: string, type: typeof Boolean, - options: { nullable: true; timeout?: number }, + options: { nullable: true; timeout?: number; signal?: AbortSignal }, ): Promise; get( tenantId: string, fieldPath: string, type: typeof String, - options: { nullable: true; timeout?: number }, + options: { nullable: true; timeout?: number; signal?: AbortSignal }, ): Promise; get( tenantId: string, @@ -182,6 +184,7 @@ export class ConfigClient { const resp = await this.callGetField( { tenantId, fieldPath, includeDescription: false }, options?.timeout, + options?.signal, ); const cv = resp.value; @@ -210,11 +213,15 @@ export class ConfigClient { * * @returns A record mapping field paths to their string values. */ - async getAll(tenantId: string, options?: { timeout?: number }): Promise> { + async getAll( + tenantId: string, + options?: { timeout?: number; signal?: AbortSignal }, + ): Promise> { const fn = async () => { const resp = await this.callGetConfig( { tenantId, includeDescriptions: false }, options?.timeout, + options?.signal, ); const result: Record = {}; @@ -237,12 +244,13 @@ export class ConfigClient { tenantId: string, fieldPath: string, value: string, - options?: { timeout?: number; idempotencyKey?: string }, + options?: { timeout?: number; idempotencyKey?: string; signal?: AbortSignal }, ): Promise { const fn = async () => { await this.callSetField( { tenantId, fieldPath, value: { stringValue: value } }, options?.timeout, + options?.signal, ); }; @@ -261,7 +269,12 @@ export class ConfigClient { async setMany( tenantId: string, values: Record, - options?: { description?: string; timeout?: number; idempotencyKey?: string }, + options?: { + description?: string; + timeout?: number; + idempotencyKey?: string; + signal?: AbortSignal; + }, ): Promise { const fn = async () => { const updates = Object.entries(values).map(([fieldPath, v]) => ({ @@ -271,6 +284,7 @@ export class ConfigClient { await this.callSetFields( { tenantId, updates, description: options?.description }, options?.timeout, + options?.signal, ); }; @@ -286,10 +300,14 @@ export class ConfigClient { async setNull( tenantId: string, fieldPath: string, - options?: { timeout?: number; idempotencyKey?: string }, + options?: { timeout?: number; idempotencyKey?: string; signal?: AbortSignal }, ): Promise { const fn = async () => { - await this.callSetField({ tenantId, fieldPath, value: undefined }, options?.timeout); + await this.callSetField( + { tenantId, fieldPath, value: undefined }, + options?.timeout, + options?.signal, + ); }; const codes = options?.idempotencyKey @@ -361,9 +379,13 @@ export class ConfigClient { } } - private callGetField(request: GetFieldRequest, timeoutMs?: number): Promise { + private callGetField( + request: GetFieldRequest, + timeoutMs?: number, + signal?: AbortSignal, + ): Promise { return new Promise((resolve, reject) => { - this.configStub.getField( + const call = this.configStub.getField( request, this.metadata, { deadline: Date.now() + (timeoutMs ?? this.timeout) }, @@ -372,12 +394,17 @@ export class ConfigClient { else resolve(resp); }, ); + signal?.addEventListener("abort", () => call.cancel(), { once: true }); }); } - private callGetConfig(request: GetConfigRequest, timeoutMs?: number): Promise { + private callGetConfig( + request: GetConfigRequest, + timeoutMs?: number, + signal?: AbortSignal, + ): Promise { return new Promise((resolve, reject) => { - this.configStub.getConfig( + const call = this.configStub.getConfig( request, this.metadata, { deadline: Date.now() + (timeoutMs ?? this.timeout) }, @@ -386,12 +413,17 @@ export class ConfigClient { else resolve(resp); }, ); + signal?.addEventListener("abort", () => call.cancel(), { once: true }); }); } - private callSetField(request: SetFieldRequest, timeoutMs?: number): Promise { + private callSetField( + request: SetFieldRequest, + timeoutMs?: number, + signal?: AbortSignal, + ): Promise { return new Promise((resolve, reject) => { - this.configStub.setField( + const call = this.configStub.setField( request, this.metadata, { deadline: Date.now() + (timeoutMs ?? this.timeout) }, @@ -400,12 +432,17 @@ export class ConfigClient { else resolve(resp); }, ); + signal?.addEventListener("abort", () => call.cancel(), { once: true }); }); } - private callSetFields(request: SetFieldsRequest, timeoutMs?: number): Promise { + private callSetFields( + request: SetFieldsRequest, + timeoutMs?: number, + signal?: AbortSignal, + ): Promise { return new Promise((resolve, reject) => { - this.configStub.setFields( + const call = this.configStub.setFields( request, this.metadata, { deadline: Date.now() + (timeoutMs ?? this.timeout) }, @@ -414,6 +451,7 @@ export class ConfigClient { else resolve(resp); }, ); + signal?.addEventListener("abort", () => call.cancel(), { once: true }); }); } diff --git a/test/client.test.ts b/test/client.test.ts index 3ed54f2..00d1e8b 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -558,6 +558,56 @@ describe("ConfigClient", () => { }); }); + describe("AbortSignal", () => { + function makeCancellableStub(mock: MockInstance) { + mock.mockImplementation( + (_req: unknown, _meta: unknown, _opts: unknown, cb: (...args: unknown[]) => void) => ({ + cancel: () => cb(makeServiceError(status.CANCELLED, "rpc cancelled"), undefined), + }), + ); + } + + it("get() cancels the in-flight call when signal is aborted", async () => { + makeCancellableStub(configStub.getField); + const controller = new AbortController(); + const p = client.get("tenant-1", "f", String, { signal: controller.signal }); + controller.abort(); + await expect(p).rejects.toThrow(DecreeError); + }); + + it("getAll() cancels the in-flight call when signal is aborted", async () => { + makeCancellableStub(configStub.getConfig); + const controller = new AbortController(); + const p = client.getAll("tenant-1", { signal: controller.signal }); + controller.abort(); + await expect(p).rejects.toThrow(DecreeError); + }); + + it("set() cancels the in-flight call when signal is aborted", async () => { + makeCancellableStub(configStub.setField); + const controller = new AbortController(); + const p = client.set("tenant-1", "f", "v", { signal: controller.signal }); + controller.abort(); + await expect(p).rejects.toThrow(DecreeError); + }); + + it("setMany() cancels the in-flight call when signal is aborted", async () => { + makeCancellableStub(configStub.setFields); + const controller = new AbortController(); + const p = client.setMany("tenant-1", { f: "v" }, { signal: controller.signal }); + controller.abort(); + await expect(p).rejects.toThrow(DecreeError); + }); + + it("setNull() cancels the in-flight call when signal is aborted", async () => { + makeCancellableStub(configStub.setField); + const controller = new AbortController(); + const p = client.setNull("tenant-1", "f", { signal: controller.signal }); + controller.abort(); + await expect(p).rejects.toThrow(DecreeError); + }); + }); + describe("TLS channel", () => { it("creates TLS channel by default", () => { const c = new ConfigClient("localhost:9090", { retry: false });