Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {
def hashProbeDynamicFilterPushdownEnabled: Boolean =
getConf(HASH_PROBE_DYNAMIC_FILTER_PUSHDOWN_ENABLED)

def parallelExecutionEnabled: Boolean =
getConf(PARALLEL_EXECUTION_ENABLED)

def parallelExecutionThreadPoolSize: Option[Int] =
getConf(PARALLEL_EXECUTION_THREAD_POOL_SIZE)

def parallelExecutionMaxDrivers: Int =
getConf(PARALLEL_EXECUTION_MAX_DRIVERS)

def valueStreamDynamicFilterEnabled: Boolean =
getConf(VALUE_STREAM_DYNAMIC_FILTER_ENABLED)

Expand Down Expand Up @@ -474,6 +483,32 @@ object VeloxConfig extends ConfigRegistry {
.booleanConf
.createWithDefault(true)

val PARALLEL_EXECUTION_ENABLED =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.parallelExecution.enabled")
.doc(
"Whether to enable parallel execution of Velox task drivers for whole-stage execution. " +
"Default is false (serial execution).")
.booleanConf
.createWithDefault(false)

val PARALLEL_EXECUTION_THREAD_POOL_SIZE =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.parallelExecution.threadPoolSize")
.doc(
"Size of the thread pool used for parallel execution of Velox task drivers. " +
"If not set, defaults to 2 * spark.gluten.numTaskSlotsPerExecutor.")
.intConf
.checkValue(_ > 0, "must be a positive number")
.createOptional

val PARALLEL_EXECUTION_MAX_DRIVERS =
buildConf("spark.gluten.sql.columnar.backend.velox.parallelExecution.maxDrivers")
.doc(
"Maximum number of parallel Velox task drivers to use for whole-stage execution. " +
"Default is 4.")
.intConf
.checkValue(_ > 0, "must be a positive number")
.createWithDefault(4)

val VALUE_STREAM_DYNAMIC_FILTER_ENABLED =
buildConf("spark.gluten.sql.columnar.backend.velox.valueStream.dynamicFilter.enabled")
.doc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ object MetricsUtil extends Logging {
}
ju.updateJoinMetrics(operatorMetrics, metrics.getSingleMetrics, joinParams)
case u: UnionMetricsUpdater =>
// JoinRel outputs two suites of metrics respectively for hash build and hash probe.
// Union outputs two suites of metrics respectively.
// Therefore, fetch one more suite of metrics here.
operatorMetrics.add(metrics.getOperatorMetrics(curMetricsIdx))
curMetricsIdx -= 1
Expand Down
2 changes: 2 additions & 0 deletions cpp/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
memory/MemoryManager.cc
memory/ArrowMemoryPool.cc
memory/ColumnarBatch.cc
threads/ThreadInitializer.cc
threads/ThreadManager.cc
shuffle/Dictionary.cc
shuffle/FallbackRangePartitioner.cc
shuffle/HashPartitioner.cc
Expand Down
3 changes: 2 additions & 1 deletion cpp/core/compute/Runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ void Runtime::registerFactory(const std::string& kind, Runtime::Factory factory,
Runtime* Runtime::create(
const std::string& kind,
MemoryManager* memoryManager,
ThreadManager* threadManager,
const std::unordered_map<std::string, std::string>& sessionConf) {
auto& factory = runtimeFactories().get(kind);
return factory(kind, std::move(memoryManager), sessionConf);
return factory(kind, std::move(memoryManager), std::move(threadManager), sessionConf);
}

void Runtime::release(Runtime* runtime) {
Expand Down
11 changes: 10 additions & 1 deletion cpp/core/compute/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "shuffle/ShuffleReader.h"
#include "shuffle/ShuffleWriter.h"
#include "substrait/plan.pb.h"
#include "threads/ThreadManager.h"
#include "utils/ObjectStore.h"
#include "utils/WholeStageDumper.h"

Expand Down Expand Up @@ -61,12 +62,14 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
using Factory = std::function<Runtime*(
const std::string& kind,
MemoryManager* memoryManager,
ThreadManager* threadManager,
const std::unordered_map<std::string, std::string>& sessionConf)>;
using Releaser = std::function<void(Runtime*)>;
static void registerFactory(const std::string& kind, Factory factory, Releaser releaser);
static Runtime* create(
const std::string& kind,
MemoryManager* memoryManager,
ThreadManager* threadManager,
const std::unordered_map<std::string, std::string>& sessionConf = {});
static void release(Runtime*);
static std::optional<std::string>* localWriteFilesTempPath();
Expand All @@ -75,8 +78,9 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
Runtime(
const std::string& kind,
MemoryManager* memoryManager,
ThreadManager* threadManager,
const std::unordered_map<std::string, std::string>& confMap)
: kind_(kind), memoryManager_(memoryManager), confMap_(confMap) {}
: kind_(kind), memoryManager_(memoryManager), threadManager_(threadManager), confMap_(confMap) {}

virtual ~Runtime() = default;

Expand Down Expand Up @@ -126,6 +130,10 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
return memoryManager_;
};

virtual ThreadManager* threadManager() {
return threadManager_;
};

/// This function is used to create certain converter from the format used by
/// the backend to Spark unsafe row.
virtual std::shared_ptr<ColumnarToRowConverter> createColumnar2RowConverter(int64_t column2RowMemThreshold) {
Expand Down Expand Up @@ -184,6 +192,7 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
protected:
std::string kind_;
MemoryManager* memoryManager_;
ThreadManager* threadManager_;
std::unique_ptr<ObjectStore> objStore_ = ObjectStore::create();
std::unordered_map<std::string, std::string> confMap_; // Session conf map

Expand Down
2 changes: 1 addition & 1 deletion cpp/core/jni/JniCommon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "JniCommon.h"
#include <folly/system/ThreadName.h>

void gluten::JniCommonState::ensureInitialized(JNIEnv* env) {
std::lock_guard<std::mutex> lockGuard(mtx_);
Expand Down Expand Up @@ -120,7 +121,6 @@ std::shared_ptr<gluten::ColumnarBatch> gluten::JniColumnarBatchIterator::next()
std::shared_ptr<gluten::ColumnarBatch> gluten::JniColumnarBatchIterator::nextInternal() const {
JNIEnv* env = nullptr;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);

if (!env->CallBooleanMethod(jColumnarBatchItr_, serializedColumnarBatchIteratorHasNext_)) {
checkException(env);
return nullptr; // stream ended
Expand Down
71 changes: 71 additions & 0 deletions cpp/core/jni/JniCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "compute/Runtime.h"
#include "memory/AllocationListener.h"
#include "shuffle/rss/RssClient.h"
#include "threads/ThreadInitializer.h"
#include "utils/Compression.h"
#include "utils/Exception.h"
#include "utils/ResourceMap.h"
Expand Down Expand Up @@ -549,3 +550,73 @@ class JavaRssClient : public RssClient {
jmethodID javaPushPartitionData_;
jbyteArray array_;
};

class SparkThreadInitializer final : public gluten::ThreadInitializer {
public:
SparkThreadInitializer(JavaVM* vm, jobject jInitializerLocalRef) : vm_(vm) {
JNIEnv* env;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);
jInitializerGlobalRef_ = env->NewGlobalRef(jInitializerLocalRef);
GLUTEN_CHECK(jInitializerGlobalRef_ != nullptr, "Failed to create global reference for native thread initializer.");
(void)initializeMethod(env);
}

SparkThreadInitializer(const SparkThreadInitializer&) = delete;
SparkThreadInitializer(SparkThreadInitializer&&) = delete;
SparkThreadInitializer& operator=(const SparkThreadInitializer&) = delete;
SparkThreadInitializer& operator=(SparkThreadInitializer&&) = delete;

~SparkThreadInitializer() override {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
LOG(WARNING) << "SparkThreadInitializer#~SparkThreadInitializer(): "
<< "JNIEnv was not attached to current thread";
return;
}
env->DeleteGlobalRef(jInitializerGlobalRef_);
}

void initialize(const std::string& threadName) override {
JNIEnv* env;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);
jstring jThreadName = env->NewStringUTF(threadName.c_str());
env->CallVoidMethod(jInitializerGlobalRef_, initializeMethod(env), jThreadName);
env->DeleteLocalRef(jThreadName);
checkException(env);
}

void destroy(const std::string& threadName) override {
// IMPORTANT: Do not call vm_.DetachCurrentThread here, otherwise Java side thread
// object might be dereferenced and garbage-collected, to break the reuse of thread
// resources.
JNIEnv* env;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);
jstring jThreadName = env->NewStringUTF(threadName.c_str());
env->CallVoidMethod(jInitializerGlobalRef_, destroyMethod(env), jThreadName);
env->DeleteLocalRef(jThreadName);
checkException(env);
}

private:
jmethodID initializeMethod(JNIEnv* env) {
static jmethodID initializeMethod =
getMethodIdOrError(env, nativeThreadInitializerClass(env), "initialize", "(Ljava/lang/String;)V");
return initializeMethod;
}

jmethodID destroyMethod(JNIEnv* env) {
static jmethodID destroyMethod =
getMethodIdOrError(env, nativeThreadInitializerClass(env), "destroy", "(Ljava/lang/String;)V");
return destroyMethod;
}

jclass nativeThreadInitializerClass(JNIEnv* env) {
static jclass javaInitializerClass =
createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/threads/NativeThreadInitializer;");
return javaInitializerClass;
}

private:
JavaVM* vm_;
jobject jInitializerGlobalRef_;
};
59 changes: 56 additions & 3 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ class InternalRuntime : public Runtime {
InternalRuntime(
const std::string& kind,
MemoryManager* memoryManager,
ThreadManager* threadManager,
const std::unordered_map<std::string, std::string>& confMap)
: Runtime(kind, memoryManager, confMap) {}
: Runtime(kind, memoryManager, threadManager, confMap) {}
};

MemoryManager* internalMemoryManagerFactory(const std::string& kind, std::unique_ptr<AllocationListener> listener) {
Expand All @@ -193,11 +194,33 @@ void internalMemoryManagerReleaser(MemoryManager* memoryManager) {
delete memoryManager;
}

class InternalThreadManager : public ThreadManager {
public:
InternalThreadManager(const std::string& kind, std::unique_ptr<ThreadInitializer> initializer)
: ThreadManager(kind), initializer_(std::shared_ptr<ThreadInitializer>(std::move(initializer))) {}

ThreadInitializer* getThreadInitializer() override {
return initializer_.get();
}

private:
std::shared_ptr<ThreadInitializer> initializer_;
};

ThreadManager* internalThreadManagerFactory(const std::string& kind, std::unique_ptr<ThreadInitializer> initializer) {
return new InternalThreadManager(kind, std::move(initializer));
}

void internalThreadManagerReleaser(ThreadManager* threadManager) {
delete threadManager;
}

Runtime* internalRuntimeFactory(
const std::string& kind,
MemoryManager* memoryManager,
ThreadManager* threadManager,
const std::unordered_map<std::string, std::string>& sessionConf) {
return new InternalRuntime(kind, memoryManager, sessionConf);
return new InternalRuntime(kind, memoryManager, threadManager, sessionConf);
}

void internalRuntimeReleaser(Runtime* runtime) {
Expand Down Expand Up @@ -252,6 +275,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
getJniErrorState()->ensureInitialized(env);

MemoryManager::registerFactory(kInternalBackendKind, internalMemoryManagerFactory, internalMemoryManagerReleaser);
ThreadManager::registerFactory(kInternalBackendKind, internalThreadManagerFactory, internalThreadManagerReleaser);
Runtime::registerFactory(kInternalBackendKind, internalRuntimeFactory, internalRuntimeReleaser);

byteArrayClass = createGlobalClassReferenceOrError(env, "[B");
Expand Down Expand Up @@ -319,14 +343,16 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_runtime_RuntimeJniWrapper_createR
jclass,
jstring jBackendType,
jlong nmmHandle,
jlong ntmHandle,
jbyteArray sessionConf) {
JNI_METHOD_START
MemoryManager* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
ThreadManager* threadManager = jniCastOrThrow<ThreadManager>(ntmHandle);
auto safeArray = getByteArrayElementsSafe(env, sessionConf);
auto sparkConf = parseConfMap(env, safeArray.elems(), safeArray.length());
auto backendType = jStringToCString(env, jBackendType);

auto runtime = Runtime::create(backendType, memoryManager, sparkConf);
auto runtime = Runtime::create(backendType, memoryManager, threadManager, sparkConf);
return reinterpret_cast<jlong>(runtime);
JNI_METHOD_END(kInvalidObjectHandle)
}
Expand Down Expand Up @@ -370,6 +396,33 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrap
JNI_METHOD_END(-1L)
}

JNIEXPORT jlong JNICALL Java_org_apache_gluten_threads_NativeThreadManagerJniWrapper_create( // NOLINT
JNIEnv* env,
jclass,
jstring jBackendType,
jobject jInitializer) {
JNI_METHOD_START
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
throw GlutenException("Unable to get JavaVM instance");
}
auto backendType = jStringToCString(env, jBackendType);
std::unique_ptr<ThreadInitializer> initializer = std::make_unique<SparkThreadInitializer>(vm, jInitializer);
ThreadManager* tm = ThreadManager::create(backendType, std::move(initializer));
return reinterpret_cast<jlong>(tm);
JNI_METHOD_END(-1L)
}

JNIEXPORT void JNICALL Java_org_apache_gluten_threads_NativeThreadManagerJniWrapper_release( // NOLINT
JNIEnv* env,
jclass,
jlong ntmHandle) {
JNI_METHOD_START
auto* threadManager = jniCastOrThrow<ThreadManager>(ntmHandle);
ThreadManager::release(threadManager);
JNI_METHOD_END()
}

JNIEXPORT jbyteArray JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_collectUsage( // NOLINT
JNIEnv* env,
jclass,
Expand Down
35 changes: 35 additions & 0 deletions cpp/core/threads/ThreadInitializer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ThreadInitializer.h"

namespace gluten {
namespace {

class NoopThreadInitializer final : public ThreadInitializer {
public:
void initialize(const std::string& threadName) override {}
void destroy(const std::string& threadName) override{};
};

} // namespace

std::unique_ptr<ThreadInitializer> ThreadInitializer::noop() {
return std::make_unique<NoopThreadInitializer>();
}

} // namespace gluten
Loading
Loading