Belle II Software development
ChargedPidMVAMulticlassModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8//THIS MODULE
9#include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
10
11//ANALYSIS
12#include <mva/interface/Interface.h>
13#include <mva/methods/TMVA.h>
14#include <analysis/dataobjects/Particle.h>
15#include <analysis/variables/ECLVariables.h>
16
17// FRAMEWORK
18#include <framework/logging/LogConfig.h>
19#include <framework/logging/LogSystem.h>
20
21
22using namespace Belle2;
23
24REG_MODULE(ChargedPidMVAMulticlass);
25
27{
28 setDescription("This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
29
31
32 addParam("particleLists",
34 "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
35 std::vector<std::string>());
36 addParam("payloadName",
38 "The name of the database payload object with the MVA weights.",
39 std::string("ChargedPidMVAWeights"));
40 addParam("chargeIndependent",
42 "Specify whether to use a charge-independent training of the MVA.",
43 bool(false));
44 addParam("useECLOnlyTraining",
46 "Specify whether to use an ECL-only training of the MVA.",
47 bool(false));
48}
49
50
52
53
55{
56 m_event_metadata.isRequired();
57
58 m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
59
60 /* Initialize MVA if the payload has changed and now. */
61 (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
63}
64
65
67{
68}
69
70
72{
73
74 // Debug strings per log level.
75 std::map<int, std::string> debugStr = {
76 {11, ""},
77 {12, ""}
78 };
79
80 B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
81
82 for (auto decayString : m_decayStrings) {
83
84 DecayDescriptor decayDescriptor;
85 decayDescriptor.init(decayString);
86 auto pListName = decayDescriptor.getMother()->getFullName();
87
88 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
89 StoreObjPtr<ParticleList> pList(pListName);
90
91 if (!pList) {
92 B2FATAL("ParticleList: " << pListName << " could not be found. Aborting...");
93 }
94
95 auto pListSize = pList->getListSize();
96
97 B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pListSize << " particles.");
98
99 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
100
101 // Need to get an absolute value in order to check if in Const::ChargedStable.
102 std::vector<int> pdgs;
103 if (m_nSelectedDaughters == 0) {
104 pdgs.push_back(pList->getPDGCode());
105 } else {
106 pdgs = decayDescriptor.getSelectionPDGCodes();
107 }
108 for (auto pdg : pdgs) {
109 // Check if this ParticleList is made up of legit Const::ChargedStable particles.
110 if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
111 B2FATAL("PDG: " << pdg << " of ParticleList: " << pListName <<
112 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
113 }
114 }
115 std::vector<const Particle*> targetParticles;
116 if (m_nSelectedDaughters > 0) {
117 for (unsigned int iPart(0); iPart < pListSize; ++iPart) {
118 auto* iParticle = pList->getParticle(iPart);
119 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
120 for (auto* iDaughter : daughters) {
121 targetParticles.push_back(iDaughter);
122 }
123 }
124 }
125
126 for (unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
127
128 const Particle* particle = (m_nSelectedDaughters > 0) ? targetParticles[ipart] : pList->getParticle(ipart);
129
130 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
131 // LEGACY TRAININGS: always require a track-cluster match.
132 const ECLCluster* eclCluster = particle->getECLCluster();
133 if (!eclCluster) {
134 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid Track-ECLCluster relation, skip MVA application...");
135 continue;
136 }
137 }
138
139 // Retrieve the index for the correct MVA expert and dataset,
140 // given the reconstructed (polar angle, p, charge)
141 auto thVarName = (*m_weightfiles_representation.get())->getThetaVarName();
142 auto theta = std::get<double>(Variable::Manager::Instance().getVariable(thVarName)->function(particle));
143 auto p = particle->getP();
144 // Set a dummy charge of zero to pick charge-independent payloads, if requested.
145 auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
146 if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
147 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid input variable, skip MVA application..." <<
148 " polar angle: " << theta << ", p: " << p << ", charge: " << charge);
149 continue;
150 }
151
152 int idx_theta, idx_p, idx_charge;
153 auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
154
155 auto hasMatch = std::isnormal(Variable::eclClusterTrackMatched(particle));
156
157 debugStr[11] += "\n";
158 debugStr[11] += ("Particle [" + std::to_string(ipart) + "]\n");
159 debugStr[11] += ("Has ECL cluster match? " + std::to_string(hasMatch) + "\n");
160 debugStr[11] += ("polar angle: " + thVarName + " = " + std::to_string(theta) + " [rad]\n");
161 debugStr[11] += ("p = " + std::to_string(p) + " [GeV/c]\n");
163 debugStr[11] += ("charge = " + std::to_string(charge) + "\n");
164 }
165 debugStr[11] += ("Is brems corrected ? " + std::to_string(particle->hasExtraInfo("bremsCorrected")) + "\n");
166 debugStr[11] += ("Weightfile idx = " + std::to_string(index) + " - (polar angle, p, charge) = (" + std::to_string(
167 idx_theta) + ", " + std::to_string(idx_p) + ", " +
168 std::to_string(idx_charge) + ")\n");
169 if (m_cuts.at(index)) {
170 debugStr[11] += ("Category cut: " + m_cuts.at(index)->decompile() + "\n");
171 }
172
173 B2DEBUG(11, debugStr[11]);
174 debugStr[11].clear();
175
176 // Don't even bother if particle does not fulfil the category selection.
177 if (m_cuts.at(index)) {
178 if (!m_cuts.at(index)->check(particle)) {
179 B2DEBUG(11, "\nParticle [" << ipart << "] didn't pass MVA category cut, skip MVA application...");
180 continue;
181 }
182 }
183
184 // Fill the MVA::SingleDataset w/ variables and spectators.
185
186 debugStr[11] += "\n";
187 debugStr[11] += "MVA variables:\n";
188
189 auto nvars = m_variables.at(index).size();
190 for (unsigned int ivar(0); ivar < nvars; ++ivar) {
191
192 auto varobj = m_variables.at(index).at(ivar);
193
194 double var = std::numeric_limits<double>::quiet_NaN();
195 auto var_result = varobj->function(particle);
196 if (std::holds_alternative<double>(var_result)) {
197 var = std::get<double>(var_result);
198 } else if (std::holds_alternative<int>(var_result)) {
199 var = std::get<int>(var_result);
200 } else if (std::holds_alternative<bool>(var_result)) {
201 var = std::get<bool>(var_result);
202 } else {
203 B2ERROR("Variable '" << varobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
204 }
205
206 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
207 // LEGACY TRAININGS: manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
208 var = (std::isnan(var)) ? -999.0 : var;
209 }
210
211 debugStr[11] += ("\tvar[" + std::to_string(ivar) + "] : " + varobj->name + " = " + std::to_string(var) + "\n");
212
213 m_datasets.at(index)->m_input[ivar] = var;
214
215 }
216
217 B2DEBUG(11, debugStr[11]);
218 debugStr[11].clear();
219
220 // Check spectators only when in debug mode.
221 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 12)) {
222
223 debugStr[12] += "\n";
224 debugStr[12] += "MVA spectators:\n";
225
226 auto nspecs = m_spectators.at(index).size();
227 for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
228
229 auto specobj = m_spectators.at(index).at(ispec);
230
231 double spec = std::numeric_limits<double>::quiet_NaN();
232 auto spec_result = specobj->function(particle);
233 if (std::holds_alternative<double>(spec_result)) {
234 spec = std::get<double>(spec_result);
235 } else if (std::holds_alternative<int>(spec_result)) {
236 spec = std::get<int>(spec_result);
237 } else if (std::holds_alternative<bool>(spec_result)) {
238 spec = std::get<bool>(spec_result);
239 } else {
240 B2ERROR("Variable '" << specobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
241 }
242
243 debugStr[12] += ("\tspec[" + std::to_string(ispec) + "] : " + specobj->name + " = " + std::to_string(spec) + "\n");
244
245 m_datasets.at(index)->m_spectators[ispec] = spec;
246
247 }
248
249 B2DEBUG(12, debugStr[12]);
250 debugStr[12].clear();
251
252 }
253
254 // Compute MVA score for each available class.
255
256 debugStr[11] += "\n";
257 debugStr[12] += "\n";
258 debugStr[11] += "MVA response:\n";
259
260 std::string score_varname("");
261 // We deal w/ a SingleDataset, so 0 is the only existing component by construction.
262 std::vector<float> scores = m_experts.at(index)->applyMulticlass(*m_datasets.at(index))[0];
263
264 for (unsigned int classID(0); classID < m_classes.size(); ++classID) {
265
266 const std::string className(m_classes.at(classID));
267
268 score_varname = "pidChargedBDTScore_" + className;
269
270 if (m_ecl_only) {
271 score_varname += "_" + std::to_string(Const::ECL);
272 } else {
273 for (const Const::EDetector& det : Const::PIDDetectorSet::set()) {
274 score_varname += "_" + std::to_string(det);
275 }
276 }
277
278 debugStr[11] += ("\tclass[" + std::to_string(classID) + "] = " + className + " - score = " + std::to_string(
279 scores[classID]) + "\n");
280 debugStr[12] += ("\textraInfo: " + score_varname + "\n");
281
282 // Store the MVA score as a new particle object property.
283 m_particles[particle->getArrayIndex()]->writeExtraInfo(score_varname, scores[classID]);
284
285 }
286
287 B2DEBUG(11, debugStr[11]);
288 B2DEBUG(12, debugStr[12]);
289 debugStr[11].clear();
290 debugStr[12].clear();
291
292 }
293
294 }
295
296 // Clear the debug string map before next event.
297 debugStr.clear();
298
299}
300
302{
303
304 std::string epsilon("1e-8");
305
306 std::map<std::string, std::string> aliasesLegacy;
307
308 aliasesLegacy.insert(std::make_pair("__event__", "evtNum"));
309
311 it != Const::PIDDetectorSet::set().end(); ++it) {
312
313 auto detName = Const::parseDetectors(*it);
314
315 aliasesLegacy.insert(std::make_pair("missingLogL_" + detName, "pidMissingProbabilityExpert(" + detName + ")"));
316
317 for (auto& [pdgId, fullName] : m_stdChargedInfo) {
318
319 std::string alias = fullName + "ID_" + detName;
320 std::string var = "pidProbabilityExpert(" + std::to_string(pdgId) + ", " + detName + ")";
321 std::string aliasLogTrf = alias + "_LogTransfo";
322 std::string varLogTrf = "formula(-1. * log10(formula(((1. - " + alias + ") + " + epsilon + ") / (" + alias + " + " + epsilon +
323 "))))";
324
325 aliasesLegacy.insert(std::make_pair(alias, var));
326 aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
327
328 if (it.getIndex() == 0) {
329 aliasLogTrf = fullName + "ID_LogTransfo";
330 varLogTrf = "formula(-1. * log10(formula(((1. - " + fullName + "ID) + " + epsilon + ") / (" + fullName + "ID + " + epsilon +
331 "))))";
332 aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
333 }
334
335 }
336
337 }
338
339 B2INFO("Setting hard-coded aliases for the ChargedPidMVA algorithm.");
340
341 std::string debugStr("\n");
342 for (const auto& [alias, variable] : aliasesLegacy) {
343 debugStr += (alias + " --> " + variable + "\n");
344 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
345 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
346 }
347 }
348 B2DEBUG(10, debugStr);
349
350}
351
352
354{
355
356 auto aliases = (*m_weightfiles_representation.get())->getAliases();
357
358 if (!aliases->empty()) {
359
360 B2INFO("Setting aliases for the ChargedPidMVA algorithm read from the payload.");
361
362 std::string debugStr("\n");
363 for (const auto& [alias, variable] : *aliases) {
364 if (alias != variable) {
365 debugStr += (alias + " --> " + variable + "\n");
366 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
367 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
368 }
369 }
370 }
371 B2DEBUG(10, debugStr);
372
373 return;
374
375 }
376
377 // Manually set aliases - for bw compatibility
378 this->registerAliasesLegacy();
379
380}
381
382
384{
385
386 B2INFO("Run: " << m_event_metadata->getRun() <<
387 ". Load supported MVA interfaces for multi-class charged particle identification...");
388
389 // Set the necessary variable aliases from the payload.
390 this->registerAliases();
391
392 // The supported methods have to be initialized once (calling it more than once is safe).
394 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
395
396 B2INFO("\tLoading weightfiles from the payload class.");
397
398 auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeightsMulticlass();
399 auto nfiles = serialized_weightfiles->size();
400
401 B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
402
403 // The size of the vectors must correspond
404 // to the number of available weightfiles for this pdgId.
405 m_experts.resize(nfiles);
406 m_datasets.resize(nfiles);
407 m_cuts.resize(nfiles);
408 m_variables.resize(nfiles);
409 m_spectators.resize(nfiles);
410
411 for (unsigned int idx(0); idx < nfiles; idx++) {
412
413 B2DEBUG(12, "\t\tweightfile[" << idx << "]");
414
415 // De-serialize the string into an MVA::Weightfile object.
416 std::stringstream ss(serialized_weightfiles->at(idx));
417 auto weightfile = MVA::Weightfile::loadFromStream(ss);
418
419 MVA::GeneralOptions general_options;
420 weightfile.getOptions(general_options);
421
422 // Store the list of pointers to the relevant variables for this xml file.
424 m_variables[idx] = manager.getVariables(general_options.m_variables);
425 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
426
427 B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
428 << " variables, N = " << general_options.m_spectators.size()
429 << " spectators");
430
431 // Store an MVA::Expert object.
432 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
433 m_experts.at(idx)->load(weightfile);
434
435 B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
436
437 // Store an MVA::SingleDataset object, in which we will save our features later...
438 std::vector<float> v(general_options.m_variables.size(), 0.0);
439 std::vector<float> s(general_options.m_spectators.size(), 0.0);
440 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
441
442 B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
443
444 // Compile cut for this category.
445 const auto cuts = (*m_weightfiles_representation.get())->getCutsMulticlass();
446 const auto cutstr = (!cuts->empty()) ? cuts->at(idx) : "";
447 m_cuts[idx] = (!cutstr.empty()) ? Variable::Cut::compile(cutstr) : nullptr;
448
449 B2DEBUG(12, "\t\tcut[" << idx << "] created successfully!");
450
451 // Register class names only once.
452 if (idx == 0) {
453 // QUESTION: could this be made generic?
454 // Problem is I am not sure how other MVA methods deal with multi-classification,
455 // so it's difficult to make an abstract interface that surely works for everything... ideas?
456 MVA::TMVAOptionsMulticlass specific_options;
457 weightfile.getOptions(specific_options);
458
459 if (specific_options.m_classes.empty()) {
460 B2FATAL("MVA::SpecificOptions of weightfile[" << idx <<
461 "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
462 }
463
464 m_classes.clear();
465 for (const auto& cls : specific_options.m_classes) {
466 m_classes.push_back(cls);
467 }
468
469 }
470 }
471
472}
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
std::map< int, std::string > m_stdChargedInfo
Map with standard charged particles' info.
ChargedPidMVAMulticlassModule()
Constructor, for setting module description and parameters.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
void registerAliases()
Set variable aliases needed by the MVA.
VariablesLists m_variables
List of lists of feature variables.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
Definition: UnitConst.cc:220
static DetectorSet set()
Accessor for the set of valid detector IDs.
Definition: Const.h:333
EDetector
Enum for identifying the detector components (detector and subdetector).
Definition: Const.h:42
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
Definition: UnitConst.cc:162
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
ECL cluster data.
Definition: ECLCluster.h:27
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:84
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:75
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
bool addAlias(const std::string &alias, const std::string &variable)
Add alias Return true if the alias was successfully added.
Definition: Manager.cc:95
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.