Skip to content

Commit 2250d7e

Browse files
author
Changhwan Choi
committed
GNN-based b-jet tagging analysis code updated. (training/evaluation dataset separation, etc.)
1 parent a2c6af3 commit 2250d7e

2 files changed

Lines changed: 37 additions & 27 deletions

File tree

PWGJE/Tasks/bjetTaggingGnn.cxx

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct BjetTaggingGnn {
5959

6060
Configurable<float> trackNppCrit{"trackNppCrit", 0.95, "track not physical primary ratio"};
6161

62-
// track level configurables
62+
// sv level configurables
6363
Configurable<float> svPtMin{"svPtMin", 0.5, "minimum SV pT"};
6464

6565
// jet level configurables
@@ -70,9 +70,15 @@ struct BjetTaggingGnn {
7070

7171
Configurable<std::vector<double>> jetRadii{"jetRadii", std::vector<double>{0.4}, "jet resolution parameters"};
7272

73+
Configurable<double> dbMin{"dbMin", -10., "minimum GNN Db"};
74+
Configurable<double> dbMax{"dbMax", 20., "maximum GNN Db"};
75+
Configurable<int> dbNbins{"dbNbins", 3000, "number of bins in axisDbFine"};
76+
7377
Configurable<bool> doDataDriven{"doDataDriven", false, "Flag whether to use fill THnSpase for data driven methods"};
7478
Configurable<bool> callSumw2{"callSumw2", false, "Flag whether to call THnSparse::Sumw2() for error calculation"};
7579

80+
Configurable<int> trainingDatasetRatioParam{"trainingDatasetRatioParam", 0, "Parameter for splitting training/evaluation datasets by collisionId"};
81+
7682
std::vector<int> eventSelectionBits;
7783

7884
std::vector<double> jetRadiiValues;
@@ -83,19 +89,19 @@ struct BjetTaggingGnn {
8389

8490
eventSelectionBits = jetderiveddatautilities::initialiseEventSelectionBits(static_cast<std::string>(eventSelections));
8591

86-
registry.add("h_vertexZ", "Vertex Z;#it{Z} (cm)", {HistType::kTH1F, {{40, -20.0, 20.0}}});
92+
registry.add("h_vertexZ", "Vertex Z;#it{Z} (cm)", {HistType::kTH1F, {{100, -20.0, 20.0}}});
8793

8894
const AxisSpec axisJetpT{200, 0., 200., "#it{p}_{T} (GeV/#it{c})"};
89-
const AxisSpec axisDb{200, -10., 20., "#it{D}_{b}"};
90-
const AxisSpec axisDbFine{3000, -10., 20., "#it{D}_{b}"};
95+
const AxisSpec axisDb{200, dbMin, dbMax, "#it{D}_{b}"};
96+
const AxisSpec axisDbFine{dbNbins, dbMin, dbMax, "#it{D}_{b}"};
9197
const AxisSpec axisSVMass{200, 0., 10., "#it{m}_{SV} (GeV/#it{c}^{2})"};
9298
const AxisSpec axisSVEnergy{200, 0., 100., "#it{E}_{SV} (GeV)"};
9399
const AxisSpec axisSLxy{200, 0., 100., "#it{SL}_{xy}"};
94100
const AxisSpec axisJetMass{200, 0., 50., "#it{m}_{jet} (GeV/#it{c}^{2})"};
95101
const AxisSpec axisJetProb{200, 0., 40., "-ln(JP)"};
96102
const AxisSpec axisNTracks{42, 0, 42, "#it{n}_{tracks}"};
97103

98-
registry.add("h_jetpT", "", {HistType::kTH1F, {axisJetpT}});
104+
registry.add("h_jetpT", "", {HistType::kTH1F, {axisJetpT}}, callSumw2);
99105
registry.add("h_Db", "", {HistType::kTH1F, {axisDbFine}});
100106
registry.add("h2_jetpT_Db", "", {HistType::kTH2F, {axisJetpT, axisDb}});
101107
registry.add("h2_jetpT_SVMass", "", {HistType::kTH2F, {axisJetpT, axisSVMass}});
@@ -104,9 +110,9 @@ struct BjetTaggingGnn {
104110
registry.add("h2_jetpT_nTracks", "", {HistType::kTH2F, {axisJetpT, axisNTracks}});
105111

106112
if (doprocessMCJets) {
107-
registry.add("h_jetpT_b", "b-jet", {HistType::kTH1F, {axisJetpT}});
108-
registry.add("h_jetpT_c", "c-jet", {HistType::kTH1F, {axisJetpT}});
109-
registry.add("h_jetpT_lf", "lf-jet", {HistType::kTH1F, {axisJetpT}});
113+
registry.add("h_jetpT_b", "b-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
114+
registry.add("h_jetpT_c", "c-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
115+
registry.add("h_jetpT_lf", "lf-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
110116
registry.add("h_Db_b", "b-jet", {HistType::kTH1F, {axisDbFine}});
111117
registry.add("h_Db_c", "c-jet", {HistType::kTH1F, {axisDbFine}});
112118
registry.add("h_Db_lf", "lf-jet", {HistType::kTH1F, {axisDbFine}});
@@ -125,10 +131,10 @@ struct BjetTaggingGnn {
125131
registry.add("h2_jetpT_nTracks_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisNTracks}});
126132
registry.add("h2_jetpT_nTracks_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisNTracks}});
127133
registry.add("h2_jetpT_nTracks_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisNTracks}});
128-
registry.add("h2_Response_DetjetpT_PartjetpT", "", {HistType::kTH2F, {axisJetpT, axisJetpT}});
129-
registry.add("h2_Response_DetjetpT_PartjetpT_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}});
130-
registry.add("h2_Response_DetjetpT_PartjetpT_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}});
131-
registry.add("h2_Response_DetjetpT_PartjetpT_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}});
134+
registry.add("h2_Response_DetjetpT_PartjetpT", "", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
135+
registry.add("h2_Response_DetjetpT_PartjetpT_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
136+
registry.add("h2_Response_DetjetpT_PartjetpT_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
137+
registry.add("h2_Response_DetjetpT_PartjetpT_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}, callSumw2});
132138
registry.add("h2_jetpT_Db_lf_none", "lf-jet (none)", {HistType::kTH2F, {axisJetpT, axisDb}});
133139
registry.add("h2_jetpT_Db_lf_matched", "lf-jet (matched)", {HistType::kTH2F, {axisJetpT, axisDb}});
134140
registry.add("h2_jetpT_Db_npp", "NotPhysPrim", {HistType::kTH2F, {axisJetpT, axisDb}});
@@ -146,10 +152,10 @@ struct BjetTaggingGnn {
146152
}
147153

148154
if (doprocessMCTruthJets) {
149-
registry.add("h_jetpT_particle", "", {HistType::kTH1F, {axisJetpT}});
150-
registry.add("h_jetpT_particle_b", "particle b-jet", {HistType::kTH1F, {axisJetpT}});
151-
registry.add("h_jetpT_particle_c", "particle c-jet", {HistType::kTH1F, {axisJetpT}});
152-
registry.add("h_jetpT_particle_lf", "particle lf-jet", {HistType::kTH1F, {axisJetpT}});
155+
registry.add("h_jetpT_particle", "", {HistType::kTH1F, {axisJetpT}}, callSumw2);
156+
registry.add("h_jetpT_particle_b", "particle b-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
157+
registry.add("h_jetpT_particle_c", "particle c-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
158+
registry.add("h_jetpT_particle_lf", "particle lf-jet", {HistType::kTH1F, {axisJetpT}}, callSumw2);
153159
}
154160

155161
if (doDataDriven) {
@@ -287,6 +293,11 @@ struct BjetTaggingGnn {
287293
if (!jetderiveddatautilities::selectCollision(collision, eventSelectionBits)) {
288294
return;
289295
}
296+
297+
// Uses only collisionId % trainingDatasetRaioParam != 0 for evaluation dataset
298+
if (trainingDatasetRatioParam && collision.collisionId() % trainingDatasetRatioParam == 0) {
299+
return;
300+
}
290301

291302
registry.fill(HIST("h_vertexZ"), collision.posZ());
292303

PWGJE/Tasks/bjetTreeCreator.cxx

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ struct BJetTreeCreator {
241241

242242
Configurable<float> vtxRes{"vtxRes", 0.01, "Vertex position resolution (cluster size) for GNN vertex predictions (cm)"};
243243

244+
Configurable<int> trainingDatasetRatioParam{"trainingDatasetRatioParam", 0, "Parameter for splitting training/evaluation datasets by collisionId"};
245+
244246
std::vector<int> eventSelectionBits;
245247

246248
std::vector<double> jetRadiiValues;
@@ -707,7 +709,7 @@ struct BJetTreeCreator {
707709
}
708710
PROCESS_SWITCH(BJetTreeCreator, processMCJets, "jet information in MC", false);
709711

710-
using MCDJetTableNoSV = soa::Filtered<soa::Join<aod::ChargedMCDetectorLevelJets, aod::ChargedMCDetectorLevelJetConstituents, aod::ChargedMCDetectorLevelJetsMatchedToChargedMCParticleLevelJets, aod::ChargedMCDetectorLevelJetEventWeights>>;
712+
using MCDJetTableNoSV = soa::Filtered<soa::Join<aod::ChargedMCDetectorLevelJets, aod::ChargedMCDetectorLevelJetConstituents, aod::ChargedMCDetectorLevelJetsMatchedToChargedMCParticleLevelJets, aod::ChargedMCDetectorLevelJetFlavourDef, aod::ChargedMCDetectorLevelJetEventWeights>>;
711713
using JetParticleswID = soa::Join<aod::JetParticles, aod::JMcParticlePIs>;
712714

713715
void processMCJetsForGNN(FilteredCollisionMCD::iterator const& collision, aod::JMcCollisions const&, MCDJetTableNoSV const& MCDjets, MCPJetTable const& MCPjets, JetTracksMCDwID const& allTracks, JetParticleswID const& MCParticles, OriginalTracks const& origTracks, aod::McParticles const& origParticles)
@@ -716,6 +718,11 @@ struct BJetTreeCreator {
716718
return;
717719
}
718720

721+
// Uses only collisionId % trainingDatasetRaioParam == 0 for training dataset
722+
if (trainingDatasetRatioParam && collision.collisionId() % trainingDatasetRatioParam != 0) {
723+
return;
724+
}
725+
719726
registry.fill(HIST("h_vertexZ"), collision.posZ());
720727

721728
auto const mcParticlesPerColl = MCParticles.sliceBy(mcParticlesPerCollision, collision.mcCollisionId());
@@ -738,15 +745,7 @@ struct BJetTreeCreator {
738745
std::vector<int> indicesTracks;
739746
std::vector<int> indicesSVs;
740747

741-
int16_t jetFlavor = 0;
742-
743-
for (const auto& mcpjet : analysisJet.template matchedJetGeo_as<MCPJetTable>()) {
744-
if (useQuarkDef) {
745-
jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl);
746-
} else {
747-
jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, mcParticlesPerColl);
748-
}
749-
}
748+
int16_t jetFlavor = analysisJet.origin();
750749

751750
if ((jetFlavor != JetTaggingSpecies::charm && jetFlavor != JetTaggingSpecies::beauty) && (static_cast<double>(std::rand()) / RAND_MAX < getReductionFactor(analysisJet.pt()))) {
752751
continue;
@@ -760,7 +759,7 @@ struct BJetTreeCreator {
760759
analyzeJetTrackInfoForGNN(collision, analysisJet, allTracks, origTracks, indicesTracks, jetFlavor, eventWeight, &trkLabels);
761760

762761
registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass(), eventWeight);
763-
registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), indicesTracks.size());
762+
registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), indicesTracks.size(), eventWeight);
764763

765764
//+jet
766765
registry.fill(HIST("h_jet_pt"), analysisJet.pt());

0 commit comments

Comments
 (0)