Belle II Software  release-08-01-10
ChargedPidMVAMulticlassModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 //THIS MODULE
9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
10 
11 //ANALYSIS
12 #include <mva/interface/Interface.h>
13 #include <mva/methods/TMVA.h>
14 #include <analysis/dataobjects/Particle.h>
15 
16 // FRAMEWORK
17 #include <framework/logging/LogConfig.h>
18 #include <framework/logging/LogSystem.h>
19 
20 
21 using namespace Belle2;
22 
23 REG_MODULE(ChargedPidMVAMulticlass);
24 
26 {
27  setDescription("This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
28 
30 
31  addParam("particleLists",
33  "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
34  std::vector<std::string>());
35  addParam("payloadName",
37  "The name of the database payload object with the MVA weights.",
38  std::string("ChargedPidMVAWeights"));
39  addParam("chargeIndependent",
41  "Specify whether to use a charge-independent training of the MVA.",
42  bool(false));
43  addParam("useECLOnlyTraining",
44  m_ecl_only,
45  "Specify whether to use an ECL-only training of the MVA.",
46  bool(false));
47 }
48 
49 
51 
52 
54 {
55  m_event_metadata.isRequired();
56 
57  m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
58 
59  /* Initialize MVA if the payload has changed and now. */
60  (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
61  initializeMVA();
62 }
63 
64 
66 {
67 }
68 
69 
71 {
72 
73  // Debug strings per log level.
74  std::map<int, std::string> debugStr = {
75  {11, ""},
76  {12, ""}
77  };
78 
79  B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
80 
81  for (auto decayString : m_decayStrings) {
82 
83  DecayDescriptor decayDescriptor;
84  decayDescriptor.init(decayString);
85  auto pListName = decayDescriptor.getMother()->getFullName();
86 
87  unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
88  StoreObjPtr<ParticleList> pList(pListName);
89 
90  if (!pList) {
91  B2FATAL("ParticleList: " << pListName << " could not be found. Aborting...");
92  }
93 
94  auto pListSize = pList->getListSize();
95 
96  B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pListSize << " particles.");
97 
98  const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
99 
100  // Need to get an absolute value in order to check if in Const::ChargedStable.
101  std::vector<int> pdgs;
102  if (m_nSelectedDaughters == 0) {
103  pdgs.push_back(pList->getPDGCode());
104  } else {
105  pdgs = decayDescriptor.getSelectionPDGCodes();
106  }
107  for (auto pdg : pdgs) {
108  // Check if this ParticleList is made up of legit Const::ChargedStable particles.
109  if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
110  B2FATAL("PDG: " << pdg << " of ParticleList: " << pListName <<
111  " is not that of a valid particle in Const::chargedStableSet! Aborting...");
112  }
113  }
114  std::vector<const Particle*> targetParticles;
115  if (m_nSelectedDaughters > 0) {
116  for (unsigned int iPart(0); iPart < pListSize; ++iPart) {
117  auto* iParticle = pList->getParticle(iPart);
118  auto daughters = decayDescriptor.getSelectionParticles(iParticle);
119  for (auto* iDaughter : daughters) {
120  targetParticles.push_back(iDaughter);
121  }
122  }
123  }
124 
125  for (unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
126 
127  const Particle* particle = (m_nSelectedDaughters > 0) ? targetParticles[ipart] : pList->getParticle(ipart);
128 
129  if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
130  // LEGACY TRAININGS: always require a track-cluster match.
131  const ECLCluster* eclCluster = particle->getECLCluster();
132  if (!eclCluster) {
133  B2DEBUG(11, "\nParticle [" << ipart << "] has invalid Track-ECLCluster relation, skip MVA application...");
134  continue;
135  }
136  }
137 
138  // Retrieve the index for the correct MVA expert and dataset,
139  // given the reconstructed (polar angle, p, charge)
140  auto thVarName = (*m_weightfiles_representation.get())->getThetaVarName();
141  auto theta = std::get<double>(Variable::Manager::Instance().getVariable(thVarName)->function(particle));
142  auto p = particle->getP();
143  // Set a dummy charge of zero to pick charge-independent payloads, if requested.
144  auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
145  if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
146  B2DEBUG(11, "\nParticle [" << ipart << "] has invalid input variable, skip MVA application..." <<
147  " polar angle: " << theta << ", p: " << p << ", charge: " << charge);
148  continue;
149  }
150 
151  int idx_theta, idx_p, idx_charge;
152  auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
153 
154  auto* matchVar = Variable::Manager::Instance().getVariable("clusterTrackMatch");
155  auto hasMatch = std::isnormal(std::get<double>(matchVar->function(particle)));
156 
157  debugStr[11] += "\n";
158  debugStr[11] += ("Particle [" + std::to_string(ipart) + "]\n");
159  debugStr[11] += ("Has ECL cluster match? " + std::to_string(hasMatch) + "\n");
160  debugStr[11] += ("polar angle: " + thVarName + " = " + std::to_string(theta) + " [rad]\n");
161  debugStr[11] += ("p = " + std::to_string(p) + " [GeV/c]\n");
162  if (!m_charge_independent) {
163  debugStr[11] += ("charge = " + std::to_string(charge) + "\n");
164  }
165  debugStr[11] += ("Is brems corrected ? " + std::to_string(particle->hasExtraInfo("bremsCorrected")) + "\n");
166  debugStr[11] += ("Weightfile idx = " + std::to_string(index) + " - (polar angle, p, charge) = (" + std::to_string(
167  idx_theta) + ", " + std::to_string(idx_p) + ", " +
168  std::to_string(idx_charge) + ")\n");
169  if (m_cuts.at(index)) {
170  debugStr[11] += ("Category cut: " + m_cuts.at(index)->decompile() + "\n");
171  }
172 
173  B2DEBUG(11, debugStr[11]);
174  debugStr[11].clear();
175 
176  // Don't even bother if particle does not fulfil the category selection.
177  if (m_cuts.at(index)) {
178  if (!m_cuts.at(index)->check(particle)) {
179  B2DEBUG(11, "\nParticle [" << ipart << "] didn't pass MVA category cut, skip MVA application...");
180  continue;
181  }
182  }
183 
184  // Fill the MVA::SingleDataset w/ variables and spectators.
185 
186  debugStr[11] += "\n";
187  debugStr[11] += "MVA variables:\n";
188 
189  auto nvars = m_variables.at(index).size();
190  for (unsigned int ivar(0); ivar < nvars; ++ivar) {
191 
192  auto varobj = m_variables.at(index).at(ivar);
193 
194  double var = std::numeric_limits<double>::quiet_NaN();
195  auto var_result = varobj->function(particle);
196  if (std::holds_alternative<double>(var_result)) {
197  var = std::get<double>(var_result);
198  } else if (std::holds_alternative<int>(var_result)) {
199  var = std::get<int>(var_result);
200  } else if (std::holds_alternative<bool>(var_result)) {
201  var = std::get<bool>(var_result);
202  } else {
203  B2ERROR("Variable '" << varobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
204  }
205 
206  if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
207  // LEGACY TRAININGS: manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
208  var = (std::isnan(var)) ? -999.0 : var;
209  }
210 
211  debugStr[11] += ("\tvar[" + std::to_string(ivar) + "] : " + varobj->name + " = " + std::to_string(var) + "\n");
212 
213  m_datasets.at(index)->m_input[ivar] = var;
214 
215  }
216 
217  B2DEBUG(11, debugStr[11]);
218  debugStr[11].clear();
219 
220  // Check spectators only when in debug mode.
221  if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 12)) {
222 
223  debugStr[12] += "\n";
224  debugStr[12] += "MVA spectators:\n";
225 
226  auto nspecs = m_spectators.at(index).size();
227  for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
228 
229  auto specobj = m_spectators.at(index).at(ispec);
230 
231  double spec = std::numeric_limits<double>::quiet_NaN();
232  auto spec_result = specobj->function(particle);
233  if (std::holds_alternative<double>(spec_result)) {
234  spec = std::get<double>(spec_result);
235  } else if (std::holds_alternative<int>(spec_result)) {
236  spec = std::get<int>(spec_result);
237  } else if (std::holds_alternative<bool>(spec_result)) {
238  spec = std::get<bool>(spec_result);
239  } else {
240  B2ERROR("Variable '" << specobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
241  }
242 
243  debugStr[12] += ("\tspec[" + std::to_string(ispec) + "] : " + specobj->name + " = " + std::to_string(spec) + "\n");
244 
245  m_datasets.at(index)->m_spectators[ispec] = spec;
246 
247  }
248 
249  B2DEBUG(12, debugStr[12]);
250  debugStr[12].clear();
251 
252  }
253 
254  // Compute MVA score for each available class.
255 
256  debugStr[11] += "\n";
257  debugStr[12] += "\n";
258  debugStr[11] += "MVA response:\n";
259 
260  std::string score_varname("");
261  // We deal w/ a SingleDataset, so 0 is the only existing component by construction.
262  std::vector<float> scores = m_experts.at(index)->applyMulticlass(*m_datasets.at(index))[0];
263 
264  for (unsigned int classID(0); classID < m_classes.size(); ++classID) {
265 
266  const std::string className(m_classes.at(classID));
267 
268  score_varname = "pidChargedBDTScore_" + className;
269 
270  if (m_ecl_only) {
271  score_varname += "_" + std::to_string(Const::ECL);
272  } else {
273  for (const Const::EDetector& det : Const::PIDDetectorSet::set()) {
274  score_varname += "_" + std::to_string(det);
275  }
276  }
277 
278  debugStr[11] += ("\tclass[" + std::to_string(classID) + "] = " + className + " - score = " + std::to_string(
279  scores[classID]) + "\n");
280  debugStr[12] += ("\textraInfo: " + score_varname + "\n");
281 
282  // Store the MVA score as a new particle object property.
283  m_particles[particle->getArrayIndex()]->writeExtraInfo(score_varname, scores[classID]);
284 
285  }
286 
287  B2DEBUG(11, debugStr[11]);
288  B2DEBUG(12, debugStr[12]);
289  debugStr[11].clear();
290  debugStr[12].clear();
291 
292  }
293 
294  }
295 
296  // Clear the debug string map before next event.
297  debugStr.clear();
298 
299 }
300 
302 {
303 
304  std::string epsilon("1e-8");
305 
306  std::map<std::string, std::string> aliasesLegacy;
307 
308  aliasesLegacy.insert(std::make_pair("__event__", "evtNum"));
309 
311  it != Const::PIDDetectorSet::set().end(); ++it) {
312 
313  auto detName = Const::parseDetectors(*it);
314 
315  aliasesLegacy.insert(std::make_pair("missingLogL_" + detName, "pidMissingProbabilityExpert(" + detName + ")"));
316 
317  for (auto& [pdgId, fullName] : m_stdChargedInfo) {
318 
319  std::string alias = fullName + "ID_" + detName;
320  std::string var = "pidProbabilityExpert(" + std::to_string(pdgId) + ", " + detName + ")";
321  std::string aliasLogTrf = alias + "_LogTransfo";
322  std::string varLogTrf = "formula(-1. * log10(formula(((1. - " + alias + ") + " + epsilon + ") / (" + alias + " + " + epsilon +
323  "))))";
324 
325  aliasesLegacy.insert(std::make_pair(alias, var));
326  aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
327 
328  if (it.getIndex() == 0) {
329  aliasLogTrf = fullName + "ID_LogTransfo";
330  varLogTrf = "formula(-1. * log10(formula(((1. - " + fullName + "ID) + " + epsilon + ") / (" + fullName + "ID + " + epsilon +
331  "))))";
332  aliasesLegacy.insert(std::make_pair(aliasLogTrf, varLogTrf));
333  }
334 
335  }
336 
337  }
338 
339  B2INFO("Setting hard-coded aliases for the ChargedPidMVA algorithm.");
340 
341  std::string debugStr("\n");
342  for (const auto& [alias, variable] : aliasesLegacy) {
343  debugStr += (alias + " --> " + variable + "\n");
344  if (!Variable::Manager::Instance().addAlias(alias, variable)) {
345  B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
346  }
347  }
348  B2DEBUG(10, debugStr);
349 
350 }
351 
352 
354 {
355 
356  auto aliases = (*m_weightfiles_representation.get())->getAliases();
357 
358  if (!aliases->empty()) {
359 
360  B2INFO("Setting aliases for the ChargedPidMVA algorithm read from the payload.");
361 
362  std::string debugStr("\n");
363  for (const auto& [alias, variable] : *aliases) {
364  if (alias != variable) {
365  debugStr += (alias + " --> " + variable + "\n");
366  if (!Variable::Manager::Instance().addAlias(alias, variable)) {
367  B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
368  }
369  }
370  }
371  B2DEBUG(10, debugStr);
372 
373  return;
374 
375  }
376 
377  // Manually set aliases - for bw compatibility
378  this->registerAliasesLegacy();
379 
380 }
381 
382 
384 {
385 
386  B2INFO("Run: " << m_event_metadata->getRun() <<
387  ". Load supported MVA interfaces for multi-class charged particle identification...");
388 
389  // Set the necessary variable aliases from the payload.
390  this->registerAliases();
391 
392  // The supported methods have to be initialized once (calling it more than once is safe).
394  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
395 
396  B2INFO("\tLoading weightfiles from the payload class.");
397 
398  auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeightsMulticlass();
399  auto nfiles = serialized_weightfiles->size();
400 
401  B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
402 
403  // The size of the vectors must correspond
404  // to the number of available weightfiles for this pdgId.
405  m_experts.resize(nfiles);
406  m_datasets.resize(nfiles);
407  m_cuts.resize(nfiles);
408  m_variables.resize(nfiles);
409  m_spectators.resize(nfiles);
410 
411  for (unsigned int idx(0); idx < nfiles; idx++) {
412 
413  B2DEBUG(12, "\t\tweightfile[" << idx << "]");
414 
415  // De-serialize the string into an MVA::Weightfile object.
416  std::stringstream ss(serialized_weightfiles->at(idx));
417  auto weightfile = MVA::Weightfile::loadFromStream(ss);
418 
419  MVA::GeneralOptions general_options;
420  weightfile.getOptions(general_options);
421 
422  // Store the list of pointers to the relevant variables for this xml file.
424  m_variables[idx] = manager.getVariables(general_options.m_variables);
425  m_spectators[idx] = manager.getVariables(general_options.m_spectators);
426 
427  B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
428  << " variables, N = " << general_options.m_spectators.size()
429  << " spectators");
430 
431  // Store an MVA::Expert object.
432  m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
433  m_experts.at(idx)->load(weightfile);
434 
435  B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
436 
437  // Store an MVA::SingleDataset object, in which we will save our features later...
438  std::vector<float> v(general_options.m_variables.size(), 0.0);
439  std::vector<float> s(general_options.m_spectators.size(), 0.0);
440  m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
441 
442  B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
443 
444  // Compile cut for this category.
445  const auto cuts = (*m_weightfiles_representation.get())->getCutsMulticlass();
446  const auto cutstr = (!cuts->empty()) ? cuts->at(idx) : "";
447  m_cuts[idx] = (!cutstr.empty()) ? Variable::Cut::compile(cutstr) : nullptr;
448 
449  B2DEBUG(12, "\t\tcut[" << idx << "] created successfully!");
450 
451  // Register class names only once.
452  if (idx == 0) {
453  // QUESTION: could this be made generic?
454  // Problem is I am not sure how other MVA methods deal with multi-classification,
455  // so it's difficult to make an abstract interface that surely works for everything... ideas?
456  MVA::TMVAOptionsMulticlass specific_options;
457  weightfile.getOptions(specific_options);
458 
459  if (specific_options.m_classes.empty()) {
460  B2FATAL("MVA::SpecificOptions of weightfile[" << idx <<
461  "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
462  }
463 
464  m_classes.clear();
465  for (const auto& cls : specific_options.m_classes) {
466  m_classes.push_back(cls);
467  }
468 
469  }
470  }
471 
472 }
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
std::map< int, std::string > m_stdChargedInfo
Map with standard charged particles' info.
ChargedPidMVAMulticlassModule()
Constructor, for setting module description and parameters.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
void registerAliases()
Set variable aliases needed by the MVA.
VariablesLists m_variables
List of lists of feature variables.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
Definition: UnitConst.cc:220
static DetectorSet set()
Accessor for the set of valid detector IDs.
Definition: Const.h:324
EDetector
Enum for identifying the detector components (detector and subdetector).
Definition: Const.h:42
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
Definition: UnitConst.cc:162
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
ECL cluster data.
Definition: ECLCluster.h:27
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:84
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:75
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
const Var * getVariable(std::string name)
Get the variable belonging to the given key.
Definition: Manager.cc:57
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
bool addAlias(const std::string &alias, const std::string &variable)
Add alias Return true if the alias was successfully added.
Definition: Manager.cc:95
REG_MODULE(arichBtest)
Register the Module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.