Belle II Software  light-2205-abys
ChargedPidMVAMulticlassModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 //THIS MODULE
9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
10 
11 //ANALYSIS
12 #include <mva/interface/Interface.h>
13 #include <mva/methods/TMVA.h>
14 #include <analysis/VariableManager/Utility.h>
15 #include <analysis/dataobjects/Particle.h>
16 
17 //MDST
18 #include <mdst/dataobjects/ECLCluster.h>
19 
20 using namespace Belle2;
21 
22 REG_MODULE(ChargedPidMVAMulticlass);
23 
25 {
26  setDescription("This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
27 
29 
30  addParam("particleLists",
32  "The input list of decay strings, where the mother particle string should correspond to a full name of a particle list. One can select to run on daughters instead of mother particle, e.g. ['Lambda0 -> ^p+ ^pi-'].",
33  std::vector<std::string>());
34  addParam("payloadName",
36  "The name of the database payload object with the MVA weights.",
37  std::string("ChargedPidMVAWeights"));
38  addParam("chargeIndependent",
40  "Specify whether to use a charge-independent training of the MVA.",
41  bool(false));
42  addParam("useECLOnlyTraining",
43  m_ecl_only,
44  "Specify whether to use an ECL-only training of the MVA.",
45  bool(false));
46 }
47 
48 
50 
51 
53 {
54 
55  m_event_metadata.isRequired();
56 
57  m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
58 }
59 
60 
62 {
63 
64  // Retrieve the payload from the DB.
65  (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
66  initializeMVA();
67 
68 }
69 
70 
72 {
73 
74  B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
75 
76  for (auto decayString : m_decayStrings) {
77  DecayDescriptor decayDescriptor;
78  decayDescriptor.init(decayString);
79  auto pl_name = decayDescriptor.getMother()->getFullName();
80 
81  unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
82  StoreObjPtr<ParticleList> pList(pl_name);
83 
84  if (!pList) { B2FATAL("ParticleList: " << pl_name << " could not be found. Aborting..."); }
85  const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pList->getListSize() : pList->getListSize() *
86  m_nSelectedDaughters;
87  // Need to get an absolute value in order to check if in Const::ChargedStable.
88  std::vector<int> pdgs;
89  if (m_nSelectedDaughters == 0)
90  pdgs.push_back(pList->getPDGCode());
91  else
92  pdgs = decayDescriptor.getSelectionPDGCodes();
93  for (auto pdg : pdgs) {
94  // Check if this ParticleList is made up of legit Const::ChargedStable particles.
95  if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
96  B2FATAL("PDG: " << pdg << " of ParticleList: " << pl_name <<
97  " is not that of a valid particle in Const::chargedStableSet! Aborting...");
98  }
99  }
100  std::vector<const Particle*> targetParticles;
101  if (m_nSelectedDaughters > 0) {
102  for (unsigned int iPart(0); iPart < pList->getListSize(); ++iPart) {
103  auto* iParticle = pList->getParticle(iPart);
104  auto daughters = decayDescriptor.getSelectionParticles(iParticle);
105  for (auto* iDaughter : daughters) {
106  targetParticles.push_back(iDaughter);
107  }
108  }
109  }
110  B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pList->getListSize() << " particles.");
111 
112  for (unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
113 
114  const Particle* particle = (m_nSelectedDaughters > 0) ? targetParticles[ipart] : pList->getParticle(ipart);
115 
116  B2DEBUG(11, "\tParticle [" << ipart << "]");
117 
118  // Check that the particle has a valid relation set between track and ECL cluster.
119  // Otherwise, skip to next.
120  const ECLCluster* eclCluster = particle->getECLCluster();
121  if (!eclCluster) {
122  B2DEBUG(11, "\t\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
123  continue;
124  }
125 
126  // Retrieve the index for the correct MVA expert and dataset,
127  // given the reconstructed (clusterTheta, p, charge)
128  auto clusterTheta = eclCluster->getTheta();
129  auto p = particle->getP();
130  // Set a dummy charge of zero to pick charge-independent payloads, if requested.
131  auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
132  int idx_theta, idx_p, idx_charge;
133  auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(clusterTheta, p, charge, idx_theta, idx_p, idx_charge);
134 
135  // Get the cut defining the MVA category under exam (this reflects the one used in the training).
136  const auto cuts = (*m_weightfiles_representation.get())->getCutsMulticlass();
137  const auto cutstr = (!cuts->empty()) ? cuts->at(index) : "";
138 
139  B2DEBUG(11, "\t\tclusterTheta = " << clusterTheta << " [rad]");
140  B2DEBUG(11, "\t\tp = " << p << " [GeV/c]");
141  if (!m_charge_independent) {
142  B2DEBUG(11, "\t\tcharge = " << charge);
143  }
144  B2DEBUG(11, "\t\tBrems corrected = " << particle->hasExtraInfo("bremsCorrectedPhotonEnergy"));
145  B2DEBUG(11, "\t\tWeightfile idx = " << index << " - (clusterTheta, p, charge) = (" << idx_theta << ", " << idx_p << ", " <<
146  idx_charge << ")");
147  if (!cutstr.empty()) {
148  B2DEBUG(11, "\t\tCategory cut = " << cutstr);
149  }
150 
151  // Fill the MVA::SingleDataset w/ variables and spectators.
152 
153  B2DEBUG(11, "\tMVA variables:");
154 
155  auto nvars = m_variables.at(index).size();
156  for (unsigned int ivar(0); ivar < nvars; ++ivar) {
157 
158  auto varobj = m_variables.at(index).at(ivar);
159 
160  double var = -999.0;
161  auto var_result = varobj->function(particle);
162  if (std::holds_alternative<double>(var_result)) {
163  var = std::get<double>(var_result);
164  } else if (std::holds_alternative<int>(var_result)) {
165  var = std::get<int>(var_result);
166  } else if (std::holds_alternative<bool>(var_result)) {
167  var = std::get<bool>(var_result);
168  } else {
169  B2ERROR("Variable '" << varobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
170  }
171 
172  // Manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
173  var = (std::isnan(var)) ? -999.0 : var;
174 
175  B2DEBUG(11, "\t\tvar[" << ivar << "] : " << varobj->name << " = " << var);
176 
177  m_datasets.at(index)->m_input[ivar] = var;
178 
179  }
180 
181  B2DEBUG(12, "\tMVA spectators:");
182 
183  auto nspecs = m_spectators.at(index).size();
184  for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
185 
186  auto specobj = m_spectators.at(index).at(ispec);
187 
188  double spec = std::numeric_limits<double>::quiet_NaN();
189  auto spec_result = specobj->function(particle);
190  if (std::holds_alternative<double>(spec_result)) {
191  spec = std::get<double>(spec_result);
192  } else if (std::holds_alternative<int>(spec_result)) {
193  spec = std::get<int>(spec_result);
194  } else if (std::holds_alternative<bool>(spec_result)) {
195  spec = std::get<bool>(spec_result);
196  } else {
197  B2ERROR("Variable '" << specobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
198  }
199 
200  B2DEBUG(12, "\t\tspec[" << ispec << "] : " << specobj->name << " = " << spec);
201 
202  m_datasets.at(index)->m_spectators[ispec] = spec;
203 
204  }
205 
206  // Compute MVA score only if particle fulfils category selection.
207  if (!cutstr.empty()) {
208 
209  std::unique_ptr<Variable::Cut> cut = Variable::Cut::compile(cutstr);
210 
211  if (!cut->check(particle)) {
212  B2DEBUG(11, "\t\tParticle didn't pass MVA category cut, skip MVA application...");
213  continue;
214  }
215 
216  }
217 
218  // Compute MVA score for each available class.
219 
220  B2DEBUG(11, "\tMVA response:");
221 
222  std::string score_varname("");
223  // We deal w/ a SingleDataset, so 0 is the only existing component by construction.
224  std::vector<float> scores = m_experts.at(index)->applyMulticlass(*m_datasets.at(index))[0];
225 
226  for (unsigned int classID(0); classID < m_classes.size(); ++classID) {
227 
228  const std::string className(m_classes.at(classID));
229 
230  score_varname = "pidChargedBDTScore_" + className;
231 
232  if (m_ecl_only) {
233  score_varname += "_" + std::to_string(Const::ECL);
234  } else {
235  for (size_t iDet(0); iDet < Const::PIDDetectors::set().size(); ++iDet) {
236  score_varname += "_" + std::to_string(Const::PIDDetectors::set()[iDet]);
237  }
238  }
239 
240  B2DEBUG(11, "\t\tclass[" << classID << "] = " << className << " - score = " << scores[classID]);
241  B2DEBUG(12, "\t\tExtraInfo: " << score_varname);
242 
243  // Store the MVA score as a new particle object property.
244  m_particles[particle->getArrayIndex()]->writeExtraInfo(score_varname, scores[classID]);
245 
246  }
247 
248  }
249 
250  }
251 }
252 
253 
255 {
256 
257  B2INFO("Run: " << m_event_metadata->getRun() <<
258  ". Load supported MVA interfaces for multi-class charged particle identification...");
259 
260  // The supported methods have to be initialized once (calling it more than once is safe).
262  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
263 
264  B2INFO("\tLoading weightfiles from the payload class.");
265 
266  auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeightsMulticlass();
267  auto nfiles = serialized_weightfiles->size();
268 
269  B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
270 
271  // The size of the vectors must correspond
272  // to the number of available weightfiles for this pdgId.
273  m_experts.resize(nfiles);
274  m_datasets.resize(nfiles);
275  m_variables.resize(nfiles);
276  m_spectators.resize(nfiles);
277 
278  for (unsigned int idx(0); idx < nfiles; idx++) {
279 
280  B2DEBUG(12, "\t\tweightfile[" << idx << "]");
281 
282  // De-serialize the string into an MVA::Weightfile object.
283  std::stringstream ss(serialized_weightfiles->at(idx));
284  auto weightfile = MVA::Weightfile::loadFromStream(ss);
285 
286  MVA::GeneralOptions general_options;
287  weightfile.getOptions(general_options);
288 
289  // Store the list of pointers to the relevant variables for this xml file.
291  m_variables[idx] = manager.getVariables(general_options.m_variables);
292  m_spectators[idx] = manager.getVariables(general_options.m_spectators);
293 
294  B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
295  << " variables, N = " << general_options.m_spectators.size()
296  << " spectators");
297 
298  // Store an MVA::Expert object.
299  m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
300  m_experts.at(idx)->load(weightfile);
301 
302  B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
303 
304  // Store an MVA::SingleDataset object, in which we will save our features later...
305  std::vector<float> v(general_options.m_variables.size(), 0.0);
306  std::vector<float> s(general_options.m_spectators.size(), 0.0);
307  m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
308 
309  B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
310 
311  // Register class names only once.
312  if (idx == 0) {
313  // QUESTION: could this be made generic?
314  // Problem is I am not sure how other MVA methods deal with multi-classification,
315  // so it's difficult to make an abstract interface that surely works for everything... ideas?
316  MVA::TMVAOptionsMulticlass specific_options;
317  weightfile.getOptions(specific_options);
318 
319  if (specific_options.m_classes.empty()) {
320  B2FATAL("MVA::SpecificOptions of weightfile[" << idx <<
321  "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
322  }
323 
324  m_classes.clear();
325  for (const auto& cls : specific_options.m_classes) {
326  m_classes.push_back(cls);
327  }
328 
329  }
330  }
331 }
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
std::vector< std::string > m_decayStrings
The input list of decay strings to which MVA weights will be applied.
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
ChargedPidMVAMulticlassModule()
Constructor, for setting module description and parameters.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
size_t size() const
Getter for number of detector IDs in this set.
Definition: UnitConst.cc:256
static DetectorSet set()
Accessor function for the set of valid detectors.
Definition: Const.h:255
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
ECL cluster data.
Definition: ECLCluster.h:27
double getTheta() const
Return Corrected Theta of Shower (radian).
Definition: ECLCluster.h:304
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:84
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:74
const ECLCluster * getECLCluster() const
Returns the pointer to the ECLCluster object that was used to create this Particle (if ParticleType =...
Definition: Particle.cc:883
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
Definition: Particle.cc:1255
double getCharge(void) const
Returns particle charge.
Definition: Particle.cc:645
double getP() const
Returns momentum magnitude (same as getMomentumMagnitude but with shorter name)
Definition: Particle.h:515
int getArrayIndex() const
Returns this object's array index (in StoreArray), or -1 if not found.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
REG_MODULE(B2BIIConvertBeamParams)
Register the module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23