Skip to content

Commit 2c18f68

Browse files
committed
[ML] Randomize train/test cluster boundary assignment in RDataLoader
1 parent df1474a commit 2c18f68

1 file changed

Lines changed: 27 additions & 6 deletions

File tree

tree/ml/inc/ROOT/ML/RClusterLoader.hxx

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,18 +227,27 @@ public:
227227
// --- Shuffled path
228228
// Every cluster contributes a prefix to training and a suffix to validation.
229229
// Cost: Each cluster is read twice per epoch, only when validation split is more than 0.
230-
// TODO(staider) Swicth between prefix or suffix for validation randomly per cluster
230+
// We generate a random boolean value to decide whether the training set gets the prefix
231+
// or suffix of each cluster to ensure better shuffling across runs when splitting.
232+
std::mt19937 g(fSetSeed);
233+
std::uniform_int_distribution<int> coin(0, 1);
234+
231235
for (const RClusterRange &c : fAllClusters) {
232236
const std::size_t sz = c.GetNumEntries();
233237
const std::size_t trainSz = static_cast<std::size_t>((1.0f - fValidationSplit) * sz);
234238
const std::size_t valSz = sz - trainSz;
235239

240+
// Randomly assign prefix or suffix to training
241+
const uint64_t trainIsPrefix = coin(g);
242+
const uint64_t trainStart = trainIsPrefix ? c.start : c.start + static_cast<std::uint64_t>(valSz);
243+
const uint64_t valStart = trainIsPrefix ? c.start + static_cast<std::uint64_t>(trainSz) : c.start;
244+
236245
if (trainSz > 0) {
237-
fTrainingClusters.push_back({c.rdfIdx, c.start, c.start + static_cast<std::uint64_t>(trainSz)});
246+
fTrainingClusters.push_back({c.rdfIdx, trainStart, trainStart + static_cast<std::uint64_t>(trainSz)});
238247
fNumTrainingEntries += trainSz;
239248
}
240249
if (valSz > 0) {
241-
fValidationClusters.push_back({c.rdfIdx, c.start + static_cast<std::uint64_t>(trainSz), c.end});
250+
fValidationClusters.push_back({c.rdfIdx, valStart, valStart + static_cast<std::uint64_t>(valSz)});
242251
fNumValidationEntries += valSz;
243252
}
244253
}
@@ -392,14 +401,26 @@ public:
392401
std::min(static_cast<std::size_t>(totalFiltered * (1.0f - fValidationSplit)), trainRemaining);
393402
const std::size_t valCount = totalFiltered - trainCount;
394403

404+
// We generate a random boolean value to decide whether the training set gets the prefix
405+
// or suffix of each cluster to ensure better shuffling across runs when splitting.
406+
std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster
407+
std::uniform_int_distribution<int> coin(0, 1);
408+
const uint64_t trainIsPrefix = coin(g);
409+
395410
// The boundary is the raw entry index of the first entry assigned to validation.
396411
// Stable across epochs since the same filter always produces the same ordered entries.
397-
const std::uint64_t boundary = (valCount > 0) ? rdfEntries[trainCount] : endRow;
412+
const std::uint64_t trainBoundaryEntry = trainIsPrefix ? rdfEntries[trainCount] : rdfEntries[valCount];
413+
const std::uint64_t boundary = (valCount > 0) ? trainBoundaryEntry : endRow;
414+
415+
const std::uint64_t trainStart = trainIsPrefix ? startRow : boundary;
416+
const std::uint64_t trainEnd = trainIsPrefix ? boundary : endRow;
417+
const std::uint64_t valStart = trainIsPrefix ? boundary : startRow;
418+
const std::uint64_t valEnd = trainIsPrefix ? endRow : boundary;
398419

399420
if (trainCount > 0)
400-
fTrainingClusters.push_back({rdfIdx, startRow, boundary, trainCount});
421+
fTrainingClusters.push_back({rdfIdx, trainStart, trainEnd, trainCount});
401422
if (valCount > 0)
402-
fValidationClusters.push_back({rdfIdx, boundary, endRow, valCount});
423+
fValidationClusters.push_back({rdfIdx, valStart, valEnd, valCount});
403424

404425
fAccumulatedFilteredForTrain += trainCount;
405426
return trainCount;

0 commit comments

Comments
 (0)