Belle II Software  release-06-02-00
ChargedPidMVAMulticlassModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 //THIS MODULE
9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
10 
11 //ANALYSIS
12 #include <mva/interface/Interface.h>
13 #include <mva/methods/TMVA.h>
14 #include <analysis/VariableManager/Utility.h>
15 #include <analysis/dataobjects/Particle.h>
16 
17 //MDST
18 #include <mdst/dataobjects/ECLCluster.h>
19 
20 using namespace Belle2;
21 
22 REG_MODULE(ChargedPidMVAMulticlass)
23 
25 {
26  setDescription("This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
27 
28  setPropertyFlags(c_ParallelProcessingCertified);
29 
30  addParam("particleLists",
31  m_particle_lists,
32  "The input list of ParticleList names.",
33  std::vector<std::string>());
34  addParam("payloadName",
35  m_payload_name,
36  "The name of the database payload object with the MVA weights.",
37  std::string("ChargedPidMVAWeights"));
38  addParam("chargeIndependent",
39  m_charge_independent,
40  "Specify whether to use a charge-independent training of the MVA.",
41  bool(false));
42  addParam("useECLOnlyTraining",
43  m_ecl_only,
44  "Specify whether to use an ECL-only training of the MVA.",
45  bool(false));
46 }
47 
48 
50 
51 
53 {
54 
55  m_event_metadata.isRequired();
56 
57  m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
58 
59 }
60 
61 
63 {
64 
65  // Retrieve the payload from the DB.
66  (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
67  initializeMVA();
68 
69 }
70 
71 
73 {
74 
75  B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
76 
77  for (const auto& name : m_particle_lists) {
78 
79  StoreObjPtr<ParticleList> pList(name);
80  if (!pList) { B2FATAL("ParticleList: " << name << " could not be found. Aborting..."); }
81 
82  // Need to get an absolute value in order to check if in Const::ChargedStable.
83  int pdg = abs(pList->getPDGCode());
84 
85  // Check if this ParticleList is made up of legit Const::ChargedStable particles.
86  if (!(*m_weightfiles_representation.get())->isValidPdg(pdg)) {
87  B2FATAL("PDG: " << pList->getPDGCode() << " of ParticleList: " << pList->getParticleListName() <<
88  " is not that of a valid particle in Const::chargedStableSet! Aborting...");
89  }
90 
91  B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pList->getListSize() << " particles.");
92 
93  for (unsigned int ipart(0); ipart < pList->getListSize(); ++ipart) {
94 
95  Particle* particle = pList->getParticle(ipart);
96 
97  B2DEBUG(11, "\tParticle [" << ipart << "]");
98 
99  // Check that the particle has a valid relation set between track and ECL cluster.
100  // Otherwise, skip to next.
101  const ECLCluster* eclCluster = particle->getECLCluster();
102  if (!eclCluster) {
103  B2DEBUG(11, "\t\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
104  continue;
105  }
106 
107  // Retrieve the index for the correct MVA expert and dataset,
108  // given the reconstructed (clusterTheta, p, charge)
109  auto clusterTheta = eclCluster->getTheta();
110  auto p = particle->getP();
111  // Set a dummy charge of zero to pick charge-independent payloads, if requested.
112  auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
113  int idx_theta, idx_p, idx_charge;
114  auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(clusterTheta, p, charge, idx_theta, idx_p, idx_charge);
115 
116  // Get the cut defining the MVA category under exam (this reflects the one used in the training).
117  const auto cuts = (*m_weightfiles_representation.get())->getCutsMulticlass();
118  const auto cutstr = (!cuts->empty()) ? cuts->at(index) : "";
119 
120  B2DEBUG(11, "\t\tclusterTheta = " << clusterTheta << " [rad]");
121  B2DEBUG(11, "\t\tp = " << p << " [GeV/c]");
122  if (!m_charge_independent) {
123  B2DEBUG(11, "\t\tcharge = " << charge);
124  }
125  B2DEBUG(11, "\t\tBrems corrected = " << particle->hasExtraInfo("bremsCorrectedPhotonEnergy"));
126  B2DEBUG(11, "\t\tWeightfile idx = " << index << " - (clusterTheta, p, charge) = (" << idx_theta << ", " << idx_p << ", " <<
127  idx_charge << ")");
128  if (!cutstr.empty()) {
129  B2DEBUG(11, "\t\tCategory cut = " << cutstr);
130  }
131 
132  // Fill the MVA::SingleDataset w/ variables and spectators.
133 
134  B2DEBUG(11, "\tMVA variables:");
135 
136  auto nvars = m_variables.at(index).size();
137  for (unsigned int ivar(0); ivar < nvars; ++ivar) {
138 
139  auto varobj = m_variables.at(index).at(ivar);
140 
141  auto var = varobj->function(particle);
142 
143  // Manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
144  var = (std::isnan(var)) ? -999.0 : var;
145 
146  B2DEBUG(11, "\t\tvar[" << ivar << "] : " << varobj->name << " = " << var);
147 
148  m_datasets.at(index)->m_input[ivar] = var;
149 
150  }
151 
152  B2DEBUG(12, "\tMVA spectators:");
153 
154  auto nspecs = m_spectators.at(index).size();
155  for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
156 
157  auto specobj = m_spectators.at(index).at(ispec);
158 
159  auto spec = specobj->function(particle);
160 
161  B2DEBUG(12, "\t\tspec[" << ispec << "] : " << specobj->name << " = " << spec);
162 
163  m_datasets.at(index)->m_spectators[ispec] = spec;
164 
165  }
166 
167  // Compute MVA score only if particle fulfils category selection.
168  if (!cutstr.empty()) {
169 
170  std::unique_ptr<Variable::Cut> cut = Variable::Cut::compile(cutstr);
171 
172  if (!cut->check(particle)) {
173  B2WARNING("\tParticle didn't pass MVA category cut, skip MVA application...");
174  continue;
175  }
176 
177  }
178 
179  // Compute MVA score for each available class.
180 
181  B2DEBUG(11, "\tMVA response:");
182 
183  std::string score_varname("");
184  // We deal w/ a SingleDataset, so 0 is the only existing component by construction.
185  std::vector<float> scores = m_experts.at(index)->applyMulticlass(*m_datasets.at(index))[0];
186 
187  for (unsigned int classID(0); classID < m_classes.size(); ++classID) {
188 
189  const std::string className(m_classes.at(classID));
190 
191  score_varname = "pidChargedBDTScore_" + className;
192 
193  if (m_ecl_only) {
194  score_varname += "_" + std::to_string(Const::ECL);
195  } else {
196  for (size_t iDet(0); iDet < Const::PIDDetectors::set().size(); ++iDet) {
197  score_varname += "_" + std::to_string(Const::PIDDetectors::set()[iDet]);
198  }
199  }
200 
201  B2DEBUG(11, "\t\tclass[" << classID << "] = " << className << " - score = " << scores[classID]);
202  B2DEBUG(12, "\t\tExtraInfo: " << score_varname);
203 
204  // Store the MVA score as a new particle object property.
205  particle->writeExtraInfo(score_varname, scores[classID]);
206 
207  }
208 
209  }
210 
211  }
212 }
213 
214 
216 {
217 
218  B2INFO("Run: " << m_event_metadata->getRun() <<
219  ". Load supported MVA interfaces for multi-class charged particle identification...");
220 
221  // The supported methods have to be initialized once (calling it more than once is safe).
223  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
224 
225  B2INFO("\tLoading weightfiles from the payload class.");
226 
227  auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeightsMulticlass();
228  auto nfiles = serialized_weightfiles->size();
229 
230  B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
231 
232  // The size of the vectors must correspond
233  // to the number of available weightfiles for this pdgId.
234  m_experts.resize(nfiles);
235  m_datasets.resize(nfiles);
236  m_variables.resize(nfiles);
237  m_spectators.resize(nfiles);
238 
239  for (unsigned int idx(0); idx < nfiles; idx++) {
240 
241  B2DEBUG(12, "\t\tweightfile[" << idx << "]");
242 
243  // De-serialize the string into an MVA::Weightfile object.
244  std::stringstream ss(serialized_weightfiles->at(idx));
245  auto weightfile = MVA::Weightfile::loadFromStream(ss);
246 
247  MVA::GeneralOptions general_options;
248  weightfile.getOptions(general_options);
249 
250  // Store the list of pointers to the relevant variables for this xml file.
252  m_variables[idx] = manager.getVariables(general_options.m_variables);
253  m_spectators[idx] = manager.getVariables(general_options.m_spectators);
254 
255  B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
256  << " variables, N = " << general_options.m_spectators.size()
257  << " spectators");
258 
259  // Store an MVA::Expert object.
260  m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
261  m_experts.at(idx)->load(weightfile);
262 
263  B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
264 
265  // Store an MVA::SingleDataset object, in which we will save our features later...
266  std::vector<float> v(general_options.m_variables.size(), 0.0);
267  std::vector<float> s(general_options.m_spectators.size(), 0.0);
268  m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
269 
270  B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
271 
272  // Register class names only once.
273  if (idx == 0) {
274  // QUESTION: could this be made generic?
275  // Problem is I am not sure how other MVA methods deal with multi-classification,
276  // so it's difficult to make an abstract interface that surely works for everything... ideas?
277  MVA::TMVAOptionsMulticlass specific_options;
278  weightfile.getOptions(specific_options);
279 
280  if (specific_options.m_classes.empty()) {
281  B2FATAL("MVA::SpecificOptions of weightfile[" << idx <<
282  "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
283  }
284 
285  m_classes.clear();
286  for (const auto& cls : specific_options.m_classes) {
287  m_classes.push_back(cls);
288  }
289 
290  }
291  }
292 }
This module evaluates the response of a multi-class MVA trained for global charged particle identific...
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
VariablesLists m_spectators
List of lists of spectator variables.
std::vector< std::string > m_particle_lists
The input list of ParticleList names.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
size_t size() const
Getter for number of detector IDs in this set.
Definition: UnitConst.cc:256
static DetectorSet set()
Accessor function for the set of valid detectors.
Definition: Const.h:255
ECL cluster data.
Definition: ECLCluster.h:27
double getTheta() const
Return Corrected Theta of Shower (radian).
Definition: ECLCluster.h:304
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:104
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
Base class for Modules.
Definition: Module.h:72
Class to store reconstructed particles.
Definition: Particle.h:74
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:98
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.