Skip to content

Commit 64730cb

Browse files
committed
Flat Gas external generator under test
1 parent 2d944cc commit 64730cb

2 files changed

Lines changed: 194 additions & 18 deletions

File tree

MC/config/common/external/generator/TPCLoopers.C

Lines changed: 190 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,22 @@ class GenTPCLoopers : public Generator
207207
mScaler_compton->load(scaler_compton);
208208
Generator::setTimeUnit(1.0);
209209
Generator::setPositionUnit(1.0);
210+
mContextFile = std::filesystem::exists("collisioncontext.root") ? TFile::Open("collisioncontext.root") : nullptr;
211+
mCollisionContext = mContextFile ? (o2::steer::DigitizationContext *)mContextFile->Get("DigitizationContext") : nullptr;
212+
mInteractionTimeRecords = mCollisionContext ? mCollisionContext->getEventRecords() : std::vector<o2::InteractionTimeRecord>{};
213+
if (mInteractionTimeRecords.empty())
214+
{
215+
LOG(warn) << "Error: No interaction time records found in the collision context!";
216+
exit(1);
217+
} else {
218+
LOG(info) << "Interaction Time records has " << mInteractionTimeRecords.size() << " entries.";
219+
mCollisionContext->printCollisionSummary();
220+
}
221+
for (int c = 0; c < mInteractionTimeRecords.size() - 1; c++)
222+
{
223+
mIntTimeRecMean += mInteractionTimeRecords[c + 1].bc2ns() - mInteractionTimeRecords[c].bc2ns();
224+
}
225+
mIntTimeRecMean /= (mInteractionTimeRecords.size() - 1); // Average interaction time record used for the reference
210226
}
211227

212228
Bool_t generateEvent() override
@@ -215,21 +231,61 @@ class GenTPCLoopers : public Generator
215231
mGenPairs.clear();
216232
// Clear the vector of compton electrons
217233
mGenElectrons.clear();
218-
// Set number of loopers if poissonian params are available
219-
if (mPoissonSet)
234+
if (mFlatGas)
220235
{
221-
mNLoopersPairs = static_cast<short int>(std::round(mMultiplier[0] * PoissonPairs()));
222-
}
223-
if (mGaussSet)
224-
{
225-
mNLoopersCompton = static_cast<short int>(std::round(mMultiplier[1] * GaussianElectrons()));
226-
}
236+
unsigned int nLoopers, nLoopersPairs, nLoopersCompton;
237+
LOG(info) << "mCurrentEvent is " << mCurrentEvent;
238+
LOG(info) << "Current event time: " << ((mCurrentEvent < mInteractionTimeRecords.size() - 1) ? std::to_string(mInteractionTimeRecords[mCurrentEvent + 1].bc2ns() - mInteractionTimeRecords[mCurrentEvent].bc2ns()) : std::to_string(mIntTimeRecMean)) << " ns";
239+
LOG(info) << "Current time offset wrt BC: " << mInteractionTimeRecords[mCurrentEvent].getTimeOffsetWrtBC() << " ns";
240+
mTimeLimit = (mCurrentEvent < mInteractionTimeRecords.size() - 1) ? mInteractionTimeRecords[mCurrentEvent + 1].bc2ns() - mInteractionTimeRecords[mCurrentEvent].bc2ns() : mIntTimeRecMean;
241+
// With Flat Gas number of loopers are adapted based on time interval widths
242+
nLoopers = mFlatGasNumber * (mTimeLimit / mIntTimeRecMean);
243+
nLoopersPairs = static_cast<unsigned int>(std::round(nLoopers * mLoopsFractionPairs));
244+
nLoopersCompton = nLoopers - nLoopersPairs;
245+
SetNLoopers(nLoopersPairs, nLoopersCompton);
246+
LOG(info) << "Flat gas loopers: " << nLoopers << " (pairs: " << nLoopersPairs << ", compton: " << nLoopersCompton << ")";
247+
generateEvent(mTimeLimit);
248+
mCurrentEvent++;
249+
} else {
250+
// Set number of loopers if poissonian params are available
251+
if (mPoissonSet)
252+
{
253+
mNLoopersPairs = static_cast<unsigned int>(std::round(mMultiplier[0] * PoissonPairs()));
254+
}
255+
if (mGaussSet)
256+
{
257+
mNLoopersCompton = static_cast<unsigned int>(std::round(mMultiplier[1] * GaussianElectrons()));
258+
}
259+
// Generate pairs
260+
for (int i = 0; i < mNLoopersPairs; ++i)
261+
{
262+
std::vector<double> pair = mONNX_pair->generate_sample();
263+
// Apply the inverse transformation using the scaler
264+
std::vector<double> transformed_pair = mScaler_pair->inverse_transform(pair);
265+
mGenPairs.push_back(transformed_pair);
266+
}
267+
// Generate compton electrons
268+
for (int i = 0; i < mNLoopersCompton; ++i)
269+
{
270+
std::vector<double> electron = mONNX_compton->generate_sample();
271+
// Apply the inverse transformation using the scaler
272+
std::vector<double> transformed_electron = mScaler_compton->inverse_transform(electron);
273+
mGenElectrons.push_back(transformed_electron);
274+
}
275+
}
276+
return true;
277+
}
278+
279+
Bool_t generateEvent(double &time_limit)
280+
{
281+
LOG(info) << "Time constraint for loopers: " << time_limit << " ns";
227282
// Generate pairs
228283
for (int i = 0; i < mNLoopersPairs; ++i)
229284
{
230285
std::vector<double> pair = mONNX_pair->generate_sample();
231286
// Apply the inverse transformation using the scaler
232287
std::vector<double> transformed_pair = mScaler_pair->inverse_transform(pair);
288+
transformed_pair[9] = gRandom->Uniform(0., time_limit); // Regenerate time, scaling is not needed because time_limit is already in nanoseconds
233289
mGenPairs.push_back(transformed_pair);
234290
}
235291
// Generate compton electrons
@@ -238,8 +294,10 @@ class GenTPCLoopers : public Generator
238294
std::vector<double> electron = mONNX_compton->generate_sample();
239295
// Apply the inverse transformation using the scaler
240296
std::vector<double> transformed_electron = mScaler_compton->inverse_transform(electron);
297+
transformed_electron[6] = gRandom->Uniform(0., time_limit); // Regenerate time, scaling is not needed because time_limit is already in nanoseconds
241298
mGenElectrons.push_back(transformed_electron);
242299
}
300+
LOG(info) << "Generated Particles with time limit";
243301
return true;
244302
}
245303

@@ -301,9 +359,9 @@ class GenTPCLoopers : public Generator
301359
return true;
302360
}
303361

304-
short int PoissonPairs()
362+
unsigned int PoissonPairs()
305363
{
306-
short int poissonValue;
364+
unsigned int poissonValue;
307365
do
308366
{
309367
// Generate a Poisson-distributed random number with mean mPoisson[0]
@@ -313,9 +371,9 @@ class GenTPCLoopers : public Generator
313371
return poissonValue;
314372
}
315373

316-
short int GaussianElectrons()
374+
unsigned int GaussianElectrons()
317375
{
318-
short int gaussValue;
376+
unsigned int gaussValue;
319377
do
320378
{
321379
// Generate a Normal-distributed random number with mean mGass[0] and stddev mGauss[1]
@@ -325,7 +383,7 @@ class GenTPCLoopers : public Generator
325383
return gaussValue;
326384
}
327385

328-
void SetNLoopers(short int &nsig_pair, short int &nsig_compton)
386+
void SetNLoopers(unsigned int &nsig_pair, unsigned int &nsig_compton)
329387
{
330388
if(mPoissonSet) {
331389
LOG(info) << "Poissonian parameters correctly loaded.";
@@ -354,6 +412,40 @@ class GenTPCLoopers : public Generator
354412
}
355413
}
356414

415+
void setFlatGas(Bool_t &flat, const Int_t &number = -1)
416+
{
417+
mFlatGas = flat;
418+
if (mFlatGas)
419+
{
420+
if (number < 0)
421+
{
422+
LOG(warn) << "Warning: Number of loopers per event must be non-negative! Switching option off.";
423+
mFlatGas = false;
424+
mFlatGasNumber = -1;
425+
}
426+
else
427+
{
428+
mFlatGasNumber = number;
429+
}
430+
}
431+
else
432+
{
433+
mFlatGasNumber = -1;
434+
}
435+
LOG(info) << "Flat gas loopers: " << (mFlatGas ? "ON" : "OFF") << ", Reference loopers number per event: " << mFlatGasNumber;
436+
}
437+
438+
void setFractionPairs(float &fractionPairs)
439+
{
440+
if (fractionPairs < 0 || fractionPairs > 1)
441+
{
442+
LOG(fatal) << "Error: Loops fraction for pairs must be in the range [0, 1].";
443+
exit(1);
444+
}
445+
mLoopsFractionPairs = fractionPairs;
446+
LOG(info) << "Pairs fraction set to: " << mLoopsFractionPairs;
447+
}
448+
357449
private:
358450
std::unique_ptr<ONNXGenerator> mONNX_pair = nullptr;
359451
std::unique_ptr<ONNXGenerator> mONNX_compton = nullptr;
@@ -363,8 +455,8 @@ class GenTPCLoopers : public Generator
363455
double mGauss[4] = {0.0, 0.0, 0.0, 0.0}; // Mean, Std, Min, Max
364456
std::vector<std::vector<double>> mGenPairs;
365457
std::vector<std::vector<double>> mGenElectrons;
366-
short int mNLoopersPairs = -1;
367-
short int mNLoopersCompton = -1;
458+
unsigned int mNLoopersPairs = -1;
459+
unsigned int mNLoopersCompton = -1;
368460
std::array<float, 2> mMultiplier = {1., 1.};
369461
bool mPoissonSet = false;
370462
bool mGaussSet = false;
@@ -374,6 +466,15 @@ class GenTPCLoopers : public Generator
374466
TDatabasePDG *mPDG = TDatabasePDG::Instance();
375467
double mMass_e = mPDG->GetParticle(11)->Mass();
376468
double mMass_p = mPDG->GetParticle(-11)->Mass();
469+
int mCurrentEvent = 0; // Current event number, used for adaptive loopers
470+
TFile *mContextFile = nullptr; // Input collision context file
471+
o2::steer::DigitizationContext *mCollisionContext = nullptr; // Pointer to the digitization context
472+
std::vector<o2::InteractionTimeRecord> mInteractionTimeRecords; // Interaction time records from collision context
473+
Bool_t mFlatGas = false; // Flag to indicate if flat gas loopers are used
474+
Int_t mFlatGasNumber = -1; // Number of flat gas loopers per event
475+
double mIntTimeRecMean = 1.0; // Average interaction time record used for the reference
476+
double mTimeLimit = 0.0; // Time limit for the current event
477+
float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
377478
};
378479

379480
} // namespace eventgen
@@ -387,8 +488,8 @@ class GenTPCLoopers : public Generator
387488
FairGenerator *
388489
Generator_TPCLoopers(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
389490
std::string poisson = "poisson.csv", std::string gauss = "gauss.csv", std::string scaler_pair = "scaler_pair.json",
390-
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, short int nloopers_pairs = 1,
391-
short int nloopers_compton = 1)
491+
std::string scaler_compton = "scaler_compton.json", std::array<float, 2> mult = {1., 1.}, unsigned int nloopers_pairs = 1,
492+
unsigned int nloopers_compton = 1)
392493
{
393494
// Expand all environment paths
394495
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
@@ -450,4 +551,75 @@ FairGenerator *
450551
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
451552
generator->SetMultiplier(mult);
452553
return generator;
453-
}
554+
}
555+
556+
// Generator with flat gas loopers. Number of loopers starts from a reference value and changes
557+
// based on the BC time intervals in each event.
558+
FairGenerator *
559+
Generator_TPCLoopersFlat(std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
560+
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json",
561+
bool flat_gas = true, const int loops_num = 500, float fraction_pairs = 0.08)
562+
{
563+
// Expand all environment paths
564+
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
565+
model_compton = gSystem->ExpandPathName(model_compton.c_str());
566+
scaler_pair = gSystem->ExpandPathName(scaler_pair.c_str());
567+
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
568+
const std::array<std::string, 2> models = {model_pairs, model_compton};
569+
const std::array<std::string, 2> local_names = {"WGANpair.onnx", "WGANcompton.onnx"};
570+
const std::array<bool, 2> isAlien = {models[0].starts_with("alien://"), models[1].starts_with("alien://")};
571+
const std::array<bool, 2> isCCDB = {models[0].starts_with("ccdb://"), models[1].starts_with("ccdb://")};
572+
if (std::any_of(isAlien.begin(), isAlien.end(), [](bool v)
573+
{ return v; }))
574+
{
575+
if (!gGrid)
576+
{
577+
TGrid::Connect("alien://");
578+
if (!gGrid)
579+
{
580+
LOG(fatal) << "AliEn connection failed, check token.";
581+
exit(1);
582+
}
583+
}
584+
for (size_t i = 0; i < models.size(); ++i)
585+
{
586+
if (isAlien[i] && !TFile::Cp(models[i].c_str(), local_names[i].c_str()))
587+
{
588+
LOG(fatal) << "Error: Model file " << models[i] << " does not exist!";
589+
exit(1);
590+
}
591+
}
592+
}
593+
if (std::any_of(isCCDB.begin(), isCCDB.end(), [](bool v)
594+
{ return v; }))
595+
{
596+
o2::ccdb::CcdbApi ccdb_api;
597+
ccdb_api.init("http://alice-ccdb.cern.ch");
598+
for (size_t i = 0; i < models.size(); ++i)
599+
{
600+
if (isCCDB[i])
601+
{
602+
auto model_path = models[i].substr(7); // Remove "ccdb://"
603+
// Treat filename if provided in the CCDB path
604+
auto extension = model_path.find(".onnx");
605+
if (extension != std::string::npos)
606+
{
607+
auto last_slash = model_path.find_last_of('/');
608+
model_path = model_path.substr(0, last_slash);
609+
}
610+
std::map<std::string, std::string> filter;
611+
if (!ccdb_api.retrieveBlob(model_path, "./", filter, o2::ccdb::getCurrentTimestamp(), false, local_names[i].c_str()))
612+
{
613+
LOG(fatal) << "Error: issues in retrieving " << model_path << " from CCDB!";
614+
exit(1);
615+
}
616+
}
617+
}
618+
}
619+
model_pairs = isAlien[0] || isCCDB[0] ? local_names[0] : model_pairs;
620+
model_compton = isAlien[1] || isCCDB[1] ? local_names[1] : model_compton;
621+
auto generator = new o2::eventgen::GenTPCLoopers(model_pairs, model_compton, "", "", scaler_pair, scaler_compton);
622+
generator->setFractionPairs(fraction_pairs);
623+
generator->setFlatGas(flat_gas, loops_num);
624+
return generator;
625+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TPC loopers injector
2+
[GeneratorExternal]
3+
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
4+
funcName = Generator_TPCLoopersFlat("ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json")

0 commit comments

Comments
 (0)