Skip to content

Commit 2a99bec

Browse files
committed
fix(tests): prevent GlobalEnv double-init crash in multi-suite binaries
Add IsInitialized() to GlobalEnv and guard SetUpTestSuite so a second test class in the same process skips re-initialization instead of hitting CHECK(!initialized_). Also print try_compile output on compile-fail test to surface header-not-found vs real type errors.
1 parent e77e3cb commit 2a99bec

4 files changed

Lines changed: 15 additions & 1 deletion

File tree

infini_train/include/nn/parallel/global.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class GlobalEnv {
3131
void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
3232
int pipeline_parallel_size, int virtual_pipeline_parallel_size);
3333

34+
bool IsInitialized() const;
35+
3436
int nnodes() const;
3537

3638
int nproc_per_node() const;

infini_train/src/nn/parallel/global.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ GlobalEnv &GlobalEnv::Instance() {
9191
return instance;
9292
}
9393

94+
bool GlobalEnv::IsInitialized() const {
95+
std::lock_guard<std::mutex> lock(mutex_);
96+
return initialized_;
97+
}
98+
9499
void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
95100
int pipeline_parallel_size, int virtual_pipeline_parallel_size) {
96101
std::lock_guard<std::mutex> lock(mutex_);

tests/common/test_utils.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ inline void FillConstantTensor(const std::shared_ptr<Tensor> &tensor, float valu
103103

104104
class InfiniTrainTest : public ::testing::TestWithParam<Device::DeviceType> {
105105
protected:
106-
static void SetUpTestSuite() { nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); }
106+
static void SetUpTestSuite() {
107+
auto &env = nn::parallel::global::GlobalEnv::Instance();
108+
if (!env.IsInitialized()) {
109+
env.Init(1, 1, false, 1, 1);
110+
}
111+
}
107112
Device GetDevice() const { return Device(GetParam(), 0); }
108113
std::shared_ptr<Tensor> createTensor(const std::vector<int64_t> &shape, DataType dtype = DataType::kFLOAT32,
109114
bool requires_grad = false) {

tests/dtype/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED)
2929
"dtype dispatch compile-fail test unexpectedly succeeded.\n"
3030
"Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n"
3131
"Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
32+
else()
33+
message(STATUS "compile-fail output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
3234
endif()
3335

3436
add_custom_target(test_dtype_dispatch_compile_fail

0 commit comments

Comments
 (0)