Skip to content

Commit 9abc04b

Browse files
Rollback prepared statements (clockworklabs#4979)
This reverts commit c98bdf9.
1 parent 6108b79 commit 9abc04b

1 file changed

Lines changed: 26 additions & 70 deletions

File tree

templates/keynote-2/src/rpc-servers/postgres-rpc-server.ts

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import http from 'node:http';
33
import { Pool } from 'pg';
44
import { drizzle } from 'drizzle-orm/node-postgres';
55
import { pgTable, integer, bigint as pgBigint } from 'drizzle-orm/pg-core';
6-
import { sql } from 'drizzle-orm';
6+
import { eq, inArray, sql } from 'drizzle-orm';
77
import { RpcRequest, RpcResponse } from '../connectors/rpc/rpc_common.ts';
88
import { getSharedRuntimeDefaults } from '../config.ts';
99

@@ -27,36 +27,6 @@ const pool = new Pool({
2727

2828
const db = drizzle(pool, { schema: { accounts } });
2929

30-
const PREPARED = {
31-
getAccountById: {
32-
name: 'get_account',
33-
text: `
34-
SELECT id, balance
35-
FROM accounts
36-
WHERE id = $1
37-
LIMIT 1
38-
`,
39-
},
40-
transferSelectForUpdate: {
41-
name: 'transfer_select',
42-
text: `
43-
SELECT id, balance
44-
FROM accounts
45-
WHERE id IN ($1, $2)
46-
ORDER BY id
47-
FOR UPDATE
48-
`,
49-
},
50-
transferUpdateBalance: {
51-
name: 'transfer_update',
52-
text: `
53-
UPDATE accounts
54-
SET balance = $1::bigint
55-
WHERE id = $2
56-
`,
57-
},
58-
} as const;
59-
6030
async function rpcTransfer(args: Record<string, unknown>) {
6131
const fromId = Number(args.from_id ?? args.from);
6232
const toId = Number(args.to_id ?? args.to);
@@ -72,17 +42,14 @@ async function rpcTransfer(args: Record<string, unknown>) {
7242
if (fromId === toId || amount <= 0) return;
7343

7444
const delta = BigInt(amount);
75-
const client = await pool.connect();
7645

77-
try {
78-
await client.query('BEGIN');
79-
80-
const rowsResult = await client.query<{ id: number; balance: string }>({
81-
name: PREPARED.transferSelectForUpdate.name,
82-
text: PREPARED.transferSelectForUpdate.text,
83-
values: [fromId, toId],
84-
});
85-
const rows = rowsResult.rows;
46+
await db.transaction(async (tx) => {
47+
const rows = await tx
48+
.select()
49+
.from(accounts)
50+
.where(inArray(accounts.id, [fromId, toId]))
51+
.for('update')
52+
.orderBy(accounts.id);
8653

8754
if (rows.length !== 2) {
8855
throw new Error('account_missing');
@@ -91,43 +58,32 @@ async function rpcTransfer(args: Record<string, unknown>) {
9158
const [first, second] = rows;
9259
const fromRow = first.id === fromId ? first : second;
9360
const toRow = first.id === fromId ? second : first;
94-
const fromBalance = BigInt(fromRow.balance);
95-
96-
if (fromBalance >= delta) {
97-
const toBalance = BigInt(toRow.balance);
98-
99-
await client.query({
100-
name: PREPARED.transferUpdateBalance.name,
101-
text: PREPARED.transferUpdateBalance.text,
102-
values: [(fromBalance - delta).toString(), fromId],
103-
});
104-
105-
await client.query({
106-
name: PREPARED.transferUpdateBalance.name,
107-
text: PREPARED.transferUpdateBalance.text,
108-
values: [(toBalance + delta).toString(), toId],
109-
});
61+
62+
if (fromRow.balance < delta) {
63+
return;
11064
}
11165

112-
await client.query('COMMIT');
113-
} catch (err) {
114-
await client.query('ROLLBACK').catch(() => {});
115-
throw err;
116-
} finally {
117-
client.release();
118-
}
66+
await tx
67+
.update(accounts)
68+
.set({ balance: fromRow.balance - delta })
69+
.where(eq(accounts.id, fromId));
70+
71+
await tx
72+
.update(accounts)
73+
.set({ balance: toRow.balance + delta })
74+
.where(eq(accounts.id, toId));
75+
});
11976
}
12077

12178
async function rpcGetAccount(args: Record<string, unknown>) {
12279
const id = Number(args.id);
12380
if (!Number.isInteger(id)) throw new Error('invalid id');
12481

125-
const rowsResult = await pool.query<{ id: number; balance: string }>({
126-
name: PREPARED.getAccountById.name,
127-
text: PREPARED.getAccountById.text,
128-
values: [id],
129-
});
130-
const rows = rowsResult.rows;
82+
const rows = await db
83+
.select()
84+
.from(accounts)
85+
.where(eq(accounts.id, id))
86+
.limit(1);
13187

13288
if (rows.length === 0) return null;
13389
const row = rows[0]!;

0 commit comments

Comments
 (0)