2 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
5 #include <mva/interface/Interface.h>
6 #include <mva/methods/TMVA.h>
7 #include <analysis/VariableManager/Utility.h>
8 #include <analysis/dataobjects/Particle.h>
11 #include <mdst/dataobjects/ECLCluster.h>
19 setDescription(
"This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
21 setPropertyFlags(c_ParallelProcessingCertified);
23 addParam(
"particleLists",
25 "The input list of ParticleList names.",
26 std::vector<std::string>());
27 addParam(
"payloadName",
29 "The name of the database payload object with the MVA weights.",
30 std::string(
"ChargedPidMVAWeights"));
31 addParam(
"useECLOnlyTraining",
33 "Specify whether to use an ECL-only training of the MVA.",
69 if (!pList) { B2FATAL(
"ParticleList: " << name <<
" could not be found. Aborting..."); }
72 int pdg = abs(pList->getPDGCode());
76 B2FATAL(
"PDG: " << pList->getPDGCode() <<
" of ParticleList: " << pList->getParticleListName() <<
77 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
80 B2DEBUG(11,
"ParticleList: " << pList->getParticleListName() <<
" - N = " << pList->getListSize() <<
" particles.");
82 for (
unsigned int ipart(0); ipart < pList->getListSize(); ++ipart) {
84 Particle* particle = pList->getParticle(ipart);
86 B2DEBUG(11,
"\tParticle [" << ipart <<
"]");
90 const ECLCluster* eclCluster = particle->getECLCluster();
92 B2WARNING(
"\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
99 auto p = particle->getP();
105 const auto cutstr = (!cuts->empty()) ? cuts->at(index) :
"";
107 B2DEBUG(11,
"\t\tcharge = " << particle->getCharge());
108 B2DEBUG(11,
"\t\tclusterTheta = " << theta <<
" [rad]");
109 B2DEBUG(11,
"\t\tp = " << p <<
" [GeV/c]");
110 B2DEBUG(11,
"\t\tBrems corrected = " << particle->hasExtraInfo(
"bremsCorrectedPhotonEnergy"));
111 B2DEBUG(11,
"\t\tWeightfile idx = " << index <<
" - (clusterTheta, p) = (" << jth <<
", " << ip <<
")");
112 if (!cutstr.empty()) {
113 B2DEBUG(11,
"\t\tCategory cut = " << cutstr);
118 B2DEBUG(11,
"\tMVA variables:");
121 for (
unsigned int ivar(0); ivar < nvars; ++ivar) {
125 auto var = varobj->function(particle);
128 var = (std::isnan(var)) ? -999.0 : var;
130 B2DEBUG(11,
"\t\tvar[" << ivar <<
"] : " << varobj->name <<
" = " << var);
136 B2DEBUG(12,
"\tMVA spectators:");
139 for (
unsigned int ispec(0); ispec < nspecs; ++ispec) {
143 auto spec = specobj->function(particle);
145 B2DEBUG(12,
"\t\tspec[" << ispec <<
"] : " << specobj->name <<
" = " << spec);
147 m_datasets.at(index)->m_spectators[ispec] = spec;
152 if (!cutstr.empty()) {
156 if (!cut->check(particle)) {
157 B2WARNING(
"\tParticle didn't pass MVA category cut, skip MVA application...");
165 B2DEBUG(11,
"\tMVA response:");
167 std::string score_varname(
"");
168 for (
unsigned int classID(0); classID <
m_classes.size(); ++classID) {
170 const std::string className(
m_classes.at(classID));
173 score_varname =
"pidChargedBDTScore_" + className;
176 score_varname +=
"_" + std::to_string(Const::ECL);
178 for (
size_t iDet(0); iDet < Const::PIDDetectors::set().size(); ++iDet) {
179 score_varname +=
"_" + std::to_string(Const::PIDDetectors::set()[iDet]);
183 B2DEBUG(11,
"\t\tclass[" << classID <<
"] = " << className <<
" - score = " << score);
184 B2DEBUG(12,
"\t\tExtraInfo: " << score_varname);
187 particle->writeExtraInfo(score_varname, score);
200 B2INFO(
"Load supported MVA interfaces for multi-class charged particle identification...");
206 B2INFO(
"\tLoading weightfiles from the payload class.");
209 auto nfiles = serialized_weightfiles->size();
211 B2INFO(
"\tConstruct the MVA experts and datasets from N = " << nfiles <<
" weightfiles...");
220 for (
unsigned int idx(0); idx < nfiles; idx++) {
222 B2DEBUG(12,
"\t\tweightfile[" << idx <<
"]");
225 std::stringstream ss(serialized_weightfiles->at(idx));
229 weightfile.getOptions(general_options);
233 m_variables[idx] = manager.getVariables(general_options.m_variables);
234 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
236 B2DEBUG(12,
"\t\tRetrieved N = " << general_options.m_variables.size()
237 <<
" variables, N = " << general_options.m_spectators.size()
241 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
244 B2DEBUG(12,
"\t\tweightfile loaded successfully into expert[" << idx <<
"]!");
247 std::vector<float> v(general_options.m_variables.size(), 0.0);
248 std::vector<float> s(general_options.m_spectators.size(), 0.0);
249 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
251 B2DEBUG(12,
"\t\tdataset[" << idx <<
"] created successfully!");
259 weightfile.getOptions(specific_options);
261 if (specific_options.m_classes.empty()) {
262 B2FATAL(
"MVA::SpecificOptions of weightfile[" << idx <<
263 "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
266 for (
const auto& cls : specific_options.m_classes) {