Skip to content

Commit cefab3d

Browse files
committed
fix(rivetkit): preserve query gateway skip ready wait
1 parent 395aa83 commit cefab3d

3 files changed

Lines changed: 105 additions & 33 deletions

File tree

rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ import {
5656
type QueueSendResult,
5757
type QueueSendWaitOptions,
5858
} from "./queue";
59-
import { resolveGatewayTarget } from "./resolve-gateway-target";
6059
import {
6160
type WebSocketMessage as ConnMessage,
6261
messageLength,
@@ -578,9 +577,7 @@ export class ActorConnRaw {
578577

579578
async #connectWebSocket() {
580579
const params = await this.#resolveConnectionParams();
581-
const target = this.#gatewayOptions.skipReadyWait
582-
? await this.#resolveGatewayTargetForSkipReadyWait()
583-
: getGatewayTarget(this.#actorResolutionState);
580+
const target = getGatewayTarget(this.#actorResolutionState);
584581
const ws = await this.#driver.openWebSocket(
585582
PATH_CONNECT,
586583
target,
@@ -634,25 +631,6 @@ export class ActorConnRaw {
634631
});
635632
}
636633

637-
async #resolveGatewayTargetForSkipReadyWait() {
638-
if ("getForId" in this.#actorResolutionState) {
639-
return {
640-
directId: this.#actorResolutionState.getForId.actorId,
641-
} as const;
642-
}
643-
644-
if (this.#actorId) {
645-
return { directId: this.#actorId } as const;
646-
}
647-
648-
return {
649-
directId: await resolveGatewayTarget(
650-
this.#driver,
651-
this.#actorResolutionState,
652-
),
653-
} as const;
654-
}
655-
656634
/** Called by the onopen event from drivers. */
657635
#handleOnOpen() {
658636
// Connection was disposed before Init message arrived - close the websocket to avoid leak

rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,13 @@ export class ActorHandleRaw {
139139
for (let attempt = 0; attempt < maxAttempts; attempt++) {
140140
let actorId: string | undefined;
141141
try {
142-
const target = await this.#resolveActionTarget(useQueryTarget);
142+
const gatewayOptions = resolveActorGatewayOptions(
143+
this.#gatewayOptions,
144+
);
145+
const target = await this.#resolveGatewayRequestTarget(
146+
useQueryTarget,
147+
gatewayOptions,
148+
);
143149
actorId = "directId" in target ? target.directId : undefined;
144150

145151
return await createQueueSender({
@@ -149,9 +155,7 @@ export class ActorHandleRaw {
149155
return await this.#driver.sendRequest(
150156
target,
151157
request,
152-
resolveActorGatewayOptions(
153-
this.#gatewayOptions,
154-
),
158+
gatewayOptions,
155159
);
156160
},
157161
}).send(name, body, options as any);
@@ -270,7 +274,10 @@ export class ActorHandleRaw {
270274
for (let attempt = 0; attempt < maxAttempts; attempt++) {
271275
let actorId: string | undefined;
272276
try {
273-
const target = await this.#resolveActionTarget(useQueryTarget);
277+
const target = await this.#resolveGatewayRequestTarget(
278+
useQueryTarget,
279+
gatewayOptions,
280+
);
274281
actorId = "directId" in target ? target.directId : undefined;
275282

276283
logger().debug(
@@ -561,6 +568,17 @@ export class ActorHandleRaw {
561568
}
562569
}
563570

571+
async #resolveGatewayRequestTarget(
572+
useQueryTarget: boolean,
573+
gatewayOptions: ActorGatewayOptions,
574+
) {
575+
if (gatewayOptions.skipReadyWait) {
576+
return getGatewayTarget(this.#actorResolutionState);
577+
}
578+
579+
return await this.#resolveActionTarget(useQueryTarget);
580+
}
581+
564582
/**
565583
* Establishes a persistent connection to the actor.
566584
*
@@ -619,7 +637,10 @@ export class ActorHandleRaw {
619637
for (let attempt = 0; attempt < maxAttempts; attempt++) {
620638
let actorId: string | undefined;
621639
try {
622-
const target = await this.#resolveActionTarget(useQueryTarget);
640+
const target = await this.#resolveGatewayRequestTarget(
641+
useQueryTarget,
642+
gatewayOptions,
643+
);
623644
actorId = "directId" in target ? target.directId : undefined;
624645
const response = await rawHttpFetch(
625646
this.#driver,
@@ -824,9 +845,10 @@ export class ActorHandleRaw {
824845
this.#gatewayOptions,
825846
options,
826847
);
827-
const target = gatewayOptions.skipReadyWait
828-
? await this.#resolveActionTarget(false)
829-
: getGatewayTarget(this.#actorResolutionState);
848+
const target = await this.#resolveGatewayRequestTarget(
849+
false,
850+
gatewayOptions,
851+
);
830852
return await rawWebSocket(
831853
this.#driver,
832854
target,

rivetkit-typescript/packages/rivetkit/tests/remote-engine-client-public-token.test.ts

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
22
import { ClientConfigSchema } from "@/client/config";
3+
import { createClient } from "@/client/mod";
34
import {
45
HEADER_RIVET_ACTOR,
56
HEADER_RIVET_SKIP_READY_WAIT,
@@ -10,7 +11,6 @@ import {
1011
WS_PROTOCOL_TARGET,
1112
WS_PROTOCOL_TOKEN,
1213
} from "@/common/actor-router-consts";
13-
import { createClient } from "@/client/mod";
1414
import { RemoteEngineControlClient } from "@/engine-client/mod";
1515

1616
describe.sequential("RemoteEngineControlClient public token usage", () => {
@@ -162,6 +162,48 @@ describe.sequential("RemoteEngineControlClient public token usage", () => {
162162
);
163163
});
164164

165+
test("query handle fetch keeps skip ready wait on gateway URL", async () => {
166+
const fetchCalls: Request[] = [];
167+
const fetchMock = vi.fn(async (input: Request | URL | string) => {
168+
const request = normalizeRequest(input);
169+
fetchCalls.push(request);
170+
return new Response("ok");
171+
});
172+
vi.stubGlobal("fetch", fetchMock);
173+
174+
const client = createClient({
175+
endpoint: "https://api.rivet.dev",
176+
disableMetadataLookup: true,
177+
gateway: { skipReadyWait: true },
178+
});
179+
const handle = client.getOrCreate("mockAgenticLoop", [
180+
"query-http-skip-ready-wait",
181+
]);
182+
183+
const response = await handle.fetch("/skip-ready-wait");
184+
185+
expect(response.status).toBe(200);
186+
expect(fetchCalls).toHaveLength(1);
187+
188+
const actorRequest = fetchCalls[0];
189+
expect(actorRequest).toBeDefined();
190+
if (!actorRequest) throw new Error("missing actor request");
191+
const url = new URL(actorRequest.url);
192+
expect(url.pathname).toBe(
193+
"/gateway/mockAgenticLoop/request/skip-ready-wait",
194+
);
195+
expect(url.searchParams.get("rvt-method")).toBe("getOrCreate");
196+
expect(url.searchParams.get("rvt-key")).toBe(
197+
"query-http-skip-ready-wait",
198+
);
199+
expect(url.searchParams.get("rvt-skip-ready-wait")).toBe("true");
200+
expect(actorRequest?.headers.get(HEADER_RIVET_TARGET)).toBeNull();
201+
expect(actorRequest?.headers.get(HEADER_RIVET_ACTOR)).toBeNull();
202+
expect(actorRequest?.headers.get(HEADER_RIVET_SKIP_READY_WAIT)).toBe(
203+
"1",
204+
);
205+
});
206+
165207
test("uses metadata clientToken for actor websocket gateway requests", async () => {
166208
const fetchMock = vi.fn(async (input: Request | URL | string) => {
167209
const request = normalizeRequest(input);
@@ -258,6 +300,36 @@ describe.sequential("RemoteEngineControlClient public token usage", () => {
258300
WS_PROTOCOL_SKIP_READY_WAIT,
259301
]),
260302
);
303+
304+
const client = createClient({
305+
endpoint: "https://api.rivet.dev",
306+
disableMetadataLookup: true,
307+
gateway: { skipReadyWait: true },
308+
});
309+
const handle = client.getOrCreate("mockAgenticLoop", [
310+
"query-ws-skip-ready-wait",
311+
]);
312+
313+
await handle.webSocket("/skip-ready-wait");
314+
315+
expect(fetchMock).toHaveBeenCalledTimes(1);
316+
expect(sockets).toHaveLength(4);
317+
const querySocket = sockets[3];
318+
expect(querySocket).toBeDefined();
319+
if (!querySocket) throw new Error("missing query websocket");
320+
const url = new URL(querySocket.url);
321+
expect(url.pathname).toBe(
322+
"/gateway/mockAgenticLoop/websocket/skip-ready-wait",
323+
);
324+
expect(url.searchParams.get("rvt-method")).toBe("getOrCreate");
325+
expect(url.searchParams.get("rvt-key")).toBe(
326+
"query-ws-skip-ready-wait",
327+
);
328+
expect(url.searchParams.get("rvt-skip-ready-wait")).toBe("true");
329+
expect(querySocket.protocols).toContain(WS_PROTOCOL_SKIP_READY_WAIT);
330+
expect(querySocket.protocols).not.toContain(
331+
`${WS_PROTOCOL_TARGET}actor`,
332+
);
261333
});
262334
});
263335

0 commit comments

Comments
 (0)