@@ -227,18 +227,27 @@ public:
227227 // --- Shuffled path
228228 // Every cluster contributes a prefix to training and a suffix to validation.
229229 // Cost: Each cluster is read twice per epoch, only when validation split is more than 0.
230- // TODO(staider) Swicth between prefix or suffix for validation randomly per cluster
230+ // We generate a random boolean value to decide whether the training set gets the prefix
231+ // or suffix of each cluster to ensure better shuffling across runs when splitting.
232+ std::mt19937 g (fSetSeed );
233+ std::uniform_int_distribution<int > coin (0 , 1 );
234+
231235 for (const RClusterRange &c : fAllClusters ) {
232236 const std::size_t sz = c.GetNumEntries ();
233237 const std::size_t trainSz = static_cast <std::size_t >((1 .0f - fValidationSplit ) * sz);
234238 const std::size_t valSz = sz - trainSz;
235239
240+ // Randomly assign prefix or suffix to training
241+ const uint64_t trainIsPrefix = coin (g);
242+ const uint64_t trainStart = trainIsPrefix ? c.start : c.start + static_cast <std::uint64_t >(valSz);
243+ const uint64_t valStart = trainIsPrefix ? c.start + static_cast <std::uint64_t >(trainSz) : c.start ;
244+
236245 if (trainSz > 0 ) {
237- fTrainingClusters .push_back ({c.rdfIdx , c. start , c. start + static_cast <std::uint64_t >(trainSz)});
246+ fTrainingClusters .push_back ({c.rdfIdx , trainStart, trainStart + static_cast <std::uint64_t >(trainSz)});
238247 fNumTrainingEntries += trainSz;
239248 }
240249 if (valSz > 0 ) {
241- fValidationClusters .push_back ({c.rdfIdx , c. start + static_cast <std::uint64_t >(trainSz), c. end });
250+ fValidationClusters .push_back ({c.rdfIdx , valStart, valStart + static_cast <std::uint64_t >(valSz) });
242251 fNumValidationEntries += valSz;
243252 }
244253 }
@@ -392,14 +401,26 @@ public:
392401 std::min (static_cast <std::size_t >(totalFiltered * (1 .0f - fValidationSplit )), trainRemaining);
393402 const std::size_t valCount = totalFiltered - trainCount;
394403
404+ // We generate a random boolean value to decide whether the training set gets the prefix
405+ // or suffix of each cluster to ensure better shuffling across runs when splitting.
406+ std::mt19937 g (fSetSeed + fAccumulatedFilteredForTrain ); // vary per cluster
407+ std::uniform_int_distribution<int > coin (0 , 1 );
408+ const uint64_t trainIsPrefix = coin (g);
409+
395410 // The boundary is the raw entry index of the first entry assigned to validation.
396411 // Stable across epochs since the same filter always produces the same ordered entries.
397- const std::uint64_t boundary = (valCount > 0 ) ? rdfEntries[trainCount] : endRow;
412+ const std::uint64_t trainBoundaryEntry = trainIsPrefix ? rdfEntries[trainCount] : rdfEntries[valCount];
413+ const std::uint64_t boundary = (valCount > 0 ) ? trainBoundaryEntry : endRow;
414+
415+ const std::uint64_t trainStart = trainIsPrefix ? startRow : boundary;
416+ const std::uint64_t trainEnd = trainIsPrefix ? boundary : endRow;
417+ const std::uint64_t valStart = trainIsPrefix ? boundary : startRow;
418+ const std::uint64_t valEnd = trainIsPrefix ? endRow : boundary;
398419
399420 if (trainCount > 0 )
400- fTrainingClusters .push_back ({rdfIdx, startRow, boundary , trainCount});
421+ fTrainingClusters .push_back ({rdfIdx, trainStart, trainEnd , trainCount});
401422 if (valCount > 0 )
402- fValidationClusters .push_back ({rdfIdx, boundary, endRow , valCount});
423+ fValidationClusters .push_back ({rdfIdx, valStart, valEnd , valCount});
403424
404425 fAccumulatedFilteredForTrain += trainCount;
405426 return trainCount;
0 commit comments