|
34 | 34 |
|
35 | 35 | namespace { |
36 | 36 |
|
| 37 | +void ThrowJavaException(JNIEnv* env, const std::string& klass, |
| 38 | + const std::string& message) { |
| 39 | + jclass exception_klass = env->FindClass(klass.c_str()); |
| 40 | + assert(exception_klass != nullptr); |
| 41 | + jmethodID exception_ctor = |
| 42 | + env->GetMethodID(exception_klass, "<init>", "(Ljava/lang/String)V"); |
| 43 | + assert(exception_ctor != nullptr); |
| 44 | + jstring message_jni = env->NewStringUTF(message.c_str()); |
| 45 | + auto exc = static_cast<jthrowable>( |
| 46 | + env->NewObject(exception_klass, exception_ctor, message_jni)); |
| 47 | + env->Throw(exc); |
| 48 | +} |
| 49 | + |
37 | 50 | /// Internal exception. Meant to be used with RaiseAdbcException and |
38 | 51 | /// CHECK_ADBC_ERROR. |
39 | 52 | struct AdbcException { |
@@ -112,19 +125,26 @@ void RaiseAdbcException(AdbcStatusCode code, const AdbcError& error) { |
112 | 125 | } while (0) |
113 | 126 |
|
114 | 127 | /// Require that a Java class exists or error. |
115 | | -jclass RequireImplClass(JNIEnv* env, std::string_view name) { |
116 | | - static std::string kPrefix = "org/apache/arrow/adbc/driver/jni/impl/"; |
117 | | - std::string full_name = kPrefix + std::string(name); |
118 | | - jclass klass = env->FindClass(full_name.c_str()); |
| 128 | +jclass RequireClass(JNIEnv* env, const std::string& name) { |
| 129 | + jclass klass = env->FindClass(name.c_str()); |
119 | 130 | if (klass == nullptr) { |
| 131 | + std::string message = "[JNI] Could not find class "; |
| 132 | + message += name; |
120 | 133 | throw AdbcException{ |
121 | 134 | .code = ADBC_STATUS_INTERNAL, |
122 | | - .message = "[JNI] Could not find class " + full_name, |
| 135 | + .message = std::move(message), |
123 | 136 | }; |
124 | 137 | } |
125 | 138 | return klass; |
126 | 139 | } |
127 | 140 |
|
| 141 | +/// Require that a Java class exists or error. |
| 142 | +jclass RequireImplClass(JNIEnv* env, std::string_view name) { |
| 143 | + static std::string kPrefix = "org/apache/arrow/adbc/driver/jni/impl/"; |
| 144 | + std::string full_name = kPrefix + std::string(name); |
| 145 | + return RequireClass(env, full_name); |
| 146 | +} |
| 147 | + |
128 | 148 | /// Require that a Java method exists or error. |
129 | 149 | jmethodID RequireMethod(JNIEnv* env, jclass klass, std::string_view name, |
130 | 150 | std::string_view signature) { |
@@ -381,6 +401,60 @@ Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementGetParameterSchem |
381 | 401 | return nullptr; |
382 | 402 | } |
383 | 403 |
|
| 404 | +JNIEXPORT jobject JNICALL |
| 405 | +Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecutePartitions( |
| 406 | + JNIEnv* env, [[maybe_unused]] jclass self, jlong handle) { |
| 407 | + struct AdbcError error = ADBC_ERROR_INIT; |
| 408 | + auto* ptr = reinterpret_cast<struct AdbcStatement*>(static_cast<uintptr_t>(handle)); |
| 409 | + struct ArrowSchema schema = {}; |
| 410 | + struct AdbcPartitions partitions = {}; |
| 411 | + int64_t rows_affected = 0; |
| 412 | + jobject result = nullptr; |
| 413 | + |
| 414 | + try { |
| 415 | + jclass native_result_class = RequireImplClass(env, "NativePartitionResult"); |
| 416 | + jmethodID native_result_ctor = |
| 417 | + RequireMethod(env, native_result_class, "<init>", "(JJ)V"); |
| 418 | + jmethodID native_result_add_partition = |
| 419 | + RequireMethod(env, native_result_class, "addPartition", "([B)V"); |
| 420 | + |
| 421 | + CHECK_ADBC_ERROR( |
| 422 | + AdbcStatementExecutePartitions(ptr, &schema, &partitions, &rows_affected, &error), |
| 423 | + error); |
| 424 | + |
| 425 | + result = env->NewObject(native_result_class, native_result_ctor, rows_affected, |
| 426 | + static_cast<jlong>(reinterpret_cast<uintptr_t>(&schema))); |
| 427 | + if (env->ExceptionCheck()) goto cleanupall; |
| 428 | + |
| 429 | + for (size_t i = 0; i < partitions.num_partitions; i++) { |
| 430 | + size_t length = partitions.partition_lengths[i]; |
| 431 | + jbyteArray partition = env->NewByteArray(static_cast<jsize>(length)); |
| 432 | + env->SetByteArrayRegion(partition, 0, static_cast<jsize>(length), |
| 433 | + reinterpret_cast<const jbyte*>(partitions.partitions[i])); |
| 434 | + if (env->ExceptionCheck()) goto cleanupall; |
| 435 | + env->CallObjectMethod(result, native_result_add_partition, partition); |
| 436 | + if (env->ExceptionCheck()) goto cleanupall; |
| 437 | + } |
| 438 | + } catch (const AdbcException& e) { |
| 439 | + e.ThrowJavaException(env); |
| 440 | + } |
| 441 | + |
| 442 | + // We can't release schema, but we copied out the partitions |
| 443 | + if (partitions.release != nullptr) { |
| 444 | + partitions.release(&partitions); |
| 445 | + } |
| 446 | + return result; |
| 447 | + |
| 448 | +cleanupall: |
| 449 | + if (schema.release != nullptr) { |
| 450 | + schema.release(&schema); |
| 451 | + } |
| 452 | + if (partitions.release != nullptr) { |
| 453 | + partitions.release(&partitions); |
| 454 | + } |
| 455 | + return nullptr; |
| 456 | +} |
| 457 | + |
384 | 458 | JNIEXPORT jobject JNICALL |
385 | 459 | Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecuteQuery( |
386 | 460 | JNIEnv* env, [[maybe_unused]] jclass self, jlong handle) { |
@@ -979,6 +1053,95 @@ Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionRollback( |
979 | 1053 | } |
980 | 1054 | } |
981 | 1055 |
|
| 1056 | +JNIEXPORT jobject JNICALL |
| 1057 | +Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionReadPartition( |
| 1058 | + JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jobject partition) { |
| 1059 | + struct AdbcError error = ADBC_ERROR_INIT; |
| 1060 | + auto* conn = reinterpret_cast<struct AdbcConnection*>(static_cast<uintptr_t>(handle)); |
| 1061 | + struct ArrowArrayStream out = {}; |
| 1062 | + size_t serialized_length = 0; |
| 1063 | + const uint8_t* serialized_partition = nullptr; |
| 1064 | + std::vector<uint8_t> allocated_partition; |
| 1065 | + |
| 1066 | + try { |
| 1067 | + jclass bb_class = RequireClass(env, "java/nio/ByteBuffer"); |
| 1068 | + jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I"); |
| 1069 | + |
| 1070 | + if (!env->IsInstanceOf(partition, bb_class)) { |
| 1071 | + ThrowJavaException(env, "java/lang/IllegalArgumentException", |
| 1072 | + "Partition must be a ByteBuffer"); |
| 1073 | + return nullptr; |
| 1074 | + } |
| 1075 | + jint remaining = env->CallIntMethod(partition, bb_remaining); |
| 1076 | + if (remaining < 0) { |
| 1077 | + ThrowJavaException(env, "java/lang/IllegalArgumentException", |
| 1078 | + "ByteBuffer remaining() must be non-negative"); |
| 1079 | + return nullptr; |
| 1080 | + } |
| 1081 | + serialized_length = static_cast<size_t>(remaining); |
| 1082 | + |
| 1083 | + // fast path (if direct buffer) |
| 1084 | + void* buf = env->GetDirectBufferAddress(partition); |
| 1085 | + if (buf) { |
| 1086 | + serialized_partition = static_cast<const uint8_t*>(buf); |
| 1087 | + } |
| 1088 | + |
| 1089 | + // middle path (backing array) |
| 1090 | + if (!serialized_partition) { |
| 1091 | + jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z"); |
| 1092 | + jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B"); |
| 1093 | + jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", "()I"); |
| 1094 | + jboolean has_array = env->CallBooleanMethod(partition, bb_has_array); |
| 1095 | + if (env->ExceptionCheck()) return nullptr; |
| 1096 | + if (has_array) { |
| 1097 | + jint array_offset = env->CallIntMethod(partition, bb_array_offset); |
| 1098 | + if (env->ExceptionCheck()) return nullptr; |
| 1099 | + |
| 1100 | + auto array = |
| 1101 | + reinterpret_cast<jbyteArray>(env->CallObjectMethod(partition, bb_array)); |
| 1102 | + if (env->ExceptionCheck()) return nullptr; |
| 1103 | + |
| 1104 | + assert(serialized_length <= static_cast<size_t>(env->GetArrayLength(array))); |
| 1105 | + allocated_partition.resize(serialized_length); |
| 1106 | + env->GetByteArrayRegion(array, array_offset, |
| 1107 | + static_cast<jsize>(serialized_length), |
| 1108 | + reinterpret_cast<jbyte*>(allocated_partition.data())); |
| 1109 | + serialized_partition = allocated_partition.data(); |
| 1110 | + } |
| 1111 | + } |
| 1112 | + |
| 1113 | + // slow path (copy) |
| 1114 | + if (!serialized_partition) { |
| 1115 | + jmethodID bb_get = RequireMethod(env, bb_class, "get", "([B)Ljava/nio/ByteBuffer;"); |
| 1116 | + jbyteArray temp = env->NewByteArray(static_cast<jsize>(serialized_length)); |
| 1117 | + if (!temp) { |
| 1118 | + ThrowJavaException(env, "java/lang/OutOfMemoryError", |
| 1119 | + "Failed to allocate byte array for partition"); |
| 1120 | + return nullptr; |
| 1121 | + } |
| 1122 | + |
| 1123 | + env->CallVoidMethod(partition, bb_get, temp); |
| 1124 | + if (env->ExceptionCheck()) return nullptr; |
| 1125 | + |
| 1126 | + allocated_partition.resize(serialized_length); |
| 1127 | + env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length), |
| 1128 | + reinterpret_cast<jbyte*>(allocated_partition.data())); |
| 1129 | + serialized_partition = allocated_partition.data(); |
| 1130 | + } |
| 1131 | + |
| 1132 | + assert(serialized_partition != nullptr); |
| 1133 | + |
| 1134 | + CHECK_ADBC_ERROR(AdbcConnectionReadPartition(conn, serialized_partition, |
| 1135 | + serialized_length, &out, &error), |
| 1136 | + error); |
| 1137 | + |
| 1138 | + return MakeNativeQueryResult(env, -1, &out); |
| 1139 | + } catch (const AdbcException& e) { |
| 1140 | + e.ThrowJavaException(env); |
| 1141 | + } |
| 1142 | + return nullptr; |
| 1143 | +} |
| 1144 | + |
982 | 1145 | JNIEXPORT jbyteArray JNICALL |
983 | 1146 | Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_databaseGetOptionBytes( |
984 | 1147 | JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jstring key) { |
|
0 commit comments