Belle II Software light-2406-ragdoll
ChargedPidMVAWeights.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11// FRAMEWORK
12#include <framework/gearbox/Const.h>
13#include <framework/gearbox/Unit.h>
14#include <framework/logging/Logger.h>
15
16// MVA
17#include <mva/interface/Weightfile.h>
18
19// ROOT
20#include <TObject.h>
21#include <TH3F.h>
22#include <TParameter.h>
23#include <TFile.h>
24
25//BOOST
26#include <boost/algorithm/string/predicate.hpp>
27
28
29namespace Belle2 {
38 class ChargedPidMVAWeights : public TObject {
39
40 typedef std::unordered_map<int, std::vector<std::string> > WeightfilesByParticle;
41 typedef std::map<std::string, std::string> VariablesByAlias;
43 public:
44
49 m_energy_unit("energyUnit", Unit::GeV),
50 m_ang_unit("angularUnit", Unit::rad),
51 m_thetaVarName("clusterTheta"),
53 {};
54
55
59 ChargedPidMVAWeights(const double& energyUnit, const double& angUnit,
60 const std::string& thetaVarName = "clusterTheta",
61 bool implictNaNmasking = false)
62 {
63 setEnergyUnit(energyUnit);
64 setAngularUnit(angUnit);
65 m_thetaVarName = thetaVarName;
66 m_implicitNaNmasking = implictNaNmasking;
67 }
68
73
77 enum class ChargedPidMVATrainingMode : unsigned int {
81 c_Multiclass = 1,
94 };
95
96
100 void setEnergyUnit(const double& unit) { m_energy_unit.SetVal(unit); }
101
102
106 void setAngularUnit(const double& unit) { m_ang_unit.SetVal(unit); }
107
117 void setWeightCategories(const double* clusterThetaBins, const int nClusterThetaBins,
118 const double* pBins, const int nPBins,
119 const double* chargeBins, const int nChargeBins)
120 {
121
122 m_categories = std::make_unique<TH3F>("clustertheta_p_charge_binsgrid",
123 ";ECL cluster #theta;p_{lab};Q",
124 nClusterThetaBins, clusterThetaBins,
125 nPBins, pBins,
126 nChargeBins, chargeBins);
127 }
128
138 void storeMVAWeights(const int pdg, const std::vector<std::string>& filepaths,
139 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
140 {
141
142 if (!isValidPdg(pdg)) {
143 B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
144 }
145
146 unsigned int idx(0);
147 for (const auto& path : filepaths) {
148
149 // Index consistency check.
150 auto bin_centres_tuple = categoryBinCentres.at(idx);
151
152 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
153 auto p_bin_centre = std::get<1>(bin_centres_tuple);
154 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
155
156 auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
157 if (idx != h_idx) {
158 B2FATAL("xml file:\n" << path << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
159 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
160 << ", " <<
161 charge_bin_centre <<
162 ")\nPlease check how the input xml file list is being filled.");
163 }
164
165 Belle2::MVA::Weightfile weightfile;
166 if (boost::ends_with(path, ".root")) {
168 } else if (boost::ends_with(path, ".xml")) {
170 } else {
171 B2WARNING("Unknown file extension for file: " << path << ", fallback to xml...");
173 }
174
175 // Serialize the MVA::Weightfile object into a string for storage in the database,
176 // otherwise there are issues w/ dictionary generation for the payload class...
177 std::stringstream ss;
179 m_weightfiles[pdg].push_back(ss.str());
180
181 ++idx;
182 }
183
184 }
185
186
196 void storeMVAWeightsMultiClass(const std::vector<std::string>& filepaths,
197 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
198 {
199 storeMVAWeights(0, filepaths, categoryBinCentres);
200 }
201
202
213 void storeCuts(const int pdg, const std::vector<std::string>& cutfiles,
214 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
215 {
216
217 if (!isValidPdg(pdg)) {
218 B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
219 }
220
221 unsigned int idx(0);
222 for (const auto& cutfile : cutfiles) {
223
224 auto bin_centres_tuple = categoryBinCentres.at(idx);
225
226 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
227 auto p_bin_centre = std::get<1>(bin_centres_tuple);
228 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
229
230 auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
231 if (idx != h_idx) {
232 B2FATAL("Cut file:\n" << cutfile << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
233 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
234 << ", " <<
235 charge_bin_centre <<
236 ")\nPlease check how the input cut file list is being filled.");
237 }
238
239 std::ifstream ifs(cutfile);
240 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
241
242 // Strip trailing newline.
243 cut.erase(std::remove(cut.begin(), cut.end(), '\n'), cut.end());
244
245 m_cuts[pdg].push_back(cut);
246
247 ++idx;
248 }
249
250 }
251
262 void storeCutsMultiClass(const std::vector<std::string>& cutfiles,
263 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
264 {
265 storeCuts(0, cutfiles, categoryBinCentres);
266 }
267
268
274 void storeAliases(const VariablesByAlias& aliases)
275 {
276 m_aliases = VariablesByAlias(aliases);
277 }
278
279
284 const TH3F* getWeightCategories() const
285 {
286 return m_categories.get();
287 }
288
289
295 const std::vector<std::string>* getMVAWeights(const int pdg) const
296 {
297 return &(m_weightfiles.at(pdg));
298 }
299
300
306 const std::vector<std::string>* getMVAWeightsMulticlass() const
307 {
308 return getMVAWeights(0);
309 }
310
311
317 const std::vector<std::string>* getCuts(const int pdg) const
318 {
319 return &(m_cuts.at(pdg));
320 }
321
322
328 const std::vector<std::string>* getCutsMulticlass() const
329 {
330 return getCuts(0);
331 }
332
333
338 {
339 return &m_aliases;
340 }
341
342
357 unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge, int& idx_theta, int& idx_p,
358 int& idx_charge) const
359 {
360
361 if (!m_categories) {
362 B2FATAL("No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
363 }
364
365 int nbins_th = m_categories->GetXaxis()->GetNbins(); // nr. of theta (visible) bins, along X.
366 int nbins_p = m_categories->GetYaxis()->GetNbins(); // nr. of p (visible) bins, along Y.
367
368 int glob_bin_idx = findBin(theta / m_ang_unit.GetVal(), p / m_energy_unit.GetVal(), charge);
369 m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
370
371 // The index of the linearised 3D m_categories.
372 // The unit offset is b/c ROOT sets global bin idx also for overflows and underflows.
373 return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
374 }
375
379 unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge) const
380 {
381 int idx_theta, idx_p, idx_charge;
382 return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
383 }
384
385
395 void dumpPayload(const double& theta, const double& p, const double& charge, const int pdg, bool dump_all = false) const
396 {
397
398 B2INFO("Dumping payload content for:");
399 B2INFO("clusterTheta(theta) = " << theta << " [rad], p = " << p << " [GeV/c], charge = " << charge);
400
401 if (m_categories) {
402 std::string filename = "db_payload_chargedpidmva__theta_p_charge_categories.root";
403 B2INFO("\tWriting ROOT file w/ TH3F grid that defines categories:" << filename);
404 auto f = std::make_unique<TFile>(filename.c_str(), "RECREATE");
405 m_categories->Write();
406 f->Close();
407 } else {
408 B2WARNING("\tThe TH3F object that defines categories is a nullptr!");
409 }
410
411 for (const auto& [pdgId, weights] : m_weightfiles) {
412
413 if (!dump_all && pdg != pdgId) continue;
414
415 auto idx = getMVAWeightIdx(theta, p, charge);
416
417 auto serialized_weightfile = weights.at(idx);
418
419 std::string filename = "db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
420 "_glob_bin_" + std::to_string(idx + 1) + ".xml";
421
422 auto cutstr = getCuts(pdgId)->at(idx);
423
424 B2INFO("\tpdgId = " << pdgId);
425 B2INFO("\tCut: " << cutstr);
426 B2INFO("\tWriting weight file: " << filename);
427
428 std::ofstream weightfile;
429 weightfile.open(filename.c_str(), std::ios::out);
430 weightfile << serialized_weightfile << std::endl;
431 weightfile.close();
432
433 }
434
435 };
436
437
442 void dumpPayloadMulticlass(const double& theta, const double& p, const double& charge) const
443 {
444 dumpPayload(theta, p, charge, 0);
445 }
446
447
452 bool isValidPdg(const int pdg) const
453 {
454 bool isValid = (Const::chargedStableSet.find(pdg) != Const::invalidParticle) || (pdg == 0);
455 return isValid;
456 }
457
461 std::string getThetaVarName() const
462 {
463 return m_thetaVarName;
464 }
465
466
471 {
473 }
474
475
476 private:
477
478
487 int findBin(const double& x, const double& y, const double& z) const
488 {
489
490 int nbinsx_vis = m_categories->GetXaxis()->GetNbins();
491 int nbinsy_vis = m_categories->GetYaxis()->GetNbins();
492 int nbinsz_vis = m_categories->GetZaxis()->GetNbins();
493
494 double xx = x;
495 double yy = y;
496 double zz = z;
497
498 // If x, y, z are outside of the 3D grid (visible) range, set their value to
499 // fall in the last (first) bin before (after) overflow (underflow).
500 if (x < m_categories->GetXaxis()->GetBinLowEdge(1)) { xx = m_categories->GetXaxis()->GetBinCenter(1); }
501 if (x >= m_categories->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = m_categories->GetXaxis()->GetBinCenter(nbinsx_vis); }
502 if (y < m_categories->GetYaxis()->GetBinLowEdge(1)) { yy = m_categories->GetYaxis()->GetBinCenter(1); }
503 if (y >= m_categories->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = m_categories->GetYaxis()->GetBinCenter(nbinsy_vis); }
504 if (z < m_categories->GetZaxis()->GetBinLowEdge(1)) { zz = m_categories->GetZaxis()->GetBinCenter(1); }
505 if (z >= m_categories->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz = m_categories->GetZaxis()->GetBinCenter(nbinsz_vis); }
506
507 int nbinsx = m_categories->GetXaxis()->GetNbins() + 2;
508 int nbinsy = m_categories->GetYaxis()->GetNbins() + 2;
509
510 int j = m_categories->GetXaxis()->FindBin(xx);
511 int i = m_categories->GetYaxis()->FindBin(yy);
512 int k = m_categories->GetZaxis()->FindBin(zz);
513
514 return j + nbinsx * (i + nbinsy * k);
515 }
516
517
518 private:
519
520
521 TParameter<double> m_energy_unit;
522 TParameter<double> m_ang_unit;
523 std::string
532 std::unique_ptr<TH3F> m_categories;
533
534
544 { 0, std::vector<std::string>() },
545 { Const::electron.getPDGCode(), std::vector<std::string>() },
546 { Const::muon.getPDGCode(), std::vector<std::string>() },
547 { Const::pion.getPDGCode(), std::vector<std::string>() },
548 { Const::kaon.getPDGCode(), std::vector<std::string>() },
549 { Const::proton.getPDGCode(), std::vector<std::string>() },
550 { Const::deuteron.getPDGCode(), std::vector<std::string>() }
551 };
552
553
563 { 0, std::vector<std::string>() },
564 { Const::electron.getPDGCode(), std::vector<std::string>() },
565 { Const::muon.getPDGCode(), std::vector<std::string>() },
566 { Const::pion.getPDGCode(), std::vector<std::string>() },
567 { Const::kaon.getPDGCode(), std::vector<std::string>() },
568 { Const::proton.getPDGCode(), std::vector<std::string>() },
569 { Const::deuteron.getPDGCode(), std::vector<std::string>() }
570 };
571
572
577
578
590 };
591
593}
Class to contain the payload of MVA weightfiles needed for charged particle identification.
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each category...
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed triplet (clusterTheta(theta),...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
VariablesByAlias m_aliases
A map that associates variable aliases used in the MVA training to variable names known to the Variab...
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
const TH3F * getWeightCategories() const
Get the raw pointer to the 3D grid representing the categories for which weightfiles are defined.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
std::unique_ptr< TH3F > m_categories
A 3D histogram whose bins represent the categories for which XML weight files are defined.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
std::map< std::string, std::string > VariablesByAlias
Typedef.
const VariablesByAlias * getAliases() const
Get the map of unique aliases.
void setWeightCategories(const double *clusterThetaBins, const int nClusterThetaBins, const double *pBins, const int nPBins, const double *chargeBins, const int nChargeBins)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
std::string getThetaVarName() const
Get the name of the polar angle variable.
std::string m_thetaVarName
The name of the polar angle variable used in the MVA categorisation.
bool hasImplicitNaNmasking() const
Check flag for implicit NaN masking.
void dumpPayload(const double &theta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void storeAliases(const VariablesByAlias &aliases)
Store the map associating variable aliases to variable names knowm to VariableManager.
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
bool m_implicitNaNmasking
Flag to indicate whether the MVA variables have been NaN-masked directly in the weightfiles.
ChargedPidMVAWeights(const double &energyUnit, const double &angUnit, const std::string &thetaVarName="clusterTheta", bool implictNaNmasking=false)
Specialized constructor.
int findBin(const double &x, const double &y, const double &z) const
Find global bin index of the 3D categories histogram for the given (x, y, z) values.
ClassDef(ChargedPidMVAWeights, 10)
2: add energy/angular units.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D bin coordinates.
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
Definition: Const.h:571
int getPDGCode() const
PDG code.
Definition: Const.h:473
static const ChargedStable muon
muon particle
Definition: Const.h:660
static const ParticleSet chargedStableSet
set of charged stable particles
Definition: Const.h:618
static const ChargedStable pion
charged pion particle
Definition: Const.h:661
static const ChargedStable proton
proton particle
Definition: Const.h:663
static const ParticleType invalidParticle
Invalid particle, used internally.
Definition: Const.h:681
static const ChargedStable kaon
charged kaon particle
Definition: Const.h:662
static const ChargedStable electron
electron particle
Definition: Const.h:659
static const ChargedStable deuteron
deuteron particle
Definition: Const.h:664
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:240
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:217
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:185
The Unit class.
Definition: Unit.h:40
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24