Belle II Software  release-08-01-10
ChargedPidMVAWeights.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 
11 // FRAMEWORK
12 #include <framework/gearbox/Const.h>
13 #include <framework/gearbox/Unit.h>
14 #include <framework/logging/Logger.h>
15 
16 // MVA
17 #include <mva/interface/Weightfile.h>
18 
19 // ROOT
20 #include <TObject.h>
21 #include <TH3F.h>
22 #include <TParameter.h>
23 #include <TFile.h>
24 
25 //BOOST
26 #include <boost/algorithm/string/predicate.hpp>
27 
28 
29 namespace Belle2 {
38  class ChargedPidMVAWeights : public TObject {
39 
40  typedef std::unordered_map<int, std::vector<std::string> > WeightfilesByParticle;
41  typedef std::map<std::string, std::string> VariablesByAlias;
43  public:
44 
49  m_energy_unit("energyUnit", Unit::GeV),
50  m_ang_unit("angularUnit", Unit::rad),
51  m_thetaVarName("clusterTheta"),
53  {};
54 
55 
59  ChargedPidMVAWeights(const double& energyUnit, const double& angUnit,
60  const std::string& thetaVarName = "clusterTheta",
61  bool implictNaNmasking = false)
62  {
63  setEnergyUnit(energyUnit);
64  setAngularUnit(angUnit);
65  m_thetaVarName = thetaVarName;
66  m_implicitNaNmasking = implictNaNmasking;
67  }
68 
73 
77  enum class ChargedPidMVATrainingMode : unsigned int {
79  c_Classification = 0,
81  c_Multiclass = 1,
85  c_ECL_Multiclass = 3,
89  c_PSD_Multiclass = 5,
94  };
95 
96 
100  void setEnergyUnit(const double& unit) { m_energy_unit.SetVal(unit); }
101 
102 
106  void setAngularUnit(const double& unit) { m_ang_unit.SetVal(unit); }
107 
117  void setWeightCategories(const double* clusterThetaBins, const int nClusterThetaBins,
118  const double* pBins, const int nPBins,
119  const double* chargeBins, const int nChargeBins)
120  {
121 
122  m_categories = std::make_unique<TH3F>("clustertheta_p_charge_binsgrid",
123  ";ECL cluster #theta;p_{lab};Q",
124  nClusterThetaBins, clusterThetaBins,
125  nPBins, pBins,
126  nChargeBins, chargeBins);
127  }
128 
138  void storeMVAWeights(const int pdg, const std::vector<std::string>& filepaths,
139  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
140  {
141 
142  if (!isValidPdg(pdg)) {
143  B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
144  }
145 
146  unsigned int idx(0);
147  for (const auto& path : filepaths) {
148 
149  // Index consistency check.
150  auto bin_centres_tuple = categoryBinCentres.at(idx);
151 
152  auto theta_bin_centre = std::get<0>(bin_centres_tuple);
153  auto p_bin_centre = std::get<1>(bin_centres_tuple);
154  auto charge_bin_centre = std::get<2>(bin_centres_tuple);
155 
156  auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
157  if (idx != h_idx) {
158  B2FATAL("xml file:\n" << path << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
159  "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
160  << ", " <<
161  charge_bin_centre <<
162  ")\nPlease check how the input xml file list is being filled.");
163  }
164 
165  Belle2::MVA::Weightfile weightfile;
166  if (boost::ends_with(path, ".root")) {
168  } else if (boost::ends_with(path, ".xml")) {
169  weightfile = Belle2::MVA::Weightfile::loadFromXMLFile(path);
170  } else {
171  B2WARNING("Unknown file extension for file: " << path << ", fallback to xml...");
172  weightfile = Belle2::MVA::Weightfile::loadFromXMLFile(path);
173  }
174 
175  // Serialize the MVA::Weightfile object into a string for storage in the database,
176  // otherwise there are issues w/ dictionary generation for the payload class...
177  std::stringstream ss;
179  m_weightfiles[pdg].push_back(ss.str());
180 
181  ++idx;
182  }
183 
184  }
185 
186 
196  void storeMVAWeightsMultiClass(const std::vector<std::string>& filepaths,
197  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
198  {
199  storeMVAWeights(0, filepaths, categoryBinCentres);
200  }
201 
202 
213  void storeCuts(const int pdg, const std::vector<std::string>& cutfiles,
214  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
215  {
216 
217  if (!isValidPdg(pdg)) {
218  B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
219  }
220 
221  unsigned int idx(0);
222  for (const auto& cutfile : cutfiles) {
223 
224  auto bin_centres_tuple = categoryBinCentres.at(idx);
225 
226  auto theta_bin_centre = std::get<0>(bin_centres_tuple);
227  auto p_bin_centre = std::get<1>(bin_centres_tuple);
228  auto charge_bin_centre = std::get<2>(bin_centres_tuple);
229 
230  auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
231  if (idx != h_idx) {
232  B2FATAL("Cut file:\n" << cutfile << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
233  "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
234  << ", " <<
235  charge_bin_centre <<
236  ")\nPlease check how the input cut file list is being filled.");
237  }
238 
239  std::ifstream ifs(cutfile);
240  std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
241 
242  // Strip trailing newline.
243  cut.erase(std::remove(cut.begin(), cut.end(), '\n'), cut.end());
244 
245  m_cuts[pdg].push_back(cut);
246 
247  ++idx;
248  }
249 
250  }
251 
262  void storeCutsMultiClass(const std::vector<std::string>& cutfiles,
263  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
264  {
265  storeCuts(0, cutfiles, categoryBinCentres);
266  }
267 
268 
274  void storeAliases(const VariablesByAlias& aliases)
275  {
276  m_aliases = VariablesByAlias(aliases);
277  }
278 
279 
284  const TH3F* getWeightCategories() const
285  {
286  return m_categories.get();
287  }
288 
289 
295  const std::vector<std::string>* getMVAWeights(const int pdg) const
296  {
297  return &(m_weightfiles.at(pdg));
298  }
299 
300 
306  const std::vector<std::string>* getMVAWeightsMulticlass() const
307  {
308  return getMVAWeights(0);
309  }
310 
311 
318  const std::vector<std::string>* getCuts(const int pdg) const
319  {
320  return &(m_cuts.at(pdg));
321  }
322 
323 
329  const std::vector<std::string>* getCutsMulticlass() const
330  {
331  return getCuts(0);
332  }
333 
334 
339  {
340  return &m_aliases;
341  }
342 
343 
358  unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge, int& idx_theta, int& idx_p,
359  int& idx_charge) const
360  {
361 
362  if (!m_categories) {
363  B2FATAL("No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
364  }
365 
366  int nbins_th = m_categories->GetXaxis()->GetNbins(); // nr. of theta (visible) bins, along X.
367  int nbins_p = m_categories->GetYaxis()->GetNbins(); // nr. of p (visible) bins, along Y.
368 
369  int glob_bin_idx = findBin(theta / m_ang_unit.GetVal(), p / m_energy_unit.GetVal(), charge);
370  m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
371 
372  // The index of the linearised 3D m_categories.
373  // The unit offset is b/c ROOT sets global bin idx also for overflows and underflows.
374  return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
375  }
376 
380  unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge) const
381  {
382  int idx_theta, idx_p, idx_charge;
383  return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
384  }
385 
386 
396  void dumpPayload(const double& theta, const double& p, const double& charge, const int pdg, bool dump_all = false) const
397  {
398 
399  B2INFO("Dumping payload content for:");
400  B2INFO("clusterTheta(theta) = " << theta << " [rad], p = " << p << " [GeV/c], charge = " << charge);
401 
402  if (m_categories) {
403  std::string filename = "db_payload_chargedpidmva__theta_p_charge_categories.root";
404  B2INFO("\tWriting ROOT file w/ TH3F grid that defines categories:" << filename);
405  auto f = std::make_unique<TFile>(filename.c_str(), "RECREATE");
406  m_categories->Write();
407  f->Close();
408  } else {
409  B2WARNING("\tThe TH3F object that defines categories is a nullptr!");
410  }
411 
412  for (const auto& [pdgId, weights] : m_weightfiles) {
413 
414  if (!dump_all && pdg != pdgId) continue;
415 
416  auto idx = getMVAWeightIdx(theta, p, charge);
417 
418  auto serialized_weightfile = weights.at(idx);
419 
420  std::string filename = "db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
421  "_glob_bin_" + std::to_string(idx + 1) + ".xml";
422 
423  auto cutstr = getCuts(pdgId)->at(idx);
424 
425  B2INFO("\tpdgId = " << pdgId);
426  B2INFO("\tCut: " << cutstr);
427  B2INFO("\tWriting weight file: " << filename);
428 
429  std::ofstream weightfile;
430  weightfile.open(filename.c_str(), std::ios::out);
431  weightfile << serialized_weightfile << std::endl;
432  weightfile.close();
433 
434  }
435 
436  };
437 
438 
443  void dumpPayloadMulticlass(const double& theta, const double& p, const double& charge) const
444  {
445  dumpPayload(theta, p, charge, 0);
446  }
447 
448 
453  bool isValidPdg(const int pdg) const
454  {
455  bool isValid = (Const::chargedStableSet.find(pdg) != Const::invalidParticle) || (pdg == 0);
456  return isValid;
457  }
458 
462  std::string getThetaVarName() const
463  {
464  return m_thetaVarName;
465  }
466 
467 
472  {
473  return m_implicitNaNmasking;
474  }
475 
476 
477  private:
478 
479 
488  int findBin(const double& x, const double& y, const double& z) const
489  {
490 
491  int nbinsx_vis = m_categories->GetXaxis()->GetNbins();
492  int nbinsy_vis = m_categories->GetYaxis()->GetNbins();
493  int nbinsz_vis = m_categories->GetZaxis()->GetNbins();
494 
495  double xx = x;
496  double yy = y;
497  double zz = z;
498 
499  // If x, y, z are outside of the 3D grid (visible) range, set their value to
500  // fall in the last (first) bin before (after) overflow (underflow).
501  if (x < m_categories->GetXaxis()->GetBinLowEdge(1)) { xx = m_categories->GetXaxis()->GetBinCenter(1); }
502  if (x >= m_categories->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = m_categories->GetXaxis()->GetBinCenter(nbinsx_vis); }
503  if (y < m_categories->GetYaxis()->GetBinLowEdge(1)) { yy = m_categories->GetYaxis()->GetBinCenter(1); }
504  if (y >= m_categories->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = m_categories->GetYaxis()->GetBinCenter(nbinsy_vis); }
505  if (z < m_categories->GetZaxis()->GetBinLowEdge(1)) { zz = m_categories->GetZaxis()->GetBinCenter(1); }
506  if (z >= m_categories->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz = m_categories->GetZaxis()->GetBinCenter(nbinsz_vis); }
507 
508  int nbinsx = m_categories->GetXaxis()->GetNbins() + 2;
509  int nbinsy = m_categories->GetYaxis()->GetNbins() + 2;
510 
511  int j = m_categories->GetXaxis()->FindBin(xx);
512  int i = m_categories->GetYaxis()->FindBin(yy);
513  int k = m_categories->GetZaxis()->FindBin(zz);
514 
515  return j + nbinsx * (i + nbinsy * k);
516  }
517 
518 
519  private:
520 
521 
522  TParameter<double> m_energy_unit;
523  TParameter<double> m_ang_unit;
524  std::string
533  std::unique_ptr<TH3F> m_categories;
534 
535 
545  { 0, std::vector<std::string>() },
546  { Const::electron.getPDGCode(), std::vector<std::string>() },
547  { Const::muon.getPDGCode(), std::vector<std::string>() },
548  { Const::pion.getPDGCode(), std::vector<std::string>() },
549  { Const::kaon.getPDGCode(), std::vector<std::string>() },
550  { Const::proton.getPDGCode(), std::vector<std::string>() },
551  { Const::deuteron.getPDGCode(), std::vector<std::string>() }
552  };
553 
554 
564  { 0, std::vector<std::string>() },
565  { Const::electron.getPDGCode(), std::vector<std::string>() },
566  { Const::muon.getPDGCode(), std::vector<std::string>() },
567  { Const::pion.getPDGCode(), std::vector<std::string>() },
568  { Const::kaon.getPDGCode(), std::vector<std::string>() },
569  { Const::proton.getPDGCode(), std::vector<std::string>() },
570  { Const::deuteron.getPDGCode(), std::vector<std::string>() }
571  };
572 
573 
578 
579 
591  };
592 
594 }
Class to contain the payload of MVA weightfiles needed for charged particle identification.
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed triplet (clusterTheta(theta),...
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
VariablesByAlias m_aliases
A map that associates variable aliases used in the MVA training to variable names known to the Variab...
const VariablesByAlias * getAliases() const
Get the map of unique aliases.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
std::unique_ptr< TH3F > m_categories
A 3D histogram whose bins represent the categories for which XML weight files are defined.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
std::map< std::string, std::string > VariablesByAlias
Typedef.
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
void setWeightCategories(const double *clusterThetaBins, const int nClusterThetaBins, const double *pBins, const int nPBins, const double *chargeBins, const int nChargeBins)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
std::string getThetaVarName() const
Get the name of the polar angle variable.
std::string m_thetaVarName
The name of the polar angle variable used in the MVA categorisation.
bool hasImplicitNaNmasking() const
Check flag for implicit NaN masking.
void dumpPayload(const double &theta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
void storeAliases(const VariablesByAlias &aliases)
Store the map associating variable aliases to variable names knowm to VariableManager.
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
const TH3F * getWeightCategories() const
Get the raw pointer to the 3D grid representing the categories for which weightfiles are defined.
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each category...
bool m_implicitNaNmasking
Flag to indicate whther the MVA variables have been NaN-masked directly in the weightfiles.
ChargedPidMVAWeights(const double &energyUnit, const double &angUnit, const std::string &thetaVarName="clusterTheta", bool implictNaNmasking=false)
Specialized constructor.
int findBin(const double &x, const double &y, const double &z) const
Find global bin index of the 3D categories histogram for the given (x, y, z) values.
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
ClassDef(ChargedPidMVAWeights, 10)
2: add energy/angular units.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D bin coordinates.
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
Definition: Const.h:562
int getPDGCode() const
PDG code.
Definition: Const.h:464
static const ChargedStable muon
muon particle
Definition: Const.h:651
static const ParticleSet chargedStableSet
set of charged stable particles
Definition: Const.h:609
static const ChargedStable pion
charged pion particle
Definition: Const.h:652
static const ChargedStable proton
proton particle
Definition: Const.h:654
static const ParticleType invalidParticle
Invalid particle, used internally.
Definition: Const.h:672
static const ChargedStable kaon
charged kaon particle
Definition: Const.h:653
static const ChargedStable electron
electron particle
Definition: Const.h:650
static const ChargedStable deuteron
deuteron particle
Definition: Const.h:655
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:240
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:217
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:185
The Unit class.
Definition: Unit.h:40
Abstract base class for different kinds of events.