Belle II Software development
ChargedPidMVAWeights.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11// FRAMEWORK
12#include <framework/gearbox/Const.h>
13#include <framework/gearbox/Unit.h>
14#include <framework/logging/Logger.h>
15
16// MVA
17#include <mva/interface/Weightfile.h>
18
19// ROOT
20#include <TObject.h>
21#include <TH3F.h>
22#include <TParameter.h>
23#include <TFile.h>
24
25namespace Belle2 {
30
34 class ChargedPidMVAWeights : public TObject {
35
36 typedef std::unordered_map<int, std::vector<std::string> > WeightfilesByParticle;
37 typedef std::map<std::string, std::string> VariablesByAlias;
38
39 public:
40
45 m_energy_unit("energyUnit", Unit::GeV),
46 m_ang_unit("angularUnit", Unit::rad),
47 m_thetaVarName("clusterTheta"),
49 {};
50
51
55 ChargedPidMVAWeights(const double& energyUnit, const double& angUnit,
56 const std::string& thetaVarName = "clusterTheta",
57 bool implictNaNmasking = false)
58 {
59 setEnergyUnit(energyUnit);
60 setAngularUnit(angUnit);
61 m_thetaVarName = thetaVarName;
62 m_implicitNaNmasking = implictNaNmasking;
63 }
64
69
91
92
96 void setEnergyUnit(const double& unit) { m_energy_unit.SetVal(unit); }
97
98
102 void setAngularUnit(const double& unit) { m_ang_unit.SetVal(unit); }
103
113 void setWeightCategories(const double* clusterThetaBins, const int nClusterThetaBins,
114 const double* pBins, const int nPBins,
115 const double* chargeBins, const int nChargeBins)
116 {
117
118 m_categories = std::make_unique<TH3F>("clustertheta_p_charge_binsgrid",
119 ";ECL cluster #theta;p_{lab};Q",
120 nClusterThetaBins, clusterThetaBins,
121 nPBins, pBins,
122 nChargeBins, chargeBins);
123 }
124
134 void storeMVAWeights(const int pdg, const std::vector<std::string>& filepaths,
135 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
136 {
137
138 if (!isValidPdg(pdg)) {
139 B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
140 }
141
142 unsigned int idx(0);
143 for (const auto& path : filepaths) {
144
145 // Index consistency check.
146 auto bin_centres_tuple = categoryBinCentres.at(idx);
147
148 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
149 auto p_bin_centre = std::get<1>(bin_centres_tuple);
150 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
151
152 auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
153 if (idx != h_idx) {
154 B2FATAL("xml file:\n" << path << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
155 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
156 << ", " <<
157 charge_bin_centre <<
158 ")\nPlease check how the input xml file list is being filled.");
159 }
160
161 MVA::Weightfile weightfile;
162 if (path.ends_with(".root")) {
163 weightfile = MVA::Weightfile::loadFromROOTFile(path);
164 } else if (path.ends_with(".xml")) {
165 weightfile = MVA::Weightfile::loadFromXMLFile(path);
166 } else {
167 B2WARNING("Unknown file extension for file: " << path << ", fallback to xml...");
168 weightfile = MVA::Weightfile::loadFromXMLFile(path);
169 }
170
171 // Serialize the MVA::Weightfile object into a string for storage in the database,
172 // otherwise there are issues w/ dictionary generation for the payload class...
173 std::stringstream ss;
174 MVA::Weightfile::saveToStream(weightfile, ss);
175 m_weightfiles[pdg].push_back(ss.str());
176
177 ++idx;
178 }
179
180 }
181
182
192 void storeMVAWeightsMultiClass(const std::vector<std::string>& filepaths,
193 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
194 {
195 storeMVAWeights(0, filepaths, categoryBinCentres);
196 }
197
198
209 void storeCuts(const int pdg, const std::vector<std::string>& cutfiles,
210 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
211 {
212
213 if (!isValidPdg(pdg)) {
214 B2FATAL("PDG: " << pdg << " is not that of a valid charged particle! Aborting...");
215 }
216
217 unsigned int idx(0);
218 for (const auto& cutfile : cutfiles) {
219
220 auto bin_centres_tuple = categoryBinCentres.at(idx);
221
222 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
223 auto p_bin_centre = std::get<1>(bin_centres_tuple);
224 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
225
226 auto h_idx = getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
227 if (idx != h_idx) {
228 B2FATAL("Cut file:\n" << cutfile << "\nindex in input vector:\n" << idx << "\ndoes not correspond to:\n" << h_idx <<
229 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre << ", " << p_bin_centre
230 << ", " <<
231 charge_bin_centre <<
232 ")\nPlease check how the input cut file list is being filled.");
233 }
234
235 std::ifstream ifs(cutfile);
236 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
237
238 // Strip trailing newline.
239 cut.erase(std::remove(cut.begin(), cut.end(), '\n'), cut.end());
240
241 m_cuts[pdg].push_back(cut);
242
243 ++idx;
244 }
245
246 }
247
258 void storeCutsMultiClass(const std::vector<std::string>& cutfiles,
259 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
260 {
261 storeCuts(0, cutfiles, categoryBinCentres);
262 }
263
264
270 void storeAliases(const VariablesByAlias& aliases)
271 {
272 m_aliases = VariablesByAlias(aliases);
273 }
274
275
280 const TH3F* getWeightCategories() const
281 {
282 return m_categories.get();
283 }
284
285
291 const std::vector<std::string>* getMVAWeights(const int pdg) const
292 {
293 return &(m_weightfiles.at(pdg));
294 }
295
296
302 const std::vector<std::string>* getMVAWeightsMulticlass() const
303 {
304 return getMVAWeights(0);
305 }
306
307
313 const std::vector<std::string>* getCuts(const int pdg) const
314 {
315 return &(m_cuts.at(pdg));
316 }
317
318
324 const std::vector<std::string>* getCutsMulticlass() const
325 {
326 return getCuts(0);
327 }
328
329
334 {
335 return &m_aliases;
336 }
337
338
353 unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge, int& idx_theta, int& idx_p,
354 int& idx_charge) const
355 {
356
357 if (!m_categories) {
358 B2FATAL("No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
359 }
360
361 int nbins_th = m_categories->GetXaxis()->GetNbins(); // nr. of theta (visible) bins, along X.
362 int nbins_p = m_categories->GetYaxis()->GetNbins(); // nr. of p (visible) bins, along Y.
363
364 int glob_bin_idx = findBin(theta / m_ang_unit.GetVal(), p / m_energy_unit.GetVal(), charge);
365 m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
366
367 // The index of the linearised 3D m_categories.
368 // The unit offset is b/c ROOT sets global bin idx also for overflows and underflows.
369 return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
370 }
371
375 unsigned int getMVAWeightIdx(const double& theta, const double& p, const double& charge) const
376 {
377 int idx_theta, idx_p, idx_charge;
378 return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
379 }
380
381
391 void dumpPayload(const double& theta, const double& p, const double& charge, const int pdg, bool dump_all = false) const
392 {
393
394 B2INFO("Dumping payload content for:");
395 B2INFO("clusterTheta(theta) = " << theta << " [rad], p = " << p << " [GeV/c], charge = " << charge);
396
397 if (m_categories) {
398 std::string filename = "db_payload_chargedpidmva__theta_p_charge_categories.root";
399 B2INFO("\tWriting ROOT file w/ TH3F grid that defines categories:" << filename);
400 auto f = std::make_unique<TFile>(filename.c_str(), "RECREATE");
401 m_categories->Write();
402 f->Close();
403 } else {
404 B2WARNING("\tThe TH3F object that defines categories is a nullptr!");
405 }
406
407 for (const auto& [pdgId, weights] : m_weightfiles) {
408
409 if (!dump_all && pdg != pdgId) continue;
410
411 auto idx = getMVAWeightIdx(theta, p, charge);
412
413 auto serialized_weightfile = weights.at(idx);
414
415 std::string filename = "db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
416 "_glob_bin_" + std::to_string(idx + 1) + ".xml";
417
418 auto cutstr = getCuts(pdgId)->at(idx);
419
420 B2INFO("\tpdgId = " << pdgId);
421 B2INFO("\tCut: " << cutstr);
422 B2INFO("\tWriting weight file: " << filename);
423
424 std::ofstream weightfile;
425 weightfile.open(filename.c_str(), std::ios::out);
426 weightfile << serialized_weightfile << std::endl;
427 weightfile.close();
428
429 }
430
431 };
432
433
438 void dumpPayloadMulticlass(const double& theta, const double& p, const double& charge) const
439 {
440 dumpPayload(theta, p, charge, 0);
441 }
442
443
448 bool isValidPdg(const int pdg) const
449 {
450 bool isValid = (Const::chargedStableSet.find(pdg) != Const::invalidParticle) || (pdg == 0);
451 return isValid;
452 }
453
457 std::string getThetaVarName() const
458 {
459 return m_thetaVarName;
460 }
461
462
467 {
469 }
470
471
472 private:
473
474
483 int findBin(const double& x, const double& y, const double& z) const
484 {
485
486 int nbinsx_vis = m_categories->GetXaxis()->GetNbins();
487 int nbinsy_vis = m_categories->GetYaxis()->GetNbins();
488 int nbinsz_vis = m_categories->GetZaxis()->GetNbins();
489
490 double xx = x;
491 double yy = y;
492 double zz = z;
493
494 // If x, y, z are outside of the 3D grid (visible) range, set their value to
495 // fall in the last (first) bin before (after) overflow (underflow).
496 if (x < m_categories->GetXaxis()->GetBinLowEdge(1)) { xx = m_categories->GetXaxis()->GetBinCenter(1); }
497 if (x >= m_categories->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = m_categories->GetXaxis()->GetBinCenter(nbinsx_vis); }
498 if (y < m_categories->GetYaxis()->GetBinLowEdge(1)) { yy = m_categories->GetYaxis()->GetBinCenter(1); }
499 if (y >= m_categories->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = m_categories->GetYaxis()->GetBinCenter(nbinsy_vis); }
500 if (z < m_categories->GetZaxis()->GetBinLowEdge(1)) { zz = m_categories->GetZaxis()->GetBinCenter(1); }
501 if (z >= m_categories->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz = m_categories->GetZaxis()->GetBinCenter(nbinsz_vis); }
502
503 int nbinsx = m_categories->GetXaxis()->GetNbins() + 2;
504 int nbinsy = m_categories->GetYaxis()->GetNbins() + 2;
505
506 int j = m_categories->GetXaxis()->FindBin(xx);
507 int i = m_categories->GetYaxis()->FindBin(yy);
508 int k = m_categories->GetZaxis()->FindBin(zz);
509
510 return j + nbinsx * (i + nbinsy * k);
511 }
512
513
514 private:
515
516
517 TParameter<double> m_energy_unit;
518 TParameter<double> m_ang_unit;
519 std::string
522
523
528 std::unique_ptr<TH3F> m_categories;
529
530
540 { 0, std::vector<std::string>() },
541 { Const::electron.getPDGCode(), std::vector<std::string>() },
542 { Const::muon.getPDGCode(), std::vector<std::string>() },
543 { Const::pion.getPDGCode(), std::vector<std::string>() },
544 { Const::kaon.getPDGCode(), std::vector<std::string>() },
545 { Const::proton.getPDGCode(), std::vector<std::string>() },
546 { Const::deuteron.getPDGCode(), std::vector<std::string>() }
547 };
548
549
559 { 0, std::vector<std::string>() },
560 { Const::electron.getPDGCode(), std::vector<std::string>() },
561 { Const::muon.getPDGCode(), std::vector<std::string>() },
562 { Const::pion.getPDGCode(), std::vector<std::string>() },
563 { Const::kaon.getPDGCode(), std::vector<std::string>() },
564 { Const::proton.getPDGCode(), std::vector<std::string>() },
565 { Const::deuteron.getPDGCode(), std::vector<std::string>() }
566 };
567
568
573
574
586 };
587
589}
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each category...
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed triplet (clusterTheta(theta),...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
VariablesByAlias m_aliases
A map that associates variable aliases used in the MVA training to variable names known to the Variab...
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
const TH3F * getWeightCategories() const
Get the raw pointer to the 3D grid representing the categories for which weightfiles are defined.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
std::unique_ptr< TH3F > m_categories
A 3D histogram whose bins represent the categories for which XML weight files are defined.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
std::map< std::string, std::string > VariablesByAlias
Typedef.
const VariablesByAlias * getAliases() const
Get the map of unique aliases.
void setWeightCategories(const double *clusterThetaBins, const int nClusterThetaBins, const double *pBins, const int nPBins, const double *chargeBins, const int nChargeBins)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
std::string getThetaVarName() const
Get the name of the polar angle variable.
std::string m_thetaVarName
The name of the polar angle variable used in the MVA categorisation.
bool hasImplicitNaNmasking() const
Check flag for implicit NaN masking.
void dumpPayload(const double &theta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void storeAliases(const VariablesByAlias &aliases)
Store the map associating variable aliases to variable names knowm to VariableManager.
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
bool m_implicitNaNmasking
Flag to indicate whether the MVA variables have been NaN-masked directly in the weightfiles.
ChargedPidMVAWeights(const double &energyUnit, const double &angUnit, const std::string &thetaVarName="clusterTheta", bool implictNaNmasking=false)
Specialized constructor.
int findBin(const double &x, const double &y, const double &z) const
Find global bin index of the 3D categories histogram for the given (x, y, z) values.
ClassDef(ChargedPidMVAWeights, 10)
2: add energy/angular units.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D bin coordinates.
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
static const ChargedStable muon
muon particle
Definition Const.h:660
static const ParticleSet chargedStableSet
set of charged stable particles
Definition Const.h:618
static const ChargedStable pion
charged pion particle
Definition Const.h:661
static const ChargedStable proton
proton particle
Definition Const.h:663
static const ParticleType invalidParticle
Invalid particle, used internally.
Definition Const.h:681
static const ChargedStable kaon
charged kaon particle
Definition Const.h:662
static const ChargedStable electron
electron particle
Definition Const.h:659
static const ChargedStable deuteron
deuteron particle
Definition Const.h:664
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
The Unit class.
Definition Unit.h:40
Abstract base class for different kinds of events.