12#include <framework/gearbox/Const.h>
13#include <framework/gearbox/Unit.h>
14#include <framework/logging/Logger.h>
17#include <mva/interface/Weightfile.h>
22#include <TParameter.h>
26#include <boost/algorithm/string/predicate.hpp>
60 const std::string& thetaVarName =
"clusterTheta",
61 bool implictNaNmasking =
false)
118 const double* pBins,
const int nPBins,
119 const double* chargeBins,
const int nChargeBins)
122 m_categories = std::make_unique<TH3F>(
"clustertheta_p_charge_binsgrid",
123 ";ECL cluster #theta;p_{lab};Q",
124 nClusterThetaBins, clusterThetaBins,
126 nChargeBins, chargeBins);
139 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
143 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
147 for (
const auto& path : filepaths) {
150 auto bin_centres_tuple = categoryBinCentres.at(idx);
152 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
153 auto p_bin_centre = std::get<1>(bin_centres_tuple);
154 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
156 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
158 B2FATAL(
"xml file:\n" << path <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
159 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
162 ")\nPlease check how the input xml file list is being filled.");
166 if (boost::ends_with(path,
".root")) {
168 }
else if (boost::ends_with(path,
".xml")) {
171 B2WARNING(
"Unknown file extension for file: " << path <<
", fallback to xml...");
177 std::stringstream ss;
197 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
213 void storeCuts(
const int pdg,
const std::vector<std::string>& cutfiles,
214 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
218 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
222 for (
const auto& cutfile : cutfiles) {
224 auto bin_centres_tuple = categoryBinCentres.at(idx);
226 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
227 auto p_bin_centre = std::get<1>(bin_centres_tuple);
228 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
230 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
232 B2FATAL(
"Cut file:\n" << cutfile <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
233 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
236 ")\nPlease check how the input cut file list is being filled.");
239 std::ifstream ifs(cutfile);
240 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
243 cut.erase(std::remove(cut.begin(), cut.end(),
'\n'), cut.end());
245 m_cuts[pdg].push_back(cut);
263 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
265 storeCuts(0, cutfiles, categoryBinCentres);
317 const std::vector<std::string>*
getCuts(
const int pdg)
const
357 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge,
int& idx_theta,
int& idx_p,
358 int& idx_charge)
const
362 B2FATAL(
"No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
369 m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
373 return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
379 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge)
const
381 int idx_theta, idx_p, idx_charge;
382 return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
395 void dumpPayload(
const double& theta,
const double& p,
const double& charge,
const int pdg,
bool dump_all =
false)
const
398 B2INFO(
"Dumping payload content for:");
399 B2INFO(
"clusterTheta(theta) = " << theta <<
" [rad], p = " << p <<
" [GeV/c], charge = " << charge);
402 std::string filename =
"db_payload_chargedpidmva__theta_p_charge_categories.root";
403 B2INFO(
"\tWriting ROOT file w/ TH3F grid that defines categories:" << filename);
404 auto f = std::make_unique<TFile>(filename.c_str(),
"RECREATE");
408 B2WARNING(
"\tThe TH3F object that defines categories is a nullptr!");
413 if (!dump_all && pdg != pdgId)
continue;
417 auto serialized_weightfile = weights.at(idx);
419 std::string filename =
"db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
420 "_glob_bin_" + std::to_string(idx + 1) +
".xml";
422 auto cutstr =
getCuts(pdgId)->at(idx);
424 B2INFO(
"\tpdgId = " << pdgId);
425 B2INFO(
"\tCut: " << cutstr);
426 B2INFO(
"\tWriting weight file: " << filename);
428 std::ofstream weightfile;
429 weightfile.open(filename.c_str(), std::ios::out);
430 weightfile << serialized_weightfile << std::endl;
487 int findBin(
const double& x,
const double& y,
const double& z)
const
500 if (x < m_categories->GetXaxis()->GetBinLowEdge(1)) { xx =
m_categories->GetXaxis()->GetBinCenter(1); }
501 if (x >=
m_categories->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx =
m_categories->GetXaxis()->GetBinCenter(nbinsx_vis); }
502 if (y < m_categories->GetYaxis()->GetBinLowEdge(1)) { yy =
m_categories->GetYaxis()->GetBinCenter(1); }
503 if (y >=
m_categories->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy =
m_categories->GetYaxis()->GetBinCenter(nbinsy_vis); }
504 if (z < m_categories->GetZaxis()->GetBinLowEdge(1)) { zz =
m_categories->GetZaxis()->GetBinCenter(1); }
505 if (z >=
m_categories->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz =
m_categories->GetZaxis()->GetBinCenter(nbinsz_vis); }
514 return j + nbinsx * (i + nbinsy * k);
544 { 0, std::vector<std::string>() },
563 { 0, std::vector<std::string>() },
Class to contain the payload of MVA weightfiles needed for charged particle identification.
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each category...
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed triplet (clusterTheta(theta),...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
VariablesByAlias m_aliases
A map that associates variable aliases used in the MVA training to variable names known to the Variab...
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
const TH3F * getWeightCategories() const
Get the raw pointer to the 3D grid representing the categories for which weightfiles are defined.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_Classification
Binary classification.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
@ c_Multiclass
Multi-class classification.
std::unique_ptr< TH3F > m_categories
A 3D histogram whose bins represent the categories for which XML weight files are defined.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
std::map< std::string, std::string > VariablesByAlias
Typedef.
const VariablesByAlias * getAliases() const
Get the map of unique aliases.
void setWeightCategories(const double *clusterThetaBins, const int nClusterThetaBins, const double *pBins, const int nPBins, const double *chargeBins, const int nChargeBins)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
std::string getThetaVarName() const
Get the name of the polar angle variable.
std::string m_thetaVarName
The name of the polar angle variable used in the MVA categorisation.
bool hasImplicitNaNmasking() const
Check flag for implicit NaN masking.
void dumpPayload(const double &theta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void storeAliases(const VariablesByAlias &aliases)
Store the map associating variable aliases to variable names knowm to VariableManager.
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
~ChargedPidMVAWeights()
Destructor.
bool m_implicitNaNmasking
Flag to indicate whether the MVA variables have been NaN-masked directly in the weightfiles.
ChargedPidMVAWeights(const double &energyUnit, const double &angUnit, const std::string &thetaVarName="clusterTheta", bool implictNaNmasking=false)
Specialized constructor.
int findBin(const double &x, const double &y, const double &z) const
Find global bin index of the 3D categories histogram for the given (x, y, z) values.
ClassDef(ChargedPidMVAWeights, 10)
2: add energy/angular units.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D bin coordinates.
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
int getPDGCode() const
PDG code.
static const ChargedStable muon
muon particle
static const ParticleSet chargedStableSet
set of charged stable particles
static const ChargedStable pion
charged pion particle
static const ChargedStable proton
proton particle
static const ParticleType invalidParticle
Invalid particle, used internally.
static const ChargedStable kaon
charged kaon particle
static const ChargedStable electron
electron particle
static const ChargedStable deuteron
deuteron particle
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Abstract base class for different kinds of events.