 |
Belle II Software
release-05-01-25
|
14 #include <framework/gearbox/Const.h>
15 #include <framework/gearbox/Unit.h>
16 #include <framework/logging/Logger.h>
19 #include <mva/interface/Weightfile.h>
24 #include <TParameter.h>
28 #include <boost/algorithm/string/predicate.hpp>
40 class ChargedPidMVAWeights :
public TObject {
113 void storeMVAWeights(
const int pdg,
const std::vector<std::string>& filepaths,
114 const std::vector<std::pair<float, float>>& categoryBinCentres)
118 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
122 for (
const auto& path : filepaths) {
125 auto theta_p = categoryBinCentres.at(idx);
128 B2FATAL(
"xml file:\n" << path <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
129 "\n, i.e. the linearised index of the 2D bin centered in (clusterTheta, p) = (" << theta_p.first <<
", " << theta_p.second <<
130 ")\nPlease check how the input xml file list is being filled.");
134 if (boost::ends_with(path,
".root")) {
136 }
else if (boost::ends_with(path,
".xml")) {
139 B2WARNING(
"Unkown file extension for file: " << path <<
", fallback to xml...");
145 std::stringstream ss;
164 const std::vector<std::pair<float, float>>& categoryBinCentres)
180 void storeCuts(
const int pdg,
const std::vector<std::string>& cutfiles,
181 const std::vector<std::pair<float, float>>& categoryBinCentres)
185 B2FATAL(
"PDG: " << pdg <<
" is not that of a valid charged particle! Aborting...");
189 for (
const auto& cutfile : cutfiles) {
192 auto theta_p = categoryBinCentres.at(idx);
195 B2FATAL(
"Cut file:\n" << cutfile <<
"\nindex in input vector:\n" << idx <<
"\ndoes not correspond to:\n" << h_idx <<
196 "\n, i.e. the linearised index of the 2D bin centered in (clusterTheta, p) = (" << theta_p.first <<
", " << theta_p.second <<
197 ")\nPlease check how the input cut file list is being filled.");
200 std::ifstream ifs(cutfile);
201 std::string cut((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
204 cut.erase(std::remove(cut.begin(), cut.end(),
'\n'), cut.end());
207 std::replace(cut.begin(), cut.end(),
'(',
'[');
208 std::replace(cut.begin(), cut.end(),
')',
']');
210 m_cuts[pdg].push_back(cut);
226 void storeCutsMultiClass(
const std::vector<std::string>& cutfiles,
const std::vector<std::pair<float, float>>& categoryBinCentres)
228 storeCuts(0, cutfiles, categoryBinCentres);
237 const std::vector<std::string>*
getMVAWeights(
const int pdg)
const
259 const std::vector<std::string>*
getCuts(
const int pdg)
const
287 unsigned int getMVAWeightIdx(
const double& theta,
const double& p,
int& jth,
int& ip)
const
291 B2FATAL(
"No (clusterTheta, p) TH2 grid was found in the DB payload. This should not happen! Abort...");
302 return (jth - 1) + nbins_th * (ip - 1);
310 unsigned int getMVAWeightIdx(
const double& theta,
const double& p)
const
325 void dumpPayload(
const double& theta,
const double& p,
const int pdg,
bool dump_all =
false)
const
328 B2INFO(
"Dumping payload content for...");
329 B2INFO(
"-) clusterTheta = " << theta <<
" [rad]");
330 B2INFO(
"-) p = " << p <<
" [GeV/c]");
333 std::string filename =
"db_payload_chargedpidmva__clustertheta_p_categories.root";
334 B2INFO(
"\tWriting ROOT file w/ (clusterTheta, p) TH2F grid that defines categories:" << filename);
335 auto f = std::make_unique<TFile>(filename.c_str(),
"RECREATE");
339 B2WARNING(
"\tThe TH2F object that defines categories is a nullptr!");
344 if (!dump_all && pdg != pdgId)
continue;
346 B2INFO(
"-) pdgId = " << pdgId);
350 auto serialized_weightfile = weights.at(idx);
352 std::string filename =
"db_payload_chargedpidmva__weightfile_pdg_" + std::to_string(pdgId) +
353 "_glob_bin_" + std::to_string(idx + 1) +
".xml";
355 auto cutstr =
getCuts(pdgId)->at(idx);
357 B2INFO(
"\tCut: " << cutstr);
358 B2INFO(
"\tWriting weight file: " << filename);
360 std::ofstream weightfile;
361 weightfile.open(filename.c_str(), std::ios::out);
362 weightfile << serialized_weightfile << std::endl;
401 int findBin(
const TH2F* h,
const double& x,
const double& y)
const
404 int nbinsx_vis = h->GetXaxis()->GetNbins();
405 int nbinsy_vis = h->GetYaxis()->GetNbins();
412 if (x < h->GetXaxis()->GetBinLowEdge(1)) { xx = h->GetXaxis()->GetBinCenter(1); }
413 if (x >= h->GetXaxis()->GetBinLowEdge(nbinsx_vis + 1)) { xx = h->GetXaxis()->GetBinCenter(nbinsx_vis); }
414 if (y < h->GetYaxis()->GetBinLowEdge(1)) { yy = h->GetYaxis()->GetBinCenter(1); }
415 if (y >= h->GetYaxis()->GetBinLowEdge(nbinsy_vis + 1)) { yy = h->GetYaxis()->GetBinCenter(nbinsy_vis); }
417 int nbinsx = h->GetXaxis()->GetNbins() + 2;
418 int j = h->GetXaxis()->FindBin(xx);
419 int i = h->GetYaxis()->FindBin(yy);
421 return j + nbinsx * i;
448 { 0, std::vector<std::string>() },
467 { 0, std::vector<std::string>() },
@ c_Classification
Binary classification.
~ChargedPidMVAWeights()
Destructor.
void dumpPayload(const double &theta, const double &p, const int pdg, bool dump_all=false) const
Read and dump the payload content from the internal 'matrioska' maps into an XML weightfile for the g...
@ c_Multiclass
Multi-class classification.
void dumpPayloadMulticlass(const double &theta, const double &p) const
Special version for multi-class mode.
static const ParticleType invalidParticle
Invalid particle, used internally.
static const ChargedStable electron
electron particle
ChargedPidMVAWeights()
Default constructor, necessary for ROOT to stream the object.
void storeMVAWeightsMultiClass(const std::vector< std::string > &filepaths, const std::vector< std::pair< float, float >> &categoryBinCentres)
For the multi-class mode, store the list of MVA weight files (one for each category) into the payload...
The Weightfile class serializes all information about a training into an xml tree.
bool isValid(EForwardBackward eForwardBackward)
Check whether the given enum instance is one of the valid values.
static const ParticleSet chargedStableSet
set of charged stable particles
void storeCutsMultiClass(const std::vector< std::string > &cutfiles, const std::vector< std::pair< float, float >> &categoryBinCentres)
For the multi-class mode, store the list of selection cuts (one for each category) into the payload.
int getPDGCode() const
PDG code.
WeightfilesByParticle m_weightfiles
For each charged particle mass hypothesis' pdgId, this map contains a list of (serialized) Weightfile...
const std::vector< std::string > * getCutsMulticlass() const
For the multi-class mode, get the list of selection cuts stored in the payload, one for each (cluster...
WeightfilesByParticle m_cuts
For each charged particle mass hypothesis' pdgId, this map contains a list of selection cuts to be st...
@ c_ECL_Classification
Binary classification, ECL only.
unsigned int getMVAWeightIdx(const double &theta, const double &p, int &jth, int &ip) const
Get the index of the XML weight file, for a given reconstructed pair (clusterTheta,...
const std::vector< std::string > * getMVAWeights(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of (serialized) MVA weightfiles stored in the p...
static const ChargedStable kaon
charged kaon particle
const std::vector< std::string > * getCuts(const int pdg) const
Given a particle mass hypothesis' pdgId, get the list of selection cuts stored in the payload,...
const std::vector< std::string > * getMVAWeightsMulticlass() const
For the multi-class mode, get the list of (serialized) MVA weightfiles stored in the payload,...
@ c_PSD_Multiclass
Multi-class classification, including PSD.
int findBin(const TH2F *h, const double &x, const double &y) const
Find global bin index of a 2D histogram for the given (x, y) values.
static const ChargedStable pion
charged pion particle
TH2F * m_categories
A 2D (clusterTheta, p) histogram whose bins represent the categories for which XML weight files are d...
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
@ c_ECL_PSD_Classification
Binary classification, ECL only, including PSD.
TParameter< double > m_energy_unit
The energy unit used for defining the bins grid.
const ParticleType & find(int pdg) const
Returns particle in set with given PDG code, or invalidParticle if not found.
Abstract base class for different kinds of events.
TParameter< double > m_ang_unit
The angular unit used for defining the bins grid.
static const ChargedStable deuteron
deuteron particle
@ c_ECL_PSD_Multiclass
Multi-class classification, ECL only, including PSD.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
void setWeightCategories(TH2F *h)
Set the 2D (clusterTheta, p) grid representing the categories for which weightfiles are defined.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
ClassDef(ChargedPidMVAWeights, 5)
2: add energy/angular units.
std::unordered_map< int, std::vector< std::string > > WeightfilesByParticle
Typedef.
void setEnergyUnit(const double &unit)
Set the energy unit to ensure consistency w/ the one used to define the bins grid.
static const ChargedStable proton
proton particle
Class to contain the payload of MVA weightfiles needed for charged particle identification.
static const ChargedStable muon
muon particle
void storeCuts(const int pdg, const std::vector< std::string > &cutfiles, const std::vector< std::pair< float, float >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of selection cuts (one for each category) int...
bool isValidPdg(const int pdg) const
Check if the input pdgId is that of a valid charged particle.
ChargedPidMVATrainingMode
A (strongly-typed) enumerator identifier for each valid MVA training mode.
void storeMVAWeights(const int pdg, const std::vector< std::string > &filepaths, const std::vector< std::pair< float, float >> &categoryBinCentres)
Given a particle mass hypothesis' pdgId, store the list of MVA weight files (one for each category) i...
void setAngularUnit(const double &unit)
Set the angular unit to ensure consistency w/ the one used to define the bins grid.
@ c_ECL_Multiclass
Multi-class classification, ECL only.
@ c_PSD_Classification
Binary classification, including PSD.