Belle II Software development
ChargedPidMVAMulticlassModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8//THIS MODULE
9#include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
10
11//ANALYSIS
12#include <analysis/dataobjects/Particle.h>
13#include <analysis/dbobjects/ChargedPidMVAWeights.h>
14#include <analysis/DecayDescriptor/DecayDescriptor.h>
15#include <analysis/variables/ECLVariables.h>
16
17// FRAMEWORK
18#include <framework/logging/LogConfig.h>
19#include <framework/logging/LogSystem.h>
20#include <mva/interface/Interface.h>
21#include <mva/methods/TMVA.h>
22
23using namespace Belle2;
24
25REG_MODULE(ChargedPidMVAMulticlass);
26
28{
29 setDescription("This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
30
32
33 addParam("particleLists",
35 "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
36 std::vector<std::string>());
37 addParam("payloadName",
39 "The name of the database payload object with the MVA weights.",
40 std::string("ChargedPidMVAWeights"));
41 addParam("chargeIndependent",
43 "Specify whether to use a charge-independent training of the MVA.",
44 bool(false));
45 addParam("useECLOnlyTraining",
47 "Specify whether to use an ECL-only training of the MVA.",
48 bool(false));
49}
50
51
53
54
56{
57 m_event_metadata.isRequired();
58
59 m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
60
61 /* Initialize MVA if the payload has changed and now. */
62 (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
64}
65
66
70
71
73{
74
75 // Debug strings per log level.
76 std::map<int, std::string> debugStr = {
77 {11, ""},
78 {12, ""}
79 };
80
81 B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
82
83 for (auto decayString : m_decayStrings) {
84
85 DecayDescriptor decayDescriptor;
86 decayDescriptor.init(decayString);
87 auto pListName = decayDescriptor.getMother()->getFullName();
88
89 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
90 StoreObjPtr<ParticleList> pList(pListName);
91
92 if (!pList) {
93 B2FATAL("ParticleList: " << pListName << " could not be found. Aborting...");
94 }
95
96 auto pListSize = pList->getListSize();
97
98 B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pListSize << " particles.");
99
100 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
101
102 // Need to get an absolute value in order to check if in Const::ChargedStable.
103 std::vector<int> pdgs;
104 if (m_nSelectedDaughters == 0) {
105 pdgs.push_back(pList->getPDGCode());
106 } else {
107 pdgs = decayDescriptor.getSelectionPDGCodes();
108 }
109 for (auto pdg : pdgs) {
110 // Check if this ParticleList is made up of legit Const::ChargedStable particles.
111 if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
112 B2FATAL("PDG: " << pdg << " of ParticleList: " << pListName <<
113 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
114 }
115 }
116 std::vector<const Particle*> targetParticles;
117 if (m_nSelectedDaughters > 0) {
118 for (unsigned int iPart(0); iPart < pListSize; ++iPart) {
119 auto* iParticle = pList->getParticle(iPart);
120 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
121 for (auto* iDaughter : daughters) {
122 targetParticles.push_back(iDaughter);
123 }
124 }
125 }
126
127 for (unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
128
129 const Particle* particle = (m_nSelectedDaughters > 0) ? targetParticles[ipart] : pList->getParticle(ipart);
130
131 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
132 // LEGACY TRAININGS: always require a track-cluster match.
133 const ECLCluster* eclCluster = particle->getECLCluster();
134 if (!eclCluster) {
135 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid Track-ECLCluster relation, skip MVA application...");
136 continue;
137 }
138 }
139
140 // Retrieve the index for the correct MVA expert and dataset,
141 // given the reconstructed (polar angle, p, charge)
142 auto thVarName = (*m_weightfiles_representation.get())->getThetaVarName();
143 auto theta = std::get<double>(Variable::Manager::Instance().getVariable(thVarName)->function(particle));
144 auto p = particle->getP();
145 // Set a dummy charge of zero to pick charge-independent payloads, if requested.
146 auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
147 if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
148 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid input variable, skip MVA application..." <<
149 " polar angle: " << theta << ", p: " << p << ", charge: " << charge);
150 continue;
151 }
152
153 int idx_theta, idx_p, idx_charge;
154 auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
155
156 auto hasMatch = std::isnormal(Variable::eclClusterTrackMatched(particle));
157
158 debugStr[11] += "\n";
159 debugStr[11] += ("Particle [" + std::to_string(ipart) + "]\n");
160 debugStr[11] += ("Has ECL cluster match? " + std::to_string(hasMatch) + "\n");
161 debugStr[11] += ("polar angle: " + thVarName + " = " + std::to_string(theta) + " [rad]\n");
162 debugStr[11] += ("p = " + std::to_string(p) + " [GeV/c]\n");
164 debugStr[11] += ("charge = " + std::to_string(charge) + "\n");
165 }
166 debugStr[11] += ("Is brems corrected ? " + std::to_string(particle->hasExtraInfo("bremsCorrected")) + "\n");
167 debugStr[11] += ("Weightfile idx = " + std::to_string(index) + " - (polar angle, p, charge) = (" + std::to_string(
168 idx_theta) + ", " + std::to_string(idx_p) + ", " +
169 std::to_string(idx_charge) + ")\n");
170 if (m_cuts.at(index)) {
171 debugStr[11] += ("Category cut: " + m_cuts.at(index)->decompile() + "\n");
172 }
173
174 B2DEBUG(11, debugStr[11]);
175 debugStr[11].clear();
176
177 // Don't even bother if particle does not fulfil the category selection.
178 if (m_cuts.at(index)) {
179 if (!m_cuts.at(index)->check(particle)) {
180 B2DEBUG(11, "\nParticle [" << ipart << "] didn't pass MVA category cut, skip MVA application...");
181 continue;
182 }
183 }
184
185 // Fill the MVA::SingleDataset w/ variables and spectators.
186
187 debugStr[11] += "\n";
188 debugStr[11] += "MVA variables:\n";
189
190 auto nvars = m_variables.at(index).size();
191 for (unsigned int ivar(0); ivar < nvars; ++ivar) {
192
193 auto varobj = m_variables.at(index).at(ivar);
194
195 double var = std::numeric_limits<double>::quiet_NaN();
196 auto var_result = varobj->function(particle);
197 if (std::holds_alternative<double>(var_result)) {
198 var = std::get<double>(var_result);
199 } else if (std::holds_alternative<int>(var_result)) {
200 var = std::get<int>(var_result);
201 } else if (std::holds_alternative<bool>(var_result)) {
202 var = std::get<bool>(var_result);
203 } else {
204 B2ERROR("Variable '" << varobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
205 }
206
207 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
208 // LEGACY TRAININGS: manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
209 var = (std::isnan(var)) ? -999.0 : var;
210 }
211
212 debugStr[11] += ("\tvar[" + std::to_string(ivar) + "] : " + varobj->name + " = " + std::to_string(var) + "\n");
213
214 m_datasets.at(index)->m_input[ivar] = var;
215
216 }
217
218 B2DEBUG(11, debugStr[11]);
219 debugStr[11].clear();
220
221 // Check spectators only when in debug mode.
222 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 12)) {
223
224 debugStr[12] += "\n";
225 debugStr[12] += "MVA spectators:\n";
226
227 auto nspecs = m_spectators.at(index).size();
228 for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
229
230 auto specobj = m_spectators.at(index).at(ispec);
231
232 double spec = std::numeric_limits<double>::quiet_NaN();
233 auto spec_result = specobj->function(particle);
234 if (std::holds_alternative<double>(spec_result)) {
235 spec = std::get<double>(spec_result);
236 } else if (std::holds_alternative<int>(spec_result)) {
237 spec = std::get<int>(spec_result);
238 } else if (std::holds_alternative<bool>(spec_result)) {
239 spec = std::get<bool>(spec_result);
240 } else {
241 B2ERROR("Variable '" << specobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
242 }
243
244 debugStr[12] += ("\tspec[" + std::to_string(ispec) + "] : " + specobj->name + " = " + std::to_string(spec) + "\n");
245
246 m_datasets.at(index)->m_spectators[ispec] = spec;
247
248 }
249
250 B2DEBUG(12, debugStr[12]);
251 debugStr[12].clear();
252
253 }
254
255 // Compute MVA score for each available class.
256
257 debugStr[11] += "\n";
258 debugStr[12] += "\n";
259 debugStr[11] += "MVA response:\n";
260
261 std::string score_varname("");
262 // We deal w/ a SingleDataset, so 0 is the only existing component by construction.
263 std::vector<float> scores = m_experts.at(index)->applyMulticlass(*m_datasets.at(index))[0];
264
265 for (unsigned int classID(0); classID < m_classes.size(); ++classID) {
266
267 const std::string className(m_classes.at(classID));
268
269 score_varname = "pidChargedBDTScore_" + className;
270
271 if (m_ecl_only) {
272 score_varname += "_" + std::to_string(Const::ECL);
273 } else {
274 for (const Const::EDetector& det : Const::PIDDetectorSet::set()) {
275 score_varname += "_" + std::to_string(det);
276 }
277 }
278
279 debugStr[11] += ("\tclass[" + std::to_string(classID) + "] = " + className + " - score = " + std::to_string(
280 scores[classID]) + "\n");
281 debugStr[12] += ("\textraInfo: " + score_varname + "\n");
282
283 // Store the MVA score as a new particle object property.
284 m_particles[particle->getArrayIndex()]->writeExtraInfo(score_varname, scores[classID]);
285
286 }
287
288 B2DEBUG(11, debugStr[11]);
289 B2DEBUG(12, debugStr[12]);
290 debugStr[11].clear();
291 debugStr[12].clear();
292
293 }
294
295 }
296
297 // Clear the debug string map before next event.
298 debugStr.clear();
299
300}
301
303{
304
305 std::string epsilon("1e-8");
306
307 std::map<std::string, std::string> aliasesLegacy;
308
309 aliasesLegacy.insert(std::make_pair("__event__", "evtNum"));
310
312 it != Const::PIDDetectorSet::set().end(); ++it) {
313
314 auto detName = Const::parseDetectors(*it);
315
316 aliasesLegacy.insert(std::make_pair("missingLogL_" + detName, "pidMissingProbabilityExpert(" + detName + ")"));
317
318 for (auto& [pdgId, fullName] : m_stdChargedInfo) {
319
320 std::string alias = fullName + "ID_" + detName;
321 std::string var = "pidProbabilityExpert(" + std::to_string(pdgId) + ", " + detName + ")";
322 std::string aliasLogTrf = alias + "_LogTransfo";
323 std::string varLogTrf = "formula(-1. * log10(formula(((1. - " + alias + ") + " + epsilon + ") / (" + alias + " + " + epsilon +
324 "))))";
325
326 aliasesLegacy.insert(std::make_pair(alias, var));
327 aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
328
329 if (it.getIndex() == 0) {
330 aliasLogTrf = fullName + "ID_LogTransfo";
331 varLogTrf = "formula(-1. * log10(formula(((1. - " + fullName + "ID) + " + epsilon + ") / (" + fullName + "ID + " + epsilon +
332 "))))";
333 aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
334 }
335
336 }
337
338 }
339
340 B2INFO("Setting hard-coded aliases for the ChargedPidMVA algorithm.");
341
342 std::string debugStr("\n");
343 for (const auto& [alias, variable] : aliasesLegacy) {
344 debugStr += (alias + " --> " + variable + "\n");
345 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
346 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
347 }
348 }
349 B2DEBUG(10, debugStr);
350
351}
352
353
355{
356
357 auto aliases = (*m_weightfiles_representation.get())->getAliases();
358
359 if (!aliases->empty()) {
360
361 B2INFO("Setting aliases for the ChargedPidMVA algorithm read from the payload.");
362
363 std::string debugStr("\n");
364 for (const auto& [alias, variable] : *aliases) {
365 if (alias != variable) {
366 debugStr += (alias + " --> " + variable + "\n");
367 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
368 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
369 }
370 }
371 }
372 B2DEBUG(10, debugStr);
373
374 return;
375
376 }
377
378 // Manually set aliases - for bw compatibility
379 this->registerAliasesLegacy();
380
381}
382
383
385{
386
387 B2INFO("Run: " << m_event_metadata->getRun() <<
388 ". Load supported MVA interfaces for multi-class charged particle identification...");
389
390 // Set the necessary variable aliases from the payload.
391 this->registerAliases();
392
393 // The supported methods have to be initialized once (calling it more than once is safe).
395 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
396
397 B2INFO("\tLoading weightfiles from the payload class.");
398
399 auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeightsMulticlass();
400 auto nfiles = serialized_weightfiles->size();
401
402 B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
403
404 // The size of the vectors must correspond
405 // to the number of available weightfiles for this pdgId.
406 m_experts.resize(nfiles);
407 m_datasets.resize(nfiles);
408 m_cuts.resize(nfiles);
409 m_variables.resize(nfiles);
410 m_spectators.resize(nfiles);
411
412 for (unsigned int idx(0); idx < nfiles; idx++) {
413
414 B2DEBUG(12, "\t\tweightfile[" << idx << "]");
415
416 // De-serialize the string into an MVA::Weightfile object.
417 std::stringstream ss(serialized_weightfiles->at(idx));
418 auto weightfile = MVA::Weightfile::loadFromStream(ss);
419
420 MVA::GeneralOptions general_options;
421 weightfile.getOptions(general_options);
422
423 // Store the list of pointers to the relevant variables for this xml file.
425 m_variables[idx] = manager.getVariables(general_options.m_variables);
426 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
427
428 B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
429 << " variables, N = " << general_options.m_spectators.size()
430 << " spectators");
431
432 // Store an MVA::Expert object.
433 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
434 m_experts.at(idx)->load(weightfile);
435
436 B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
437
438 // Store an MVA::SingleDataset object, in which we will save our features later...
439 std::vector<float> v(general_options.m_variables.size(), 0.0);
440 std::vector<float> s(general_options.m_spectators.size(), 0.0);
441 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
442
443 B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
444
445 // Compile cut for this category.
446 const auto cuts = (*m_weightfiles_representation.get())->getCutsMulticlass();
447 const auto cutstr = (!cuts->empty()) ? cuts->at(idx) : "";
448 m_cuts[idx] = (!cutstr.empty()) ? Variable::Cut::compile(cutstr) : nullptr;
449
450 B2DEBUG(12, "\t\tcut[" << idx << "] created successfully!");
451
452 // Register class names only once.
453 if (idx == 0) {
454 // QUESTION: could this be made generic?
455 // Problem is I am not sure how other MVA methods deal with multi-classification,
456 // so it's difficult to make an abstract interface that surely works for everything... ideas?
457 MVA::TMVAOptionsMulticlass specific_options;
458 weightfile.getOptions(specific_options);
459
460 if (specific_options.m_classes.empty()) {
461 B2FATAL("MVA::SpecificOptions of weightfile[" << idx <<
462 "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
463 }
464
465 m_classes.clear();
466 for (const auto& cls : specific_options.m_classes) {
467 m_classes.push_back(cls);
468 }
469
470 }
471 }
472
473}
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
std::map< int, std::string > m_stdChargedInfo
Map with standard charged particles' info.
ChargedPidMVAMulticlassModule()
Constructor, for setting module description and parameters.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
void registerAliases()
Set variable aliases needed by the MVA.
VariablesLists m_variables
List of lists of feature variables.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
Definition UnitConst.cc:217
EDetector
Enum for identifying the detector components (detector and subdetector).
Definition Const.h:42
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
Definition UnitConst.cc:161
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
ECL cluster data.
Definition ECLCluster.h:27
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Definition GeneralCut.h:84
@ c_Debug
Debug: for code development.
Definition LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition LogSystem.cc:28
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
General options which are shared by all MVA trainings.
Definition Options.h:62
Options for the TMVA Multiclass MVA method.
Definition TMVA.h:122
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
Class to store reconstructed particles.
Definition Particle.h:76
Type-safe access to single objects in the data store.
Definition StoreObjPtr.h:96
Global list of available variables.
Definition Manager.h:100
static Manager & Instance()
get singleton instance.
Definition Manager.cc:26
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.