Belle II Software  release-06-02-00
ChargedPidMVAWeights.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 
11 // FRAMEWORK
12 #include <framework/gearbox/Const.h>
13 #include <framework/gearbox/Unit.h>
14 #include <framework/logging/Logger.h>
15 
16 // MVA
17 #include <mva/interface/Weightfile.h>
18 
19 // ROOT
20 #include <TObject.h>
21 #include <TH3F.h>
22 #include <TParameter.h>
23 #include <TFile.h>
24 
25 //BOOST
26 #include <boost/algorithm/string/predicate.hpp>
27 
28 
29 namespace Belle2 {
38  class ChargedPidMVAWeights : public TObject {
39 
40  typedef std::unordered_map<int, std::vector<std::string> > WeightfilesByParticle;
42  public:
43 
48  m_energy_unit("energyUnit", Unit::GeV),
49  m_ang_unit("angularUnit", Unit::rad)
50  {};
51 
52 
57 
61  enum class ChargedPidMVATrainingMode : unsigned int {
63  c_Classification = 0,
65  c_Multiclass = 1,
69  c_ECL_Multiclass = 3,
73  c_PSD_Multiclass = 5,
78  };
79 
80 
84  void setEnergyUnit(const double& unit) { m_energy_unit.SetVal(unit); }
85 
86 
90  void setAngularUnit(const double& unit) { m_ang_unit.SetVal(unit); }
91 
92 
97  void setWeightCategories(TH3F* h)
98  {
99  m_categories = h;
100  }
101 
111  void storeMVAWeights(const int pdg, const std::vector<std::string>& filepaths,
112  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
113  {
114 
115  if (!isValidPdg(pdg)) {
116  B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
117  }
118 
119  unsigned int idx(0);
120  for (const auto& path : filepaths) {
121 
122  // Index consistency check.
123  auto bin_centres_tuple = categoryBinCentres.at(idx);
124 
125  auto theta_bin_centre = std::get<0>(bin_centres_tuple);
126  auto p_bin_centre = std::get<1>(bin_centres_tuple);
127  auto charge_bin_centre = std::get<2>(bin_centres_tuple);
128 
129  auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
130  if (idx != h_idx) {
131  B2FATAL("xml file:\n" << path << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
132  "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
133  << ", " <<
134  charge_bin_centre <<
135  ")\nPlease check how the input xml file list is being filled.");
136  }
137 
138  Belle2::MVA::Weightfile weightfile;
139  if (boost::ends_with(path, ".root")) {
141  } else if (boost::ends_with(path, ".xml")) {
142  weightfile = Belle2::MVA::Weightfile::loadFromXMLFile(path);
143  } else {
144  B2WARNING("Unknown file extension for file: " << path << ", fallback to xml...");
145  weightfile = Belle2::MVA::Weightfile::loadFromXMLFile(path);
146  }
147 
148  // Serialize the MVA::Weightfile object into a string for storage in the database,
149  // otherwise there are issues w/ dictionary generation for the payload class...
150  std::stringstream ss;
152  m_weightfiles[pdg].push_back(ss.str());
153 
154  ++idx;
155  }
156 
157  }
158 
159 
169  void storeMVAWeightsMultiClass(const std::vector<std::string>& filepaths,
170  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
171  {
172  storeMVAWeights(0, filepaths, categoryBinCentres);
173  }
174 
175 
186  void storeCuts(const int pdg, const std::vector<std::string>& cutfiles,
187  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
188  {
189 
190  if (!isValidPdg(pdg)) {
191  B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
192  }
193 
194  unsigned int idx(0);
195  for (const auto& cutfile : cutfiles) {
196 
197  auto bin_centres_tuple = categoryBinCentres.at(idx);
198 
199  auto theta_bin_centre = std::get<0>(bin_centres_tuple);
200  auto p_bin_centre = std::get<1>(bin_centres_tuple);
201  auto charge_bin_centre = std::get<2>(bin_centres_tuple);
202 
203  auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
204  if (idx != h_idx) {
205  B2FATAL("Cut file:\n" << cutfile << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
206  "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
207  << ", " <<
208  charge_bin_centre <<
209  ")\nPlease check how the input cut file list is being filled.");
210  }
211 
212  std::ifstream ifs(cutfile);
213  std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
214 
215  // Strip trailing newline.
216  cut.erase(std::remove(cut.begin(), cut.end(), '\n'), cut.end());
217 
218  // Conditional expression separator must use square brackets in basf2.
219  std::replace(cut.begin(), cut.end(), '(', '[');
220  std::replace(cut.begin(), cut.end(), ')', ']');
221 
222  m_cuts[pdg].push_back(cut);
223 
224  ++idx;
225  }
226 
227  }
228 
239  void storeCutsMultiClass(const std::vector<std::string>& cutfiles,
240  const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
241  {
242  storeCuts(0, cutfiles, categoryBinCentres);
243  }
244 
245 
251  const std::vector<std::string>* getMVAWeights(const int pdg) const
252  {
253  return &(m_weightfiles.at(pdg));
254  }
255 
256 
262  const std::vector<std::string>* getMVAWeightsMulticlass() const
263  {
264  return getMVAWeights(0);
265  }
266 
267 
274  const std::vector<std::string>* getCuts(const int pdg) const
275  {
276  return &(m_cuts.at(pdg));
277  }
278 
279 
285  const std::vector<std::string>* getCutsMulticlass() const
286  {
287  return getCuts(0);
288  }
289 
290 
305  unsigned int getMVAWeightIdx(const double& clusterTheta, const double& p, const double& charge, int& idx_theta, int& idx_p,
306  int& idx_charge) const
307  {
308 
309  if (!m_categories) {
310  B2FATAL("No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
311  }
312 
313  int nbins_th = m_categories->GetXaxis()->GetNbins(); // nr. of clusterTheta (visible) bins, along X.
314  int nbins_p = m_categories->GetYaxis()->GetNbins(); // nr. of p (visible) bins, along Y.
315 
316  int glob_bin_idx = findBin(m_categories, clusterTheta / m_ang_unit.GetVal(), p / m_energy_unit.GetVal(), charge);
317  m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
318 
319  // The index of the linearised 3D (clusterTheta, p, charge) m_categories.
320  // The unit offset is b/c ROOT sets global bin idx also for overflows and underflows.
321  return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
322  }
323 
324 
328  unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge) const
329  {
330  int idx_theta, idx_p, idx_charge;
331  return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
332  }
333 
334 
344  void dumpPayload(const double& clusterTheta, const double& p, const double& charge, const int pdg, bool dump_all = false) const
345  {
346 
347  B2INFO("Dumping payload content for:");
348  B2INFO("clusterTheta = " << clusterTheta << " [rad], p = " << p << " [GeV/c], charge = " << charge);
349 
350  if (m_categories) {
351  std::string filename = "db_payload_chargedpidmva__clustertheta_p_charge_categories.root";
352  B2INFO("\tWriting ROOT file w/ (clusterTheta, p, charge) TH3F grid that defines categories:" << filename);
353  auto f = std::make_unique<TFile>(filename.c_str(), "RECREATE");
354  m_categories->Write();
355  f->Close();
356  } else {
357  B2WARNING("\tThe TH3F object that defines categories is a nullptr!");
358  }
359 
360  for (const auto& [pdgId, weights] : m_weightfiles) {
361 
362  if (!dump_all && pdg != pdgId) continue;
363 
364  auto idx = getMVAWeightIdx(clusterTheta, p, charge);
365 
366  auto serialized_weightfile = weights.at(idx);
367 
368  std::string filename = "db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
369  "_glob_bin_" + std::to_string(idx + 1) + ".xml";
370 
371  auto cutstr = getCuts(pdgId)->at(idx);
372 
373  B2INFO("\tpdgId = " << pdgId);
374  B2INFO("\tCut: " << cutstr);
375  B2INFO("\tWriting weight file: " << filename);
376 
377  std::ofstream weightfile;
378  weightfile.open(filename.c_str(), std::ios::out);
379  weightfile << serialized_weightfile << std::endl;
380  weightfile.close();
381 
382  }
383 
384  };
385 
386 
391  void dumpPayloadMulticlass(const double& theta, const double& p, const double& charge) const
392  {
393  dumpPayload(theta, p, charge, 0);
394  }
395 
396 
401  bool isValidPdg(const int pdg) const
402  {
403  bool isValid = (Const::chargedStableSet.find(pdg) != Const::invalidParticle) || (pdg == 0);
404  return isValid;
405  }
406 
407 
408  private:
409 
410 
420  int findBin(const TH3F* h, const double& x, const double& y, const double& z) const
421  {
422 
423  int nbinsx_vis = h->GetXaxis()->GetNbins();
424  int nbinsy_vis = h->GetYaxis()->GetNbins();
425  int nbinsz_vis = h->GetZaxis()->GetNbins();
426 
427  double xx = x;
428  double yy = y;
429  double zz = z;
430 
431  // If x, y, z are outside of the 3D grid (visible) range, set their value to
432  // fall in the last (first) bin before (after) overflow (underflow).
433  if (x < h->GetXaxis()->GetBinLowEdge(1)) { xx = h->GetXaxis()->GetBinCenter(1); }
434  if (x >= h->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = h->GetXaxis()->GetBinCenter(nbinsx_vis); }
435  if (y < h->GetYaxis()->GetBinLowEdge(1)) { yy = h->GetYaxis()->GetBinCenter(1); }
436  if (y >= h->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = h->GetYaxis()->GetBinCenter(nbinsy_vis); }
437  if (z < h->GetZaxis()->GetBinLowEdge(1)) { zz = h->GetZaxis()->GetBinCenter(1); }
438  if (z >= h->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz = h->GetZaxis()->GetBinCenter(nbinsz_vis); }
439 
440  int nbinsx = h->GetXaxis()->GetNbins() + 2;
441  int nbinsy = h->GetYaxis()->GetNbins() + 2;
442 
443  int j = h->GetXaxis()->FindBin(xx);
444  int i = h->GetYaxis()->FindBin(yy);
445  int k = h->GetZaxis()->FindBin(zz);
446 
447  return j + nbinsx * (i + nbinsy * k);
448  }
449 
450 
451  private:
452 
453 
454  TParameter<double> m_energy_unit;
455  TParameter<double> m_ang_unit;
462  TH3F* m_categories = nullptr;
463 
464 
474  { 0, std::vector<std::string>() },
475  { Const::electron.getPDGCode(), std::vector<std::string>() },
476  { Const::muon.getPDGCode(), std::vector<std::string>() },
477  { Const::pion.getPDGCode(), std::vector<std::string>() },
478  { Const::kaon.getPDGCode(), std::vector<std::string>() },
479  { Const::proton.getPDGCode(), std::vector<std::string>() },
480  { Const::deuteron.getPDGCode(), std::vector<std::string>() }
481  };
482 
483 
493  { 0, std::vector<std::string>() },
494  { Const::electron.getPDGCode(), std::vector<std::string>() },
495  { Const::muon.getPDGCode(), std::vector<std::string>() },
496  { Const::pion.getPDGCode(), std::vector<std::string>() },
497  { Const::kaon.getPDGCode(), std::vector<std::string>() },
498  { Const::proton.getPDGCode(), std::vector<std::string>() },
499  { Const::deuteron.getPDGCode(), std::vector<std::string>() }
500  };
501 
502 
511  };
512 
514 }
Class to contain the payload of MVA weightfiles needed for charged particle identification.
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
void setWeightCategories(TH3F *h)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
ClassDef(ChargedPidMVAWeights, 7)
2: add energy/angular units.
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
unsigned int getMVAWeightIdx(const double &clusterTheta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed pair (clusterTheta,...
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
void dumpPayload(const double &clusterTheta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each (cluster...
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D (clusterTheta, p,...
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
int findBin(const TH3F *h, const double &x, const double &y, const double &z) const
Find global bin index of a 3D histogram for the given (x, y, z) values.
TH3F * m_categories
A 3D (clusterTheta, p, charge) histogram whose bins represent the categories for which XML weight fil...
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
Definition: Const.h:452
int getPDGCode() const
PDG code.
Definition: Const.h:354
static const ChargedStable muon
muon particle
Definition: Const.h:541
static const ParticleSet chargedStableSet
set of charged stable particles
Definition: Const.h:499
static const ChargedStable pion
charged pion particle
Definition: Const.h:542
static const ChargedStable proton
proton particle
Definition: Const.h:544
static const ParticleType invalidParticle
Invalid particle, used internally.
Definition: Const.h:561
static const ChargedStable kaon
charged kaon particle
Definition: Const.h:543
static const ChargedStable electron
electron particle
Definition: Const.h:540
static const ChargedStable deuteron
deuteron particle
Definition: Const.h:545
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:239
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:216
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:184
The Unit class.
Definition: Unit.h:40
Abstract base class for different kinds of events.