Skip to content

Commit 021c18f

Browse files
committed
fix(rivetkit): wrap onMigrate in savepoint
1 parent e1be3d6 commit 021c18f

13 files changed

Lines changed: 143 additions & 167 deletions

File tree

rivetkit-typescript/packages/rivetkit/src/common/database/config.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ export interface SqliteDatabase {
4848
): Promise<SqliteExecuteResult>;
4949
run(sql: string, params?: SqliteBindings): Promise<void>;
5050
query(sql: string, params?: SqliteBindings): Promise<SqliteQueryResult>;
51-
writeMode<T>(callback: () => Promise<T>): Promise<T>;
5251
nativeMetrics?():
5352
| SqliteNativeMetrics
5453
| Promise<SqliteNativeMetrics | null>

rivetkit-typescript/packages/rivetkit/src/common/database/mod.test.ts

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@ import type {
88
import { db } from "./mod";
99

1010
class FakeSqliteDatabase implements SqliteDatabase {
11-
writeModeDepth = 0;
1211
executeCalls: {
1312
sql: string;
1413
params?: SqliteBindings;
15-
writeMode: boolean;
1614
}[] = [];
1715

1816
async exec(): Promise<void> {}
@@ -24,7 +22,6 @@ class FakeSqliteDatabase implements SqliteDatabase {
2422
this.executeCalls.push({
2523
sql,
2624
params,
27-
writeMode: this.writeModeDepth > 0,
2825
});
2926
return {
3027
columns: [],
@@ -43,15 +40,6 @@ class FakeSqliteDatabase implements SqliteDatabase {
4340
return { columns, rows };
4441
}
4542

46-
async writeMode<T>(callback: () => Promise<T>): Promise<T> {
47-
this.writeModeDepth++;
48-
try {
49-
return await callback();
50-
} finally {
51-
this.writeModeDepth--;
52-
}
53-
}
54-
5543
async close(): Promise<void> {}
5644
}
5745

@@ -73,7 +61,7 @@ function testProviderContext(
7361
}
7462

7563
describe("db", () => {
76-
test("runs onMigrate through sqlite write mode", async () => {
64+
test("runs onMigrate inside a sqlite savepoint", async () => {
7765
const nativeDb = new FakeSqliteDatabase();
7866
const provider = db({
7967
onMigrate: async (client) => {
@@ -90,15 +78,59 @@ describe("db", () => {
9078
await provider.onMigrate(client);
9179

9280
expect(nativeDb.executeCalls).toEqual([
81+
{
82+
sql: "SAVEPOINT __rivet_on_migrate",
83+
params: undefined,
84+
},
9385
{
9486
sql: "CREATE TABLE items(id INTEGER PRIMARY KEY, value TEXT)",
9587
params: undefined,
96-
writeMode: true,
9788
},
9889
{
9990
sql: "SELECT COUNT(*) AS count FROM items",
10091
params: undefined,
101-
writeMode: true,
92+
},
93+
{
94+
sql: "RELEASE SAVEPOINT __rivet_on_migrate",
95+
params: undefined,
96+
},
97+
]);
98+
});
99+
100+
test("rolls back the migration savepoint when onMigrate fails", async () => {
101+
const nativeDb = new FakeSqliteDatabase();
102+
const provider = db({
103+
onMigrate: async (client) => {
104+
await client.execute(
105+
"CREATE TABLE items(id INTEGER PRIMARY KEY, value TEXT)",
106+
);
107+
throw new Error("migration failed");
108+
},
109+
});
110+
const client = await provider.createClient(
111+
testProviderContext(nativeDb),
112+
);
113+
114+
await expect(provider.onMigrate(client)).rejects.toThrow(
115+
"migration failed",
116+
);
117+
118+
expect(nativeDb.executeCalls).toEqual([
119+
{
120+
sql: "SAVEPOINT __rivet_on_migrate",
121+
params: undefined,
122+
},
123+
{
124+
sql: "CREATE TABLE items(id INTEGER PRIMARY KEY, value TEXT)",
125+
params: undefined,
126+
},
127+
{
128+
sql: "ROLLBACK TO SAVEPOINT __rivet_on_migrate",
129+
params: undefined,
130+
},
131+
{
132+
sql: "RELEASE SAVEPOINT __rivet_on_migrate",
133+
params: undefined,
102134
},
103135
]);
104136
});

rivetkit-typescript/packages/rivetkit/src/common/database/mod.ts

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ interface DatabaseFactoryConfig {
77
onMigrate?: (db: RawAccess) => Promise<void> | void;
88
}
99

10-
type RawAccessWithWriteMode = RawAccess & {
11-
__rivetWriteMode: <T>(callback: () => Promise<T> | T) => Promise<T>;
12-
};
13-
1410
function hasMultipleStatements(query: string): boolean {
1511
const trimmed = query.trim().replace(/;+$/, "").trimEnd();
1612
return trimmed.includes(";");
@@ -38,7 +34,7 @@ export function db({
3834
}
3935
};
4036

41-
const client: RawAccessWithWriteMode = {
37+
const client: RawAccess = {
4238
execute: async <
4339
TRow extends Record<string, unknown> = Record<
4440
string,
@@ -103,17 +99,12 @@ export function db({
10399
}
104100
},
105101
nativeMetrics: () => db.nativeMetrics?.() ?? null,
106-
__rivetWriteMode: async <T>(
107-
callback: () => Promise<T> | T,
108-
): Promise<T> => {
109-
return await db.writeMode(async () => await callback());
110-
},
111102
};
112103
return client;
113104
},
114105
onMigrate: async (client) => {
115106
if (onMigrate) {
116-
await dbWriteMode(client, () => onMigrate(client));
107+
await withMigrationSavepoint(client, () => onMigrate(client));
117108
}
118109
},
119110
};
@@ -145,17 +136,21 @@ async function execMultiStatement<TRow extends Record<string, unknown>>(
145136
return results as TRow[];
146137
}
147138

148-
async function dbWriteMode<T>(
139+
async function withMigrationSavepoint<T>(
149140
client: RawAccess,
150141
callback: () => Promise<T> | T,
151142
): Promise<T> {
152-
const maybeClient = client as RawAccess & {
153-
__rivetWriteMode?: <TInner>(
154-
callback: () => Promise<TInner> | TInner,
155-
) => Promise<TInner>;
156-
};
157-
if (maybeClient.__rivetWriteMode) {
158-
return await maybeClient.__rivetWriteMode(callback);
143+
await client.execute("SAVEPOINT __rivet_on_migrate");
144+
try {
145+
const result = await callback();
146+
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
147+
return result;
148+
} catch (error) {
149+
try {
150+
await client.execute("ROLLBACK TO SAVEPOINT __rivet_on_migrate");
151+
} finally {
152+
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
153+
}
154+
throw error;
159155
}
160-
return await callback();
161156
}

rivetkit-typescript/packages/rivetkit/src/common/database/native-database.test.ts

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,28 +104,6 @@ describe("wrapJsNativeDatabase", () => {
104104
});
105105
});
106106

107-
test("keeps write mode on the normal native execute lane", async () => {
108-
const native = new FakeNativeDatabase();
109-
const db = wrapJsNativeDatabase(native);
110-
111-
const query = db.writeMode(async () => {
112-
const promise = db.query("SELECT 1");
113-
expect(native.executeCalls).toMatchObject([
114-
{ sql: "SELECT 1", write: false },
115-
]);
116-
native.resolveNext({
117-
columns: ["value"],
118-
rows: [[1]],
119-
});
120-
return await promise;
121-
});
122-
123-
await expect(query).resolves.toEqual({
124-
columns: ["value"],
125-
rows: [[1]],
126-
});
127-
});
128-
129107
test("normalizes supported sqlite bind values", async () => {
130108
const native = new FakeNativeDatabase();
131109
const db = wrapJsNativeDatabase(native);

rivetkit-typescript/packages/rivetkit/src/common/database/native-database.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,6 @@ export function wrapJsNativeDatabase(
346346
const { columns, rows } = await executeNative(sql, params);
347347
return { columns, rows };
348348
},
349-
async writeMode<T>(callback: () => Promise<T>): Promise<T> {
350-
return await callback();
351-
},
352349
nativeMetrics(): SqliteNativeMetrics | null {
353350
return normalizeNativeMetrics(database.metrics?.());
354351
},

rivetkit-typescript/packages/rivetkit/src/db/drizzle.ts

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -163,19 +163,14 @@ export function db<TSchema extends DrizzleSchema = Record<string, never>>({
163163
await nativeDb.close();
164164
}
165165
};
166-
(
167-
drizzleDb as DrizzleDatabase<TSchema> & {
168-
__rivetWriteMode: <T>(
169-
callback: () => Promise<T> | T,
170-
) => Promise<T>;
171-
}
172-
).__rivetWriteMode = async (callback) =>
173-
await nativeDb.writeMode(async () => await callback());
174166

175167
return drizzleDb;
176168
},
177169
onMigrate: async (client) => {
178-
await dbWriteMode(client, async () => {
170+
if (!migrations && !onMigrate) {
171+
return;
172+
}
173+
await withMigrationSavepoint(client, async () => {
179174
if (migrations) {
180175
await runMigrations(client, migrations);
181176
}
@@ -187,19 +182,23 @@ export function db<TSchema extends DrizzleSchema = Record<string, never>>({
187182
};
188183
}
189184

190-
async function dbWriteMode<T>(
185+
async function withMigrationSavepoint<T>(
191186
client: RawAccess,
192187
callback: () => Promise<T> | T,
193188
): Promise<T> {
194-
const maybeClient = client as RawAccess & {
195-
__rivetWriteMode?: <TInner>(
196-
callback: () => Promise<TInner> | TInner,
197-
) => Promise<TInner>;
198-
};
199-
if (maybeClient.__rivetWriteMode) {
200-
return await maybeClient.__rivetWriteMode(callback);
189+
await client.execute("SAVEPOINT __rivet_on_migrate");
190+
try {
191+
const result = await callback();
192+
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
193+
return result;
194+
} catch (error) {
195+
try {
196+
await client.execute("ROLLBACK TO SAVEPOINT __rivet_on_migrate");
197+
} finally {
198+
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
199+
}
200+
throw error;
201201
}
202-
return await callback();
203202
}
204203

205204
async function runMigrations<TSchema extends DrizzleSchema>(

rivetkit-typescript/packages/rivetkit/tests/platforms/shared-platform-harness.ts

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ ${wasmModuleSource}
112112
interface SqliteDatabase {
113113
\trun(sql: string, params?: unknown[]): Promise<void>;
114114
\tquery(sql: string, params?: unknown[]): Promise<{ rows: unknown[][] }>;
115-
\twriteMode<T>(callback: () => Promise<T>): Promise<T>;
116115
}
117116
118117
interface RegistryConfig {
@@ -137,29 +136,23 @@ const rawSqlDatabaseProvider = {
137136
};
138137
139138
async function ensureCounterTable(db: SqliteDatabase) {
140-
\tawait db.writeMode(async () => {
141-
\t\tawait db.run(
142-
\t\t\t"CREATE TABLE IF NOT EXISTS platform_counter (id INTEGER PRIMARY KEY CHECK (id = 1), count INTEGER NOT NULL)",
143-
\t\t);
144-
\t});
139+
\tawait db.run(
140+
\t\t"CREATE TABLE IF NOT EXISTS platform_counter (id INTEGER PRIMARY KEY CHECK (id = 1), count INTEGER NOT NULL)",
141+
\t);
145142
}
146143
147144
async function ensureLifecycleTable(db: SqliteDatabase) {
148-
\tawait db.writeMode(async () => {
149-
\t\tawait db.run(
150-
\t\t\t"CREATE TABLE IF NOT EXISTS platform_counter_lifecycle (event TEXT PRIMARY KEY, count INTEGER NOT NULL)",
151-
\t\t);
152-
\t});
145+
\tawait db.run(
146+
\t\t"CREATE TABLE IF NOT EXISTS platform_counter_lifecycle (event TEXT PRIMARY KEY, count INTEGER NOT NULL)",
147+
\t);
153148
}
154149
155150
async function recordLifecycleEvent(db: SqliteDatabase, event: string) {
156151
\tawait ensureLifecycleTable(db);
157-
\tawait db.writeMode(async () => {
158-
\t\tawait db.run(
159-
\t\t\t"INSERT INTO platform_counter_lifecycle (event, count) VALUES (?, 1) ON CONFLICT(event) DO UPDATE SET count = count + 1",
160-
\t\t\t[event],
161-
\t\t);
162-
\t});
152+
\tawait db.run(
153+
\t\t"INSERT INTO platform_counter_lifecycle (event, count) VALUES (?, 1) ON CONFLICT(event) DO UPDATE SET count = count + 1",
154+
\t\t[event],
155+
\t);
163156
}
164157
165158
async function readCounter(db: SqliteDatabase): Promise<number> {
@@ -201,12 +194,10 @@ const sqliteCounter = actor({
201194
\t\tincrement: async (ctx, amount = 1) => {
202195
\t\t\tconst db = ctx.sql as SqliteDatabase;
203196
\t\t\tawait ensureCounterTable(db);
204-
\t\t\tawait db.writeMode(async () => {
205-
\t\t\t\tawait db.run(
206-
\t\t\t\t\t"INSERT INTO platform_counter (id, count) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET count = count + excluded.count",
207-
\t\t\t\t\t[COUNTER_ID, amount],
208-
\t\t\t\t);
209-
\t\t\t});
197+
\t\t\tawait db.run(
198+
\t\t\t\t"INSERT INTO platform_counter (id, count) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET count = count + excluded.count",
199+
\t\t\t\t[COUNTER_ID, amount],
200+
\t\t\t);
210201
211202
\t\t\treturn await readCounter(db);
212203
\t\t},

0 commit comments

Comments
 (0)