Skip to content

Commit 5b65d6c

Browse files
authored
perf: cache offsetBufferAddress in CometPlainVector for variable-width vectors (apache#4364)
1 parent 48f7b03 commit 5b65d6c

2 files changed

Lines changed: 90 additions & 8 deletions

File tree

spark/src/main/java/org/apache/comet/vector/CometPlainVector.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
/** A column vector whose elements are plainly decoded. */
3434
public class CometPlainVector extends CometDecodedVector {
3535
private final long valueBufferAddress;
36+
private final long offsetBufferAddress;
3637
private final boolean isBaseFixedWidthVector;
3738

3839
private byte booleanByteCache;
@@ -58,6 +59,12 @@ public CometPlainVector(ValueVector vector, boolean isUuid, boolean isReused) {
5859
}
5960

6061
isBaseFixedWidthVector = valueVector instanceof BaseFixedWidthVector;
62+
if (vector instanceof BaseVariableWidthVector) {
63+
this.offsetBufferAddress =
64+
((BaseVariableWidthVector) vector).getOffsetBuffer().memoryAddress();
65+
} else {
66+
this.offsetBufferAddress = -1;
67+
}
6168
this.isReused = isReused;
6269
}
6370

@@ -123,13 +130,11 @@ public double getDouble(int rowId) {
123130
@Override
124131
public UTF8String getUTF8String(int rowId) {
125132
if (isNullAt(rowId)) return null;
126-
if (!isBaseFixedWidthVector) {
127-
BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) valueVector;
128-
long offsetBufferAddress = varWidthVector.getOffsetBuffer().memoryAddress();
133+
if (offsetBufferAddress != -1) {
129134
int offset = Platform.getInt(null, offsetBufferAddress + rowId * 4L);
130135
int length = Platform.getInt(null, offsetBufferAddress + (rowId + 1L) * 4L) - offset;
131136
return UTF8String.fromAddress(null, valueBufferAddress + offset, length);
132-
} else {
137+
} else if (isBaseFixedWidthVector) {
133138
BaseFixedWidthVector fixedWidthVector = (BaseFixedWidthVector) valueVector;
134139
int length = fixedWidthVector.getTypeWidth();
135140
int offset = rowId * length;
@@ -142,6 +147,8 @@ public UTF8String getUTF8String(int rowId) {
142147
} else {
143148
return UTF8String.fromString(convertToUuid(result).toString());
144149
}
150+
} else {
151+
throw new IllegalStateException("Unsupported UTF8 vector type: " + valueVector.getName());
145152
}
146153
}
147154

@@ -150,17 +157,15 @@ public byte[] getBinary(int rowId) {
150157
if (isNullAt(rowId)) return null;
151158
int offset;
152159
int length;
153-
if (valueVector instanceof BaseVariableWidthVector) {
154-
BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) valueVector;
155-
long offsetBufferAddress = varWidthVector.getOffsetBuffer().memoryAddress();
160+
if (offsetBufferAddress != -1) {
156161
offset = Platform.getInt(null, offsetBufferAddress + rowId * 4L);
157162
length = Platform.getInt(null, offsetBufferAddress + (rowId + 1L) * 4L) - offset;
158163
} else if (valueVector instanceof BaseFixedWidthVector) {
159164
BaseFixedWidthVector fixedWidthVector = (BaseFixedWidthVector) valueVector;
160165
length = fixedWidthVector.getTypeWidth();
161166
offset = rowId * length;
162167
} else {
163-
throw new RuntimeException("Unsupported binary vector type: " + valueVector.getName());
168+
throw new IllegalStateException("Unsupported binary vector type: " + valueVector.getName());
164169
}
165170
byte[] result = new byte[length];
166171
Platform.copyMemory(
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.vector;
21+
22+
import java.nio.charset.StandardCharsets;
23+
24+
import org.junit.Test;
25+
26+
import org.apache.arrow.memory.RootAllocator;
27+
import org.apache.arrow.vector.VarBinaryVector;
28+
import org.apache.arrow.vector.VarCharVector;
29+
30+
import static org.junit.Assert.assertArrayEquals;
31+
import static org.junit.Assert.assertEquals;
32+
import static org.junit.Assert.assertNull;
33+
34+
public class TestCometPlainVector {
35+
36+
@Test
37+
public void testGetUTF8StringWithVariableWidthVector() {
38+
try (RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
39+
VarCharVector vector = new VarCharVector("strings", allocator);
40+
vector.allocateNew();
41+
vector.setSafe(0, bytes("alpha"));
42+
vector.setSafe(1, bytes(""));
43+
vector.setSafe(2, bytes("spark"));
44+
vector.setValueCount(4); // row 3 is null (validity bit not set)
45+
46+
try (CometPlainVector cv = new CometPlainVector(vector, false)) {
47+
assertEquals("alpha", cv.getUTF8String(0).toString());
48+
assertEquals("", cv.getUTF8String(1).toString());
49+
assertEquals("spark", cv.getUTF8String(2).toString());
50+
assertNull(cv.getUTF8String(3));
51+
}
52+
}
53+
}
54+
55+
@Test
56+
public void testGetBinaryWithVariableWidthVector() {
57+
try (RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
58+
VarBinaryVector vector = new VarBinaryVector("bytes", allocator);
59+
vector.allocateNew();
60+
vector.setSafe(0, new byte[] {1, 2, 3}, 0, 3);
61+
vector.setSafe(1, new byte[0], 0, 0);
62+
vector.setSafe(2, new byte[] {4, 5}, 0, 2);
63+
vector.setValueCount(4); // row 3 is null (validity bit not set)
64+
65+
try (CometPlainVector cv = new CometPlainVector(vector, false)) {
66+
assertArrayEquals(new byte[] {1, 2, 3}, cv.getBinary(0));
67+
assertArrayEquals(new byte[0], cv.getBinary(1));
68+
assertArrayEquals(new byte[] {4, 5}, cv.getBinary(2));
69+
assertNull(cv.getBinary(3));
70+
}
71+
}
72+
}
73+
74+
private static byte[] bytes(String s) {
75+
return s.getBytes(StandardCharsets.UTF_8);
76+
}
77+
}

0 commit comments

Comments
 (0)