Skip to content

Commit 51ce3eb

Browse files
committed
Enforce decoded message size limit for permessage-deflate
1 parent 0beb0f2 commit 51ce3eb

File tree

5 files changed

+100
-3
lines changed

5 files changed

+100
-3
lines changed

httpclient5-websocket/src/main/java/org/apache/hc/client5/http/websocket/transport/WebSocketSessionEngine.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,11 @@ private void handleFrame() {
404404
final byte[] comp = WebSocketBufferOps.toBytes(payload);
405405
final byte[] plain;
406406
try {
407-
plain = decChain.decode(comp);
407+
plain = decChain.decode(comp, cfg.getMaxMessageSize());
408+
} catch (final WebSocketProtocolException wspe) {
409+
initiateClose(wspe.closeCode, wspe.getMessage());
410+
inbuf.clear();
411+
return;
408412
} catch (final Exception e) {
409413
initiateClose(1007, "Extension decode failed");
410414
inbuf.clear();
@@ -506,7 +510,10 @@ private void deliverAssembledMessage() {
506510
byte[] data = body;
507511
if (compressed && decChain != null) {
508512
try {
509-
data = decChain.decode(body);
513+
data = decChain.decode(body, cfg.getMaxMessageSize());
514+
} catch (final WebSocketProtocolException wspe) {
515+
initiateClose(wspe.closeCode, wspe.getMessage());
516+
return;
510517
} catch (final Exception e) {
511518
try {
512519
listener.onError(e);

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/ExtensionChain.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,21 @@ public DecodeChain(final List<WebSocketExtensionChain.Decoder> decs) {
125125
* Decode a full message (reverse order if stacking).
126126
*/
127127
public byte[] decode(final byte[] data) throws Exception {
128+
return decode(data, 0L);
129+
}
130+
131+
/**
132+
* Decode a full message (reverse order if stacking), enforcing a hard cap on
133+
* the decoded payload size at every step. A non-positive {@code maxDecodedSize}
134+
* disables the cap. The cap is propagated into each extension so that expanding
135+
* extensions (e.g. permessage-deflate) abort during expansion rather than after.
136+
*
137+
* @since 5.7
138+
*/
139+
public byte[] decode(final byte[] data, final long maxDecodedSize) throws Exception {
128140
byte[] out = data;
129141
for (int i = decs.size() - 1; i >= 0; i--) {
130-
out = decs.get(i).decode(out);
142+
out = decs.get(i).decode(out, maxDecodedSize);
131143
}
132144
return out;
133145
}

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/PerMessageDeflate.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.zip.Inflater;
3232

3333
import org.apache.hc.core5.annotation.Internal;
34+
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
3435
import org.apache.hc.core5.websocket.frame.FrameHeaderBits;
3536

3637
/**
@@ -144,6 +145,11 @@ public Decoder newDecoder() {
144145

145146
@Override
146147
public byte[] decode(final byte[] compressedMessage) throws Exception {
148+
return decode(compressedMessage, 0L);
149+
}
150+
151+
@Override
152+
public byte[] decode(final byte[] compressedMessage, final long maxDecodedSize) throws Exception {
147153
final byte[] withTail;
148154
if (compressedMessage == null || compressedMessage.length == 0) {
149155
withTail = TAIL.clone();
@@ -156,10 +162,17 @@ public byte[] decode(final byte[] compressedMessage) throws Exception {
156162
inf.setInput(withTail);
157163
final ByteArrayOutputStream out = new ByteArrayOutputStream(Math.max(128, withTail.length * 2));
158164
final byte[] buf = new byte[Math.min(16384, Math.max(1024, withTail.length * 2))];
165+
long produced = 0L;
159166
while (!inf.needsInput()) {
160167
final int n = inf.inflate(buf);
161168
if (n > 0) {
169+
// Enforce the decoded size cap during inflation, not after, so a small
170+
// compressed payload cannot expand into a huge buffer before we react.
171+
if (maxDecodedSize > 0L && produced + n > maxDecodedSize) {
172+
throw new WebSocketProtocolException(1009, "Message too big");
173+
}
162174
out.write(buf, 0, n);
175+
produced += n;
163176
} else {
164177
break;
165178
}

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/extension/WebSocketExtensionChain.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,17 @@ interface Decoder {
7676
* Decode a full message produced with this extension.
7777
*/
7878
byte[] decode(byte[] payload) throws Exception;
79+
80+
/**
81+
* Decode a full message, aborting as soon as the produced output exceeds
82+
* {@code maxDecodedSize}. A non-positive limit means no limit. Implementations
83+
* that may expand input (e.g. permessage-deflate) MUST honour the limit during
84+
* the expansion step, not only after it, to prevent decompression-bomb attacks.
85+
*
86+
* @since 5.7
87+
*/
88+
default byte[] decode(final byte[] payload, final long maxDecodedSize) throws Exception {
89+
return decode(payload);
90+
}
7991
}
8092
}

httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/extension/MessageDeflateTest.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@
3030
import static org.junit.jupiter.api.Assertions.assertEquals;
3131
import static org.junit.jupiter.api.Assertions.assertFalse;
3232
import static org.junit.jupiter.api.Assertions.assertNotEquals;
33+
import static org.junit.jupiter.api.Assertions.assertThrows;
3334
import static org.junit.jupiter.api.Assertions.assertTrue;
3435

3536
import java.nio.charset.StandardCharsets;
37+
import java.util.Arrays;
3638

39+
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
3740
import org.apache.hc.core5.websocket.frame.FrameHeaderBits;
3841
import org.junit.jupiter.api.Test;
3942

@@ -81,6 +84,56 @@ void roundTrip_message() throws Exception {
8184
assertArrayEquals(plain, roundTrip);
8285
}
8386

87+
@Test
88+
void decode_withinLimit_succeeds() throws Exception {
89+
final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null);
90+
final WebSocketExtensionChain.Encoder enc = pmce.newEncoder();
91+
final WebSocketExtensionChain.Decoder dec = pmce.newDecoder();
92+
93+
final byte[] plain = "hello world hello world hello world".getBytes(StandardCharsets.UTF_8);
94+
final byte[] wire = enc.encode(plain, true, true).payload;
95+
96+
// Limit comfortably above the inflated size.
97+
final byte[] roundTrip = dec.decode(wire, plain.length + 16);
98+
assertArrayEquals(plain, roundTrip);
99+
}
100+
101+
@Test
102+
void decode_inflationBomb_isRejectedDuringInflate() {
103+
// A small, highly compressible payload that inflates to a much larger plaintext.
104+
final byte[] plain = new byte[64 * 1024];
105+
Arrays.fill(plain, (byte) 'A');
106+
107+
final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null);
108+
final WebSocketExtensionChain.Encoder enc = pmce.newEncoder();
109+
final WebSocketExtensionChain.Decoder dec = pmce.newDecoder();
110+
111+
final byte[] wire = enc.encode(plain, true, true).payload;
112+
// Sanity: the compressed wire form is far smaller than the inflated payload.
113+
assertTrue(wire.length < plain.length / 4,
114+
"test setup: payload should be highly compressible, was " + wire.length + " vs " + plain.length);
115+
116+
// maxDecodedSize is well below the inflated size; decode must abort with 1009.
117+
final WebSocketProtocolException ex = assertThrows(WebSocketProtocolException.class,
118+
() -> dec.decode(wire, 1024L));
119+
assertEquals(1009, ex.closeCode);
120+
assertEquals("Message too big", ex.getMessage());
121+
}
122+
123+
@Test
124+
void decode_zeroLimitMeansUnlimited() throws Exception {
125+
final PerMessageDeflate pmce = new PerMessageDeflate(true, true, true, null, null);
126+
final WebSocketExtensionChain.Encoder enc = pmce.newEncoder();
127+
final WebSocketExtensionChain.Decoder dec = pmce.newDecoder();
128+
129+
final byte[] plain = new byte[8 * 1024];
130+
Arrays.fill(plain, (byte) 'B');
131+
final byte[] wire = enc.encode(plain, true, true).payload;
132+
133+
final byte[] roundTrip = dec.decode(wire, 0L);
134+
assertArrayEquals(plain, roundTrip);
135+
}
136+
84137
private static boolean endsWithTail(final byte[] b) {
85138
if (b.length < 4) {
86139
return false;

0 commit comments

Comments
 (0)