12 #include <framework/gearbox/Const.h>
13 #include <framework/gearbox/Unit.h>
14 #include <framework/logging/Logger.h>
17 #include <mva/interface/Weightfile.h>
22 #include <TParameter.h>
26 #include <boost/algorithm/string/predicate.hpp>
60 const std::string& thetaVarName =
"clusterTheta",
61 bool implictNaNmasking =
false)
118 const double* pBins,
const int nPBins,
119 const double* chargeBins,
const int nChargeBins)
122 m_categories = std::make_unique<TH3F>(
"clustertheta_p_charge_binsgrid",
123 ";ECL cluster #theta;p_{lab};Q",
124 nClusterThetaBins, clusterThetaBins,
126 nChargeBins, chargeBins);
139 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
143 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
147 for (
const auto& path : filepaths) {
150 auto bin_centres_tuple = categoryBinCentres.at(idx);
152 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
153 auto p_bin_centre = std::get<1>(bin_centres_tuple);
154 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
156 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
158 B2FATAL(
"xml file:\n" << path <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
159 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
162 ")\nPlease check how the input xml file list is being filled.");
166 if (boost::ends_with(path,
".root")) {
168 }
else if (boost::ends_with(path,
".xml")) {
171 B2WARNING(
"Unknown file extension for file: " << path <<
", fallback to xml...");
177 std::stringstream ss;
197 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
213 void storeCuts(
const int pdg,
const std::vector<std::string>& cutfiles,
214 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
218 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
222 for (
const auto& cutfile : cutfiles) {
224 auto bin_centres_tuple = categoryBinCentres.at(idx);
226 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
227 auto p_bin_centre = std::get<1>(bin_centres_tuple);
228 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
230 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
232 B2FATAL(
"Cut file:\n" << cutfile <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
233 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
236 ")\nPlease check how the input cut file list is being filled.");
239 std::ifstream ifs(cutfile);
240 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
243 cut.erase(std::remove(cut.begin(), cut.end(),
'\n'), cut.end());
245 m_cuts[pdg].push_back(cut);
263 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
265 storeCuts(0, cutfiles, categoryBinCentres);
318 const std::vector<std::string>*
getCuts(
const int pdg)
const
358 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge,
int& idx_theta,
int& idx_p,
359 int& idx_charge)
const
363 B2FATAL(
"No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
370 m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
374 return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
380 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge)
const
382 int idx_theta, idx_p, idx_charge;
383 return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
396 void dumpPayload(
const double& theta,
const double& p,
const double& charge,
const int pdg,
bool dump_all =
false)
const
399 B2INFO(
"Dumping payload content for:");
400 B2INFO(
"clusterTheta(theta) = " << theta <<
" [rad], p = " << p <<
" [GeV/c], charge = " << charge);
403 std::string filename =
"db_payload_chargedpidmva__theta_p_charge_categories.root";
404 B2INFO(
"\tWriting ROOT file w/ TH3F grid that defines categories:" << filename);
405 auto f = std::make_unique<TFile>(filename.c_str(),
"RECREATE");
409 B2WARNING(
"\tThe TH3F object that defines categories is a nullptr!");
414 if (!dump_all && pdg != pdgId)
continue;
418 auto serialized_weightfile = weights.at(idx);
420 std::string filename =
"db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
421 "_glob_bin_" + std::to_string(idx + 1) +
".xml";
423 auto cutstr =
getCuts(pdgId)->at(idx);
425 B2INFO(
"\tpdgId = " << pdgId);
426 B2INFO(
"\tCut: " << cutstr);
427 B2INFO(
"\tWriting weight file: " << filename);
429 std::ofstream weightfile;
430 weightfile.open(filename.c_str(), std::ios::out);
431 weightfile << serialized_weightfile << std::endl;
488 int findBin(
const double& x,
const double& y,
const double& z)
const
501 if (x < m_categories->GetXaxis()->GetBinLowEdge(1)) { xx =
m_categories->GetXaxis()->GetBinCenter(1); }
502 if (x >=
m_categories->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx =
m_categories->GetXaxis()->GetBinCenter(nbinsx_vis); }
503 if (y < m_categories->GetYaxis()->GetBinLowEdge(1)) { yy =
m_categories->GetYaxis()->GetBinCenter(1); }
504 if (y >=
m_categories->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy =
m_categories->GetYaxis()->GetBinCenter(nbinsy_vis); }
505 if (z < m_categories->GetZaxis()->GetBinLowEdge(1)) { zz =
m_categories->GetZaxis()->GetBinCenter(1); }
506 if (z >=
m_categories->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz =
m_categories->GetZaxis()->GetBinCenter(nbinsz_vis); }
515 return j + nbinsx * (i + nbinsy * k);
545 { 0, std::vector<std::string>() },
564 { 0, std::vector<std::string>() },
Class to contain the payload of MVA weightfiles needed for charged particle identification.
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed triplet (clusterTheta(theta),...
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
VariablesByAlias m_aliases
A map that associates variable aliases used in the MVA training to variable names known to the Variab...
const VariablesByAlias * getAliases() const
Get the map of unique aliases.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_Classification
Binary classification.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
@ c_Multiclass
Multi-class classification.
std::unique_ptr< TH3F > m_categories
A 3D histogram whose bins represent the categories for which XML weight files are defined.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
std::map< std::string, std::string > VariablesByAlias
Typedef.
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
void setWeightCategories(const double *clusterThetaBins, const int nClusterThetaBins, const double *pBins, const int nPBins, const double *chargeBins, const int nChargeBins)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
std::string getThetaVarName() const
Get the name of the polar angle variable.
std::string m_thetaVarName
The name of the polar angle variable used in the MVA categorisation.
bool hasImplicitNaNmasking() const
Check flag for implicit NaN masking.
void dumpPayload(const double &theta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
void storeAliases(const VariablesByAlias &aliases)
Store the map associating variable aliases to variable names knowm to VariableManager.
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
~ChargedPidMVAWeights()
Destructor.
const TH3F * getWeightCategories() const
Get the raw pointer to the 3D grid representing the categories for which weightfiles are defined.
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each category...
bool m_implicitNaNmasking
Flag to indicate whther the MVA variables have been NaN-masked directly in the weightfiles.
ChargedPidMVAWeights(const double &energyUnit, const double &angUnit, const std::string &thetaVarName="clusterTheta", bool implictNaNmasking=false)
Specialized constructor.
int findBin(const double &x, const double &y, const double &z) const
Find global bin index of the 3D categories histogram for the given (x, y, z) values.
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double >> &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
ClassDef(ChargedPidMVAWeights, 10)
2: add energy/angular units.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D bin coordinates.
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
int getPDGCode() const
PDG code.
static const ChargedStable muon
muon particle
static const ParticleSet chargedStableSet
set of charged stable particles
static const ChargedStable pion
charged pion particle
static const ChargedStable proton
proton particle
static const ParticleType invalidParticle
Invalid particle, used internally.
static const ChargedStable kaon
charged kaon particle
static const ChargedStable electron
electron particle
static const ChargedStable deuteron
deuteron particle
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Abstract base class for different kinds of events.