Skip to content

Commit aaf5b8f

Browse files
committed
fixup! [ML] Randomize train/test cluster boundary assignment in RDataLoader
1 parent 5f78a62 commit aaf5b8f

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

tree/ml/inc/ROOT/ML/RClusterLoader.hxx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,14 @@ public:
401401
std::min(static_cast<std::size_t>(totalFiltered * (1.0f - fValidationSplit)), trainRemaining);
402402
const std::size_t valCount = totalFiltered - trainCount;
403403

404-
// We generate a random boolean value to decide whether the training set gets the prefix
405-
// or suffix of each cluster to ensure better shuffling across runs when splitting.
406-
std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster
407-
std::uniform_int_distribution<int> coin(0, 1);
408-
const uint64_t trainIsPrefix = coin(g);
404+
uint64_t trainIsPrefix = true;
405+
if (fShuffle) {
406+
// If shuffling is enabled, we generate a random boolean value to decide whether the training set
407+
// gets the prefix or suffix of each cluster to ensure better shuffling across runs when splitting.
408+
std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster
409+
std::uniform_int_distribution<int> coin(0, 1);
410+
trainIsPrefix = coin(g);
411+
}
409412

410413
// The boundary is the raw entry index of the first entry assigned to validation.
411414
// Stable across epochs since the same filter always produces the same ordered entries.

0 commit comments

Comments
 (0)