diff --git a/README.md b/README.md index 296ad6de..4bc27463 100644 --- a/README.md +++ b/README.md @@ -819,11 +819,20 @@ pool.close() ## Request ```javascript -const request = new sql.Request(/* [pool or transaction] */) +const request = new sql.Request(/* [pool or transaction], [options] */) ``` If you omit pool/transaction argument, global pool is used instead. +The optional `options` argument allows per-request configuration overrides: + +- **requestTimeout** - Override the pool's default request timeout (in ms) for this request only. + +```javascript +// Request with a 60-second timeout instead of the pool default +const request = new sql.Request(pool, { requestTimeout: 60000 }) +``` + ### Events - **recordset(columns)** - Dispatched when metadata for new recordset are parsed. @@ -1144,11 +1153,13 @@ request.cancel() **IMPORTANT:** always use `Transaction` class to create transactions - it ensures that all your requests are executed on one connection. Once you call `begin`, a single connection is acquired from the connection pool and all subsequent requests (initialized with the `Transaction` object) are executed exclusively on this connection. After you call `commit` or `rollback`, connection is then released back to the connection pool. ```javascript -const transaction = new sql.Transaction(/* [pool] */) +const transaction = new sql.Transaction(/* [pool], [options] */) ``` If you omit connection argument, global connection is used instead. +The optional `options` argument allows per-transaction configuration overrides (e.g. `{ requestTimeout: 60000 }`). These are inherited by any requests created from this transaction unless overridden at the request level. Note that the timeout applies to data requests only, not to the `begin`/`commit`/`rollback` operations themselves. + __Example__ ```javascript @@ -1296,11 +1307,13 @@ __Errors__ **IMPORTANT:** always use `PreparedStatement` class to create prepared statements - it ensures that all your executions of prepared statement are executed on one connection. Once you call `prepare`, a single connection is acquired from the connection pool and all subsequent executions are executed exclusively on this connection. After you call `unprepare`, the connection is then released back to the connection pool. ```javascript -const ps = new sql.PreparedStatement(/* [pool] */) +const ps = new sql.PreparedStatement(/* [pool], [options] */) ``` If you omit the connection argument, the global connection is used instead. +The optional `options` argument allows per-statement configuration overrides (e.g. `{ requestTimeout: 60000 }`). The timeout is applied to the `prepare`, `execute`, and `unprepare` operations. + __Example__ ```javascript diff --git a/lib/base/connection-pool.js b/lib/base/connection-pool.js index 1633b334..4f275677 100644 --- a/lib/base/connection-pool.js +++ b/lib/base/connection-pool.js @@ -576,21 +576,23 @@ class ConnectionPool extends EventEmitter { /** * Returns new request using this connection. * + * @param {{ requestTimeout?: number }} [conf] Per-request overrides. * @return {Request} */ - request () { - return new shared.driver.Request(this) + request (conf) { + return new shared.driver.Request(this, conf) } /** * Returns new transaction using this connection. * + * @param {{ requestTimeout?: number }} [conf] Per-transaction overrides, cascaded to child requests. * @return {Transaction} */ - transaction () { - return new shared.driver.Transaction(this) + transaction (conf) { + return new shared.driver.Transaction(this, conf) } /** diff --git a/lib/base/prepared-statement.js b/lib/base/prepared-statement.js index 0300bcf4..87c7d91d 100644 --- a/lib/base/prepared-statement.js +++ b/lib/base/prepared-statement.js @@ -20,10 +20,11 @@ class PreparedStatement extends EventEmitter { /** * Creates a new Prepared Statement. * - * @param {ConnectionPool|Transaction} [holder] + * @param {ConnectionPool|Transaction} [parent] + * @param {{ requestTimeout?: number }} [overrides] */ - constructor (parent) { + constructor (parent, overrides = {}) { super() IDS.add(this, 'PreparedStatement') @@ -33,6 +34,10 @@ class PreparedStatement extends EventEmitter { this._handle = 0 this.prepared = false this.parameters = {} + this.overrides = {} + if (Number.isFinite(overrides?.requestTimeout) && overrides.requestTimeout >= 0) { + this.overrides.requestTimeout = overrides.requestTimeout + } } get config () { @@ -232,7 +237,7 @@ class PreparedStatement extends EventEmitter { this._acquiredConnection = connection this._acquiredConfig = config - const req = new shared.driver.Request(this) + const req = new shared.driver.Request(this, this.overrides) req.stream = false req.output('handle', TYPES.Int) req.input('params', TYPES.NVarChar, ((() => { @@ -294,7 +299,7 @@ class PreparedStatement extends EventEmitter { */ _execute (values, callback) { - const req = new shared.driver.Request(this) + const req = new shared.driver.Request(this, this.overrides) req.stream = this.stream req.arrayRowMode = this.arrayRowMode req.input('handle', TYPES.Int, this._handle) @@ -362,7 +367,7 @@ class PreparedStatement extends EventEmitter { return setImmediate(callback, new TransactionError("Can't unprepare the statement. There is a request in progress.", 'EREQINPROG')) } - const req = new shared.driver.Request(this) + const req = new shared.driver.Request(this, this.overrides) req.stream = false req.input('handle', TYPES.Int, this._handle) req.execute('sp_unprepare', err => { diff --git a/lib/base/request.js b/lib/base/request.js index 9fdece91..cf58d8d9 100644 --- a/lib/base/request.js +++ b/lib/base/request.js @@ -26,10 +26,11 @@ class Request extends EventEmitter { /** * Create new Request. * - * @param {Connection|ConnectionPool|Transaction|PreparedStatement} parent If omitted, global connection is used instead. + * @param {Connection|ConnectionPool|Transaction|PreparedStatement} [parent] If omitted, global connection is used instead. + * @param {{ requestTimeout?: number }} [overrides] */ - constructor (parent) { + constructor (parent, overrides) { super() IDS.add(this, 'Request') @@ -41,6 +42,10 @@ class Request extends EventEmitter { this.parameters = {} this.stream = null this.arrayRowMode = null + this.overrides = {} + if (Number.isFinite(overrides?.requestTimeout) && overrides.requestTimeout >= 0) { + this.overrides.requestTimeout = overrides.requestTimeout + } } get paused () { diff --git a/lib/base/transaction.js b/lib/base/transaction.js index 3500453e..e50144ef 100644 --- a/lib/base/transaction.js +++ b/lib/base/transaction.js @@ -24,9 +24,10 @@ class Transaction extends EventEmitter { * Create new Transaction. * * @param {Connection} [parent] If ommited, global connection is used instead. + * @param {{ requestTimeout?: number }} [overrides] */ - constructor (parent) { + constructor (parent, overrides = {}) { super() IDS.add(this, 'Transaction') @@ -35,6 +36,10 @@ class Transaction extends EventEmitter { this.parent = parent || globalConnection.pool this.isolationLevel = Transaction.defaultIsolationLevel this.name = '' + this.overrides = {} + if (Number.isFinite(overrides?.requestTimeout) && overrides.requestTimeout >= 0) { + this.overrides.requestTimeout = overrides.requestTimeout + } } get config () { @@ -196,11 +201,12 @@ class Transaction extends EventEmitter { /** * Returns new request using this transaction. * + * @param {{ requestTimeout?: number }} [config] * @return {Request} */ - request () { - return new shared.driver.Request(this) + request (config) { + return new shared.driver.Request(this, { ...this.overrides, ...config }) } /** diff --git a/lib/msnodesqlv8/request.js b/lib/msnodesqlv8/request.js index d7c8e1e3..0fd55470 100644 --- a/lib/msnodesqlv8/request.js +++ b/lib/msnodesqlv8/request.js @@ -172,7 +172,7 @@ class Request extends BaseRequest { setImmediate(callback, new RequestError("You can't use table variables for bulk insert.", 'ENAME')) } - this.parent.acquire(this, (err, connection) => { + this.parent.acquire(this, (err, connection, config) => { let hasReturned = false if (!err) { debug('connection(%d): borrowed to request #%d', IDS.get(connection), IDS.get(this)) @@ -244,7 +244,10 @@ class Request extends BaseRequest { objectid = table.path } - return connection.queryRaw(`if object_id('${objectid.replace(/'/g, '\'\'')}') is null ${table.declare()}`, function (err) { + return connection.queryRaw({ + query_str: `if object_id('${objectid.replace(/'/g, '\'\'')}') is null ${table.declare()}`, + query_timeout: (this.overrides.requestTimeout ?? config.requestTimeout) / 1000 // msnodesqlv8 timeouts are in seconds (<1 second not supported), + }, function (err) { if (err) { return done(err) } go() }) @@ -389,7 +392,7 @@ class Request extends BaseRequest { const req = connection.queryRaw({ query_str: command, - query_timeout: config.requestTimeout / 1000 // msnodesqlv8 timeouts are in seconds (<1 second not supported) + query_timeout: (this.overrides.requestTimeout ?? config.requestTimeout) / 1000 // msnodesqlv8 timeouts are in seconds (<1 second not supported) }, params) this._setCurrentRequest(req) diff --git a/lib/tedious/request.js b/lib/tedious/request.js index 4833382e..431478df 100644 --- a/lib/tedious/request.js +++ b/lib/tedious/request.js @@ -361,6 +361,9 @@ class Request extends BaseRequest { connection.execBulkLoad(bulk, table.rows) }) + if (typeof this.overrides.requestTimeout === 'number') { + req.setTimeout(this.overrides.requestTimeout) + } this._setCurrentRequest(req) connection.execSqlBatch(req) @@ -510,6 +513,10 @@ class Request extends BaseRequest { } }) + if (typeof this.overrides.requestTimeout === 'number') { + req.setTimeout(this.overrides.requestTimeout) + } + this._setCurrentRequest(req) req.on('columnMetadata', metadata => { @@ -859,6 +866,10 @@ class Request extends BaseRequest { } }) + if (typeof this.overrides.requestTimeout === 'number') { + req.setTimeout(this.overrides.requestTimeout) + } + this._setCurrentRequest(req) req.on('columnMetadata', metadata => { diff --git a/lib/tedious/transaction.js b/lib/tedious/transaction.js index aaf27d21..ba8a5a62 100644 --- a/lib/tedious/transaction.js +++ b/lib/tedious/transaction.js @@ -6,8 +6,8 @@ const { IDS } = require('../utils') const TransactionError = require('../error/transaction-error') class Transaction extends BaseTransaction { - constructor (parent) { - super(parent) + constructor (parent, overrides) { + super(parent, overrides) this._abort = () => { if (!this._rollbackRequested) { diff --git a/test/cleanup.sql b/test/cleanup.sql index a1ccf2f8..4c572cd6 100644 --- a/test/cleanup.sql +++ b/test/cleanup.sql @@ -22,6 +22,9 @@ if exists (select * from sys.procedures where name = '__testInputOutputValue') if exists (select * from sys.procedures where name = '__testRowsAffected') exec('drop procedure [dbo].[__testRowsAffected]') +if exists (select * from sys.procedures where name = '__testDelay') + exec('drop procedure [dbo].[__testDelay]') + if exists (select * from sys.types where is_user_defined = 1 and name = 'MSSQLTestType') exec('drop type [dbo].[MSSQLTestType]') diff --git a/test/common/tests.js b/test/common/tests.js index 039112b9..8e5165b0 100644 --- a/test/common/tests.js +++ b/test/common/tests.js @@ -1198,6 +1198,108 @@ module.exports = (sql, driver) => { }) }, + 'per-request timeout overrides pool default' (done, driver, message) { + const config = readConfig() + config.driver = driver + config.requestTimeout = 15000 + + new sql.ConnectionPool(config).connect().then(conn => { + const req = new sql.Request(conn, { requestTimeout: 1000 }) + req.query('waitfor delay \'00:00:05\';select 1').catch(err => { + assert.ok((message ? (message.exec(err.message) != null) : (err instanceof sql.RequestError))) + + conn.close() + done() + }) + }).catch(done) + }, + + 'per-request timeout does not affect other requests' (done, driver, message) { + const config = readConfig() + config.driver = driver + config.requestTimeout = 15000 + + new sql.ConnectionPool(config).connect().then(conn => { + const reqFast = new sql.Request(conn, { requestTimeout: 1000 }) + const reqNormal = new sql.Request(conn) + + return Promise.allSettled([ + reqFast.query('waitfor delay \'00:00:05\';select 1'), + reqNormal.query('waitfor delay \'00:00:02\';select 1') + ]).then(([fastResult, normalResult]) => { + assert.strictEqual(fastResult.status, 'rejected') + assert.ok((message ? (message.exec(fastResult.reason.message) != null) : (fastResult.reason instanceof sql.RequestError))) + assert.strictEqual(normalResult.status, 'fulfilled') + + conn.close() + done() + }) + }).catch(done) + }, + + 'per-request timeout in transaction' (done, driver, message) { + const config = readConfig() + config.driver = driver + config.requestTimeout = 15000 + + new sql.ConnectionPool(config).connect().then(conn => { + const tx = new sql.Transaction(conn, { requestTimeout: 1000 }) + tx.begin().then(() => { + const req = tx.request() + req.query('waitfor delay \'00:00:05\';select 1').catch(err => { + assert.ok((message ? (message.exec(err.message) != null) : (err instanceof sql.RequestError))) + + tx.rollback().then(() => { + conn.close() + done() + }).catch(() => { + conn.close() + done() + }) + }) + }) + }).catch(done) + }, + + 'per-request timeout on stored procedure' (done, driver, message) { + const config = readConfig() + config.driver = driver + config.requestTimeout = 15000 + + new sql.ConnectionPool(config).connect().then(conn => { + const req = new sql.Request(conn, { requestTimeout: 1000 }) + req.execute('__testDelay').catch(err => { + assert.ok((message ? (message.exec(err.message) != null) : (err instanceof sql.RequestError))) + + conn.close() + done() + }) + }).catch(done) + }, + + 'per-request timeout on prepared statement' (done, driver, message) { + const config = readConfig() + config.driver = driver + config.requestTimeout = 15000 + + new sql.ConnectionPool(config).connect().then(conn => { + const ps = new sql.PreparedStatement(conn, { requestTimeout: 1000 }) + ps.prepare('waitfor delay \'00:00:05\';select 1').then(() => { + ps.execute().catch(err => { + assert.ok((message ? (message.exec(err.message) != null) : (err instanceof sql.RequestError))) + + ps.unprepare().then(() => { + conn.close() + done() + }).catch(() => { + conn.close() + done() + }) + }) + }) + }).catch(done) + }, + 'type validation' (mode, done) { const req = new TestRequest() req.input('image', sql.VarBinary, 'asdf') diff --git a/test/common/unit.js b/test/common/unit.js index 1f6f7e4f..39fcfb13 100644 --- a/test/common/unit.js +++ b/test/common/unit.js @@ -966,4 +966,92 @@ describe('connection string auth - tedious', () => { }) }) }) + + describe('per-request requestTimeout overrides', () => { + const BaseRequest = require('../../lib/base/request') + const BaseTransaction = require('../../lib/base/transaction') + const BasePreparedStatement = require('../../lib/base/prepared-statement') + + describe('Request', () => { + it('stores valid requestTimeout override', () => { + const req = new BaseRequest(null, { requestTimeout: 5000 }) + assert.strictEqual(req.overrides.requestTimeout, 5000) + }) + + it('accepts zero as a valid timeout', () => { + const req = new BaseRequest(null, { requestTimeout: 0 }) + assert.strictEqual(req.overrides.requestTimeout, 0) + }) + + it('ignores NaN', () => { + const req = new BaseRequest(null, { requestTimeout: NaN }) + assert.strictEqual(req.overrides.requestTimeout, undefined) + }) + + it('ignores Infinity', () => { + const req = new BaseRequest(null, { requestTimeout: Infinity }) + assert.strictEqual(req.overrides.requestTimeout, undefined) + }) + + it('ignores negative values', () => { + const req = new BaseRequest(null, { requestTimeout: -1 }) + assert.strictEqual(req.overrides.requestTimeout, undefined) + }) + + it('ignores non-number values', () => { + const req = new BaseRequest(null, { requestTimeout: '5000' }) + assert.strictEqual(req.overrides.requestTimeout, undefined) + }) + + it('defaults to empty overrides when none provided', () => { + const req = new BaseRequest(null) + assert.deepStrictEqual(req.overrides, {}) + }) + }) + + describe('Transaction', () => { + it('stores valid requestTimeout override', () => { + const tx = new BaseTransaction(null, { requestTimeout: 10000 }) + assert.strictEqual(tx.overrides.requestTimeout, 10000) + }) + + it('ignores invalid overrides', () => { + const tx = new BaseTransaction(null, { requestTimeout: NaN }) + assert.strictEqual(tx.overrides.requestTimeout, undefined) + }) + + it('cascades overrides to request when no per-request config given', () => { + const pool = new ConnectionPool({ server: 'localhost' }) + const tx = pool.transaction({ requestTimeout: 10000 }) + const req = tx.request() + assert.strictEqual(req.overrides.requestTimeout, 10000) + }) + + it('per-request config overrides transaction overrides', () => { + const pool = new ConnectionPool({ server: 'localhost' }) + const tx = pool.transaction({ requestTimeout: 10000 }) + const req = tx.request({ requestTimeout: 3000 }) + assert.strictEqual(req.overrides.requestTimeout, 3000) + }) + + it('per-request config merges with transaction overrides', () => { + const pool = new ConnectionPool({ server: 'localhost' }) + const tx = pool.transaction({ requestTimeout: 10000 }) + const req = tx.request({}) + assert.strictEqual(req.overrides.requestTimeout, 10000) + }) + }) + + describe('PreparedStatement', () => { + it('stores valid requestTimeout override', () => { + const ps = new BasePreparedStatement(null, { requestTimeout: 8000 }) + assert.strictEqual(ps.overrides.requestTimeout, 8000) + }) + + it('ignores invalid overrides', () => { + const ps = new BasePreparedStatement(null, { requestTimeout: -100 }) + assert.strictEqual(ps.overrides.requestTimeout, undefined) + }) + }) + }) }) diff --git a/test/msnodesqlv8/msnodesqlv8.js b/test/msnodesqlv8/msnodesqlv8.js index ab7a2880..8963a3db 100644 --- a/test/msnodesqlv8/msnodesqlv8.js +++ b/test/msnodesqlv8/msnodesqlv8.js @@ -94,6 +94,11 @@ describe('msnodesqlv8', function () { it('request timeout', done => TESTS['request timeout'](done)) it('BigInt parameters', done => TESTS['BigInt parameters'](done)) it('BigInt casted types', done => TESTS['BigInt casted types'](done)) + it('per-request timeout overrides pool default', done => TESTS['per-request timeout overrides pool default'](done)) + it('per-request timeout does not affect other requests', done => TESTS['per-request timeout does not affect other requests'](done)) + it('per-request timeout in transaction', done => TESTS['per-request timeout in transaction'](done)) + it('per-request timeout on stored procedure', done => TESTS['per-request timeout on stored procedure'](done)) + it('per-request timeout on prepared statement', done => TESTS['per-request timeout on prepared statement'](done)) it('dataLength type correction', done => TESTS['dataLength type correction'](done)) it('chunked xml support', done => TESTS['chunked xml support'](done)) diff --git a/test/prepare.sql b/test/prepare.sql index 84675056..ca49adff 100644 --- a/test/prepare.sql +++ b/test/prepare.sql @@ -171,6 +171,15 @@ begin end') +exec('create procedure [dbo].[__testDelay] +as +begin + + waitfor delay ''00:00:05'' + select 1 as result + +end') + ;with nums as ( select 0 AS n