@@ -376,6 +376,89 @@ class GenTPCLoopers : public Generator
376376 double mMass_p = mPDG -> GetParticle (-11 )-> Mass ();
377377};
378378
379+ class GenLoopersInjector : public Generator
380+ {
381+ public :
382+ GenLoopersInjector (std ::string kineFN = "genevents_Kine.root" , std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
383+ std ::string scaler_pair = "scaler_pair.json" , std ::string scaler_compton = "scaler_compton.json" , std ::string poisson = "" , std ::string gauss = "" )
384+ {
385+ mKineGen = std ::make_unique < GeneratorFromO2Kine > (kineFN .c_str ());
386+ mGenTPCLoopers = std ::make_unique < GenTPCLoopers > (model_pairs , model_compton , poisson , gauss , scaler_pair , scaler_compton );
387+ Generator ::setTimeUnit (1.0 );
388+ Generator ::setPositionUnit (1.0 );
389+ Generator ::setMomentumUnit (1.0 );
390+ Generator ::setEnergyUnit (1.0 );
391+ }
392+
393+ void setAdaptiveLoopers (Bool_t adaptive )
394+ {
395+ mAdaptiveLoopers = adaptive ;
396+ LOG (info ) << "Adaptive loopers: " << (mAdaptiveLoopers ? "ON" : "OFF" );
397+ }
398+
399+ void setLoopsFractions (float & fraction , float & fractionPairs )
400+ {
401+ if (fraction < 0 || fraction >= 1 )
402+ {
403+ LOG (fatal ) << "Error: Loops fraction must be in the range [0, 1)." ;
404+ exit (1 );
405+ }
406+ mLoopsFraction = fraction ;
407+ if (fractionPairs < 0 || fractionPairs > 1 )
408+ {
409+ LOG (fatal ) << "Error: Loops fraction for pairs must be in the range [0, 1]." ;
410+ exit (1 );
411+ }
412+ mLoopsFractionPairs = fractionPairs ;
413+ LOG (info ) << "Pairs fraction set to: " << mLoopsFraction ;
414+ }
415+
416+ Bool_t generateEvent () override
417+ {
418+ // Trivial, real work in importParticles
419+ return true;
420+ }
421+
422+ Bool_t importParticles () override
423+ {
424+ mParticles .clear (); // Clear the particles stack before importing new ones
425+ // Combination of import particles from GeneratorFromO2Kine and GenTPCLoopers
426+ auto stat1 = mKineGen -> importParticles ();
427+ // Check size of mParticles stack and set loopers accordingly if mAdaptiveLoopers is true
428+ if (mAdaptiveLoopers )
429+ {
430+ int nParticles = mParticles .size ();
431+ if (nParticles > 0 )
432+ {
433+ // Calculate the number of loopers to inject adaptively
434+ short int nLoopers = static_cast < short int > (std ::round ((nParticles * mLoopsFraction ) / (1 - mLoopsFractionPairs )));
435+ short int nLoopersPairs = static_cast < short int > (std ::round (nLoopers * mLoopsFractionPairs ));
436+ short int nLoopersCompton = nLoopers - nLoopersPairs ;
437+ mGenTPCLoopers -> SetNLoopers (nLoopersPairs , nLoopersCompton );
438+ mGenTPCLoopers -> generateEvent ();
439+ }
440+ }
441+ auto stat2 = mGenTPCLoopers -> importParticles ();
442+ if (stat1 && stat2 )
443+ {
444+ // Merge particles from both generators
445+ mParticles .insert (mParticles .end (), mKineGen -> getParticles ().begin (), mKineGen -> getParticles ().end ());
446+ mParticles .insert (mParticles .end (), mGenTPCLoopers -> getParticles ().begin (), mGenTPCLoopers -> getParticles ().end ());
447+ } else {
448+ LOG (error ) << "Failed to import particles from O2 Kinematics or TPCLoopers" ;
449+ return false;
450+ }
451+ return true;
452+ }
453+
454+ private :
455+ std ::unique_ptr < GeneratorFromO2Kine > mKineGen = nullptr ; // Instance of GeneratorFromO2Kine to read particles from O2 kinematics file
456+ std ::unique_ptr < GenTPCLoopers > mGenTPCLoopers = nullptr ; // Instance of GenTPCLoopers to generate loopers
457+ Bool_t mAdaptiveLoopers = true; // Flag to indicate if adaptive loopers are used
458+ float mLoopsFraction = 0.1 ; // Fraction of loopers to be injected adaptively
459+ float mLoopsFractionPairs = 0.08 ; // Fraction of loopers from Pairs
460+ };
461+
379462} // namespace eventgen
380463} // namespace o2
381464
@@ -450,4 +533,73 @@ FairGenerator *
450533 generator -> SetNLoopers (nloopers_pairs , nloopers_compton );
451534 generator -> SetMultiplier (mult );
452535 return generator ;
536+ }
537+
538+ // Loopers injector to O2 kinematics file
539+ // Loopers are considered adaptive by default, meaning that the number of loopers is determined by the number of particles in the kinematics file per event
540+ FairGenerator *
541+ GeneratorLoopersInjector (std ::string kineFileName = "genevents_Kine.root" , std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
542+ std ::string scaler_pair = "scaler_pair.json" , std ::string scaler_compton = "scaler_compton.json" , float loopers_fraction = 0.1 , float fraction_pairs = 0.08 )
543+ {
544+ // Expand all environment paths
545+ model_pairs = gSystem -> ExpandPathName (model_pairs .c_str ()) ;
546+ model_compton = gSystem -> ExpandPathName (model_compton .c_str ());
547+ scaler_pair = gSystem -> ExpandPathName (scaler_pair .c_str ());
548+ scaler_compton = gSystem -> ExpandPathName (scaler_compton .c_str ());
549+ const std ::array < std ::string , 2 > models = {model_pairs , model_compton };
550+ const std ::array < std ::string , 2 > local_names = {"WGANpair.onnx" , "WGANcompton.onnx" };
551+ const std ::array < bool , 2 > isAlien = {models [0 ].starts_with ("alien://" ), models [1 ].starts_with ("alien://" )};
552+ const std ::array < bool , 2 > isCCDB = {models [0 ].starts_with ("ccdb://" ), models [1 ].starts_with ("ccdb://" )};
553+ if (std ::any_of (isAlien .begin (), isAlien .end (), [](bool v )
554+ { return v ; }))
555+ {
556+ if (!gGrid )
557+ {
558+ TGrid ::Connect ("alien://" );
559+ if (!gGrid )
560+ {
561+ LOG (fatal ) << "AliEn connection failed, check token." ;
562+ exit (1 );
563+ }
564+ }
565+ for (size_t i = 0 ; i < models .size (); ++ i )
566+ {
567+ if (isAlien [i ] && !TFile ::Cp (models [i ].c_str (), local_names [i ].c_str ()))
568+ {
569+ LOG (fatal ) << "Error: Model file " << models [i ] << " does not exist!" ;
570+ exit (1 );
571+ }
572+ }
573+ }
574+ if (std ::any_of (isCCDB .begin (), isCCDB .end (), [](bool v )
575+ { return v ; }))
576+ {
577+ o2 ::ccdb ::CcdbApi ccdb_api ;
578+ ccdb_api .init ("http://alice-ccdb.cern.ch" );
579+ for (size_t i = 0 ; i < models .size (); ++ i )
580+ {
581+ if (isCCDB [i ])
582+ {
583+ auto model_path = models [i ].substr (7 ); // Remove "ccdb://"
584+ // Treat filename if provided in the CCDB path
585+ auto extension = model_path .find (".onnx" );
586+ if (extension != std ::string ::npos )
587+ {
588+ auto last_slash = model_path .find_last_of ('/' );
589+ model_path = model_path .substr (0 , last_slash );
590+ }
591+ std ::map < std ::string , std ::string > filter ;
592+ if (!ccdb_api .retrieveBlob (model_path , "./" , filter , o2 ::ccdb ::getCurrentTimestamp (), false, local_names [i ].c_str ()))
593+ {
594+ LOG (fatal ) << "Error: issues in retrieving " << model_path << " from CCDB!" ;
595+ exit (1 );
596+ }
597+ }
598+ }
599+ }
600+ model_pairs = isAlien [0 ] || isCCDB [0 ] ? local_names [0 ] : model_pairs ;
601+ model_compton = isAlien [1 ] || isCCDB [1 ] ? local_names [1 ] : model_compton ;
602+ auto generator = new o2 ::eventgen ::GenLoopersInjector (kineFileName , model_pairs , model_compton , scaler_pair , scaler_compton );
603+ generator -> setLoopsFractions (loopers_fraction , fraction_pairs );
604+ return generator ;
453605}
0 commit comments