Skip to content

Commit a042f91

Browse files
committed
New extgen for loopers inclusion to be used in common workflows
1 parent 075a360 commit a042f91

3 files changed

Lines changed: 175 additions & 1 deletion

File tree

MC/bin/o2dpg_sim_workflow.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@
155155
parser.add_argument('--fwdmatching-assessment-full', action='store_true', help='enables complete assessment of global forward reco')
156156
parser.add_argument('--fwdmatching-4-param', action='store_true', help='excludes q/pt from matching parameters')
157157
parser.add_argument('--fwdmatching-cut-4-param', action='store_true', help='apply selection cuts on position and angular parameters')
158+
# TPC loopers with external WGAN generator
159+
parser.add_argument('--disable-loopers', action='store_true', help='disables fast simulated TPC loopers')
158160

159161
# Matching training for machine learning
160162
parser.add_argument('--fwdmatching-save-trainingdata', action='store_true', help='enables saving parameters at plane for matching training with machine learning')
@@ -857,14 +859,30 @@ def getDPL_global_options(bigshm=False, ccdbbackend=True):
857859
# GeneratorFromO2Kine parameters are needed only before the transport
858860
CONFKEY = re.sub(r'GeneratorFromO2Kine.*?;', '', CONFKEY)
859861

862+
kineFileName = 'genevents_Kine.root'
863+
864+
# Include fast simulated TPC loopers
865+
if isActive('TPC') and not args.disable_loopers:
866+
LOOPStask = createTask(name='sgngenloops_' + str(tf), needs=signalneeds, tf=tf, cwd='tf' + str(tf), lab=["GEN"], cpu=1, mem=1000)
867+
LOOPScfgbase = "GeneratorFromO2Kine.randomize=true"
868+
LOOPSinicfg = " --configFile $O2DPG_MC_CONFIG_ROOT/MC/config/common/ini/GeneratorLoopersInjector.ini"
869+
LOOPSCONFKEY = constructConfigKeyArg(create_geant_config(args, LOOPScfgbase))
870+
LOOPStask['cmd'] = '${O2_ROOT}/bin/o2-sim --noGeant --field ccdb -j 1 --vertexMode kNoVertex' \
871+
+ ' --run ' + str(args.run) + ' ' + str(LOOPSCONFKEY) + ' -g external' \
872+
+ ' -n ' + str(NSIGEVENTS) + ' --seed ' + str(TFSEED) + ' -o loops ' \
873+
+ embeddinto + ' --fromCollContext collisioncontext.root:' + signalprefix + LOOPSinicfg
874+
kineFileName = 'loops_Kine.root' # Kine file now has injected loopers
875+
signalneeds = signalneeds + [LOOPStask['name']]
876+
workflow['stages'].append(LOOPStask)
877+
860878
sgnmem = 6000 if COLTYPE == 'PbPb' else 4000
861879
SGNtask=createTask(name='sgnsim_'+str(tf), needs=signalneeds, tf=tf, cwd='tf'+str(tf), lab=["GEANT"],
862880
relative_cpu=7/8, n_workers=NWORKERS_TF, mem=str(sgnmem))
863881
sgncmdbase = '${O2_ROOT}/bin/o2-sim -e ' + str(SIMENGINE) + ' ' + str(MODULES) + ' -n ' + str(NSIGEVENTS) + ' --seed ' + str(TFSEED) \
864882
+ ' --field ccdb -j ' + str(NWORKERS_TF) + ' ' + str(CONFKEY) + ' ' + str(INIFILE) + ' -o ' + signalprefix + ' ' + embeddinto \
865883
+ ('', ' --timestamp ' + str(args.timestamp))[args.timestamp!=-1] + ' --run ' + str(args.run)
866884
if sep_event_mode:
867-
SGNtask['cmd'] = sgncmdbase + ' -g extkinO2 --extKinFile genevents_Kine.root ' + ' --vertexMode kNoVertex'
885+
SGNtask['cmd'] = sgncmdbase + ' -g extkinO2 --extKinFile ' + kineFileName + ' --vertexMode kNoVertex'
868886
else:
869887
SGNtask['cmd'] = sgncmdbase + ' -g ' + str(GENERATOR) + ' ' + str(TRIGGER) + ' --vertexMode kCCDB '
870888
if not isActive('all'):

MC/config/common/external/generator/TPCLoopers.C

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,89 @@ class GenTPCLoopers : public Generator
376376
double mMass_p = mPDG->GetParticle(-11)->Mass();
377377
};
378378

379+
class GenLoopersInjector : public Generator
380+
{
381+
public:
382+
GenLoopersInjector(std::string kineFN = "genevents_Kine.root", std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
383+
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json", std::string poisson = "", std::string gauss = "")
384+
{
385+
mKineGen = std::make_unique<GeneratorFromO2Kine>(kineFN.c_str());
386+
mGenTPCLoopers = std::make_unique<GenTPCLoopers>(model_pairs, model_compton, poisson, gauss, scaler_pair, scaler_compton);
387+
Generator::setTimeUnit(1.0);
388+
Generator::setPositionUnit(1.0);
389+
Generator::setMomentumUnit(1.0);
390+
Generator::setEnergyUnit(1.0);
391+
}
392+
393+
void setAdaptiveLoopers(Bool_t adaptive)
394+
{
395+
mAdaptiveLoopers = adaptive;
396+
LOG(info) << "Adaptive loopers: " << (mAdaptiveLoopers ? "ON" : "OFF");
397+
}
398+
399+
void setLoopsFractions(float &fraction, float &fractionPairs)
400+
{
401+
if (fraction < 0 || fraction >= 1)
402+
{
403+
LOG(fatal) << "Error: Loops fraction must be in the range [0, 1).";
404+
exit(1);
405+
}
406+
mLoopsFraction = fraction;
407+
if (fractionPairs < 0 || fractionPairs > 1)
408+
{
409+
LOG(fatal) << "Error: Loops fraction for pairs must be in the range [0, 1].";
410+
exit(1);
411+
}
412+
mLoopsFractionPairs = fractionPairs;
413+
LOG(info) << "Pairs fraction set to: " << mLoopsFraction;
414+
}
415+
416+
Bool_t generateEvent() override
417+
{
418+
// Trivial, real work in importParticles
419+
return true;
420+
}
421+
422+
Bool_t importParticles() override
423+
{
424+
mParticles.clear(); // Clear the particles stack before importing new ones
425+
// Combination of import particles from GeneratorFromO2Kine and GenTPCLoopers
426+
auto stat1 = mKineGen->importParticles();
427+
// Check size of mParticles stack and set loopers accordingly if mAdaptiveLoopers is true
428+
if (mAdaptiveLoopers)
429+
{
430+
int nParticles = mParticles.size();
431+
if (nParticles > 0)
432+
{
433+
// Calculate the number of loopers to inject adaptively
434+
short int nLoopers = static_cast<short int>(std::round((nParticles * mLoopsFraction) / (1 - mLoopsFractionPairs)));
435+
short int nLoopersPairs = static_cast<short int>(std::round(nLoopers * mLoopsFractionPairs));
436+
short int nLoopersCompton = nLoopers - nLoopersPairs;
437+
mGenTPCLoopers->SetNLoopers(nLoopersPairs, nLoopersCompton);
438+
mGenTPCLoopers->generateEvent();
439+
}
440+
}
441+
auto stat2 = mGenTPCLoopers->importParticles();
442+
if (stat1 && stat2)
443+
{
444+
// Merge particles from both generators
445+
mParticles.insert(mParticles.end(), mKineGen->getParticles().begin(), mKineGen->getParticles().end());
446+
mParticles.insert(mParticles.end(), mGenTPCLoopers->getParticles().begin(), mGenTPCLoopers->getParticles().end());
447+
} else {
448+
LOG(error) << "Failed to import particles from O2 Kinematics or TPCLoopers";
449+
return false;
450+
}
451+
return true;
452+
}
453+
454+
private:
455+
std::unique_ptr<GeneratorFromO2Kine> mKineGen = nullptr; // Instance of GeneratorFromO2Kine to read particles from O2 kinematics file
456+
std::unique_ptr<GenTPCLoopers> mGenTPCLoopers = nullptr; // Instance of GenTPCLoopers to generate loopers
457+
Bool_t mAdaptiveLoopers = true; // Flag to indicate if adaptive loopers are used
458+
float mLoopsFraction = 0.1; // Fraction of loopers to be injected adaptively
459+
float mLoopsFractionPairs = 0.08; // Fraction of loopers from Pairs
460+
};
461+
379462
} // namespace eventgen
380463
} // namespace o2
381464

@@ -450,4 +533,73 @@ FairGenerator *
450533
generator->SetNLoopers(nloopers_pairs, nloopers_compton);
451534
generator->SetMultiplier(mult);
452535
return generator;
536+
}
537+
538+
// Loopers injector to O2 kinematics file
539+
// Loopers are considered adaptive by default, meaning that the number of loopers is determined by the number of particles in the kinematics file per event
540+
FairGenerator *
541+
GeneratorLoopersInjector(std::string kineFileName = "genevents_Kine.root", std::string model_pairs = "tpcloopmodel.onnx", std::string model_compton = "tpcloopmodelcompton.onnx",
542+
std::string scaler_pair = "scaler_pair.json", std::string scaler_compton = "scaler_compton.json", float loopers_fraction = 0.1, float fraction_pairs = 0.08)
543+
{
544+
// Expand all environment paths
545+
model_pairs = gSystem->ExpandPathName(model_pairs.c_str());
546+
model_compton = gSystem->ExpandPathName(model_compton.c_str());
547+
scaler_pair = gSystem->ExpandPathName(scaler_pair.c_str());
548+
scaler_compton = gSystem->ExpandPathName(scaler_compton.c_str());
549+
const std::array<std::string, 2> models = {model_pairs, model_compton};
550+
const std::array<std::string, 2> local_names = {"WGANpair.onnx", "WGANcompton.onnx"};
551+
const std::array<bool, 2> isAlien = {models[0].starts_with("alien://"), models[1].starts_with("alien://")};
552+
const std::array<bool, 2> isCCDB = {models[0].starts_with("ccdb://"), models[1].starts_with("ccdb://")};
553+
if (std::any_of(isAlien.begin(), isAlien.end(), [](bool v)
554+
{ return v; }))
555+
{
556+
if (!gGrid)
557+
{
558+
TGrid::Connect("alien://");
559+
if (!gGrid)
560+
{
561+
LOG(fatal) << "AliEn connection failed, check token.";
562+
exit(1);
563+
}
564+
}
565+
for (size_t i = 0; i < models.size(); ++i)
566+
{
567+
if (isAlien[i] && !TFile::Cp(models[i].c_str(), local_names[i].c_str()))
568+
{
569+
LOG(fatal) << "Error: Model file " << models[i] << " does not exist!";
570+
exit(1);
571+
}
572+
}
573+
}
574+
if (std::any_of(isCCDB.begin(), isCCDB.end(), [](bool v)
575+
{ return v; }))
576+
{
577+
o2::ccdb::CcdbApi ccdb_api;
578+
ccdb_api.init("http://alice-ccdb.cern.ch");
579+
for (size_t i = 0; i < models.size(); ++i)
580+
{
581+
if (isCCDB[i])
582+
{
583+
auto model_path = models[i].substr(7); // Remove "ccdb://"
584+
// Treat filename if provided in the CCDB path
585+
auto extension = model_path.find(".onnx");
586+
if (extension != std::string::npos)
587+
{
588+
auto last_slash = model_path.find_last_of('/');
589+
model_path = model_path.substr(0, last_slash);
590+
}
591+
std::map<std::string, std::string> filter;
592+
if (!ccdb_api.retrieveBlob(model_path, "./", filter, o2::ccdb::getCurrentTimestamp(), false, local_names[i].c_str()))
593+
{
594+
LOG(fatal) << "Error: issues in retrieving " << model_path << " from CCDB!";
595+
exit(1);
596+
}
597+
}
598+
}
599+
}
600+
model_pairs = isAlien[0] || isCCDB[0] ? local_names[0] : model_pairs;
601+
model_compton = isAlien[1] || isCCDB[1] ? local_names[1] : model_compton;
602+
auto generator = new o2::eventgen::GenLoopersInjector(kineFileName, model_pairs, model_compton, scaler_pair, scaler_compton);
603+
generator->setLoopsFractions(loopers_fraction, fraction_pairs);
604+
return generator;
453605
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TPC loopers injector
2+
[GeneratorExternal]
3+
fileName = ${O2DPG_MC_CONFIG_ROOT}/MC/config/common/external/generator/TPCLoopers.C
4+
funcName = GeneratorLoopersInjector("genevents_Kine.root", "ccdb://Users/m/mgiacalo/WGAN_ExtGenPair", "ccdb://Users/m/mgiacalo/WGAN_ExtGenCompton", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerPairParams.json", "${O2DPG_MC_CONFIG_ROOT}/MC/config/common/TPCloopers/ScalerComptonParams.json")

0 commit comments

Comments
 (0)