Skip to content

Commit e83c6dd

Browse files
committed
New extgen for loopers inclusion to be used in common workflows
1 parent 075a360 commit e83c6dd

3 files changed

Lines changed: 177 additions & 1 deletion

File tree

MC/bin/o2dpg_sim_workflow.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@
155155
parser.add_argument('--fwdmatching-assessment-full', action='store_true', help='enables complete assessment of global forward reco')
156156
parser.add_argument('--fwdmatching-4-param', action='store_true', help='excludes q/pt from matching parameters')
157157
parser.add_argument('--fwdmatching-cut-4-param', action='store_true', help='apply selection cuts on position and angular parameters')
158+
# TPC loopers with external WGAN generator
159+
parser.add_argument('--disable-loopers', action='store_true', help='disables fast simulated TPC loopers')
158160

159161
# Matching training for machine learning
160162
parser.add_argument('--fwdmatching-save-trainingdata', action='store_true', help='enables saving parameters at plane for matching training with machine learning')
@@ -857,14 +859,30 @@ def getDPL_global_options(bigshm=False, ccdbbackend=True):
857859
# GeneratorFromO2Kine parameters are needed only before the transport
858860
CONFKEY = re.sub(r'GeneratorFromO2Kine.*?;', '', CONFKEY)
859861

862+
kineFileName = 'genevents_Kine.root'
863+
864+
# Include fast simulated TPC loopers
865+
if isActive('TPC') and not args.disable_loopers:
866+
LOOPStask = createTask(name='sgngenloops_' + str(tf), needs=signalneeds, tf=tf, cwd='tf' + str(tf), lab=["GEN"], cpu=1, mem=1000)
867+
LOOPScfgbase = "GeneratorFromO2Kine.randomize=true"
868+
LOOPSinicfg = " --configFile $O2DPG_MC_CONFIG_ROOT/MC/config/common/ini/GeneratorLoopersInjector.ini"
869+
LOOPSCONFKEY = constructConfigKeyArg(create_geant_config(args, LOOPScfgbase))
870+
LOOPStask['cmd'] = '${O2_ROOT}/bin/o2-sim --noGeant --field ccdb -j 1 --vertexMode kNoVertex' \
871+
+ ' --run ' + str(args.run) + ' ' + str(LOOPSCONFKEY) + ' -g external' \
872+
+ ' -n ' + str(NSIGEVENTS) + ' --seed ' + str(TFSEED) + ' -o loops ' \
873+
+ embeddinto + ' --fromCollContext collisioncontext.root:' + signalprefix + LOOPSinicfg
874+
kineFileName = 'loops_Kine.root' # Kine file now has injected loopers
875+
signalneeds = signalneeds + [LOOPStask['name']]
876+
workflow['stages'].append(LOOPStask)
877+
860878
sgnmem = 6000 if COLTYPE == 'PbPb' else 4000
861879
SGNtask=createTask(name='sgnsim_'+str(tf), needs=signalneeds, tf=tf, cwd='tf'+str(tf), lab=["GEANT"],
862880
relative_cpu=7/8, n_workers=NWORKERS_TF, mem=str(sgnmem))
863881
sgncmdbase = '${O2_ROOT}/bin/o2-sim -e ' + str(SIMENGINE) + ' ' + str(MODULES) + ' -n ' + str(NSIGEVENTS) + ' --seed ' + str(TFSEED) \
864882
+ ' --field ccdb -j ' + str(NWORKERS_TF) + ' ' + str(CONFKEY) + ' ' + str(INIFILE) + ' -o ' + signalprefix + ' ' + embeddinto \
865883
+ ('', ' --timestamp ' + str(args.timestamp))[args.timestamp!=-1] + ' --run ' + str(args.run)
866884
if sep_event_mode:
867-
SGNtask['cmd'] = sgncmdbase + ' -g extkinO2 --extKinFile genevents_Kine.root ' + ' --vertexMode kNoVertex'
885+
SGNtask['cmd'] = sgncmdbase + ' -g extkinO2 --extKinFile ' + kineFileName + ' --vertexMode kNoVertex'
868886
else:
869887
SGNtask['cmd'] = sgncmdbase + ' -g ' + str(GENERATOR) + ' ' + str(TRIGGER) + ' --vertexMode kCCDB '
870888
if not isActive('all'):

MC/config/common/external/generator/TPCLoopers.C

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,91 @@ class GenTPCLoopers : public Generator
376376
double mMass_p = mPDG->GetParticle(-11)->Mass();
377377
};
378378

379+
class GenLoopersInjector : public Generator
380+
{
381+
public:
382+
GenLoopersInjector(std::string kineFN = "genevents_Kine.root", std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
383+
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json", std::string poisson = "", std::string gauss = "")
384+
{
385+
mKineGen = std::make_unique<GeneratorFromO2Kine>(kineFN.c_str());
386+
mGenTPCLoopers = std::make_unique<GenTPCLoopers>(model_pairs, model_compton, poisson, gauss, scaler_pair, scaler_compton);
387+
Generator::setTimeUnit(1.0);
388+
Generator::setPositionUnit(1.0);
389+
Generator::setMomentumUnit(1.0);
390+
Generator::setEnergyUnit(1.0);
391+
}
392+
393+
void setAdaptiveLoopers(Bool_t adaptive)
394+
{
395+
mAdaptiveLoopers = adaptive;
396+
LOG(info) << "Adaptive loopers: " << (mAdaptiveLoopers ? "ON" : "OFF");
397+
}
398+
399+
void setLoopsFractions(float &fraction, float &fractionPairs)
400+
{
401+
if (fraction < 0 || fraction >= 1)
402+
{
403+
LOG(fatal) << "Error: Loops fraction must be in the range [0, 1).";
404+
exit(1);
405+
}
406+
mLoopsFraction = fraction;
407+
if (fractionPairs < 0 || fractionPairs > 1)
408+
{
409+
LOG(fatal) << "Error: Loops fraction for pairs must be in the range [0, 1].";
410+
exit(1);
411+
}
412+
mLoopsFractionPairs = fractionPairs;
413+
LOG(info) << "Pairs fraction set to: " << mLoopsFraction;
414+
}
415+
416+
Bool_t generateEvent() override
417+
{
418+
// Trivial, real work in importParticles
419+
return true;
420+
}
421+
422+
Bool_t importParticles() override
423+
{
424+
mParticles.clear(); // Clear the particles stack before importing new ones
425+
// Combination of import particles from GeneratorFromO2Kine and GenTPCLoopers
426+
mKineGen->clearParticles(); // Clear particles from O2 Kinematics generator
427+
mGenTPCLoopers->clearParticles(); // Clear particles from GenTPCLoopers
428+
auto stat1 = mKineGen->importParticles();
429+
// Check size of mParticles stack and set loopers accordingly if mAdaptiveLoopers is true
430+
if (mAdaptiveLoopers)
431+
{
432+
int nParticles = mParticles.size();
433+
if (nParticles > 0)
434+
{
435+
// Calculate the number of loopers to inject adaptively
436+
short int nLoopers = static_cast<short int>(std::round((nParticles * mLoopsFraction) / (1 - mLoopsFractionPairs)));
437+
short int nLoopersPairs = static_cast<short int>(std::round(nLoopers * mLoopsFractionPairs));
438+
short int nLoopersCompton = nLoopers - nLoopersPairs;
439+
mGenTPCLoopers->SetNLoopers(nLoopersPairs, nLoopersCompton);
440+
mGenTPCLoopers->generateEvent();
441+
}
442+
}
443+
auto stat2 = mGenTPCLoopers->importParticles();
444+
if (stat1 && stat2)
445+
{
446+
// Merge particles from both generators
447+
mParticles.insert(mParticles.end(), mKineGen->getParticles().begin(), mKineGen->getParticles().end());
448+
mParticles.insert(mParticles.end(), mGenTPCLoopers->getParticles().begin(), mGenTPCLoopers->getParticles().end());
449+
} else {
450+
LOG(error) << "Failed to import particles from O2 Kinematics or TPCLoopers";
451+
return false;
452+
}
453+
return true;
454+
}
455+
456+
private:
457+
std::unique_ptr<GeneratorFromO2Kine> mKineGen = nullptr; // Instance of GeneratorFromO2Kine to read particles from O2 kinematics file
458+
std::unique_ptr<GenTPCLoopers> mGenTPCLoopers = nullptr; // Instance of GenTPCLoopers to generate loopers
459+
Bool_t mAdaptiveLoopers = true; // Flag to indicate if adaptive loopers are used
460+
float mLoopsFraction = 0.1; // Fraction of loopers to be injected adaptively
461+
float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
462+
};
463+
379464
} // namespace eventgen
380465
} // namespace o2
381466

@@ -450,4 +535,73 @@ FairGenerator *
450535
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
451536
generator->SetMultiplier(mult);
452537
return generator;
538+
}
539+
540+
// Loopers injector to O2 kinematics file
541+
// Loopers are considered adaptive by default, meaning that the number of loopers is determined by the number of particles in the kinematics file per event
542+
FairGenerator *
543+
GeneratorLoopersInjector(std::string kineFileName = "genevents_Kine.root", std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
544+
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json", float loopers_fraction = 0.1, float fraction_pairs = 0.08)
545+
{
546+
// Expand all environment paths
547+
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
548+
model_compton = gSystem->ExpandPathName(model_compton.c_str());
549+
scaler_pair = gSystem->ExpandPathName(scaler_pair.c_str());
550+
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
551+
const std::array<std::string, 2> models = {model_pairs, model_compton};
552+
const std::array<std::string, 2> local_names = {"WGANpair.onnx", "WGANcompton.onnx"};
553+
const std::array<bool, 2> isAlien = {models[0].starts_with("alien://"), models[1].starts_with("alien://")};
554+
const std::array<bool, 2> isCCDB = {models[0].starts_with("ccdb://"), models[1].starts_with("ccdb://")};
555+
if (std::any_of(isAlien.begin(), isAlien.end(), [](bool v)
556+
{ return v; }))
557+
{
558+
if (!gGrid)
559+
{
560+
TGrid::Connect("alien://");
561+
if (!gGrid)
562+
{
563+
LOG(fatal) << "AliEn connection failed, check token.";
564+
exit(1);
565+
}
566+
}
567+
for (size_t i = 0; i < models.size(); ++i)
568+
{
569+
if (isAlien[i] && !TFile::Cp(models[i].c_str(), local_names[i].c_str()))
570+
{
571+
LOG(fatal) << "Error: Model file " << models[i] << " does not exist!";
572+
exit(1);
573+
}
574+
}
575+
}
576+
if (std::any_of(isCCDB.begin(), isCCDB.end(), [](bool v)
577+
{ return v; }))
578+
{
579+
o2::ccdb::CcdbApi ccdb_api;
580+
ccdb_api.init("http://alice-ccdb.cern.ch");
581+
for (size_t i = 0; i < models.size(); ++i)
582+
{
583+
if (isCCDB[i])
584+
{
585+
auto model_path = models[i].substr(7); // Remove "ccdb://"
586+
// Treat filename if provided in the CCDB path
587+
auto extension = model_path.find(".onnx");
588+
if (extension != std::string::npos)
589+
{
590+
auto last_slash = model_path.find_last_of('/');
591+
model_path = model_path.substr(0, last_slash);
592+
}
593+
std::map<std::string, std::string> filter;
594+
if (!ccdb_api.retrieveBlob(model_path, "./", filter, o2::ccdb::getCurrentTimestamp(), false, local_names[i].c_str()))
595+
{
596+
LOG(fatal) << "Error: issues in retrieving " << model_path << " from CCDB!";
597+
exit(1);
598+
}
599+
}
600+
}
601+
}
602+
model_pairs = isAlien[0] || isCCDB[0] ? local_names[0] : model_pairs;
603+
model_compton = isAlien[1] || isCCDB[1] ? local_names[1] : model_compton;
604+
auto generator = new o2::eventgen::GenLoopersInjector(kineFileName, model_pairs, model_compton, scaler_pair, scaler_compton);
605+
generator->setLoopsFractions(loopers_fraction, fraction_pairs);
606+
return generator;
453607
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TPC loopers injector
2+
[GeneratorExternal]
3+
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
4+
funcName = GeneratorLoopersInjector("genevents_Kine.root", "ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json")

0 commit comments

Comments
 (0)