12#include <framework/gearbox/Const.h>
13#include <framework/gearbox/Unit.h>
14#include <framework/logging/Logger.h>
17#include <mva/interface/Weightfile.h>
22#include <TParameter.h>
56 const std::string& thetaVarName =
"clusterTheta",
57 bool implictNaNmasking =
false)
114 const double* pBins,
const int nPBins,
115 const double* chargeBins,
const int nChargeBins)
118 m_categories = std::make_unique<TH3F>(
"clustertheta_p_charge_binsgrid",
119 ";ECL cluster #theta;p_{lab};Q",
120 nClusterThetaBins, clusterThetaBins,
122 nChargeBins, chargeBins);
135 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
139 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
143 for (
const auto& path : filepaths) {
146 auto bin_centres_tuple = categoryBinCentres.at(idx);
148 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
149 auto p_bin_centre = std::get<1>(bin_centres_tuple);
150 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
152 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
154 B2FATAL(
"xml file:\n" << path <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
155 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
158 ")\nPlease check how the input xml file list is being filled.");
162 if (path.ends_with(
".root")) {
164 }
else if (path.ends_with(
".xml")) {
167 B2WARNING(
"Unknown file extension for file: " << path <<
", fallback to xml...");
173 std::stringstream ss;
193 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
209 void storeCuts(
const int pdg,
const std::vector<std::string>& cutfiles,
210 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
214 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
218 for (
const auto& cutfile : cutfiles) {
220 auto bin_centres_tuple = categoryBinCentres.at(idx);
222 auto theta_bin_centre = std::get<0>(bin_centres_tuple);
223 auto p_bin_centre = std::get<1>(bin_centres_tuple);
224 auto charge_bin_centre = std::get<2>(bin_centres_tuple);
226 auto h_idx =
getMVAWeightIdx(theta_bin_centre, p_bin_centre, charge_bin_centre);
228 B2FATAL(
"Cut file:\n" << cutfile <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
229 "\n, i.e. the linearised index of the 3D bin centered in (clusterTheta, p, charge) = (" << theta_bin_centre <<
", " << p_bin_centre
232 ")\nPlease check how the input cut file list is being filled.");
235 std::ifstream ifs(cutfile);
236 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
239 cut.erase(std::remove(cut.begin(), cut.end(),
'\n'), cut.end());
241 m_cuts[pdg].push_back(cut);
259 const std::vector<std::tuple<double, double, double>>& categoryBinCentres)
261 storeCuts(0, cutfiles, categoryBinCentres);
313 const std::vector<std::string>*
getCuts(
const int pdg)
const
353 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge,
int& idx_theta,
int& idx_p,
354 int& idx_charge)
const
358 B2FATAL(
"No (clusterTheta, p, charge) TH3 grid was found in the DB payload. Most likely, you are using a GT w/ an old payload which is no longer compatible with the DB object class implementation. This should not happen! Abort...");
365 m_categories->GetBinXYZ(glob_bin_idx, idx_theta, idx_p, idx_charge);
369 return (idx_theta - 1) + nbins_th * ((idx_p - 1) + nbins_p * (idx_charge - 1));
375 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
const double& charge)
const
377 int idx_theta, idx_p, idx_charge;
378 return getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
391 void dumpPayload(
const double& theta,
const double& p,
const double& charge,
const int pdg,
bool dump_all =
false)
const
394 B2INFO(
"Dumping payload content for:");
395 B2INFO(
"clusterTheta(theta) = " << theta <<
" [rad], p = " << p <<
" [GeV/c], charge = " << charge);
398 std::string filename =
"db_payload_chargedpidmva__theta_p_charge_categories.root";
399 B2INFO(
"\tWriting ROOT file w/ TH3F grid that defines categories:" << filename);
400 auto f = std::make_unique<TFile>(filename.c_str(),
"RECREATE");
404 B2WARNING(
"\tThe TH3F object that defines categories is a nullptr!");
409 if (!dump_all && pdg != pdgId)
continue;
413 auto serialized_weightfile = weights.at(idx);
415 std::string filename =
"db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
416 "_glob_bin_" + std::to_string(idx + 1) +
".xml";
418 auto cutstr =
getCuts(pdgId)->at(idx);
420 B2INFO(
"\tpdgId = " << pdgId);
421 B2INFO(
"\tCut: " << cutstr);
422 B2INFO(
"\tWriting weight file: " << filename);
424 std::ofstream weightfile;
425 weightfile.open(filename.c_str(), std::ios::out);
426 weightfile << serialized_weightfile << std::endl;
483 int findBin(
const double& x,
const double& y,
const double& z)
const
496 if (x < m_categories->GetXaxis()->GetBinLowEdge(1)) { xx =
m_categories->GetXaxis()->GetBinCenter(1); }
497 if (x >=
m_categories->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx =
m_categories->GetXaxis()->GetBinCenter(nbinsx_vis); }
498 if (y < m_categories->GetYaxis()->GetBinLowEdge(1)) { yy =
m_categories->GetYaxis()->GetBinCenter(1); }
499 if (y >=
m_categories->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy =
m_categories->GetYaxis()->GetBinCenter(nbinsy_vis); }
500 if (z < m_categories->GetZaxis()->GetBinLowEdge(1)) { zz =
m_categories->GetZaxis()->GetBinCenter(1); }
501 if (z >=
m_categories->GetZaxis()->GetBinLowEdge(nbinsz_vis + 1)) { zz =
m_categories->GetZaxis()->GetBinCenter(nbinsz_vis); }
510 return j + nbinsx * (i + nbinsy * k);
540 { 0, std::vector<std::string>() },
542 {
Const::muon.getPDGCode(), std::vector<std::string>() },
543 {
Const::pion.getPDGCode(), std::vector<std::string>() },
544 {
Const::kaon.getPDGCode(), std::vector<std::string>() },
559 { 0, std::vector<std::string>() },
561 {
Const::muon.getPDGCode(), std::vector<std::string>() },
562 {
Const::pion.getPDGCode(), std::vector<std::string>() },
563 {
Const::kaon.getPDGCode(), std::vector<std::string>() },
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each category...
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge, int &idx_theta, int &idx_p, int &idx_charge) const
Get the index of the XML weight file, for a given reconstructed triplet (clusterTheta(theta),...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
VariablesByAlias m_aliases
A map that associates variable aliases used in the MVA training to variable names known to the Variab...
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
const TH3F * getWeightCategories() const
Get the raw pointer to the 3D grid representing the categories for which weightfiles are defined.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
@ c_PSD_Multiclass
Multi-class classification, including PSD.
@ c_Classification
Binary classification.
@ c_PSD_Classification
Binary classification, including PSD.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
@ c_ECL_Classification
Binary classification, ECL only.
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
@ c_Multiclass
Multi-class classification.
std::unique_ptr< TH3F > m_categories
A 3D histogram whose bins represent the categories for which XML weight files are defined.
void dumpPayloadMulticlass(const double &theta, const double &p, const double &charge) const
Special version for multi-class mode.
std::map< std::string, std::string > VariablesByAlias
Typedef.
const VariablesByAlias * getAliases() const
Get the map of unique aliases.
void setWeightCategories(const double *clusterThetaBins, const int nClusterThetaBins, const double *pBins, const int nPBins, const double *chargeBins, const int nChargeBins)
Set the 3D (clusterTheta, p, charge) grid representing the categories for which weightfiles are defin...
std::string getThetaVarName() const
Get the name of the polar angle variable.
std::string m_thetaVarName
The name of the polar angle variable used in the MVA categorisation.
bool hasImplicitNaNmasking() const
Check flag for implicit NaN masking.
void dumpPayload(const double &theta, const double &p, const double &charge, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void storeAliases(const VariablesByAlias &aliases)
Store the map associating variable aliases to variable names knowm to VariableManager.
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::tuple< double, double, double > > &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
~ChargedPidMVAWeights()
Destructor.
bool m_implicitNaNmasking
Flag to indicate whether the MVA variables have been NaN-masked directly in the weightfiles.
ChargedPidMVAWeights(const double &energyUnit, const double &angUnit, const std::string &thetaVarName="clusterTheta", bool implictNaNmasking=false)
Specialized constructor.
int findBin(const double &x, const double &y, const double &z) const
Find global bin index of the 3D categories histogram for the given (x, y, z) values.
ClassDef(ChargedPidMVAWeights, 10)
2: add energy/angular units.
unsigned int getMVAWeightIdx(const double &theta, const double &p, const double &charge) const
Overloaded method, to be used if not interested in knowing the 3D bin coordinates.
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
static const ChargedStable muon
muon particle
static const ParticleSet chargedStableSet
set of charged stable particles
static const ChargedStable pion
charged pion particle
static const ChargedStable proton
proton particle
static const ParticleType invalidParticle
Invalid particle, used internally.
static const ChargedStable kaon
charged kaon particle
static const ChargedStable electron
electron particle
static const ChargedStable deuteron
deuteron particle
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Abstract base class for different kinds of events.