Belle II Software  release-05-01-25
ChargedPidMVAWeights.h
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2019 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Marco Milesi *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #pragma once
12 
13 // FRAMEWORK
14 #include <framework/gearbox/Const.h>
15 #include <framework/gearbox/Unit.h>
16 #include <framework/logging/Logger.h>
17 
18 // MVA
19 #include <mva/interface/Weightfile.h>
20 
21 // ROOT
22 #include <TObject.h>
23 #include <TH2F.h>
24 #include <TParameter.h>
25 #include <TFile.h>
26 
27 //BOOST
28 #include <boost/algorithm/string/predicate.hpp>
29 
30 
31 namespace Belle2 {
40  class ChargedPidMVAWeights : public TObject {
41 
42  typedef std::unordered_map<int, std::vector<std::string> > WeightfilesByParticle;
44  public:
45 
50  m_energy_unit("energyUnit", Unit::rad),
51  m_ang_unit("angularUnit", Unit::GeV)
52  {};
53 
54 
59 
63  enum class ChargedPidMVATrainingMode : unsigned int {
65  c_Classification = 0,
67  c_Multiclass = 1,
75  c_PSD_Multiclass = 5,
80  };
81 
82 
86  void setEnergyUnit(const double& unit) { m_energy_unit.SetVal(unit); }
87 
88 
92  void setAngularUnit(const double& unit) { m_ang_unit.SetVal(unit); }
93 
94 
99  void setWeightCategories(TH2F* h)
100  {
101  m_categories = h;
102  }
103 
113  void storeMVAWeights(const int pdg, const std::vector<std::string>& filepaths,
114  const std::vector<std::pair<float, float>>& categoryBinCentres)
115  {
116 
117  if (!isValidPdg(pdg)) {
118  B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
119  }
120 
121  unsigned int idx(0);
122  for (const auto& path : filepaths) {
123 
124  // Index consistency check.
125  auto theta_p = categoryBinCentres.at(idx);
126  auto h_idx = getMVAWeightIdx(theta_p.first, theta_p.second);
127  if (idx != h_idx) {
128  B2FATAL("xml file:\n" << path << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
129  "\n, i.e. the linearised index of the 2D bin centered in (clusterTheta, p) = (" << theta_p.first << ", " << theta_p.second <<
130  ")\nPlease check how the input xml file list is being filled.");
131  }
132 
133  Belle2::MVA::Weightfile weightfile;
134  if (boost::ends_with(path, ".root")) {
136  } else if (boost::ends_with(path, ".xml")) {
137  weightfile = Belle2::MVA::Weightfile::loadFromXMLFile(path);
138  } else {
139  B2WARNING("Unkown file extension for file: " << path << ", fallback to xml...");
140  weightfile = Belle2::MVA::Weightfile::loadFromXMLFile(path);
141  }
142 
143  // Serialize the MVA::Weightfile object into a string for storage in the database,
144  // otherwise there are issues w/ dictionary generation for the payload class...
145  std::stringstream ss;
147  m_weightfiles[pdg].push_back(ss.str());
148 
149  ++idx;
150  }
151 
152  }
153 
154 
163  void storeMVAWeightsMultiClass(const std::vector<std::string>& filepaths,
164  const std::vector<std::pair<float, float>>& categoryBinCentres)
165  {
166  storeMVAWeights(0, filepaths, categoryBinCentres);
167  }
168 
169 
180  void storeCuts(const int pdg, const std::vector<std::string>& cutfiles,
181  const std::vector<std::pair<float, float>>& categoryBinCentres)
182  {
183 
184  if (!isValidPdg(pdg)) {
185  B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
186  }
187 
188  unsigned int idx(0);
189  for (const auto& cutfile : cutfiles) {
190 
191  // Index consistency check.
192  auto theta_p = categoryBinCentres.at(idx);
193  auto h_idx = getMVAWeightIdx(theta_p.first, theta_p.second);
194  if (idx != h_idx) {
195  B2FATAL("Cut file:\n" << cutfile << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
196  "\n, i.e. the linearised index of the 2D bin centered in (clusterTheta, p) = (" << theta_p.first << ", " << theta_p.second <<
197  ")\nPlease check how the input cut file list is being filled.");
198  }
199 
200  std::ifstream ifs(cutfile);
201  std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
202 
203  // Strip trailing newline.
204  cut.erase(std::remove(cut.begin(), cut.end(), '\n'), cut.end());
205 
206  // Conditional expression separator must use square brackets in basf2.
207  std::replace(cut.begin(), cut.end(), '(', '[');
208  std::replace(cut.begin(), cut.end(), ')', ']');
209 
210  m_cuts[pdg].push_back(cut);
211 
212  ++idx;
213  }
214 
215  }
216 
226  void storeCutsMultiClass(const std::vector<std::string>& cutfiles, const std::vector<std::pair<float, float>>& categoryBinCentres)
227  {
228  storeCuts(0, cutfiles, categoryBinCentres);
229  }
230 
231 
237  const std::vector<std::string>* getMVAWeights(const int pdg) const
238  {
239  return &(m_weightfiles.at(pdg));
240  }
241 
242 
247  const std::vector<std::string>* getMVAWeightsMulticlass() const
248  {
249  return getMVAWeights(0);
250  }
251 
252 
259  const std::vector<std::string>* getCuts(const int pdg) const
260  {
261  return &(m_cuts.at(pdg));
262  }
263 
264 
269  const std::vector<std::string>* getCutsMulticlass() const
270  {
271  return getCuts(0);
272  }
273 
274 
287  unsigned int getMVAWeightIdx(const double& theta, const double& p, int& jth, int& ip) const
288  {
289 
290  if (!m_categories) {
291  B2FATAL("No (clusterTheta, p) TH2 grid was found in the DB payload. This should not happen! Abort...");
292  }
293 
294  int nbins_th = m_categories->GetXaxis()->GetNbins(); // nr. of theta (visible) bins, along X.
295 
296  int glob_bin_idx = findBin(m_categories, theta / m_ang_unit.GetVal(), p / m_energy_unit.GetVal());
297  int k;
298  m_categories->GetBinXYZ(glob_bin_idx, jth, ip, k);
299 
300  // The index of the linearised 2D (theta,p) m_categories.
301  // The unit offset is b/c ROOT sets global bin idx also for overflows and underflows.
302  return (jth - 1) + nbins_th * (ip - 1);
303 
304  }
305 
306 
310  unsigned int getMVAWeightIdx(const double& theta, const double& p) const
311  {
312  int jth, ip;
313  return getMVAWeightIdx(theta, p, jth, ip);
314  }
315 
316 
325  void dumpPayload(const double& theta, const double& p, const int pdg, bool dump_all = false) const
326  {
327 
328  B2INFO("Dumping payload content for...");
329  B2INFO("-) clusterTheta = " << theta << " [rad]");
330  B2INFO("-) p = " << p << " [GeV/c]");
331 
332  if (m_categories) {
333  std::string filename = "db_payload_chargedpidmva__clustertheta_p_categories.root";
334  B2INFO("\tWriting ROOT file w/ (clusterTheta, p) TH2F grid that defines categories:" << filename);
335  auto f = std::make_unique<TFile>(filename.c_str(), "RECREATE");
336  m_categories->Write();
337  f->Close();
338  } else {
339  B2WARNING("\tThe TH2F object that defines categories is a nullptr!");
340  }
341 
342  for (const auto& [pdgId, weights] : m_weightfiles) {
343 
344  if (!dump_all && pdg != pdgId) continue;
345 
346  B2INFO("-) pdgId = " << pdgId);
347 
348  auto idx = getMVAWeightIdx(theta, p);
349 
350  auto serialized_weightfile = weights.at(idx);
351 
352  std::string filename = "db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
353  "_glob_bin_" + std::to_string(idx + 1) + ".xml";
354 
355  auto cutstr = getCuts(pdgId)->at(idx);
356 
357  B2INFO("\tCut: " << cutstr);
358  B2INFO("\tWriting weight file: " << filename);
359 
360  std::ofstream weightfile;
361  weightfile.open(filename.c_str(), std::ios::out);
362  weightfile << serialized_weightfile << std::endl;
363  weightfile.close();
364 
365  }
366 
367  };
368 
369 
373  void dumpPayloadMulticlass(const double& theta, const double& p) const
374  {
375  dumpPayload(theta, p, 0);
376  }
377 
378 
383  bool isValidPdg(const int pdg) const
384  {
385  bool isValid = (Const::chargedStableSet.find(pdg) != Const::invalidParticle) || (pdg == 0);
386  return isValid;
387  }
388 
389 
390  private:
391 
392 
401  int findBin(const TH2F* h, const double& x, const double& y) const
402  {
403 
404  int nbinsx_vis = h->GetXaxis()->GetNbins();
405  int nbinsy_vis = h->GetYaxis()->GetNbins();
406 
407  double xx = x;
408  double yy = y;
409 
410  // If x, y are outside of the 2D hogram grid (visible) range, set their value to
411  // fall in the last (first) bin before (after) overflow (underflow).
412  if (x < h->GetXaxis()->GetBinLowEdge(1)) { xx = h->GetXaxis()->GetBinCenter(1); }
413  if (x >= h->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = h->GetXaxis()->GetBinCenter(nbinsx_vis); }
414  if (y < h->GetYaxis()->GetBinLowEdge(1)) { yy = h->GetYaxis()->GetBinCenter(1); }
415  if (y >= h->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = h->GetYaxis()->GetBinCenter(nbinsy_vis); }
416 
417  int nbinsx = h->GetXaxis()->GetNbins() + 2;
418  int j = h->GetXaxis()->FindBin(xx);
419  int i = h->GetYaxis()->FindBin(yy);
420 
421  return j + nbinsx * i;
422  }
423 
424 
425  private:
426 
427 
428  TParameter<double> m_energy_unit;
429  TParameter<double> m_ang_unit;
436  TH2F* m_categories = nullptr;
437 
438 
448  { 0, std::vector<std::string>() },
449  { Const::electron.getPDGCode(), std::vector<std::string>() },
450  { Const::muon.getPDGCode(), std::vector<std::string>() },
451  { Const::pion.getPDGCode(), std::vector<std::string>() },
452  { Const::kaon.getPDGCode(), std::vector<std::string>() },
453  { Const::proton.getPDGCode(), std::vector<std::string>() },
454  { Const::deuteron.getPDGCode(), std::vector<std::string>() }
455  };
456 
457 
467  { 0, std::vector<std::string>() },
468  { Const::electron.getPDGCode(), std::vector<std::string>() },
469  { Const::muon.getPDGCode(), std::vector<std::string>() },
470  { Const::pion.getPDGCode(), std::vector<std::string>() },
471  { Const::kaon.getPDGCode(), std::vector<std::string>() },
472  { Const::proton.getPDGCode(), std::vector<std::string>() },
473  { Const::deuteron.getPDGCode(), std::vector<std::string>() }
474  };
475 
476 
483  };
484 
486 }
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_Classification
@ c_Classification
Binary classification.
Belle2::ChargedPidMVAWeights::~ChargedPidMVAWeights
~ChargedPidMVAWeights()
Destructor.
Definition: ChargedPidMVAWeights.h:66
Belle2::ChargedPidMVAWeights::dumpPayload
void dumpPayload(const double &theta, const double &p, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
Definition: ChargedPidMVAWeights.h:333
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_Multiclass
@ c_Multiclass
Multi-class classification.
Belle2::ChargedPidMVAWeights::dumpPayloadMulticlass
void dumpPayloadMulticlass(const double &theta, const double &p) const
Special version for multi-class mode.
Definition: ChargedPidMVAWeights.h:381
Belle2::Const::invalidParticle
static const ParticleType invalidParticle
Invalid particle, used internally.
Definition: Const.h:554
Belle2::Const::electron
static const ChargedStable electron
electron particle
Definition: Const.h:533
Belle2::ChargedPidMVAWeights::ChargedPidMVAWeights
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
Definition: ChargedPidMVAWeights.h:57
Belle2::ChargedPidMVAWeights::storeMVAWeightsMultiClass
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::pair< float, float >> &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
Definition: ChargedPidMVAWeights.h:171
Belle2::MVA::Weightfile
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:40
Belle2::TrackFindingCDC::NForwardBackward::isValid
bool isValid(EForwardBackward eForwardBackward)
Check whether the given enum instance is one of the valid values.
Definition: EForwardBackward.h:55
Belle2::Const::chargedStableSet
static const ParticleSet chargedStableSet
set of charged stable particles
Definition: Const.h:494
Belle2::ChargedPidMVAWeights::storeCutsMultiClass
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::pair< float, float >> &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
Definition: ChargedPidMVAWeights.h:234
Belle2::Const::ParticleType::getPDGCode
int getPDGCode() const
PDG code.
Definition: Const.h:349
Belle2::ChargedPidMVAWeights::m_weightfiles
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
Definition: ChargedPidMVAWeights.h:455
Belle2::ChargedPidMVAWeights::getCutsMulticlass
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each (cluster...
Definition: ChargedPidMVAWeights.h:277
Belle2::ChargedPidMVAWeights::m_cuts
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
Definition: ChargedPidMVAWeights.h:474
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_ECL_Classification
@ c_ECL_Classification
Binary classification, ECL only.
Belle2::ChargedPidMVAWeights::getMVAWeightIdx
unsigned int getMVAWeightIdx(const double &theta, const double &p, int &jth, int &ip) const
Get the index of the XML weight file, for a given reconstructed pair (clusterTheta,...
Definition: ChargedPidMVAWeights.h:295
Belle2::ChargedPidMVAWeights::getMVAWeights
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
Definition: ChargedPidMVAWeights.h:245
Belle2::Const::kaon
static const ChargedStable kaon
charged kaon particle
Definition: Const.h:536
Belle2::ChargedPidMVAWeights::getCuts
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
Definition: ChargedPidMVAWeights.h:267
Belle2::ChargedPidMVAWeights::getMVAWeightsMulticlass
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
Definition: ChargedPidMVAWeights.h:255
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_PSD_Multiclass
@ c_PSD_Multiclass
Multi-class classification, including PSD.
Belle2::ChargedPidMVAWeights::findBin
int findBin(const TH2F *h, const double &x, const double &y) const
Find global bin index of a 2D histogram for the given (x, y) values.
Definition: ChargedPidMVAWeights.h:409
Belle2::Const::pion
static const ChargedStable pion
charged pion particle
Definition: Const.h:535
Belle2::ChargedPidMVAWeights::m_categories
TH2F * m_categories
A 2D (clusterTheta, p) histogram whose bins represent the categories for which XML weight files are d...
Definition: ChargedPidMVAWeights.h:444
Belle2::MVA::Weightfile::loadFromROOTFile
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:226
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_ECL_PSD_Classification
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
Belle2::ChargedPidMVAWeights::m_energy_unit
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
Definition: ChargedPidMVAWeights.h:436
Belle2::Const::ParticleSet::find
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
Definition: Const.h:447
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::ChargedPidMVAWeights::m_ang_unit
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
Definition: ChargedPidMVAWeights.h:437
Belle2::Const::deuteron
static const ChargedStable deuteron
deuteron particle
Definition: Const.h:538
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_ECL_PSD_Multiclass
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
Belle2::MVA::Weightfile::loadFromXMLFile
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:249
Belle2::ChargedPidMVAWeights::setWeightCategories
void setWeightCategories(TH2F *h)
Set the 2D (clusterTheta, p) grid representing the categories for which weightfiles are defined.
Definition: ChargedPidMVAWeights.h:107
Belle2::MVA::Weightfile::saveToStream
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:194
Belle2::ChargedPidMVAWeights::ClassDef
ClassDef(ChargedPidMVAWeights, 5)
2: add energy/angular units.
Belle2::ChargedPidMVAWeights::WeightfilesByParticle
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
Definition: ChargedPidMVAWeights.h:50
Belle2::ChargedPidMVAWeights::setEnergyUnit
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
Definition: ChargedPidMVAWeights.h:94
Belle2::Const::proton
static const ChargedStable proton
proton particle
Definition: Const.h:537
Belle2::ChargedPidMVAWeights
Class to contain the payload of MVA weightfiles needed for charged particle identification.
Definition: ChargedPidMVAWeights.h:48
Belle2::Const::muon
static const ChargedStable muon
muon particle
Definition: Const.h:534
Belle2::Unit
The Unit class.
Definition: Unit.h:50
Belle2::ChargedPidMVAWeights::storeCuts
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::pair< float, float >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
Definition: ChargedPidMVAWeights.h:188
Belle2::ChargedPidMVAWeights::isValidPdg
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
Definition: ChargedPidMVAWeights.h:391
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
Definition: ChargedPidMVAWeights.h:71
Belle2::ChargedPidMVAWeights::storeMVAWeights
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::pair< float, float >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
Definition: ChargedPidMVAWeights.h:121
Belle2::ChargedPidMVAWeights::setAngularUnit
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
Definition: ChargedPidMVAWeights.h:100
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_ECL_Multiclass
@ c_ECL_Multiclass
Multi-class classification, ECL only.
Belle2::ChargedPidMVAWeights::ChargedPidMVATrainingMode::c_PSD_Classification
@ c_PSD_Classification
Binary classification, including PSD.