Skip to content

Commit 871b71e

Browse files
committed
fix(rng): use fixed seed for deterministic LoRA init
Replace std::random_device with 42 + omp_get_thread_num() to ensure reproducible LoRA initialization across runs.
1 parent 0e79fc9 commit 871b71e

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

infini_train/src/nn/init.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,20 @@
2222

2323
namespace infini_train::nn::init {
2424
namespace {
25-
static std::random_device rd;
26-
static std::mt19937 gen(rd());
25+
constexpr int kRandomSeed = 42;
26+
27+
// FIXME: RNG design is incomplete.
28+
//
29+
// Current implementation lacks:
30+
// - unified Generator abstraction
31+
// - global default generator and seed control
32+
// - reproducible / clonable RNG state
33+
//
34+
// TODO:
35+
// - introduce Generator interface and backend impl
36+
// - add default generator management (per device)
37+
// - refactor random ops to consume Generator
38+
static std::mt19937 gen(kRandomSeed);
2739
} // namespace
2840

2941
std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean, float std,
@@ -34,7 +46,7 @@ std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean
3446
#ifdef USE_OMP
3547
#pragma omp parallel
3648
{
37-
std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num());
49+
std::mt19937 local_gen(kRandomSeed + omp_get_thread_num());
3850
std::normal_distribution<float> local_dis(mean, std);
3951
#pragma omp for
4052
for (int i = 0; i < buffer.size(); ++i) {
@@ -126,7 +138,7 @@ std::shared_ptr<Tensor> Uniform(const std::shared_ptr<Tensor> &tensor, float a,
126138
#ifdef USE_OMP
127139
#pragma omp parallel
128140
{
129-
std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num());
141+
std::mt19937 local_gen(kRandomSeed + omp_get_thread_num());
130142
std::uniform_real_distribution<float> local_dis(a, b);
131143
#pragma omp for
132144
for (int i = 0; i < buffer.size(); ++i) {

0 commit comments

Comments
 (0)