@@ -50,7 +50,7 @@ template <typename... Args>
5050class RBatchGenerator {
5151private:
5252 std::vector<std::string> fCols ;
53- std::vector<std::size_t > fVecSizes ;
53+ std::vector<std::size_t > fVecSizes ;
5454 // clang-format on
5555 std::size_t fChunkSize ;
5656 std::size_t fMaxChunks ;
@@ -60,17 +60,17 @@ private:
6060
6161 float fValidationSplit ;
6262
63- std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader ;
63+ std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader ;
6464 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader ;
6565 std::unique_ptr<RBatchLoader> fTrainingBatchLoader ;
6666 std::unique_ptr<RBatchLoader> fValidationBatchLoader ;
6767 std::unique_ptr<RSampler> fTrainingSampler ;
6868 std::unique_ptr<RSampler> fValidationSampler ;
6969
7070 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators ;
71-
71+
7272 std::vector<ROOT ::RDF ::RNode> f_rdfs;
73-
73+
7474 std::unique_ptr<std::thread> fLoadingThread ;
7575
7676 std::size_t fTrainingChunkNum ;
@@ -83,8 +83,8 @@ private:
8383 bool fLoadEager ;
8484 std::string fSampleType ;
8585 float fSampleRatio ;
86- bool fReplacement ;
87-
86+ bool fReplacement ;
87+
8888 bool fIsActive {false }; // Whether the loading thread is active
8989 bool fUseWholeFile ;
9090
@@ -104,10 +104,10 @@ private:
104104
105105 RFlat2DMatrix fTrainingDataset ;
106106 RFlat2DMatrix fValidationDataset ;
107-
107+
108108 RFlat2DMatrix fSampledTrainingDataset ;
109109 RFlat2DMatrix fSampledValidationDataset ;
110-
110+
111111 RFlat2DMatrix fTrainTensor ;
112112 RFlat2DMatrix fTrainChunkTensor ;
113113
@@ -124,7 +124,7 @@ public:
124124
125125 : f_rdfs(rdfs),
126126 fCols (cols),
127- fVecSizes(vecSizes),
127+ fVecSizes(vecSizes),
128128 fChunkSize(chunkSize),
129129 fBlockSize(blockSize),
130130 fBatchSize(batchSize),
@@ -140,57 +140,56 @@ public:
140140 fUseWholeFile(maxChunks == 0 )
141141 {
142142 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle , fSetSeed );
143-
143+
144144 if (fLoadEager ) {
145145 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(f_rdfs, fValidationSplit , fCols , fVecSizes ,
146- vecPadding, fShuffle , fSetSeed );
146+ vecPadding, fShuffle , fSetSeed );
147147 // split the datasets and extract the training and validation datasets
148148 fDatasetLoader ->SplitDatasets ();
149149
150150 if (fSampleType == " " ) {
151151 fDatasetLoader ->ConcatenateDatasets ();
152-
152+
153153 fTrainingDataset = fDatasetLoader ->GetTrainingDataset ();
154- fValidationDataset = fDatasetLoader ->GetValidationDataset ();
155-
154+ fValidationDataset = fDatasetLoader ->GetValidationDataset ();
155+
156156 fNumTrainingEntries = fDatasetLoader ->GetNumTrainingEntries ();
157157 fNumValidationEntries = fDatasetLoader ->GetNumValidationEntries ();
158158 }
159159
160160 else {
161161 fTrainingDatasets = fDatasetLoader ->GetTrainingDatasets ();
162- fValidationDatasets = fDatasetLoader ->GetValidationDatasets ();
163-
162+ fValidationDatasets = fDatasetLoader ->GetValidationDatasets ();
163+
164164 fTrainingSampler = std::make_unique<RSampler>(fTrainingDatasets , fSampleType , fSampleRatio , fReplacement ,
165165 fShuffle , fSetSeed );
166- fValidationSampler = std::make_unique<RSampler>(fValidationDatasets , fSampleType , fSampleRatio , fReplacement ,
167- fShuffle , fSetSeed );
166+ fValidationSampler = std::make_unique<RSampler>(fValidationDatasets , fSampleType , fSampleRatio ,
167+ fReplacement , fShuffle , fSetSeed );
168168
169- fNumTrainingEntries = fTrainingSampler ->GetNumEntries ();
169+ fNumTrainingEntries = fTrainingSampler ->GetNumEntries ();
170170 fNumValidationEntries = fValidationSampler ->GetNumEntries ();
171171 }
172172 }
173173
174174 else {
175- fChunkLoader =
176- std::make_unique<RChunkLoader<Args...>>(f_rdfs[0 ], fChunkSize , fBlockSize , fValidationSplit ,
177- fCols , fVecSizes , vecPadding, fShuffle , fSetSeed );
175+ fChunkLoader = std::make_unique<RChunkLoader<Args...>>(f_rdfs[0 ], fChunkSize , fBlockSize , fValidationSplit ,
176+ fCols , fVecSizes , vecPadding, fShuffle , fSetSeed );
178177
179178 // split the dataset into training and validation sets
180179 fChunkLoader ->SplitDataset ();
181180
182181 fNumTrainingEntries = fChunkLoader ->GetNumTrainingEntries ();
183- fNumValidationEntries = fChunkLoader ->GetNumValidationEntries ();
182+ fNumValidationEntries = fChunkLoader ->GetNumValidationEntries ();
184183
185184 // number of training and validation chunks, calculated in RChunkConstructor
186185 fNumTrainingChunks = fChunkLoader ->GetNumTrainingChunks ();
187186 fNumValidationChunks = fChunkLoader ->GetNumValidationChunks ();
188187 }
189188
190- fTrainingBatchLoader = std::make_unique<RBatchLoader>( fBatchSize , fCols , fVecSizes ,
191- fNumTrainingEntries , fDropRemainder );
192- fValidationBatchLoader = std::make_unique<RBatchLoader>( fBatchSize , fCols , fVecSizes ,
193- fNumValidationEntries , fDropRemainder );
189+ fTrainingBatchLoader =
190+ std::make_unique<RBatchLoader>( fBatchSize , fCols , fVecSizes , fNumTrainingEntries , fDropRemainder );
191+ fValidationBatchLoader =
192+ std::make_unique<RBatchLoader>( fBatchSize , fCols , fVecSizes , fNumValidationEntries , fDropRemainder );
194193 }
195194
196195 ~RBatchGenerator () { DeActivate (); }
@@ -203,7 +202,7 @@ public:
203202 }
204203
205204 fTrainingBatchLoader ->DeActivate ();
206- fValidationBatchLoader ->DeActivate ();
205+ fValidationBatchLoader ->DeActivate ();
207206
208207 if (fLoadingThread ) {
209208 if (fLoadingThread ->joinable ()) {
@@ -225,7 +224,7 @@ public:
225224 }
226225
227226 fTrainingBatchLoader ->Activate ();
228- fValidationBatchLoader ->Activate ();
227+ fValidationBatchLoader ->Activate ();
229228 // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
230229 }
231230
@@ -241,10 +240,11 @@ public:
241240
242241 void DeActivateValidationEpoch () { fValidationEpochActive = false ; }
243242
244- // / \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RBatchLoader)
243+ // / \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see
244+ // / RBatchLoader)
245245 void CreateTrainBatches ()
246246 {
247- fTrainingEpochActive = true ;
247+ fTrainingEpochActive = true ;
248248 if (fLoadEager ) {
249249 if (fSampleType == " " ) {
250250 fTensorOperators ->ShuffleTensor (fSampledTrainingDataset , fTrainingDataset );
@@ -253,10 +253,10 @@ public:
253253 else {
254254 fTrainingSampler ->Sampler (fSampledTrainingDataset );
255255 }
256-
256+
257257 fTrainingBatchLoader ->CreateBatches (fSampledTrainingDataset , 1 );
258258 }
259-
259+
260260 else {
261261 fChunkLoader ->CreateTrainingChunksIntervals ();
262262 fTrainingChunkNum = 0 ;
@@ -266,10 +266,11 @@ public:
266266 }
267267 }
268268
269- // / \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches (see RBatchLoader)
269+ // / \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches
270+ // / (see RBatchLoader)
270271 void CreateValidationBatches ()
271272 {
272- fValidationEpochActive = true ;
273+ fValidationEpochActive = true ;
273274 if (fLoadEager ) {
274275 if (fSampleType == " " ) {
275276 fTensorOperators ->ShuffleTensor (fSampledValidationDataset , fValidationDataset );
@@ -278,7 +279,7 @@ public:
278279 else {
279280 fValidationSampler ->Sampler (fSampledValidationDataset );
280281 }
281-
282+
282283 fValidationBatchLoader ->CreateBatches (fSampledValidationDataset , 1 );
283284 }
284285
@@ -294,28 +295,28 @@ public:
294295 // / \brief Loads a training batch from the queue
295296 RFlat2DMatrix GetTrainBatch ()
296297 {
297- if (!fLoadEager ) {
298- auto batchQueue = fTrainingBatchLoader ->GetNumBatchQueue ();
299-
300- // load the next chunk if the queue is empty
301- if (batchQueue < 1 && fTrainingChunkNum < fNumTrainingChunks ) {
302- fChunkLoader ->LoadTrainingChunk (fTrainChunkTensor , fTrainingChunkNum );
303- std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum ;
304- fTrainingBatchLoader ->CreateBatches (fTrainChunkTensor , lastTrainingBatch);
305- fTrainingChunkNum ++;
306- }
307- }
308- // Get next batch if available
309- return fTrainingBatchLoader ->GetBatch ();
298+ if (!fLoadEager ) {
299+ auto batchQueue = fTrainingBatchLoader ->GetNumBatchQueue ();
300+
301+ // load the next chunk if the queue is empty
302+ if (batchQueue < 1 && fTrainingChunkNum < fNumTrainingChunks ) {
303+ fChunkLoader ->LoadTrainingChunk (fTrainChunkTensor , fTrainingChunkNum );
304+ std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum ;
305+ fTrainingBatchLoader ->CreateBatches (fTrainChunkTensor , lastTrainingBatch);
306+ fTrainingChunkNum ++;
307+ }
308+ }
309+ // Get next batch if available
310+ return fTrainingBatchLoader ->GetBatch ();
310311 }
311312
312313 // / \brief Loads a validation batch from the queue
313314 RFlat2DMatrix GetValidationBatch ()
314315 {
315- if (!fLoadEager ) {
316+ if (!fLoadEager ) {
316317 auto batchQueue = fValidationBatchLoader ->GetNumBatchQueue ();
317318
318- // load the next chunk if the queue is empty
319+ // load the next chunk if the queue is empty
319320 if (batchQueue < 1 && fValidationChunkNum < fNumValidationChunks ) {
320321 fChunkLoader ->LoadValidationChunk (fValidationChunkTensor , fValidationChunkNum );
321322 std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum ;
0 commit comments