12 #include <framework/gearbox/Const.h>
13 #include <framework/gearbox/Unit.h>
14 #include <framework/logging/Logger.h>
17 #include <mva/interface/Weightfile.h>
22 #include <TParameter.h>
26 #include <boost/algorithm/string/predicate.hpp>
112 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
116 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
120 for (
const auto& path : filepaths) {
123 auto bin_centres_tuple = categoryBinCentres.at(idx);
125 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
126 auto p_bin_centre = std::get<1>(bin_centres_tuple);
127 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
129 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
131 B2FATAL(
"xml file:\n" << path <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
132 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
135 ")\nPlease check how the input xml file list is being filled.");
139 if (boost::ends_with(path,
".root")) {
141 }
else if (boost::ends_with(path,
".xml")) {
144 B2WARNING(
"Unknown file extension for file: " << path <<
", fallback to xml...");
150 std::stringstream ss;
170 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
186 void storeCuts(
const int pdg,
const std::vector<std::string>& cutfiles,
187 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
191 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
195 for (
const auto& cutfile : cutfiles) {
197 auto bin_centres_tuple = categoryBinCentres.at(idx);
199 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
200 auto p_bin_centre = std::get<1>(bin_centres_tuple);
201 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
203 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
205 B2FATAL(
"Cut file:\n" << cutfile <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
206 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
209 ")\nPlease check how the input cut file list is being filled.");
212 std::ifstream ifs(cutfile);
213 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
216 cut.erase(std::remove(cut.begin(), cut.end(),
'\n'), cut.end());
219 std::replace(cut.begin(), cut.end(),
'(',
'[');
220 std::replace(cut.begin(), cut.end(),
')',
']');
222 m_cuts[pdg].push_back(cut);
240 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
242 storeCuts(0, cutfiles, categoryBinCentres);
274 const std::vector<std::string>*
getCuts(
const int pdg)
const
305 unsigned int getMVAWeightIdx(
const double& clusterTheta,
const double& p,
const double& charge,
int& idx_theta,
int& idx_p,
306 int& idx_charge)
const
310 B2FATAL(
"No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
317 m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
321 return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
328 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge)
const
330 int idx_theta, idx_p, idx_charge;
331 return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
344 void dumpPayload(
const double& clusterTheta,
const double& p,
const double& charge,
const int pdg,
bool dump_all =
false)
const
347 B2INFO(
"Dumping payload content for:");
348 B2INFO(
"clusterTheta = " << clusterTheta <<
" [rad], p = " << p <<
" [GeV/c], charge = " << charge);
351 std::string filename =
"db_payload_chargedpidmva__clustertheta_p_charge_categories.root";
352 B2INFO(
"\tWriting ROOT file w/ (clusterTheta, p, charge) TH3F grid that defines categories:" << filename);
353 auto f = std::make_unique<TFile>(filename.c_str(),
"RECREATE");
357 B2WARNING(
"\tThe TH3F object that defines categories is a nullptr!");
362 if (!dump_all && pdg != pdgId)
continue;
366 auto serialized_weightfile = weights.at(idx);
368 std::string filename =
"db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
369 "_glob_bin_" + std::to_string(idx + 1) +
".xml";
371 auto cutstr =
getCuts(pdgId)->at(idx);
373 B2INFO(
"\tpdgId = " << pdgId);
374 B2INFO(
"\tCut: " << cutstr);
375 B2INFO(
"\tWriting weight file: " << filename);
377 std::ofstream weightfile;
378 weightfile.open(filename.c_str(), std::ios::out);
379 weightfile << serialized_weightfile << std::endl;
420 int findBin(
const TH3F* h,
const double& x,
const double& y,
const double& z)
const
423 int nbinsx_vis = h->GetXaxis()->GetNbins();
424 int nbinsy_vis = h->GetYaxis()->GetNbins();
425 int nbinsz_vis = h->GetZaxis()->GetNbins();
433 if (x < h->GetXaxis()->GetBinLowEdge(1)) { xx = h->GetXaxis()->GetBinCenter(1); }
434 if (x >= h->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = h->GetXaxis()->GetBinCenter(nbinsx_vis); }
435 if (y < h->GetYaxis()->GetBinLowEdge(1)) { yy = h->GetYaxis()->GetBinCenter(1); }
436 if (y >= h->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = h->GetYaxis()->GetBinCenter(nbinsy_vis); }
437 if (z < h->GetZaxis()->GetBinLowEdge(1)) { zz = h->GetZaxis()->GetBinCenter(1); }
438 if (z >= h->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz = h->GetZaxis()->GetBinCenter(nbinsz_vis); }
440 int nbinsx = h->GetXaxis()->GetNbins() + 2;
441 int nbinsy = h->GetYaxis()->GetNbins() + 2;
443 int j = h->GetXaxis()->FindBin(xx);
444 int i = h->GetYaxis()->FindBin(yy);
445 int k = h->GetZaxis()->FindBin(zz);
447 return j + nbinsx * (i + nbinsy * k);
474 { 0, std::vector<std::string>() },
493 { 0, std::vector<std::string>() },
Class to contain the payload of MVA weightfiles needed for charged particle identification.
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
void setWeightCategories(TH3F *h)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_Classification
Binary classification.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
@ c_Multiclass
Multi-class classification.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
ClassDef(ChargedPidMVAWeights, 7)
2: add energy/angular units.
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
unsigned int getMVAWeightIdx(const double &clusterTheta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed pair (clusterTheta,...
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
~ChargedPidMVAWeights()
Destructor.
void dumpPayload(const double &clusterTheta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each (cluster...
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D (clusterTheta, p,...
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
int findBin(const TH3F *h, const double &x, const double &y, const double &z) const
Find global bin index of a 3D histogram for the given (x, y, z) values.
TH3F * m_categories
A 3D (clusterTheta, p, charge) histogram whose bins represent the categories for which XML weight fil...
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
int getPDGCode() const
PDG code.
static const ChargedStable muon
muon particle
static const ParticleSet chargedStableSet
set of charged stable particles
static const ChargedStable pion
charged pion particle
static const ChargedStable proton
proton particle
static const ParticleType invalidParticle
Invalid particle, used internally.
static const ChargedStable kaon
charged kaon particle
static const ChargedStable electron
electron particle
static const ChargedStable deuteron
deuteron particle
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Abstract base class for different kinds of events.