Skip to content

Commit c2ce406

Browse files
authored
refactor(rivetkit): extract shared db utilities and improve sqlite bindings (#4278)
* chore(rivetkit): enforce limits on fs driver kv api to match engine * refactor(rivetkit): extract shared db utilities and improve sqlite bindings
1 parent 6aa3833 commit c2ce406

5 files changed

Lines changed: 455 additions & 238 deletions

File tree

rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts

Lines changed: 29 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import {
44
type SqliteRemoteDatabase,
55
} from "drizzle-orm/sqlite-proxy";
66
import type { DatabaseProvider, RawAccess } from "../config";
7-
import type { KvVfsOptions } from "../sqlite-vfs";
7+
import { AsyncMutex, createActorKvStore, toSqliteBindings } from "../shared";
88

99
export * from "./sqlite-core";
1010

@@ -27,73 +27,12 @@ interface DatabaseFactoryConfig<
2727
migrations?: any;
2828
}
2929

30-
/**
31-
* Create a KV store wrapper that uses the actor driver's KV operations
32-
*/
33-
function createActorKvStore(kv: {
34-
batchPut: (entries: [Uint8Array, Uint8Array][]) => Promise<void>;
35-
batchGet: (keys: Uint8Array[]) => Promise<(Uint8Array | null)[]>;
36-
batchDelete: (keys: Uint8Array[]) => Promise<void>;
37-
}): KvVfsOptions {
38-
return {
39-
get: async (key: Uint8Array) => {
40-
const results = await kv.batchGet([key]);
41-
return results[0];
42-
},
43-
getBatch: async (keys: Uint8Array[]) => {
44-
return await kv.batchGet(keys);
45-
},
46-
put: async (key: Uint8Array, value: Uint8Array) => {
47-
await kv.batchPut([[key, value]]);
48-
},
49-
putBatch: async (entries: [Uint8Array, Uint8Array][]) => {
50-
await kv.batchPut(entries);
51-
},
52-
deleteBatch: async (keys: Uint8Array[]) => {
53-
await kv.batchDelete(keys);
54-
},
55-
};
56-
}
57-
58-
/**
59-
* Mutex to serialize async operations on a @rivetkit/sqlite database handle.
60-
* @rivetkit/sqlite is not safe for concurrent operations on the same handle.
61-
*/
62-
class DbMutex {
63-
#locked = false;
64-
#waiting: (() => void)[] = [];
65-
66-
async acquire(): Promise<void> {
67-
while (this.#locked) {
68-
await new Promise<void>((resolve) => this.#waiting.push(resolve));
69-
}
70-
this.#locked = true;
71-
}
72-
73-
release(): void {
74-
this.#locked = false;
75-
const next = this.#waiting.shift();
76-
if (next) {
77-
next();
78-
}
79-
}
80-
81-
async run<T>(fn: () => Promise<T>): Promise<T> {
82-
await this.acquire();
83-
try {
84-
return await fn();
85-
} finally {
86-
this.release();
87-
}
88-
}
89-
}
90-
9130
/**
9231
* Create a sqlite-proxy async callback from a @rivetkit/sqlite Database
9332
*/
9433
function createProxyCallback(
9534
waDb: Database,
96-
mutex: DbMutex,
35+
mutex: AsyncMutex,
9736
isClosed: () => boolean,
9837
) {
9938
return async (
@@ -107,12 +46,12 @@ function createProxyCallback(
10746
}
10847

10948
if (method === "run") {
110-
await waDb.run(sql, params);
49+
await waDb.run(sql, toSqliteBindings(params));
11150
return { rows: [] };
11251
}
11352

11453
// For all/get/values, use parameterized query
115-
const result = await waDb.query(sql, params);
54+
const result = await waDb.query(sql, toSqliteBindings(params));
11655

11756
// drizzle's mapResultRow accesses rows by column index (positional arrays)
11857
// so we return raw arrays for all methods
@@ -131,7 +70,7 @@ function createProxyCallback(
13170
*/
13271
async function runInlineMigrations(
13372
waDb: Database,
134-
mutex: DbMutex,
73+
mutex: AsyncMutex,
13574
migrations: any,
13675
): Promise<void> {
13776
// Create migrations table
@@ -174,8 +113,9 @@ async function runInlineMigrations(
174113

175114
// Record migration
176115
await mutex.run(() =>
177-
waDb.exec(
178-
`INSERT INTO __drizzle_migrations (hash, created_at) VALUES ('${entry.tag}', ${entry.when})`,
116+
waDb.run(
117+
"INSERT INTO __drizzle_migrations (hash, created_at) VALUES (?, ?)",
118+
[entry.tag, entry.when],
179119
),
180120
);
181121
}
@@ -188,7 +128,7 @@ export function db<
188128
): DatabaseProvider<SqliteRemoteDatabase<TSchema> & RawAccess> {
189129
// Store the @rivetkit/sqlite Database instance alongside the drizzle client
190130
let waDbInstance: Database | null = null;
191-
const mutex = new DbMutex();
131+
const mutex = new AsyncMutex();
192132

193133
return {
194134
createClient: async (ctx) => {
@@ -226,18 +166,17 @@ export function db<
226166
ensureOpen();
227167

228168
if (args.length > 0) {
229-
const result = await waDb.query(query, args);
169+
const result = await waDb.query(
170+
query,
171+
toSqliteBindings(args),
172+
);
230173
return result.rows.map((row: unknown[]) => {
231174
const obj: Record<string, unknown> = {};
232-
for (
233-
let i = 0;
234-
i < result.columns.length;
235-
i++
236-
) {
237-
obj[result.columns[i]] = row[i];
238-
}
239-
return obj;
240-
}) as TRow[];
175+
for (let i = 0; i < result.columns.length; i++) {
176+
obj[result.columns[i]] = row[i];
177+
}
178+
return obj;
179+
}) as TRow[];
241180
}
242181
// Use exec for non-parameterized queries since
243182
// @rivetkit/sqlite's query() can crash on some statements.
@@ -246,17 +185,17 @@ export function db<
246185
await waDb.exec(
247186
query,
248187
(row: unknown[], columns: string[]) => {
249-
if (!columnNames) {
250-
columnNames = columns;
251-
}
252-
const obj: Record<string, unknown> = {};
253-
for (let i = 0; i < row.length; i++) {
254-
obj[columnNames[i]] = row[i];
255-
}
256-
results.push(obj);
257-
},
258-
);
259-
return results as TRow[];
188+
if (!columnNames) {
189+
columnNames = columns;
190+
}
191+
const obj: Record<string, unknown> = {};
192+
for (let i = 0; i < row.length; i++) {
193+
obj[columnNames[i]] = row[i];
194+
}
195+
results.push(obj);
196+
},
197+
);
198+
return results as TRow[];
260199
});
261200
},
262201
close: async () => {

rivetkit-typescript/packages/rivetkit/src/db/mod.ts

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,12 @@
1-
import type { KvVfsOptions } from "./sqlite-vfs";
21
import type { DatabaseProvider, RawAccess } from "./config";
2+
import { AsyncMutex, createActorKvStore, toSqliteBindings } from "./shared";
33

44
export type { RawAccess } from "./config";
55

66
interface DatabaseFactoryConfig {
77
onMigrate?: (db: RawAccess) => Promise<void> | void;
88
}
99

10-
/**
11-
* Create a KV store wrapper that uses the actor driver's KV operations
12-
*/
13-
function createActorKvStore(kv: {
14-
batchPut: (entries: [Uint8Array, Uint8Array][]) => Promise<void>;
15-
batchGet: (keys: Uint8Array[]) => Promise<(Uint8Array | null)[]>;
16-
batchDelete: (keys: Uint8Array[]) => Promise<void>;
17-
}): KvVfsOptions {
18-
return {
19-
get: async (key: Uint8Array) => {
20-
const results = await kv.batchGet([key]);
21-
return results[0];
22-
},
23-
getBatch: async (keys: Uint8Array[]) => {
24-
return await kv.batchGet(keys);
25-
},
26-
put: async (key: Uint8Array, value: Uint8Array) => {
27-
await kv.batchPut([[key, value]]);
28-
},
29-
putBatch: async (entries: [Uint8Array, Uint8Array][]) => {
30-
await kv.batchPut(entries);
31-
},
32-
deleteBatch: async (keys: Uint8Array[]) => {
33-
await kv.batchDelete(keys);
34-
},
35-
};
36-
}
37-
3810
export function db({
3911
onMigrate,
4012
}: DatabaseFactoryConfig = {}): DatabaseProvider<RawAccess> {
@@ -70,23 +42,12 @@ export function db({
7042
const kvStore = createActorKvStore(ctx.kv);
7143
const db = await ctx.sqliteVfs.open(ctx.actorId, kvStore);
7244
let closed = false;
45+
const mutex = new AsyncMutex();
7346
const ensureOpen = () => {
7447
if (closed) {
7548
throw new Error("database is closed");
7649
}
7750
};
78-
let op: Promise<void> = Promise.resolve();
79-
80-
const serialize = async <T>(fn: () => Promise<T>): Promise<T> => {
81-
// Ensure @rivetkit/sqlite calls are not concurrent. Actors can process multiple
82-
// actions concurrently, and @rivetkit/sqlite is not re-entrant.
83-
const next = op.then(fn, fn);
84-
op = next.then(
85-
() => undefined,
86-
() => undefined,
87-
);
88-
return await next;
89-
};
9051

9152
return {
9253
execute: async <
@@ -95,7 +56,7 @@ export function db({
9556
query: string,
9657
...args: unknown[]
9758
): Promise<TRow[]> => {
98-
return await serialize(async () => {
59+
return await mutex.run(async () => {
9960
ensureOpen();
10061

10162
// `db.exec` does not support binding `?` placeholders.
@@ -104,14 +65,15 @@ export function db({
10465
// Keep using `db.exec` for non-parameterized SQL because it
10566
// supports multi-statement migrations.
10667
if (args.length > 0) {
68+
const bindings = toSqliteBindings(args);
10769
const token = query.trimStart().slice(0, 16).toUpperCase();
10870
const returnsRows =
10971
token.startsWith("SELECT") ||
11072
token.startsWith("PRAGMA") ||
11173
token.startsWith("WITH");
11274

11375
if (returnsRows) {
114-
const { rows, columns } = await db.query(query, args);
76+
const { rows, columns } = await db.query(query, bindings);
11577
return rows.map((row: unknown[]) => {
11678
const rowObj: Record<string, unknown> = {};
11779
for (let i = 0; i < columns.length; i++) {
@@ -121,7 +83,7 @@ export function db({
12183
}) as TRow[];
12284
}
12385

124-
await db.run(query, args);
86+
await db.run(query, bindings);
12587
return [] as TRow[];
12688
}
12789

@@ -141,7 +103,7 @@ export function db({
141103
});
142104
},
143105
close: async () => {
144-
await serialize(async () => {
106+
await mutex.run(async () => {
145107
if (closed) {
146108
return;
147109
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import type { DatabaseProviderContext } from "./config";
2+
import type { Database } from "@rivetkit/sqlite-vfs";
3+
import type { KvVfsOptions } from "./sqlite-vfs";
4+
5+
type ActorKvOperations = DatabaseProviderContext["kv"];
6+
type SqliteBindings = NonNullable<Parameters<Database["run"]>[1]>;
7+
8+
function isSqliteBindingValue(value: unknown): boolean {
9+
if (
10+
value === null ||
11+
typeof value === "number" ||
12+
typeof value === "string" ||
13+
typeof value === "bigint" ||
14+
value instanceof Uint8Array
15+
) {
16+
return true;
17+
}
18+
19+
if (Array.isArray(value)) {
20+
return value.every((item) => typeof item === "number");
21+
}
22+
23+
return false;
24+
}
25+
26+
export function toSqliteBindings(args: unknown[]): SqliteBindings {
27+
for (const value of args) {
28+
if (!isSqliteBindingValue(value)) {
29+
throw new Error(
30+
`unsupported sqlite binding type: ${typeof value}`,
31+
);
32+
}
33+
}
34+
35+
return args as SqliteBindings;
36+
}
37+
38+
/**
39+
* Create a KV store wrapper that uses the actor driver's KV operations.
40+
*/
41+
export function createActorKvStore(kv: ActorKvOperations): KvVfsOptions {
42+
return {
43+
get: async (key: Uint8Array) => {
44+
const results = await kv.batchGet([key]);
45+
return results[0] ?? null;
46+
},
47+
getBatch: async (keys: Uint8Array[]) => {
48+
return await kv.batchGet(keys);
49+
},
50+
put: async (key: Uint8Array, value: Uint8Array) => {
51+
await kv.batchPut([[key, value]]);
52+
},
53+
putBatch: async (entries: [Uint8Array, Uint8Array][]) => {
54+
await kv.batchPut(entries);
55+
},
56+
deleteBatch: async (keys: Uint8Array[]) => {
57+
await kv.batchDelete(keys);
58+
},
59+
};
60+
}
61+
62+
/**
63+
* Serialize async operations on a shared non-reentrant resource.
64+
*/
65+
export class AsyncMutex {
66+
#locked = false;
67+
#waiting: (() => void)[] = [];
68+
69+
async acquire(): Promise<void> {
70+
while (this.#locked) {
71+
await new Promise<void>((resolve) => this.#waiting.push(resolve));
72+
}
73+
this.#locked = true;
74+
}
75+
76+
release(): void {
77+
this.#locked = false;
78+
const next = this.#waiting.shift();
79+
if (next) {
80+
next();
81+
}
82+
}
83+
84+
async run<T>(fn: () => Promise<T>): Promise<T> {
85+
await this.acquire();
86+
try {
87+
return await fn();
88+
} finally {
89+
this.release();
90+
}
91+
}
92+
}

0 commit comments

Comments
 (0)