Skip to content

Commit 9908183

Browse files
adamreevejimexist
andauthored
Allow streaming batches of data from a DataFrame (#75)
* Implement executeStream method This returns a new RecordBatchStream class that acts similarly to ArrowReader but record batches can be retrieved asynchronously. * Use C data interface for reading stream data * Support reading Arrow dictionary data with executeStream * Update datafusion-jni/src/dataframe.rs * Update datafusion-jni/src/stream.rs * Update datafusion-jni/src/stream.rs * Apply suggestions from code review * Fixes for upgraded JNI and datafusion --------- Co-authored-by: Jiayu Liu <Jimexist@users.noreply.github.com>
1 parent a3e34d4 commit 9908183

13 files changed

Lines changed: 466 additions & 1 deletion

File tree

datafusion-java/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies {
1111
implementation 'org.slf4j:slf4j-api:1.7.36'
1212
implementation 'org.apache.arrow:arrow-format:13.0.0'
1313
implementation 'org.apache.arrow:arrow-vector:13.0.0'
14+
implementation 'org.apache.arrow:arrow-c-data:13.0.0'
1415
runtimeOnly 'org.apache.arrow:arrow-memory-unsafe:13.0.0'
1516
testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1'
1617
testImplementation 'org.apache.hadoop:hadoop-client:3.3.5'

datafusion-java/src/main/java/org/apache/arrow/datafusion/DataFrame.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ public interface DataFrame extends NativeProxy {
2121
*/
2222
CompletableFuture<ArrowReader> collect(BufferAllocator allocator);
2323

24+
/**
25+
* Execute this DataFrame and return a stream of the result data
26+
*
27+
* @param allocator {@link BufferAllocator buffer allocator} to allocate vectors for the stream
28+
* @return Stream of results
29+
*/
30+
CompletableFuture<RecordBatchStream> executeStream(BufferAllocator allocator);
31+
2432
/**
2533
* Print results.
2634
*

datafusion-java/src/main/java/org/apache/arrow/datafusion/DataFrames.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ private DataFrames() {}
1515
static native void collectDataframe(
1616
long runtime, long dataframe, BiConsumer<String, byte[]> callback);
1717

18+
static native void executeStream(long runtime, long dataframe, ObjectResultCallback callback);
19+
1820
static native void writeParquet(
1921
long runtime, long dataframe, String path, Consumer<String> callback);
2022

datafusion-java/src/main/java/org/apache/arrow/datafusion/DefaultDataFrame.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ public CompletableFuture<ArrowReader> collect(BufferAllocator allocator) {
4141
return result;
4242
}
4343

44+
@Override
45+
public CompletableFuture<RecordBatchStream> executeStream(BufferAllocator allocator) {
46+
CompletableFuture<RecordBatchStream> result = new CompletableFuture<>();
47+
Runtime runtime = context.getRuntime();
48+
long runtimePointer = runtime.getPointer();
49+
long dataframe = getPointer();
50+
DataFrames.executeStream(
51+
runtimePointer,
52+
dataframe,
53+
(errString, streamId) -> {
54+
if (containsError(errString)) {
55+
result.completeExceptionally(new RuntimeException(errString));
56+
} else {
57+
result.complete(new DefaultRecordBatchStream(context, streamId, allocator));
58+
}
59+
});
60+
return result;
61+
}
62+
4463
private boolean containsError(String errString) {
4564
return errString != null && !errString.isEmpty();
4665
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package org.apache.arrow.datafusion;
2+
3+
import java.util.Set;
4+
import java.util.concurrent.CompletableFuture;
5+
import org.apache.arrow.c.ArrowArray;
6+
import org.apache.arrow.c.ArrowSchema;
7+
import org.apache.arrow.c.CDataDictionaryProvider;
8+
import org.apache.arrow.c.Data;
9+
import org.apache.arrow.memory.BufferAllocator;
10+
import org.apache.arrow.vector.VectorSchemaRoot;
11+
import org.apache.arrow.vector.dictionary.Dictionary;
12+
import org.apache.arrow.vector.types.pojo.Schema;
13+
14+
class DefaultRecordBatchStream extends AbstractProxy implements RecordBatchStream {
15+
private final SessionContext context;
16+
private final BufferAllocator allocator;
17+
private final CDataDictionaryProvider dictionaryProvider;
18+
private VectorSchemaRoot vectorSchemaRoot = null;
19+
private boolean initialized = false;
20+
21+
DefaultRecordBatchStream(SessionContext context, long pointer, BufferAllocator allocator) {
22+
super(pointer);
23+
this.context = context;
24+
this.allocator = allocator;
25+
this.dictionaryProvider = new CDataDictionaryProvider();
26+
}
27+
28+
@Override
29+
void doClose(long pointer) {
30+
destroy(pointer);
31+
dictionaryProvider.close();
32+
if (initialized) {
33+
vectorSchemaRoot.close();
34+
}
35+
}
36+
37+
@Override
38+
public VectorSchemaRoot getVectorSchemaRoot() {
39+
ensureInitialized();
40+
return vectorSchemaRoot;
41+
}
42+
43+
@Override
44+
public CompletableFuture<Boolean> loadNextBatch() {
45+
ensureInitialized();
46+
Runtime runtime = context.getRuntime();
47+
long runtimePointer = runtime.getPointer();
48+
long recordBatchStream = getPointer();
49+
CompletableFuture<Boolean> result = new CompletableFuture<>();
50+
next(
51+
runtimePointer,
52+
recordBatchStream,
53+
(errString, arrowArrayAddress) -> {
54+
if (containsError(errString)) {
55+
result.completeExceptionally(new RuntimeException(errString));
56+
} else if (arrowArrayAddress == 0) {
57+
// Reached end of stream
58+
result.complete(false);
59+
} else {
60+
try {
61+
ArrowArray arrowArray = ArrowArray.wrap(arrowArrayAddress);
62+
Data.importIntoVectorSchemaRoot(
63+
allocator, arrowArray, vectorSchemaRoot, dictionaryProvider);
64+
result.complete(true);
65+
} catch (Exception e) {
66+
result.completeExceptionally(e);
67+
}
68+
}
69+
});
70+
return result;
71+
}
72+
73+
@Override
74+
public Dictionary lookup(long id) {
75+
return dictionaryProvider.lookup(id);
76+
}
77+
78+
@Override
79+
public Set<Long> getDictionaryIds() {
80+
return dictionaryProvider.getDictionaryIds();
81+
}
82+
83+
private void ensureInitialized() {
84+
if (!initialized) {
85+
Schema schema = getSchema();
86+
this.vectorSchemaRoot = VectorSchemaRoot.create(schema, allocator);
87+
}
88+
initialized = true;
89+
}
90+
91+
private Schema getSchema() {
92+
long recordBatchStream = getPointer();
93+
// Native method is not async, but use a future to store the result for convenience
94+
CompletableFuture<Schema> result = new CompletableFuture<>();
95+
getSchema(
96+
recordBatchStream,
97+
(errString, arrowSchemaAddress) -> {
98+
if (containsError(errString)) {
99+
result.completeExceptionally(new RuntimeException(errString));
100+
} else {
101+
try {
102+
ArrowSchema arrowSchema = ArrowSchema.wrap(arrowSchemaAddress);
103+
Schema schema = Data.importSchema(allocator, arrowSchema, dictionaryProvider);
104+
result.complete(schema);
105+
// The FFI schema will be released from rust when it is dropped
106+
} catch (Exception e) {
107+
result.completeExceptionally(e);
108+
}
109+
}
110+
});
111+
return result.join();
112+
}
113+
114+
private static boolean containsError(String errString) {
115+
return errString != null && !"".equals(errString);
116+
}
117+
118+
private static native void getSchema(long pointer, ObjectResultCallback callback);
119+
120+
private static native void next(long runtime, long pointer, ObjectResultCallback callback);
121+
122+
private static native void destroy(long pointer);
123+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.apache.arrow.datafusion;
2+
3+
import java.util.concurrent.CompletableFuture;
4+
import org.apache.arrow.vector.VectorSchemaRoot;
5+
import org.apache.arrow.vector.dictionary.DictionaryProvider;
6+
7+
/**
8+
* A record batch stream is a stream of tabular Arrow data that can be iterated over asynchronously
9+
*/
10+
public interface RecordBatchStream extends AutoCloseable, NativeProxy, DictionaryProvider {
11+
/**
12+
* Get the VectorSchemaRoot that will be populated with data as the stream is iterated over
13+
*
14+
* @return the stream's VectorSchemaRoot
15+
*/
16+
VectorSchemaRoot getVectorSchemaRoot();
17+
18+
/**
19+
* Load the next record batch in the stream into the VectorSchemaRoot
20+
*
21+
* @return Future that will complete with true if a batch was loaded or false if the end of the
22+
* stream has been reached
23+
*/
24+
CompletableFuture<Boolean> loadNextBatch();
25+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package org.apache.arrow.datafusion;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import java.net.URL;
6+
import java.nio.charset.StandardCharsets;
7+
import java.nio.file.Files;
8+
import java.nio.file.Path;
9+
import java.nio.file.Paths;
10+
import java.util.Arrays;
11+
import java.util.List;
12+
import org.apache.arrow.memory.BufferAllocator;
13+
import org.apache.arrow.memory.RootAllocator;
14+
import org.apache.arrow.vector.BigIntVector;
15+
import org.apache.arrow.vector.Float8Vector;
16+
import org.apache.arrow.vector.VarCharVector;
17+
import org.apache.arrow.vector.VectorSchemaRoot;
18+
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
19+
import org.apache.arrow.vector.types.pojo.Schema;
20+
import org.junit.jupiter.api.Test;
21+
import org.junit.jupiter.api.io.TempDir;
22+
23+
public class TestExecuteStream {
24+
@Test
25+
public void executeStream(@TempDir Path tempDir) throws Exception {
26+
try (SessionContext context = SessionContexts.create();
27+
BufferAllocator allocator = new RootAllocator()) {
28+
Path csvFilePath = tempDir.resolve("data.csv");
29+
30+
List<String> lines = Arrays.asList("x,y,z", "1,2,3.5", "4,5,6.5", "7,8,9.5");
31+
Files.write(csvFilePath, lines);
32+
33+
context.registerCsv("test", csvFilePath).join();
34+
35+
try (RecordBatchStream stream =
36+
context
37+
.sql("SELECT y,z FROM test WHERE x > 3")
38+
.thenComposeAsync(df -> df.executeStream(allocator))
39+
.join()) {
40+
VectorSchemaRoot root = stream.getVectorSchemaRoot();
41+
Schema schema = root.getSchema();
42+
assertEquals(2, schema.getFields().size());
43+
assertEquals("y", schema.getFields().get(0).getName());
44+
assertEquals("z", schema.getFields().get(1).getName());
45+
46+
assertTrue(stream.loadNextBatch().join());
47+
assertEquals(2, root.getRowCount());
48+
BigIntVector yValues = (BigIntVector) root.getVector(0);
49+
assertEquals(5, yValues.get(0));
50+
assertEquals(8, yValues.get(1));
51+
Float8Vector zValues = (Float8Vector) root.getVector(1);
52+
assertEquals(6.5, zValues.get(0));
53+
assertEquals(9.5, zValues.get(1));
54+
55+
assertFalse(stream.loadNextBatch().join());
56+
}
57+
}
58+
}
59+
60+
@Test
61+
public void readDictionaryData() throws Exception {
62+
try (SessionContext context = SessionContexts.create();
63+
BufferAllocator allocator = new RootAllocator()) {
64+
65+
URL fileUrl = this.getClass().getResource("/dictionary_data.parquet");
66+
Path parquetFilePath = Paths.get(fileUrl.getPath());
67+
68+
context.registerParquet("test", parquetFilePath).join();
69+
70+
try (RecordBatchStream stream =
71+
context
72+
.sql("SELECT x,y FROM test")
73+
.thenComposeAsync(df -> df.executeStream(allocator))
74+
.join()) {
75+
VectorSchemaRoot root = stream.getVectorSchemaRoot();
76+
Schema schema = root.getSchema();
77+
assertEquals(2, schema.getFields().size());
78+
assertEquals("x", schema.getFields().get(0).getName());
79+
assertEquals("y", schema.getFields().get(1).getName());
80+
81+
int rowsRead = 0;
82+
while (stream.loadNextBatch().join()) {
83+
int batchNumRows = root.getRowCount();
84+
BigIntVector xValuesEncoded = (BigIntVector) root.getVector(0);
85+
long xDictionaryId = xValuesEncoded.getField().getDictionary().getId();
86+
try (VarCharVector xValues =
87+
(VarCharVector)
88+
DictionaryEncoder.decode(xValuesEncoded, stream.lookup(xDictionaryId))) {
89+
String[] expected = {"one", "two", "three"};
90+
for (int i = 0; i < batchNumRows; ++i) {
91+
assertEquals(
92+
new String(xValues.get(i), StandardCharsets.UTF_8), expected[(rowsRead + i) % 3]);
93+
}
94+
}
95+
96+
BigIntVector yValuesEncoded = (BigIntVector) root.getVector(1);
97+
long yDictionaryId = yValuesEncoded.getField().getDictionary().getId();
98+
try (VarCharVector yValues =
99+
(VarCharVector)
100+
DictionaryEncoder.decode(yValuesEncoded, stream.lookup(yDictionaryId))) {
101+
String[] expected = {"four", "five", "six"};
102+
for (int i = 0; i < batchNumRows; ++i) {
103+
assertEquals(
104+
new String(yValues.get(i), StandardCharsets.UTF_8), expected[(rowsRead + i) % 3]);
105+
}
106+
}
107+
rowsRead += batchNumRows;
108+
}
109+
110+
assertEquals(100, rowsRead);
111+
}
112+
}
113+
}
114+
}
887 Bytes
Binary file not shown.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pyarrow as pa
2+
import pyarrow.parquet as pq
3+
4+
5+
num_rows = 100
6+
7+
dict_array_x = pa.DictionaryArray.from_arrays(
8+
pa.array([i % 3 for i in range(num_rows)]),
9+
pa.array(["one", "two", "three"])
10+
)
11+
12+
dict_array_y = pa.DictionaryArray.from_arrays(
13+
pa.array([i % 3 for i in range(num_rows)]),
14+
pa.array(["four", "five", "six"])
15+
)
16+
17+
table = pa.Table.from_arrays([dict_array_x, dict_array_y], ["x", "y"])
18+
pq.write_table(table, "src/test/resources/dictionary_data.parquet")

datafusion-jni/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ edition = "2021"
1212
[dependencies]
1313
jni = "^0.21.0"
1414
tokio = "^1.32.0"
15-
arrow = "^36.0"
15+
arrow = { version = "^36.0", features = ["ffi"] }
1616
datafusion = "^22.0"
17+
futures = "0.3.28"
1718

1819
[lib]
1920
crate_type = ["cdylib"]

0 commit comments

Comments
 (0)