Skip to content

Commit ea90bfa

Browse files
committed
[SPARK-57766][SQL] Validate WKB element counts before allocation
### What changes were proposed in this pull request? This PR validates WKB collection element counts before using them as `ArrayList` initial capacities in `WkbReader`. The new `readCount` helper rejects negative counts and counts that cannot fit in the remaining WKB buffer before parsing these structures: - LineString points - Polygon ring points - Polygon rings - MultiPoint points - MultiLineString line strings - MultiPolygon polygons - GeometryCollection geometries It also adds regression coverage for negative and oversized counts, including the public `Geometry.fromWkb` path. ### Why are the changes needed? Malformed WKB can encode invalid collection counts. Before this change, those counts were passed directly to `new ArrayList<>(count)`, which could throw raw allocation-related exceptions such as `IllegalArgumentException` for negative capacities or attempt excessive allocation for very large counts. Invalid WKB should be rejected consistently as a WKB parse error before allocation. ### Does this PR introduce _any_ user-facing change? Yes. For malformed WKB with invalid collection counts, parsing now fails with Spark's normal `WKB_PARSE_ERROR` instead of raw Java allocation failures. This affects unreleased WKB parsing behavior. ### How was this patch tested? Added tests in `WkbErrorHandlingTest` for negative and oversized counts across all count-bearing WKB collection readers, plus a `Geometry.fromWkb` regression test. Result: 37 tests passed, 0 failed. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: OpenAI Codex (GPT-5) Closes #56875 from szehon-ho/SPARK-57766-wkb-count-validation. Authored-by: Szehon Ho <szehon.apache@gmail.com> Signed-off-by: Szehon Ho <szehon.apache@gmail.com> (cherry picked from commit e58192f) Signed-off-by: Szehon Ho <szehon.apache@gmail.com>
1 parent 8b95cd3 commit ea90bfa

2 files changed

Lines changed: 70 additions & 7 deletions

File tree

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/geo/WkbReader.java

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,20 @@ private int readInt() {
178178
return buffer.getInt();
179179
}
180180

181+
private int readCount(String countName, int minBytesPerItem) {
182+
long countPos = buffer.position();
183+
int count = readInt();
184+
if (count < 0) {
185+
throw new WkbParseException("Invalid count for " + countName + ": " + count, countPos,
186+
currentWkb);
187+
}
188+
if (count > buffer.remaining() / minBytesPerItem) {
189+
throw new WkbParseException("Invalid count for " + countName
190+
+ ": exceeds remaining bytes", countPos, currentWkb);
191+
}
192+
return count;
193+
}
194+
181195
/**
182196
* Reads a double coordinate value, allowing NaN for empty points.
183197
*/
@@ -386,7 +400,7 @@ private Point readInternalPoint(int srid, int dimensionCount, boolean hasZ,
386400

387401
private LineString readLineString(int srid, int dimensionCount, boolean hasZ, boolean hasM) {
388402
long numPointsPos = buffer.position();
389-
int numPoints = readInt();
403+
int numPoints = readCount("LineString points", dimensionCount * WkbUtil.DOUBLE_SIZE);
390404

391405
if (validationLevel > 0 && numPoints == 1) {
392406
throw new WkbParseException("Too few points in linestring", numPointsPos, currentWkb);
@@ -402,7 +416,7 @@ private LineString readLineString(int srid, int dimensionCount, boolean hasZ, bo
402416

403417
private Ring readRing(int srid, int dimensionCount, boolean hasZ, boolean hasM) {
404418
long numPointsPos = buffer.position();
405-
int numPoints = readInt();
419+
int numPoints = readCount("ring points", dimensionCount * WkbUtil.DOUBLE_SIZE);
406420

407421
List<Point> points = new ArrayList<>(numPoints);
408422

@@ -425,7 +439,7 @@ private Ring readRing(int srid, int dimensionCount, boolean hasZ, boolean hasM)
425439
}
426440

427441
private Polygon readPolygon(int srid, int dimensionCount, boolean hasZ, boolean hasM) {
428-
int numRings = readInt();
442+
int numRings = readCount("polygon rings", WkbUtil.INT_SIZE);
429443
List<Ring> rings = new ArrayList<>(numRings);
430444

431445
for (int i = 0; i < numRings; i++) {
@@ -436,7 +450,7 @@ private Polygon readPolygon(int srid, int dimensionCount, boolean hasZ, boolean
436450
}
437451

438452
private MultiPoint readMultiPoint(int srid, boolean hasZ, boolean hasM) {
439-
int numPoints = readInt();
453+
int numPoints = readCount("MultiPoint points", WkbUtil.BYTE_SIZE + WkbUtil.TYPE_SIZE);
440454
List<Point> points = new ArrayList<>(numPoints);
441455

442456
for (int i = 0; i < numPoints; i++) {
@@ -452,7 +466,8 @@ private MultiPoint readMultiPoint(int srid, boolean hasZ, boolean hasM) {
452466
}
453467

454468
private MultiLineString readMultiLineString(int srid, boolean hasZ, boolean hasM) {
455-
int numLineStrings = readInt();
469+
int numLineStrings =
470+
readCount("MultiLineString line strings", WkbUtil.BYTE_SIZE + WkbUtil.TYPE_SIZE);
456471
List<LineString> lineStrings = new ArrayList<>(numLineStrings);
457472

458473
for (int i = 0; i < numLineStrings; i++) {
@@ -468,7 +483,8 @@ private MultiLineString readMultiLineString(int srid, boolean hasZ, boolean hasM
468483
}
469484

470485
private MultiPolygon readMultiPolygon(int srid, boolean hasZ, boolean hasM) {
471-
int numPolygons = readInt();
486+
int numPolygons =
487+
readCount("MultiPolygon polygons", WkbUtil.BYTE_SIZE + WkbUtil.TYPE_SIZE);
472488
List<Polygon> polygons = new ArrayList<>(numPolygons);
473489

474490
for (int i = 0; i < numPolygons; i++) {
@@ -484,7 +500,8 @@ private MultiPolygon readMultiPolygon(int srid, boolean hasZ, boolean hasM) {
484500
}
485501

486502
private GeometryCollection readGeometryCollection(int srid, boolean hasZ, boolean hasM) {
487-
int numGeometries = readInt();
503+
int numGeometries =
504+
readCount("GeometryCollection geometries", WkbUtil.BYTE_SIZE + WkbUtil.TYPE_SIZE);
488505
List<GeometryModel> geometries = new ArrayList<>(numGeometries);
489506

490507
for (int i = 0; i < numGeometries; i++) {

sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/util/geo/WkbErrorHandlingTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.util.geo;
1919

20+
import org.apache.spark.SparkIllegalArgumentException;
21+
import org.apache.spark.sql.catalyst.util.Geometry;
2022
import org.junit.jupiter.api.Assertions;
2123
import org.junit.jupiter.api.Test;
2224

@@ -118,6 +120,50 @@ public void testTruncatedLineString() {
118120
Assertions.assertSame(truncated, ex.getWkb());
119121
}
120122

123+
@Test
124+
public void testNegativeElementCounts() {
125+
String[] invalidCounts = new String[] {
126+
"0102000000ffffffff", // LineString with -1 points
127+
"010300000001000000ffffffff", // Polygon ring with -1 points
128+
"0103000000ffffffff", // Polygon with -1 rings
129+
"0104000000ffffffff", // MultiPoint with -1 points
130+
"0105000000ffffffff", // MultiLineString with -1 linestrings
131+
"0106000000ffffffff", // MultiPolygon with -1 polygons
132+
"0107000000ffffffff" // GeometryCollection with -1 geometries
133+
};
134+
135+
for (String invalidCount : invalidCounts) {
136+
assertParseError(invalidCount, "Invalid count");
137+
}
138+
}
139+
140+
@Test
141+
public void testElementCountsExceedRemainingBytes() {
142+
String[] invalidCounts = new String[] {
143+
"0102000000ffffff7f", // LineString with too many points
144+
"010300000001000000ffffff7f", // Polygon ring with too many points
145+
"0103000000ffffff7f", // Polygon with too many rings
146+
"0104000000ffffff7f", // MultiPoint with too many points
147+
"0105000000ffffff7f", // MultiLineString with too many linestrings
148+
"0106000000ffffff7f", // MultiPolygon with too many polygons
149+
"0107000000ffffff7f" // GeometryCollection with too many geometries
150+
};
151+
152+
for (String invalidCount : invalidCounts) {
153+
assertParseError(invalidCount, "Invalid count");
154+
}
155+
}
156+
157+
@Test
158+
public void testGeometryFromWkbRejectsInvalidCount() {
159+
byte[] wkb = hexToBytes("0102000000ffffffff");
160+
SparkIllegalArgumentException ex = Assertions.assertThrows(
161+
SparkIllegalArgumentException.class, () -> Geometry.fromWkb(wkb));
162+
163+
Assertions.assertEquals("WKB_PARSE_ERROR", ex.getCondition());
164+
Assertions.assertTrue(ex.getMessage().contains("Invalid count"));
165+
}
166+
121167
@Test
122168
public void testValidationLevels() {
123169
// With validation level 0, invalid geometries might be accepted

0 commit comments

Comments
 (0)