Skip to content

Commit 5a1f2ac

Browse files
Improve behavior under the risk of an int overflow and negative lengths in JavaLite
PiperOrigin-RevId: 908845731
1 parent 799e73b commit 5a1f2ac

2 files changed

Lines changed: 121 additions & 48 deletions

File tree

java/core/src/main/java/com/google/protobuf/ArrayDecoders.java

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,14 @@ static int decodePackedVarint32List(
437437
byte[] data, int position, ProtobufList<?> list, Registers registers) throws IOException {
438438
final IntArrayList output = (IntArrayList) list;
439439
position = decodeVarint32(data, position, registers);
440-
final int fieldLimit = position + registers.int1;
440+
final int packedDataByteSize = registers.int1;
441+
if (packedDataByteSize < 0) {
442+
throw InvalidProtocolBufferException.negativeSize();
443+
}
444+
if (packedDataByteSize > data.length - position) {
445+
throw InvalidProtocolBufferException.truncatedMessage();
446+
}
447+
final int fieldLimit = position + packedDataByteSize;
441448
while (position < fieldLimit) {
442449
position = decodeVarint32(data, position, registers);
443450
output.addInt(registers.int1);
@@ -453,7 +460,14 @@ static int decodePackedVarint64List(
453460
byte[] data, int position, ProtobufList<?> list, Registers registers) throws IOException {
454461
final LongArrayList output = (LongArrayList) list;
455462
position = decodeVarint32(data, position, registers);
456-
final int fieldLimit = position + registers.int1;
463+
final int packedDataByteSize = registers.int1;
464+
if (packedDataByteSize < 0) {
465+
throw InvalidProtocolBufferException.negativeSize();
466+
}
467+
if (packedDataByteSize > data.length - position) {
468+
throw InvalidProtocolBufferException.truncatedMessage();
469+
}
470+
final int fieldLimit = position + packedDataByteSize;
457471
while (position < fieldLimit) {
458472
position = decodeVarint64(data, position, registers);
459473
output.addLong(registers.long1);
@@ -471,10 +485,13 @@ static int decodePackedFixed32List(
471485
final IntArrayList output = (IntArrayList) list;
472486
position = decodeVarint32(data, position, registers);
473487
final int packedDataByteSize = registers.int1;
474-
final int fieldLimit = position + packedDataByteSize;
475-
if (fieldLimit > data.length) {
488+
if (packedDataByteSize < 0) {
489+
throw InvalidProtocolBufferException.negativeSize();
490+
}
491+
if (packedDataByteSize > data.length - position) {
476492
throw InvalidProtocolBufferException.truncatedMessage();
477493
}
494+
final int fieldLimit = position + packedDataByteSize;
478495
output.ensureCapacity(output.size() + packedDataByteSize / 4);
479496
while (position < fieldLimit) {
480497
output.addInt(decodeFixed32(data, position));
@@ -493,10 +510,13 @@ static int decodePackedFixed64List(
493510
final LongArrayList output = (LongArrayList) list;
494511
position = decodeVarint32(data, position, registers);
495512
final int packedDataByteSize = registers.int1;
496-
final int fieldLimit = position + packedDataByteSize;
497-
if (fieldLimit > data.length) {
513+
if (packedDataByteSize < 0) {
514+
throw InvalidProtocolBufferException.negativeSize();
515+
}
516+
if (packedDataByteSize > data.length - position) {
498517
throw InvalidProtocolBufferException.truncatedMessage();
499518
}
519+
final int fieldLimit = position + packedDataByteSize;
500520
output.ensureCapacity(output.size() + packedDataByteSize / 8);
501521
while (position < fieldLimit) {
502522
output.addLong(decodeFixed64(data, position));
@@ -515,10 +535,13 @@ static int decodePackedFloatList(
515535
final FloatArrayList output = (FloatArrayList) list;
516536
position = decodeVarint32(data, position, registers);
517537
final int packedDataByteSize = registers.int1;
518-
final int fieldLimit = position + packedDataByteSize;
519-
if (fieldLimit > data.length) {
538+
if (packedDataByteSize < 0) {
539+
throw InvalidProtocolBufferException.negativeSize();
540+
}
541+
if (packedDataByteSize > data.length - position) {
520542
throw InvalidProtocolBufferException.truncatedMessage();
521543
}
544+
final int fieldLimit = position + packedDataByteSize;
522545
output.ensureCapacity(output.size() + packedDataByteSize / 4);
523546
while (position < fieldLimit) {
524547
output.addFloat(decodeFloat(data, position));
@@ -537,10 +560,13 @@ static int decodePackedDoubleList(
537560
final DoubleArrayList output = (DoubleArrayList) list;
538561
position = decodeVarint32(data, position, registers);
539562
final int packedDataByteSize = registers.int1;
540-
final int fieldLimit = position + packedDataByteSize;
541-
if (fieldLimit > data.length) {
563+
if (packedDataByteSize < 0) {
564+
throw InvalidProtocolBufferException.negativeSize();
565+
}
566+
if (packedDataByteSize > data.length - position) {
542567
throw InvalidProtocolBufferException.truncatedMessage();
543568
}
569+
final int fieldLimit = position + packedDataByteSize;
544570
output.ensureCapacity(output.size() + packedDataByteSize / 8);
545571
while (position < fieldLimit) {
546572
output.addDouble(decodeDouble(data, position));
@@ -558,7 +584,14 @@ static int decodePackedBoolList(
558584
throws InvalidProtocolBufferException {
559585
final BooleanArrayList output = (BooleanArrayList) list;
560586
position = decodeVarint32(data, position, registers);
561-
final int fieldLimit = position + registers.int1;
587+
final int packedDataByteSize = registers.int1;
588+
if (packedDataByteSize < 0) {
589+
throw InvalidProtocolBufferException.negativeSize();
590+
}
591+
if (packedDataByteSize > data.length - position) {
592+
throw InvalidProtocolBufferException.truncatedMessage();
593+
}
594+
final int fieldLimit = position + packedDataByteSize;
562595
while (position < fieldLimit) {
563596
position = decodeVarint64(data, position, registers);
564597
output.addBoolean(registers.long1 != 0);
@@ -575,7 +608,14 @@ static int decodePackedSInt32List(
575608
throws InvalidProtocolBufferException {
576609
final IntArrayList output = (IntArrayList) list;
577610
position = decodeVarint32(data, position, registers);
578-
final int fieldLimit = position + registers.int1;
611+
final int packedDataByteSize = registers.int1;
612+
if (packedDataByteSize < 0) {
613+
throw InvalidProtocolBufferException.negativeSize();
614+
}
615+
if (packedDataByteSize > data.length - position) {
616+
throw InvalidProtocolBufferException.truncatedMessage();
617+
}
618+
final int fieldLimit = position + packedDataByteSize;
579619
while (position < fieldLimit) {
580620
position = decodeVarint32(data, position, registers);
581621
output.addInt(CodedInputStream.decodeZigZag32(registers.int1));
@@ -592,7 +632,14 @@ static int decodePackedSInt64List(
592632
throws InvalidProtocolBufferException {
593633
final LongArrayList output = (LongArrayList) list;
594634
position = decodeVarint32(data, position, registers);
595-
final int fieldLimit = position + registers.int1;
635+
final int packedDataByteSize = registers.int1;
636+
if (packedDataByteSize < 0) {
637+
throw InvalidProtocolBufferException.negativeSize();
638+
}
639+
if (packedDataByteSize > data.length - position) {
640+
throw InvalidProtocolBufferException.truncatedMessage();
641+
}
642+
final int fieldLimit = position + packedDataByteSize;
596643
while (position < fieldLimit) {
597644
position = decodeVarint64(data, position, registers);
598645
output.addLong(CodedInputStream.decodeZigZag64(registers.long1));

java/core/src/test/java/com/google/protobuf/ArrayDecodersTest.java

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package com.google.protobuf;
99

10+
import static com.google.common.truth.Truth.assertThat;
1011
import static org.junit.Assert.assertThrows;
1112

1213
import com.google.protobuf.ArrayDecoders.Registers;
@@ -152,11 +153,24 @@ public void testException_decodeUnknownField() {
152153

153154
@Test
154155
public void testDecodePackedFixed32List_negativeSize() {
155-
assertThrows(
156-
InvalidProtocolBufferException.class,
157-
() ->
158-
ArrayDecoders.decodePackedFixed32List(
159-
packedSizeBytesNoTag(-1), 0, new IntArrayList(), registers));
156+
InvalidProtocolBufferException e =
157+
assertThrows(
158+
InvalidProtocolBufferException.class,
159+
() ->
160+
ArrayDecoders.decodePackedFixed32List(
161+
packedSizeBytesNoTag(-1), 0, new IntArrayList(), registers));
162+
assertThat(e).hasMessageThat().contains("negative size");
163+
}
164+
165+
@Test
166+
public void testDecodePackedFixed32List_overflowsArrayEnd() {
167+
InvalidProtocolBufferException e =
168+
assertThrows(
169+
InvalidProtocolBufferException.class,
170+
() ->
171+
ArrayDecoders.decodePackedFixed32List(
172+
packedSizeBytesNoTag(Integer.MAX_VALUE - 5), 0, new IntArrayList(), registers));
173+
assertThat(e).hasMessageThat().contains("truncated");
160174
}
161175

162176
@Test
@@ -197,56 +211,68 @@ public void testDecodePackedDoubleList_2gb_beyondEndOfArray() {
197211

198212
@Test
199213
public void testDecodePackedFixed64List_negativeSize() {
200-
assertThrows(
201-
InvalidProtocolBufferException.class,
202-
() ->
203-
ArrayDecoders.decodePackedFixed64List(
204-
packedSizeBytesNoTag(-1), 0, new LongArrayList(), registers));
214+
InvalidProtocolBufferException e =
215+
assertThrows(
216+
InvalidProtocolBufferException.class,
217+
() ->
218+
ArrayDecoders.decodePackedFixed64List(
219+
packedSizeBytesNoTag(-1), 0, new LongArrayList(), registers));
220+
assertThat(e).hasMessageThat().contains("negative size");
205221
}
206222

207223
@Test
208224
public void testDecodePackedFloatList_negativeSize() {
209-
assertThrows(
210-
InvalidProtocolBufferException.class,
211-
() ->
212-
ArrayDecoders.decodePackedFloatList(
213-
packedSizeBytesNoTag(-1), 0, new FloatArrayList(), registers));
225+
InvalidProtocolBufferException e =
226+
assertThrows(
227+
InvalidProtocolBufferException.class,
228+
() ->
229+
ArrayDecoders.decodePackedFloatList(
230+
packedSizeBytesNoTag(-1), 0, new FloatArrayList(), registers));
231+
assertThat(e).hasMessageThat().contains("negative size");
214232
}
215233

216234
@Test
217235
public void testDecodePackedDoubleList_negativeSize() {
218-
assertThrows(
219-
InvalidProtocolBufferException.class,
220-
() ->
221-
ArrayDecoders.decodePackedDoubleList(
222-
packedSizeBytesNoTag(-1), 0, new DoubleArrayList(), registers));
236+
InvalidProtocolBufferException e =
237+
assertThrows(
238+
InvalidProtocolBufferException.class,
239+
() ->
240+
ArrayDecoders.decodePackedDoubleList(
241+
packedSizeBytesNoTag(-1), 0, new DoubleArrayList(), registers));
242+
assertThat(e).hasMessageThat().contains("negative size");
223243
}
224244

225245
@Test
226246
public void testDecodePackedBoolList_negativeSize() {
227-
assertThrows(
228-
InvalidProtocolBufferException.class,
229-
() ->
230-
ArrayDecoders.decodePackedBoolList(
231-
packedSizeBytesNoTag(-1), 0, new BooleanArrayList(), registers));
247+
InvalidProtocolBufferException e =
248+
assertThrows(
249+
InvalidProtocolBufferException.class,
250+
() ->
251+
ArrayDecoders.decodePackedBoolList(
252+
packedSizeBytesNoTag(-1), 0, new BooleanArrayList(), registers));
253+
assertThat(e).hasMessageThat().contains("negative size");
232254
}
233255

234256
@Test
235257
public void testDecodePackedSInt32List_negativeSize() {
236-
assertThrows(
237-
InvalidProtocolBufferException.class,
238-
() ->
239-
ArrayDecoders.decodePackedSInt32List(
240-
packedSizeBytesNoTag(-1), 0, new IntArrayList(), registers));
258+
InvalidProtocolBufferException e =
259+
assertThrows(
260+
InvalidProtocolBufferException.class,
261+
() ->
262+
ArrayDecoders.decodePackedSInt32List(
263+
packedSizeBytesNoTag(-1), 0, new IntArrayList(), registers));
264+
assertThat(e).hasMessageThat().contains("negative size");
241265
}
242266

243267
@Test
244268
public void testDecodePackedSInt64List_negativeSize() {
245-
assertThrows(
246-
InvalidProtocolBufferException.class,
247-
() ->
248-
ArrayDecoders.decodePackedSInt64List(
249-
packedSizeBytesNoTag(-1), 0, new LongArrayList(), registers));
269+
InvalidProtocolBufferException e =
270+
assertThrows(
271+
InvalidProtocolBufferException.class,
272+
() ->
273+
ArrayDecoders.decodePackedSInt64List(
274+
packedSizeBytesNoTag(-1), 0, new LongArrayList(), registers));
275+
assertThat(e).hasMessageThat().contains("negative size");
250276
}
251277

252278
@Test

0 commit comments

Comments
 (0)