Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/api/Errors.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { errors } from 'undici'
| `InformationalError` | `UND_ERR_INFO` | expected error with reason |
| `ResponseExceededMaxSizeError` | `UND_ERR_RES_EXCEEDED_MAX_SIZE` | response body exceed the max size allowed |
| `SecureProxyConnectionError` | `UND_ERR_PRX_TLS` | tls connection to a proxy failed |
| `MessageSizeExceededError` | `UND_ERR_WS_MESSAGE_SIZE_EXCEEDED` | WebSocket decompressed message exceeded the maximum allowed size |

### `SocketError`

Expand Down
23 changes: 23 additions & 0 deletions docs/docs/api/WebSocket.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ Arguments:
* **url** `URL | string` - The url's protocol *must* be `ws` or `wss`.
* **protocol** `string | string[] | WebSocketInit` (optional) - Subprotocol(s) to request the server use, or a [`Dispatcher`](./Dispatcher.md).

### WebSocketInit

When passing an object as the second argument, the following options are available:

* **protocols** `string | string[]` (optional) - Subprotocol(s) to request the server use.
* **dispatcher** `Dispatcher` (optional) - A custom [`Dispatcher`](/docs/docs/api/Dispatcher.md) to use for the connection.
* **headers** `HeadersInit` (optional) - Custom headers to include in the WebSocket handshake request.
* **maxDecompressedMessageSize** `number` (optional) - Maximum allowed size in bytes for decompressed messages when using the `permessage-deflate` extension. **Default:** `4194304` (4 MB).

### Example:

This example will not work in browsers or other platforms that don't allow passing an object.
Expand All @@ -36,6 +45,20 @@ import { WebSocket } from 'undici'
const ws = new WebSocket('wss://echo.websocket.events', ['echo', 'chat'])
```

### Example with custom decompression limit:

To protect against decompression bombs (small compressed payloads that expand to very large sizes), you can set a custom limit:

```mjs
import { WebSocket } from 'undici'

const ws = new WebSocket('wss://echo.websocket.events', {
maxDecompressedMessageSize: 1 * 1024 * 1024
})
```

> ⚠️ **Security Note**: The `maxDecompressedMessageSize` option protects against memory exhaustion attacks where a malicious server sends a small compressed payload that decompresses to an extremely large size. If you increase this limit significantly above the default, ensure your application can handle the increased memory usage.

## Read More

- [MDN - WebSocket](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket)
Expand Down
21 changes: 20 additions & 1 deletion lib/core/errors.js
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,24 @@ class SecureProxyConnectionError extends UndiciError {
[kSecureProxyConnectionError] = true
}

const kMessageSizeExceededError = Symbol.for('undici.error.UND_ERR_WS_MESSAGE_SIZE_EXCEEDED')
class MessageSizeExceededError extends UndiciError {
constructor (message) {
super(message)
this.name = 'MessageSizeExceededError'
this.message = message || 'Max decompressed message size exceeded'
this.code = 'UND_ERR_WS_MESSAGE_SIZE_EXCEEDED'
}

static [Symbol.hasInstance] (instance) {
return instance && instance[kMessageSizeExceededError] === true
}

get [kMessageSizeExceededError] () {
return true
}
}

module.exports = {
AbortError,
HTTPParserError,
Expand All @@ -402,5 +420,6 @@ module.exports = {
ResponseExceededMaxSizeError,
RequestRetryError,
ResponseError,
SecureProxyConnectionError
SecureProxyConnectionError,
MessageSizeExceededError
}
14 changes: 12 additions & 2 deletions lib/core/request.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class Request {
throw new InvalidArgumentError('upgrade must be a string')
}

if (upgrade && !isValidHeaderValue(upgrade)) {
throw new InvalidArgumentError('invalid upgrade header')
}

if (headersTimeout != null && (!Number.isFinite(headersTimeout) || headersTimeout < 0)) {
throw new InvalidArgumentError('invalid headersTimeout')
}
Expand Down Expand Up @@ -360,13 +364,19 @@ function processHeader (request, key, val) {
val = `${val}`
}

if (request.host === null && headerName === 'host') {
if (headerName === 'host') {
if (request.host !== null) {
throw new InvalidArgumentError('duplicate host header')
}
if (typeof val !== 'string') {
throw new InvalidArgumentError('invalid host header')
}
// Consumed by Client
request.host = val
} else if (request.contentLength === null && headerName === 'content-length') {
} else if (headerName === 'content-length') {
if (request.contentLength !== null) {
throw new InvalidArgumentError('duplicate content-length header')
}
request.contentLength = parseInt(val, 10)
if (!Number.isFinite(request.contentLength)) {
throw new InvalidArgumentError('invalid content-length header')
Expand Down
59 changes: 56 additions & 3 deletions lib/web/websocket/permessage-deflate.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,38 @@

const { createInflateRaw, Z_DEFAULT_WINDOWBITS } = require('node:zlib')
const { isValidClientWindowBits } = require('./util')
const { MessageSizeExceededError } = require('../../core/errors')

const tail = Buffer.from([0x00, 0x00, 0xff, 0xff])
const kBuffer = Symbol('kBuffer')
const kLength = Symbol('kLength')

// Default maximum decompressed message size: 4 MB
const kDefaultMaxDecompressedSize = 4 * 1024 * 1024

class PerMessageDeflate {
/** @type {import('node:zlib').InflateRaw} */
#inflate

#options = {}

constructor (extensions) {
/** @type {number} */
#maxDecompressedSize

/** @type {boolean} */
#aborted = false

/** @type {Function|null} */
#currentCallback = null

/**
* @param {Map<string, string>} extensions
* @param {{ maxDecompressedMessageSize?: number }} [options]
*/
constructor (extensions, options = {}) {
this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover')
this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits')
this.#maxDecompressedSize = options.maxDecompressedMessageSize ?? kDefaultMaxDecompressedSize
}

decompress (chunk, fin, callback) {
Expand All @@ -24,6 +42,11 @@ class PerMessageDeflate {
// payload of the message.
// 2. Decompress the resulting data using DEFLATE.

if (this.#aborted) {
callback(new MessageSizeExceededError())
return
}

if (!this.#inflate) {
let windowBits = Z_DEFAULT_WINDOWBITS

Expand All @@ -36,13 +59,37 @@ class PerMessageDeflate {
windowBits = Number.parseInt(this.#options.serverMaxWindowBits)
}

this.#inflate = createInflateRaw({ windowBits })
try {
this.#inflate = createInflateRaw({ windowBits })
} catch (err) {
callback(err)
return
}
this.#inflate[kBuffer] = []
this.#inflate[kLength] = 0

this.#inflate.on('data', (data) => {
this.#inflate[kBuffer].push(data)
if (this.#aborted) {
return
}

this.#inflate[kLength] += data.length

if (this.#inflate[kLength] > this.#maxDecompressedSize) {
this.#aborted = true
this.#inflate.removeAllListeners()
this.#inflate.destroy()
this.#inflate = null

if (this.#currentCallback) {
const cb = this.#currentCallback
this.#currentCallback = null
cb(new MessageSizeExceededError())
}
return
}

this.#inflate[kBuffer].push(data)
})

this.#inflate.on('error', (err) => {
Expand All @@ -51,16 +98,22 @@ class PerMessageDeflate {
})
}

this.#currentCallback = callback
this.#inflate.write(chunk)
if (fin) {
this.#inflate.write(tail)
}

this.#inflate.flush(() => {
if (this.#aborted || !this.#inflate) {
return
}

const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength])

this.#inflate[kBuffer].length = 0
this.#inflate[kLength] = 0
this.#currentCallback = null

callback(null, full)
})
Expand Down
22 changes: 15 additions & 7 deletions lib/web/websocket/receiver.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,23 @@ class ByteParser extends Writable {
/** @type {Map<string, PerMessageDeflate>} */
#extensions

constructor (ws, extensions) {
/** @type {{ maxDecompressedMessageSize?: number }} */
#options

/**
* @param {import('./websocket').WebSocket} ws
* @param {Map<string, string>|null} extensions
* @param {{ maxDecompressedMessageSize?: number }} [options]
*/
constructor (ws, extensions, options = {}) {
super()

this.ws = ws
this.#extensions = extensions == null ? new Map() : extensions
this.#options = options

if (this.#extensions.has('permessage-deflate')) {
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions))
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions, options))
}
}

Expand Down Expand Up @@ -179,21 +188,20 @@ class ByteParser extends Writable {

const buffer = this.consume(8)
const upper = buffer.readUInt32BE(0)
const lower = buffer.readUInt32BE(4)

// 2^31 is the maximum bytes an arraybuffer can contain
// on 32-bit systems. Although, on 64-bit systems, this is
// 2^53-1 bytes.
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Errors/Invalid_array_length
// https://source.chromium.org/chromium/chromium/src/+/main:v8/src/common/globals.h;drc=1946212ac0100668f14eb9e2843bdd846e510a1e;bpv=1;bpt=1;l=1275
// https://source.chromium.org/chromium/chromium/src/+/main:v8/src/objects/js-array-buffer.h;l=34;drc=1946212ac0100668f14eb9e2843bdd846e510a1e
if (upper > 2 ** 31 - 1) {
if (upper !== 0 || lower > 2 ** 31 - 1) {
failWebsocketConnection(this.ws, 'Received payload length > 2^31 bytes.')
return
}

const lower = buffer.readUInt32BE(4)

this.#info.payloadLength = (upper << 8) + lower
this.#info.payloadLength = lower
this.#state = parserStates.READ_DATA
} else if (this.#state === parserStates.READ_DATA) {
if (this.#byteOffset < this.#info.payloadLength) {
Expand Down Expand Up @@ -223,7 +231,7 @@ class ByteParser extends Writable {
} else {
this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => {
if (error) {
closeWebSocketConnection(this.ws, 1007, error.message, error.message.length)
failWebsocketConnection(this.ws, error.message)
return
}

Expand Down
10 changes: 9 additions & 1 deletion lib/web/websocket/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ function parseExtensions (extensions) {
* @param {string} value
*/
function isValidClientWindowBits (value) {
// Must have at least one character
if (value.length === 0) {
return false
}

// Check all characters are ASCII digits
for (let i = 0; i < value.length; i++) {
const byte = value.charCodeAt(i)

Expand All @@ -274,7 +280,9 @@ function isValidClientWindowBits (value) {
}
}

return true
// Check numeric range: zlib requires windowBits in range 8-15
const num = Number.parseInt(value, 10)
return num >= 8 && num <= 15
}

// https://nodejs.org/api/intl.html#detecting-internationalization-support
Expand Down
25 changes: 23 additions & 2 deletions lib/web/websocket/websocket.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class WebSocket extends EventTarget {
/** @type {SendQueue} */
#sendQueue

/** @type {{ maxDecompressedMessageSize?: number }} */
#options

/**
* @param {string} url
* @param {string|string[]} protocols
Expand Down Expand Up @@ -117,6 +120,11 @@ class WebSocket extends EventTarget {
// 10. Set this's url to urlRecord.
this[kWebSocketURL] = new URL(urlRecord.href)

// Store options for later use (e.g., maxDecompressedMessageSize)
this.#options = {
maxDecompressedMessageSize: options.maxDecompressedMessageSize
}

// 11. Let client be this's relevant settings object.
const client = environmentSettingsObject.settingsObject

Expand Down Expand Up @@ -431,11 +439,11 @@ class WebSocket extends EventTarget {
* @see https://websockets.spec.whatwg.org/#feedback-from-the-protocol
*/
#onConnectionEstablished (response, parsedExtensions) {
// processResponse is called when the "responses header list has been received and initialized."
// processResponse is called when the "response's header list has been received and initialized."
// once this happens, the connection is open
this[kResponse] = response

const parser = new ByteParser(this, parsedExtensions)
const parser = new ByteParser(this, parsedExtensions, this.#options)
parser.on('drain', onParserDrain)
parser.on('error', onParserError.bind(this))

Expand Down Expand Up @@ -538,6 +546,19 @@ webidl.converters.WebSocketInit = webidl.dictionaryConverter([
{
key: 'headers',
converter: webidl.nullableConverter(webidl.converters.HeadersInit)
},
{
key: 'maxDecompressedMessageSize',
converter: webidl.nullableConverter((V) => {
V = webidl.converters['unsigned long long'](V)
if (V <= 0) {
throw webidl.errors.exception({
header: 'WebSocket constructor',
message: 'maxDecompressedMessageSize must be greater than 0'
})
}
return V
})
}
])

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
"devDependencies": {
"@fastify/busboy": "2.1.1",
"@matteo.collina/tspl": "^0.1.1",
"@metcoder95/https-pem": "^1.0.0",
"@sinonjs/fake-timers": "^11.1.0",
"@types/node": "~18.19.50",
"abort-controller": "^3.0.0",
Expand All @@ -117,7 +118,6 @@
"fast-check": "^3.17.1",
"form-data": "^4.0.0",
"formdata-node": "^6.0.3",
"https-pem": "^3.0.0",
"husky": "^9.0.7",
"jest": "^29.0.2",
"jsdom": "^24.0.0",
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate-pem.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/* istanbul ignore file */

require('https-pem/install')
require('@metcoder95/https-pem/install')
2 changes: 1 addition & 1 deletion test/connect-pre-shared-session.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const { tspl } = require('@matteo.collina/tspl')
const { test, after, mock } = require('node:test')
const { Client } = require('..')
const { createServer } = require('node:https')
const pem = require('https-pem')
const pem = require('@metcoder95/https-pem')
const tls = require('node:tls')

test('custom session passed to client will be used in tls connect call', async (t) => {
Expand Down
Loading
Loading