File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -401,11 +401,14 @@ public:
401401 std::min (static_cast <std::size_t >(totalFiltered * (1 .0f - fValidationSplit )), trainRemaining);
402402 const std::size_t valCount = totalFiltered - trainCount;
403403
404- // We generate a random boolean value to decide whether the training set gets the prefix
405- // or suffix of each cluster to ensure better shuffling across runs when splitting.
406- std::mt19937 g (fSetSeed + fAccumulatedFilteredForTrain ); // vary per cluster
407- std::uniform_int_distribution<int > coin (0 , 1 );
408- const uint64_t trainIsPrefix = coin (g);
404+ uint64_t trainIsPrefix = true ;
405+ if (fShuffle ) {
406+ // If shuffling is enabled, we generate a random boolean value to decide whether the training set
407+ // gets the prefix or suffix of each cluster to ensure better shuffling across runs when splitting.
408+ std::mt19937 g (fSetSeed + fAccumulatedFilteredForTrain ); // vary per cluster
409+ std::uniform_int_distribution<int > coin (0 , 1 );
410+ trainIsPrefix = coin (g);
411+ }
409412
410413 // The boundary is the raw entry index of the first entry assigned to validation.
411414 // Stable across epochs since the same filter always produces the same ordered entries.
You can’t perform that action at this time.
0 commit comments