@@ -376,6 +376,91 @@ class GenTPCLoopers : public Generator
376376 double mMass_p = mPDG -> GetParticle (-11 )-> Mass ();
377377};
378378
379+ class GenLoopersInjector : public Generator
380+ {
381+ public :
382+ GenLoopersInjector (std ::string kineFN = "genevents_Kine.root" , std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
383+ std ::string scaler_pair = "scaler_pair.json" , std ::string scaler_compton = "scaler_compton.json" , std ::string poisson = "" , std ::string gauss = "" )
384+ {
385+ mKineGen = std ::make_unique < GeneratorFromO2Kine > (kineFN .c_str ());
386+ mGenTPCLoopers = std ::make_unique < GenTPCLoopers > (model_pairs , model_compton , poisson , gauss , scaler_pair , scaler_compton );
387+ Generator ::setTimeUnit (1.0 );
388+ Generator ::setPositionUnit (1.0 );
389+ Generator ::setMomentumUnit (1.0 );
390+ Generator ::setEnergyUnit (1.0 );
391+ }
392+
393+ void setAdaptiveLoopers (Bool_t adaptive )
394+ {
395+ mAdaptiveLoopers = adaptive ;
396+ LOG (info ) << "Adaptive loopers: " << (mAdaptiveLoopers ? "ON" : "OFF" );
397+ }
398+
399+ void setLoopsFractions (float & fraction , float & fractionPairs )
400+ {
401+ if (fraction < 0 || fraction >= 1 )
402+ {
403+ LOG (fatal ) << "Error: Loops fraction must be in the range [0, 1)." ;
404+ exit (1 );
405+ }
406+ mLoopsFraction = fraction ;
407+ if (fractionPairs < 0 || fractionPairs > 1 )
408+ {
409+ LOG (fatal ) << "Error: Loops fraction for pairs must be in the range [0, 1]." ;
410+ exit (1 );
411+ }
412+ mLoopsFractionPairs = fractionPairs ;
413+ LOG (info ) << "Pairs fraction set to: " << mLoopsFraction ;
414+ }
415+
416+ Bool_t generateEvent () override
417+ {
418+ // Trivial, real work in importParticles
419+ return true;
420+ }
421+
422+ Bool_t importParticles () override
423+ {
424+ mParticles .clear (); // Clear the particles stack before importing new ones
425+ // Combination of import particles from GeneratorFromO2Kine and GenTPCLoopers
426+ mKineGen -> clearParticles (); // Clear particles from O2 Kinematics generator
427+ mGenTPCLoopers -> clearParticles (); // Clear particles from GenTPCLoopers
428+ auto stat1 = mKineGen -> importParticles ();
429+ // Check size of mParticles stack and set loopers accordingly if mAdaptiveLoopers is true
430+ if (mAdaptiveLoopers )
431+ {
432+ int nParticles = mParticles .size ();
433+ if (nParticles > 0 )
434+ {
435+ // Calculate the number of loopers to inject adaptively
436+ short int nLoopers = static_cast < short int > (std ::round ((nParticles * mLoopsFraction ) / (1 - mLoopsFractionPairs )));
437+ short int nLoopersPairs = static_cast < short int > (std ::round (nLoopers * mLoopsFractionPairs ));
438+ short int nLoopersCompton = nLoopers - nLoopersPairs ;
439+ mGenTPCLoopers -> SetNLoopers (nLoopersPairs , nLoopersCompton );
440+ mGenTPCLoopers -> generateEvent ();
441+ }
442+ }
443+ auto stat2 = mGenTPCLoopers -> importParticles ();
444+ if (stat1 && stat2 )
445+ {
446+ // Merge particles from both generators
447+ mParticles .insert (mParticles .end (), mKineGen -> getParticles ().begin (), mKineGen -> getParticles ().end ());
448+ mParticles .insert (mParticles .end (), mGenTPCLoopers -> getParticles ().begin (), mGenTPCLoopers -> getParticles ().end ());
449+ } else {
450+ LOG (error ) << "Failed to import particles from O2 Kinematics or TPCLoopers" ;
451+ return false;
452+ }
453+ return true;
454+ }
455+
456+ private :
457+ std ::unique_ptr < GeneratorFromO2Kine > mKineGen = nullptr ; // Instance of GeneratorFromO2Kine to read particles from O2 kinematics file
458+ std ::unique_ptr < GenTPCLoopers > mGenTPCLoopers = nullptr ; // Instance of GenTPCLoopers to generate loopers
459+ Bool_t mAdaptiveLoopers = true; // Flag to indicate if adaptive loopers are used
460+ float mLoopsFraction = 0.1 ; // Fraction of loopers to be injected adaptively
461+ float mLoopsFractionPairs = 0.08 ; // Fraction of loopers from Pairs
462+ };
463+
379464} // namespace eventgen
380465} // namespace o2
381466
@@ -450,4 +535,73 @@ FairGenerator *
450535 generator -> SetNLoopers (nloopers_pairs , nloopers_compton );
451536 generator -> SetMultiplier (mult );
452537 return generator ;
538+ }
539+
540+ // Loopers injector to O2 kinematics file
541+ // Loopers are considered adaptive by default, meaning that the number of loopers is determined by the number of particles in the kinematics file per event
542+ FairGenerator *
543+ GeneratorLoopersInjector (std ::string kineFileName = "genevents_Kine.root" , std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
544+ std ::string scaler_pair = "scaler_pair.json" , std ::string scaler_compton = "scaler_compton.json" , float loopers_fraction = 0.1 , float fraction_pairs = 0.08 )
545+ {
546+ // Expand all environment paths
547+ model_pairs = gSystem -> ExpandPathName (model_pairs .c_str ()) ;
548+ model_compton = gSystem -> ExpandPathName (model_compton .c_str ());
549+ scaler_pair = gSystem -> ExpandPathName (scaler_pair .c_str ());
550+ scaler_compton = gSystem -> ExpandPathName (scaler_compton .c_str ());
551+ const std ::array < std ::string , 2 > models = {model_pairs , model_compton };
552+ const std ::array < std ::string , 2 > local_names = {"WGANpair.onnx" , "WGANcompton.onnx" };
553+ const std ::array < bool , 2 > isAlien = {models [0 ].starts_with ("alien://" ), models [1 ].starts_with ("alien://" )};
554+ const std ::array < bool , 2 > isCCDB = {models [0 ].starts_with ("ccdb://" ), models [1 ].starts_with ("ccdb://" )};
555+ if (std ::any_of (isAlien .begin (), isAlien .end (), [](bool v )
556+ { return v ; }))
557+ {
558+ if (!gGrid )
559+ {
560+ TGrid ::Connect ("alien://" );
561+ if (!gGrid )
562+ {
563+ LOG (fatal ) << "AliEn connection failed, check token." ;
564+ exit (1 );
565+ }
566+ }
567+ for (size_t i = 0 ; i < models .size (); ++ i )
568+ {
569+ if (isAlien [i ] && !TFile ::Cp (models [i ].c_str (), local_names [i ].c_str ()))
570+ {
571+ LOG (fatal ) << "Error: Model file " << models [i ] << " does not exist!" ;
572+ exit (1 );
573+ }
574+ }
575+ }
576+ if (std ::any_of (isCCDB .begin (), isCCDB .end (), [](bool v )
577+ { return v ; }))
578+ {
579+ o2 ::ccdb ::CcdbApi ccdb_api ;
580+ ccdb_api .init ("http://alice-ccdb.cern.ch" );
581+ for (size_t i = 0 ; i < models .size (); ++ i )
582+ {
583+ if (isCCDB [i ])
584+ {
585+ auto model_path = models [i ].substr (7 ); // Remove "ccdb://"
586+ // Treat filename if provided in the CCDB path
587+ auto extension = model_path .find (".onnx" );
588+ if (extension != std ::string ::npos )
589+ {
590+ auto last_slash = model_path .find_last_of ('/' );
591+ model_path = model_path .substr (0 , last_slash );
592+ }
593+ std ::map < std ::string , std ::string > filter ;
594+ if (!ccdb_api .retrieveBlob (model_path , "./" , filter , o2 ::ccdb ::getCurrentTimestamp (), false, local_names [i ].c_str ()))
595+ {
596+ LOG (fatal ) << "Error: issues in retrieving " << model_path << " from CCDB!" ;
597+ exit (1 );
598+ }
599+ }
600+ }
601+ }
602+ model_pairs = isAlien [0 ] || isCCDB [0 ] ? local_names [0 ] : model_pairs ;
603+ model_compton = isAlien [1 ] || isCCDB [1 ] ? local_names [1 ] : model_compton ;
604+ auto generator = new o2 ::eventgen ::GenLoopersInjector (kineFileName , model_pairs , model_compton , scaler_pair , scaler_compton );
605+ generator -> setLoopsFractions (loopers_fraction , fraction_pairs );
606+ return generator ;
453607}
0 commit comments