@@ -3,7 +3,7 @@ import http from 'node:http';
33import { Pool } from 'pg' ;
44import { drizzle } from 'drizzle-orm/node-postgres' ;
55import { pgTable , integer , bigint as pgBigint } from 'drizzle-orm/pg-core' ;
6- import { sql } from 'drizzle-orm' ;
6+ import { eq , inArray , sql } from 'drizzle-orm' ;
77import { RpcRequest , RpcResponse } from '../connectors/rpc/rpc_common.ts' ;
88import { getSharedRuntimeDefaults } from '../config.ts' ;
99
@@ -27,36 +27,6 @@ const pool = new Pool({
2727
2828const db = drizzle ( pool , { schema : { accounts } } ) ;
2929
30- const PREPARED = {
31- getAccountById : {
32- name : 'get_account' ,
33- text : `
34- SELECT id, balance
35- FROM accounts
36- WHERE id = $1
37- LIMIT 1
38- ` ,
39- } ,
40- transferSelectForUpdate : {
41- name : 'transfer_select' ,
42- text : `
43- SELECT id, balance
44- FROM accounts
45- WHERE id IN ($1, $2)
46- ORDER BY id
47- FOR UPDATE
48- ` ,
49- } ,
50- transferUpdateBalance : {
51- name : 'transfer_update' ,
52- text : `
53- UPDATE accounts
54- SET balance = $1::bigint
55- WHERE id = $2
56- ` ,
57- } ,
58- } as const ;
59-
6030async function rpcTransfer ( args : Record < string , unknown > ) {
6131 const fromId = Number ( args . from_id ?? args . from ) ;
6232 const toId = Number ( args . to_id ?? args . to ) ;
@@ -72,17 +42,14 @@ async function rpcTransfer(args: Record<string, unknown>) {
7242 if ( fromId === toId || amount <= 0 ) return ;
7343
7444 const delta = BigInt ( amount ) ;
75- const client = await pool . connect ( ) ;
7645
77- try {
78- await client . query ( 'BEGIN' ) ;
79-
80- const rowsResult = await client . query < { id : number ; balance : string } > ( {
81- name : PREPARED . transferSelectForUpdate . name ,
82- text : PREPARED . transferSelectForUpdate . text ,
83- values : [ fromId , toId ] ,
84- } ) ;
85- const rows = rowsResult . rows ;
46+ await db . transaction ( async ( tx ) => {
47+ const rows = await tx
48+ . select ( )
49+ . from ( accounts )
50+ . where ( inArray ( accounts . id , [ fromId , toId ] ) )
51+ . for ( 'update' )
52+ . orderBy ( accounts . id ) ;
8653
8754 if ( rows . length !== 2 ) {
8855 throw new Error ( 'account_missing' ) ;
@@ -91,43 +58,32 @@ async function rpcTransfer(args: Record<string, unknown>) {
9158 const [ first , second ] = rows ;
9259 const fromRow = first . id === fromId ? first : second ;
9360 const toRow = first . id === fromId ? second : first ;
94- const fromBalance = BigInt ( fromRow . balance ) ;
95-
96- if ( fromBalance >= delta ) {
97- const toBalance = BigInt ( toRow . balance ) ;
98-
99- await client . query ( {
100- name : PREPARED . transferUpdateBalance . name ,
101- text : PREPARED . transferUpdateBalance . text ,
102- values : [ ( fromBalance - delta ) . toString ( ) , fromId ] ,
103- } ) ;
104-
105- await client . query ( {
106- name : PREPARED . transferUpdateBalance . name ,
107- text : PREPARED . transferUpdateBalance . text ,
108- values : [ ( toBalance + delta ) . toString ( ) , toId ] ,
109- } ) ;
61+
62+ if ( fromRow . balance < delta ) {
63+ return ;
11064 }
11165
112- await client . query ( 'COMMIT' ) ;
113- } catch ( err ) {
114- await client . query ( 'ROLLBACK' ) . catch ( ( ) => { } ) ;
115- throw err ;
116- } finally {
117- client . release ( ) ;
118- }
66+ await tx
67+ . update ( accounts )
68+ . set ( { balance : fromRow . balance - delta } )
69+ . where ( eq ( accounts . id , fromId ) ) ;
70+
71+ await tx
72+ . update ( accounts )
73+ . set ( { balance : toRow . balance + delta } )
74+ . where ( eq ( accounts . id , toId ) ) ;
75+ } ) ;
11976}
12077
12178async function rpcGetAccount ( args : Record < string , unknown > ) {
12279 const id = Number ( args . id ) ;
12380 if ( ! Number . isInteger ( id ) ) throw new Error ( 'invalid id' ) ;
12481
125- const rowsResult = await pool . query < { id : number ; balance : string } > ( {
126- name : PREPARED . getAccountById . name ,
127- text : PREPARED . getAccountById . text ,
128- values : [ id ] ,
129- } ) ;
130- const rows = rowsResult . rows ;
82+ const rows = await db
83+ . select ( )
84+ . from ( accounts )
85+ . where ( eq ( accounts . id , id ) )
86+ . limit ( 1 ) ;
13187
13288 if ( rows . length === 0 ) return null ;
13389 const row = rows [ 0 ] ! ;
0 commit comments