Skip to content

Commit 818294b

Browse files
committed
feat(java/driver/jni): implement executePartitioned
Closes #4242.
1 parent d812b8f commit 818294b

11 files changed

Lines changed: 368 additions & 12 deletions

File tree

.github/workflows/java.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ jobs:
233233
- name: Start Dependencies
234234
if: matrix.os == 'Linux' && matrix.arch == 'amd64'
235235
run: |
236-
docker compose up --detach --wait mssql-test postgres-test
236+
docker compose up --detach --wait flightsql-sqlite-test mssql-test postgres-test
237237
cat .env | grep -v -e '^#' | grep -e '^ADBC_' | awk NF | sed 's/"//g' | tee -a $GITHUB_ENV
238238
239239
- name: Download thirdparty driver

go/adbc/driver/flightsql/flightsql_connection.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,15 +1145,15 @@ func (c *connectionImpl) ReadPartition(ctx context.Context, serializedPartition
11451145
var info flight.FlightInfo
11461146
if err := proto.Unmarshal(serializedPartition, &info); err != nil {
11471147
return nil, adbc.Error{
1148-
Msg: err.Error(),
1148+
Msg: fmt.Sprintf("[flightsql] could not unmarshal partition as FlightInfo: %v", err),
11491149
Code: adbc.StatusInvalidArgument,
11501150
}
11511151
}
11521152

11531153
// The driver only ever returns one endpoint.
11541154
if len(info.Endpoint) != 1 {
11551155
return nil, adbc.Error{
1156-
Msg: fmt.Sprintf("Invalid partition: expected 1 endpoint, got %d", len(info.Endpoint)),
1156+
Msg: fmt.Sprintf("[flightsql] invalid partition: expected 1 endpoint, got %d", len(info.Endpoint)),
11571157
Code: adbc.StatusInvalidArgument,
11581158
}
11591159
}

go/adbc/driver/flightsql/flightsql_statement.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
811811
data, err := proto.Marshal(partition)
812812
if err != nil {
813813
return sc, out, -1, adbc.Error{
814-
Msg: err.Error(),
814+
Msg: fmt.Sprintf("[flightsql] could not marshal partition as FlightInfo: %v", err),
815815
Code: adbc.StatusInternal,
816816
}
817817
}

java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.arrow.vector.VectorSchemaRoot;
2525
import org.apache.arrow.vector.ipc.ArrowReader;
2626
import org.apache.arrow.vector.types.pojo.Schema;
27+
import org.checkerframework.checker.nullness.qual.Nullable;
2728

2829
/**
2930
* A container for all state needed to execute a database query, such as the query itself,
@@ -231,19 +232,21 @@ public String toString() {
231232

232233
/** The partitions of a result set. */
233234
class PartitionResult {
234-
private final Schema schema;
235+
private final @Nullable Schema schema;
235236
private final long affectedRows;
236237
private final List<PartitionDescriptor> partitionDescriptors;
237238

238239
public PartitionResult(
239-
Schema schema, long affectedRows, List<PartitionDescriptor> partitionDescriptors) {
240+
@Nullable Schema schema,
241+
long affectedRows,
242+
List<PartitionDescriptor> partitionDescriptors) {
240243
this.schema = schema;
241244
this.affectedRows = affectedRows;
242245
this.partitionDescriptors = partitionDescriptors;
243246
}
244247

245248
/** Get the schema of the eventual result set. */
246-
public Schema getSchema() {
249+
public @Nullable Schema getSchema() {
247250
return schema;
248251
}
249252

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.arrow.adbc.driver.jni;
19+
20+
import static org.assertj.core.api.Assertions.assertThat;
21+
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
import org.apache.arrow.adbc.core.AdbcConnection;
25+
import org.apache.arrow.adbc.core.AdbcDatabase;
26+
import org.apache.arrow.adbc.core.AdbcDriver;
27+
import org.apache.arrow.memory.BufferAllocator;
28+
import org.apache.arrow.memory.RootAllocator;
29+
import org.junit.jupiter.api.AfterEach;
30+
import org.junit.jupiter.api.Assumptions;
31+
import org.junit.jupiter.api.BeforeAll;
32+
import org.junit.jupiter.api.BeforeEach;
33+
import org.junit.jupiter.api.Test;
34+
35+
public class FlightSqlIntegrationTest {
36+
public static final String URI_ENV = "ADBC_SQLITE_FLIGHTSQL_URI";
37+
static String URI = System.getenv(URI_ENV);
38+
39+
BufferAllocator allocator;
40+
JniDriver driver;
41+
AdbcDatabase db;
42+
AdbcConnection conn;
43+
44+
@BeforeAll
45+
static void beforeAll() {
46+
Assumptions.assumeFalse(
47+
URI == null || URI.isEmpty(),
48+
String.format("Must set %s to run Flight SQL integration tests", URI_ENV));
49+
}
50+
51+
@BeforeEach
52+
void beforeEach() throws Exception {
53+
allocator = new RootAllocator();
54+
driver = new JniDriver(allocator);
55+
Map<String, Object> parameters = new HashMap<>();
56+
JniDriver.PARAM_DRIVER.set(parameters, "adbc_driver_flightsql");
57+
AdbcDriver.PARAM_URI.set(parameters, URI);
58+
db = driver.open(parameters);
59+
conn = db.connect();
60+
}
61+
62+
@AfterEach
63+
void afterEach() throws Exception {
64+
conn.close();
65+
db.close();
66+
allocator.close();
67+
}
68+
69+
@Test
70+
void simple() throws Exception {
71+
try (var stmt = conn.createStatement()) {
72+
stmt.setSqlQuery("SELECT 1 + 1 AS sum");
73+
try (var reader = stmt.executeQuery()) {
74+
assertThat(reader.getReader().loadNextBatch()).isTrue();
75+
assertThat(reader.getReader().getVectorSchemaRoot().getVector("sum").getObject(0))
76+
.isEqualTo(2L);
77+
}
78+
}
79+
}
80+
81+
@Test
82+
void partitioned() throws Exception {
83+
try (var stmt = conn.createStatement()) {
84+
stmt.setSqlQuery("SELECT 1 + 1 AS sum");
85+
var partitions = stmt.executePartitioned();
86+
assertThat(partitions.getPartitionDescriptors().size()).isEqualTo(1);
87+
assertThat(partitions.getAffectedRows()).isEqualTo(-1);
88+
// The test server doesn't give a schema.
89+
assertThat(partitions.getSchema()).isNull();
90+
91+
try (var reader =
92+
conn.readPartition(partitions.getPartitionDescriptors().get(0).getDescriptor())) {
93+
assertThat(reader.loadNextBatch()).isTrue();
94+
assertThat(reader.getVectorSchemaRoot().getVector("sum").getObject(0)).isEqualTo(2L);
95+
}
96+
}
97+
}
98+
}

java/driver/jni/src/main/cpp/jni_wrapper.cc

Lines changed: 168 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@
3434

3535
namespace {
3636

37+
void ThrowJavaException(JNIEnv* env, const std::string& klass,
38+
const std::string& message) {
39+
jclass exception_klass = env->FindClass(klass.c_str());
40+
assert(exception_klass != nullptr);
41+
jmethodID exception_ctor =
42+
env->GetMethodID(exception_klass, "<init>", "(Ljava/lang/String)V");
43+
assert(exception_ctor != nullptr);
44+
jstring message_jni = env->NewStringUTF(message.c_str());
45+
auto exc = static_cast<jthrowable>(
46+
env->NewObject(exception_klass, exception_ctor, message_jni));
47+
env->Throw(exc);
48+
}
49+
3750
/// Internal exception. Meant to be used with RaiseAdbcException and
3851
/// CHECK_ADBC_ERROR.
3952
struct AdbcException {
@@ -112,19 +125,26 @@ void RaiseAdbcException(AdbcStatusCode code, const AdbcError& error) {
112125
} while (0)
113126

114127
/// Require that a Java class exists or error.
115-
jclass RequireImplClass(JNIEnv* env, std::string_view name) {
116-
static std::string kPrefix = "org/apache/arrow/adbc/driver/jni/impl/";
117-
std::string full_name = kPrefix + std::string(name);
118-
jclass klass = env->FindClass(full_name.c_str());
128+
jclass RequireClass(JNIEnv* env, const std::string& name) {
129+
jclass klass = env->FindClass(name.c_str());
119130
if (klass == nullptr) {
131+
std::string message = "[JNI] Could not find class ";
132+
message += name;
120133
throw AdbcException{
121134
.code = ADBC_STATUS_INTERNAL,
122-
.message = "[JNI] Could not find class " + full_name,
135+
.message = std::move(message),
123136
};
124137
}
125138
return klass;
126139
}
127140

141+
/// Require that a Java class exists or error.
142+
jclass RequireImplClass(JNIEnv* env, std::string_view name) {
143+
static std::string kPrefix = "org/apache/arrow/adbc/driver/jni/impl/";
144+
std::string full_name = kPrefix + std::string(name);
145+
return RequireClass(env, full_name);
146+
}
147+
128148
/// Require that a Java method exists or error.
129149
jmethodID RequireMethod(JNIEnv* env, jclass klass, std::string_view name,
130150
std::string_view signature) {
@@ -381,6 +401,60 @@ Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementGetParameterSchem
381401
return nullptr;
382402
}
383403

404+
JNIEXPORT jobject JNICALL
405+
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecutePartitions(
406+
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle) {
407+
struct AdbcError error = ADBC_ERROR_INIT;
408+
auto* ptr = reinterpret_cast<struct AdbcStatement*>(static_cast<uintptr_t>(handle));
409+
struct ArrowSchema schema = {};
410+
struct AdbcPartitions partitions = {};
411+
int64_t rows_affected = 0;
412+
jobject result = nullptr;
413+
414+
try {
415+
jclass native_result_class = RequireImplClass(env, "NativePartitionResult");
416+
jmethodID native_result_ctor =
417+
RequireMethod(env, native_result_class, "<init>", "(JJ)V");
418+
jmethodID native_result_add_partition =
419+
RequireMethod(env, native_result_class, "addPartition", "([B)V");
420+
421+
CHECK_ADBC_ERROR(
422+
AdbcStatementExecutePartitions(ptr, &schema, &partitions, &rows_affected, &error),
423+
error);
424+
425+
result = env->NewObject(native_result_class, native_result_ctor, rows_affected,
426+
static_cast<jlong>(reinterpret_cast<uintptr_t>(&schema)));
427+
if (env->ExceptionCheck()) goto cleanupall;
428+
429+
for (size_t i = 0; i < partitions.num_partitions; i++) {
430+
size_t length = partitions.partition_lengths[i];
431+
jbyteArray partition = env->NewByteArray(static_cast<jsize>(length));
432+
env->SetByteArrayRegion(partition, 0, static_cast<jsize>(length),
433+
reinterpret_cast<const jbyte*>(partitions.partitions[i]));
434+
if (env->ExceptionCheck()) goto cleanupall;
435+
env->CallObjectMethod(result, native_result_add_partition, partition);
436+
if (env->ExceptionCheck()) goto cleanupall;
437+
}
438+
} catch (const AdbcException& e) {
439+
e.ThrowJavaException(env);
440+
}
441+
442+
// We can't release schema, but we copied out the partitions
443+
if (partitions.release != nullptr) {
444+
partitions.release(&partitions);
445+
}
446+
return result;
447+
448+
cleanupall:
449+
if (schema.release != nullptr) {
450+
schema.release(&schema);
451+
}
452+
if (partitions.release != nullptr) {
453+
partitions.release(&partitions);
454+
}
455+
return nullptr;
456+
}
457+
384458
JNIEXPORT jobject JNICALL
385459
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecuteQuery(
386460
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle) {
@@ -979,6 +1053,95 @@ Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionRollback(
9791053
}
9801054
}
9811055

1056+
JNIEXPORT jobject JNICALL
1057+
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionReadPartition(
1058+
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jobject partition) {
1059+
struct AdbcError error = ADBC_ERROR_INIT;
1060+
auto* conn = reinterpret_cast<struct AdbcConnection*>(static_cast<uintptr_t>(handle));
1061+
struct ArrowArrayStream out = {};
1062+
size_t serialized_length = 0;
1063+
const uint8_t* serialized_partition = nullptr;
1064+
std::vector<uint8_t> allocated_partition;
1065+
1066+
try {
1067+
jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
1068+
jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
1069+
1070+
if (!env->IsInstanceOf(partition, bb_class)) {
1071+
ThrowJavaException(env, "java/lang/IllegalArgumentException",
1072+
"Partition must be a ByteBuffer");
1073+
return nullptr;
1074+
}
1075+
jint remaining = env->CallIntMethod(partition, bb_remaining);
1076+
if (remaining < 0) {
1077+
ThrowJavaException(env, "java/lang/IllegalArgumentException",
1078+
"ByteBuffer remaining() must be non-negative");
1079+
return nullptr;
1080+
}
1081+
serialized_length = static_cast<size_t>(remaining);
1082+
1083+
// fast path (if direct buffer)
1084+
void* buf = env->GetDirectBufferAddress(partition);
1085+
if (buf) {
1086+
serialized_partition = static_cast<const uint8_t*>(buf);
1087+
}
1088+
1089+
// middle path (backing array)
1090+
if (!serialized_partition) {
1091+
jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
1092+
jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
1093+
jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", "()I");
1094+
jboolean has_array = env->CallBooleanMethod(partition, bb_has_array);
1095+
if (env->ExceptionCheck()) return nullptr;
1096+
if (has_array) {
1097+
jint array_offset = env->CallIntMethod(partition, bb_array_offset);
1098+
if (env->ExceptionCheck()) return nullptr;
1099+
1100+
auto array =
1101+
reinterpret_cast<jbyteArray>(env->CallObjectMethod(partition, bb_array));
1102+
if (env->ExceptionCheck()) return nullptr;
1103+
1104+
assert(serialized_length <= static_cast<size_t>(env->GetArrayLength(array)));
1105+
allocated_partition.resize(serialized_length);
1106+
env->GetByteArrayRegion(array, array_offset,
1107+
static_cast<jsize>(serialized_length),
1108+
reinterpret_cast<jbyte*>(allocated_partition.data()));
1109+
serialized_partition = allocated_partition.data();
1110+
}
1111+
}
1112+
1113+
// slow path (copy)
1114+
if (!serialized_partition) {
1115+
jmethodID bb_get = RequireMethod(env, bb_class, "get", "([B)Ljava/nio/ByteBuffer;");
1116+
jbyteArray temp = env->NewByteArray(static_cast<jsize>(serialized_length));
1117+
if (!temp) {
1118+
ThrowJavaException(env, "java/lang/OutOfMemoryError",
1119+
"Failed to allocate byte array for partition");
1120+
return nullptr;
1121+
}
1122+
1123+
env->CallVoidMethod(partition, bb_get, temp);
1124+
if (env->ExceptionCheck()) return nullptr;
1125+
1126+
allocated_partition.resize(serialized_length);
1127+
env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
1128+
reinterpret_cast<jbyte*>(allocated_partition.data()));
1129+
serialized_partition = allocated_partition.data();
1130+
}
1131+
1132+
assert(serialized_partition != nullptr);
1133+
1134+
CHECK_ADBC_ERROR(AdbcConnectionReadPartition(conn, serialized_partition,
1135+
serialized_length, &out, &error),
1136+
error);
1137+
1138+
return MakeNativeQueryResult(env, -1, &out);
1139+
} catch (const AdbcException& e) {
1140+
e.ThrowJavaException(env);
1141+
}
1142+
return nullptr;
1143+
}
1144+
9821145
JNIEXPORT jbyteArray JNICALL
9831146
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_databaseGetOptionBytes(
9841147
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jstring key) {

java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniConnection.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.arrow.adbc.driver.jni;
1919

20+
import java.nio.ByteBuffer;
2021
import org.apache.arrow.adbc.core.AdbcConnection;
2122
import org.apache.arrow.adbc.core.AdbcException;
2223
import org.apache.arrow.adbc.core.AdbcStatement;
@@ -254,6 +255,11 @@ public void setCurrentDbSchema(String dbSchema) throws AdbcException {
254255
setOption(JniDriver.CURRENT_DB_SCHEMA, dbSchema);
255256
}
256257

258+
@Override
259+
public ArrowReader readPartition(ByteBuffer descriptor) throws AdbcException {
260+
return JniLoader.INSTANCE.connectionReadPartition(handle, descriptor).importStream(allocator);
261+
}
262+
257263
@Override
258264
public void close() {
259265
handle.close();

0 commit comments

Comments
 (0)