9#include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
12#include <mva/interface/Interface.h>
13#include <mva/methods/TMVA.h>
14#include <analysis/dataobjects/Particle.h>
15#include <analysis/variables/ECLVariables.h>
18#include <framework/logging/LogConfig.h>
19#include <framework/logging/LogSystem.h>
28 setDescription(
"This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
34 "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
35 std::vector<std::string>());
38 "The name of the database payload object with the MVA weights.",
39 std::string(
"ChargedPidMVAWeights"));
42 "Specify whether to use a charge-independent training of the MVA.",
46 "Specify whether to use an ECL-only training of the MVA.",
75 std::map<int, std::string> debugStr = {
85 decayDescriptor.init(decayString);
86 auto pListName = decayDescriptor.getMother()->getFullName();
88 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
92 B2FATAL(
"ParticleList: " << pListName <<
" could not be found. Aborting...");
95 auto pListSize = pList->getListSize();
97 B2DEBUG(11,
"ParticleList: " << pList->getParticleListName() <<
" - N = " << pListSize <<
" particles.");
99 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
102 std::vector<int> pdgs;
103 if (m_nSelectedDaughters == 0) {
104 pdgs.push_back(pList->getPDGCode());
106 pdgs = decayDescriptor.getSelectionPDGCodes();
108 for (
auto pdg : pdgs) {
111 B2FATAL(
"PDG: " << pdg <<
" of ParticleList: " << pListName <<
112 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
115 std::vector<const Particle*> targetParticles;
116 if (m_nSelectedDaughters > 0) {
117 for (
unsigned int iPart(0); iPart < pListSize; ++iPart) {
118 auto* iParticle = pList->getParticle(iPart);
119 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
120 for (
auto* iDaughter : daughters) {
121 targetParticles.push_back(iDaughter);
126 for (
unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
128 const Particle* particle = (m_nSelectedDaughters > 0) ? targetParticles[ipart] : pList->getParticle(ipart);
132 const ECLCluster* eclCluster = particle->getECLCluster();
134 B2DEBUG(11,
"\nParticle [" << ipart <<
"] has invalid Track-ECLCluster relation, skip MVA application...");
143 auto p = particle->getP();
146 if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
147 B2DEBUG(11,
"\nParticle [" << ipart <<
"] has invalid input variable, skip MVA application..." <<
148 " polar angle: " << theta <<
", p: " << p <<
", charge: " << charge);
152 int idx_theta, idx_p, idx_charge;
155 auto hasMatch = std::isnormal(Variable::eclClusterTrackMatched(particle));
157 debugStr[11] +=
"\n";
158 debugStr[11] += (
"Particle [" + std::to_string(ipart) +
"]\n");
159 debugStr[11] += (
"Has ECL cluster match? " + std::to_string(hasMatch) +
"\n");
160 debugStr[11] += (
"polar angle: " + thVarName +
" = " + std::to_string(theta) +
" [rad]\n");
161 debugStr[11] += (
"p = " + std::to_string(p) +
" [GeV/c]\n");
163 debugStr[11] += (
"charge = " + std::to_string(charge) +
"\n");
165 debugStr[11] += (
"Is brems corrected ? " + std::to_string(particle->hasExtraInfo(
"bremsCorrected")) +
"\n");
166 debugStr[11] += (
"Weightfile idx = " + std::to_string(index) +
" - (polar angle, p, charge) = (" + std::to_string(
167 idx_theta) +
", " + std::to_string(idx_p) +
", " +
168 std::to_string(idx_charge) +
")\n");
170 debugStr[11] += (
"Category cut: " +
m_cuts.at(index)->decompile() +
"\n");
173 B2DEBUG(11, debugStr[11]);
174 debugStr[11].clear();
178 if (!
m_cuts.at(index)->check(particle)) {
179 B2DEBUG(11,
"\nParticle [" << ipart <<
"] didn't pass MVA category cut, skip MVA application...");
186 debugStr[11] +=
"\n";
187 debugStr[11] +=
"MVA variables:\n";
190 for (
unsigned int ivar(0); ivar < nvars; ++ivar) {
194 double var = std::numeric_limits<double>::quiet_NaN();
195 auto var_result = varobj->function(particle);
196 if (std::holds_alternative<double>(var_result)) {
197 var = std::get<double>(var_result);
198 }
else if (std::holds_alternative<int>(var_result)) {
199 var = std::get<int>(var_result);
200 }
else if (std::holds_alternative<bool>(var_result)) {
201 var = std::get<bool>(var_result);
203 B2ERROR(
"Variable '" << varobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
208 var = (std::isnan(var)) ? -999.0 : var;
211 debugStr[11] += (
"\tvar[" + std::to_string(ivar) +
"] : " + varobj->name +
" = " + std::to_string(var) +
"\n");
217 B2DEBUG(11, debugStr[11]);
218 debugStr[11].clear();
223 debugStr[12] +=
"\n";
224 debugStr[12] +=
"MVA spectators:\n";
227 for (
unsigned int ispec(0); ispec < nspecs; ++ispec) {
231 double spec = std::numeric_limits<double>::quiet_NaN();
232 auto spec_result = specobj->function(particle);
233 if (std::holds_alternative<double>(spec_result)) {
234 spec = std::get<double>(spec_result);
235 }
else if (std::holds_alternative<int>(spec_result)) {
236 spec = std::get<int>(spec_result);
237 }
else if (std::holds_alternative<bool>(spec_result)) {
238 spec = std::get<bool>(spec_result);
240 B2ERROR(
"Variable '" << specobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
243 debugStr[12] += (
"\tspec[" + std::to_string(ispec) +
"] : " + specobj->name +
" = " + std::to_string(spec) +
"\n");
245 m_datasets.at(index)->m_spectators[ispec] = spec;
249 B2DEBUG(12, debugStr[12]);
250 debugStr[12].clear();
256 debugStr[11] +=
"\n";
257 debugStr[12] +=
"\n";
258 debugStr[11] +=
"MVA response:\n";
260 std::string score_varname(
"");
262 std::vector<float> scores =
m_experts.at(index)->applyMulticlass(*
m_datasets.at(index))[0];
264 for (
unsigned int classID(0); classID <
m_classes.size(); ++classID) {
266 const std::string className(
m_classes.at(classID));
268 score_varname =
"pidChargedBDTScore_" + className;
271 score_varname +=
"_" + std::to_string(Const::ECL);
274 score_varname +=
"_" + std::to_string(det);
278 debugStr[11] += (
"\tclass[" + std::to_string(classID) +
"] = " + className +
" - score = " + std::to_string(
279 scores[classID]) +
"\n");
280 debugStr[12] += (
"\textraInfo: " + score_varname +
"\n");
283 m_particles[particle->getArrayIndex()]->writeExtraInfo(score_varname, scores[classID]);
287 B2DEBUG(11, debugStr[11]);
288 B2DEBUG(12, debugStr[12]);
289 debugStr[11].clear();
290 debugStr[12].clear();
304 std::string epsilon(
"1e-8");
306 std::map<std::string, std::string> aliasesLegacy;
308 aliasesLegacy.insert(std::make_pair(
"__event__",
"evtNum"));
315 aliasesLegacy.insert(std::make_pair(
"missingLogL_" + detName,
"pidMissingProbabilityExpert(" + detName +
")"));
319 std::string alias = fullName +
"ID_" + detName;
320 std::string var =
"pidProbabilityExpert(" + std::to_string(pdgId) +
", " + detName +
")";
321 std::string aliasLogTrf = alias +
"_LogTransfo";
322 std::string varLogTrf =
"formula(-1. * log10(formula(((1. - " + alias +
") + " + epsilon +
") / (" + alias +
" + " + epsilon +
325 aliasesLegacy.insert(std::make_pair(alias, var));
326 aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
328 if (it.getIndex() == 0) {
329 aliasLogTrf = fullName +
"ID_LogTransfo";
330 varLogTrf =
"formula(-1. * log10(formula(((1. - " + fullName +
"ID) + " + epsilon +
") / (" + fullName +
"ID + " + epsilon +
332 aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
339 B2INFO(
"Setting hard-coded aliases for the ChargedPidMVA algorithm.");
341 std::string debugStr(
"\n");
342 for (
const auto& [alias, variable] : aliasesLegacy) {
343 debugStr += (alias +
" --> " + variable +
"\n");
345 B2ERROR(
"Something went wrong with setting alias: " << alias <<
" for variable: " << variable);
348 B2DEBUG(10, debugStr);
358 if (!aliases->empty()) {
360 B2INFO(
"Setting aliases for the ChargedPidMVA algorithm read from the payload.");
362 std::string debugStr(
"\n");
363 for (
const auto& [alias, variable] : *aliases) {
364 if (alias != variable) {
365 debugStr += (alias +
" --> " + variable +
"\n");
367 B2ERROR(
"Something went wrong with setting alias: " << alias <<
" for variable: " << variable);
371 B2DEBUG(10, debugStr);
387 ". Load supported MVA interfaces for multi-class charged particle identification...");
396 B2INFO(
"\tLoading weightfiles from the payload class.");
399 auto nfiles = serialized_weightfiles->size();
401 B2INFO(
"\tConstruct the MVA experts and datasets from N = " << nfiles <<
" weightfiles...");
411 for (
unsigned int idx(0); idx < nfiles; idx++) {
413 B2DEBUG(12,
"\t\tweightfile[" << idx <<
"]");
416 std::stringstream ss(serialized_weightfiles->at(idx));
420 weightfile.getOptions(general_options);
424 m_variables[idx] = manager.getVariables(general_options.m_variables);
425 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
427 B2DEBUG(12,
"\t\tRetrieved N = " << general_options.m_variables.size()
428 <<
" variables, N = " << general_options.m_spectators.size()
432 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
435 B2DEBUG(12,
"\t\tweightfile loaded successfully into expert[" << idx <<
"]!");
438 std::vector<float> v(general_options.m_variables.size(), 0.0);
439 std::vector<float> s(general_options.m_spectators.size(), 0.0);
440 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
442 B2DEBUG(12,
"\t\tdataset[" << idx <<
"] created successfully!");
446 const auto cutstr = (!cuts->empty()) ? cuts->at(idx) :
"";
449 B2DEBUG(12,
"\t\tcut[" << idx <<
"] created successfully!");
457 weightfile.getOptions(specific_options);
459 if (specific_options.m_classes.empty()) {
460 B2FATAL(
"MVA::SpecificOptions of weightfile[" << idx <<
461 "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
465 for (
const auto& cls : specific_options.m_classes) {
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
std::map< int, std::string > m_stdChargedInfo
Map with standard charged particles' info.
ChargedPidMVAMulticlassModule()
Constructor, for setting module description and parameters.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
void registerAliases()
Set variable aliases needed by the MVA.
VariablesLists m_variables
List of lists of feature variables.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
CutsList m_cuts
List of Cut objects.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
static DetectorSet set()
Accessor for the set of valid detector IDs.
EDetector
Enum for identifying the detector components (detector and subdetector).
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
@ c_Debug
Debug: for code development.
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
General options which are shared by all MVA trainings.
Options for the TMVA Multiclass MVA method.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void setDescription(const std::string &description)
Sets the description of the module.
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Class to store reconstructed particles.
Type-safe access to single objects in the data store.
Global list of available variables.
static Manager & Instance()
get singleton instance.
bool addAlias(const std::string &alias, const std::string &variable)
Add alias Return true if the alias was successfully added.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Abstract base class for different kinds of events.