Belle II Software  release-05-02-19
ChargedPidMVAMulticlassModule.cc
1 //THIS MODULE
2 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
3 
4 //ANALYSIS
5 #include <mva/interface/Interface.h>
6 #include <mva/methods/TMVA.h>
7 #include <analysis/VariableManager/Utility.h>
8 #include <analysis/dataobjects/Particle.h>
9 
10 //MDST
11 #include <mdst/dataobjects/ECLCluster.h>
12 
13 using namespace Belle2;
14 
15 REG_MODULE(ChargedPidMVAMulticlass)
16 
18 {
19  setDescription("This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
20 
21  setPropertyFlags(c_ParallelProcessingCertified);
22 
23  addParam("particleLists",
24  m_particle_lists,
25  "The input list of ParticleList names.",
26  std::vector<std::string>());
27  addParam("payloadName",
28  m_payload_name,
29  "The name of the database payload object with the MVA weights.",
30  std::string("ChargedPidMVAWeights"));
31  addParam("useECLOnlyTraining",
32  m_ecl_only,
33  "Specify whether to use an ECL-only training of the MVA.",
34  bool(false));
35 }
36 
37 
39 
40 
42 {
43 
44  m_event_metadata.isRequired();
45 
46  m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
47 
48 }
49 
50 
52 {
53 
54  // Retrieve the payload from the DB.
55  (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
56  initializeMVA();
57 
58 }
59 
60 
62 {
63 
64  B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
65 
66  for (const auto& name : m_particle_lists) {
67 
68  StoreObjPtr<ParticleList> pList(name);
69  if (!pList) { B2FATAL("ParticleList: " << name << " could not be found. Aborting..."); }
70 
71  // Need to get an absolute value in order to check if in Const::ChargedStable.
72  int pdg = abs(pList->getPDGCode());
73 
74  // Check if this ParticleList is made up of legit Const::ChargedStable particles.
75  if (!(*m_weightfiles_representation.get())->isValidPdg(pdg)) {
76  B2FATAL("PDG: " << pList->getPDGCode() << " of ParticleList: " << pList->getParticleListName() <<
77  " is not that of a valid particle in Const::chargedStableSet! Aborting...");
78  }
79 
80  B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pList->getListSize() << " particles.");
81 
82  for (unsigned int ipart(0); ipart < pList->getListSize(); ++ipart) {
83 
84  Particle* particle = pList->getParticle(ipart);
85 
86  B2DEBUG(11, "\tParticle [" << ipart << "]");
87 
88  // Check that the particle has a valid relation set between track and ECL cluster.
89  // Otherwise, skip to next.
90  const ECLCluster* eclCluster = particle->getECLCluster();
91  if (!eclCluster) {
92  B2WARNING("\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
93  continue;
94  }
95 
96  // Retrieve the index for the correct MVA expert and dataset,
97  // given reconstructed (clusterTheta, p)
98  auto theta = eclCluster->getTheta();
99  auto p = particle->getP();
100  int jth, ip;
101  auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, jth, ip);
102 
103  // Get the cut defining the MVA category under exam (this reflects the one used in the training).
104  const auto cuts = (*m_weightfiles_representation.get())->getCutsMulticlass();
105  const auto cutstr = (!cuts->empty()) ? cuts->at(index) : "";
106 
107  B2DEBUG(11, "\t\tcharge = " << particle->getCharge());
108  B2DEBUG(11, "\t\tclusterTheta = " << theta << " [rad]");
109  B2DEBUG(11, "\t\tp = " << p << " [GeV/c]");
110  B2DEBUG(11, "\t\tBrems corrected = " << particle->hasExtraInfo("bremsCorrectedPhotonEnergy"));
111  B2DEBUG(11, "\t\tWeightfile idx = " << index << " - (clusterTheta, p) = (" << jth << ", " << ip << ")");
112  if (!cutstr.empty()) {
113  B2DEBUG(11, "\t\tCategory cut = " << cutstr);
114  }
115 
116  // Fill the MVA::SingleDataset w/ variables and spectators.
117 
118  B2DEBUG(11, "\tMVA variables:");
119 
120  auto nvars = m_variables.at(index).size();
121  for (unsigned int ivar(0); ivar < nvars; ++ivar) {
122 
123  auto varobj = m_variables.at(index).at(ivar);
124 
125  auto var = varobj->function(particle);
126 
127  // Manual imputation value of -999 for NaN (undefined) variables.
128  var = (std::isnan(var)) ? -999.0 : var;
129 
130  B2DEBUG(11, "\t\tvar[" << ivar << "] : " << varobj->name << " = " << var);
131 
132  m_datasets.at(index)->m_input[ivar] = var;
133 
134  }
135 
136  B2DEBUG(12, "\tMVA spectators:");
137 
138  auto nspecs = m_spectators.at(index).size();
139  for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
140 
141  auto specobj = m_spectators.at(index).at(ispec);
142 
143  auto spec = specobj->function(particle);
144 
145  B2DEBUG(12, "\t\tspec[" << ispec << "] : " << specobj->name << " = " << spec);
146 
147  m_datasets.at(index)->m_spectators[ispec] = spec;
148 
149  }
150 
151  // Compute MVA score only if particle fulfils category selection.
152  if (!cutstr.empty()) {
153 
154  std::unique_ptr<Variable::Cut> cut = Variable::Cut::compile(cutstr);
155 
156  if (!cut->check(particle)) {
157  B2WARNING("\tParticle didn't pass MVA category cut, skip MVA application...");
158  continue;
159  }
160 
161  }
162 
163  // Compute MVA score for each available class.
164 
165  B2DEBUG(11, "\tMVA response:");
166 
167  std::string score_varname("");
168  for (unsigned int classID(0); classID < m_classes.size(); ++classID) {
169 
170  const std::string className(m_classes.at(classID));
171 
172  float score = m_experts.at(index)->apply(*m_datasets.at(index), classID)[0];
173  score_varname = "pidChargedBDTScore_" + className;
174 
175  if (m_ecl_only) {
176  score_varname += "_" + std::to_string(Const::ECL);
177  } else {
178  for (size_t iDet(0); iDet < Const::PIDDetectors::set().size(); ++iDet) {
179  score_varname += "_" + std::to_string(Const::PIDDetectors::set()[iDet]);
180  }
181  }
182 
183  B2DEBUG(11, "\t\tclass[" << classID << "] = " << className << " - score = " << score);
184  B2DEBUG(12, "\t\tExtraInfo: " << score_varname);
185 
186  // Store the MVA score as a new particle object property.
187  particle->writeExtraInfo(score_varname, score);
188 
189  }
190 
191  }
192 
193  }
194 }
195 
196 
198 {
199 
200  B2INFO("Load supported MVA interfaces for multi-class charged particle identification...");
201 
202  // The supported methods have to be initialized once (calling it more than once is safe).
204  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
205 
206  B2INFO("\tLoading weightfiles from the payload class.");
207 
208  auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeightsMulticlass();
209  auto nfiles = serialized_weightfiles->size();
210 
211  B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
212 
213  // The size of the vectors must correspond
214  // to the number of available weightfiles for this pdgId.
215  m_experts.resize(nfiles);
216  m_datasets.resize(nfiles);
217  m_variables.resize(nfiles);
218  m_spectators.resize(nfiles);
219 
220  for (unsigned int idx(0); idx < nfiles; idx++) {
221 
222  B2DEBUG(12, "\t\tweightfile[" << idx << "]");
223 
224  // De-serialize the string into an MVA::Weightfile object.
225  std::stringstream ss(serialized_weightfiles->at(idx));
226  auto weightfile = MVA::Weightfile::loadFromStream(ss);
227 
228  MVA::GeneralOptions general_options;
229  weightfile.getOptions(general_options);
230 
231  // Store the list of pointers to the relevant variables for this xml file.
233  m_variables[idx] = manager.getVariables(general_options.m_variables);
234  m_spectators[idx] = manager.getVariables(general_options.m_spectators);
235 
236  B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
237  << " variables, N = " << general_options.m_spectators.size()
238  << " spectators");
239 
240  // Store an MVA::Expert object.
241  m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
242  m_experts.at(idx)->load(weightfile);
243 
244  B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
245 
246  // Store an MVA::SingleDataset object, in which we will save our features later...
247  std::vector<float> v(general_options.m_variables.size(), 0.0);
248  std::vector<float> s(general_options.m_spectators.size(), 0.0);
249  m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
250 
251  B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
252 
253  // Register class names only once.
254  if (idx == 0) {
255  // QUESTION: could this be made generic?
256  // Problem is I am not sure how other MVA methods deal with multi-classification,
257  // so it's difficult to make an abstract interface that surely works for everything... ideas?
258  MVA::TMVAOptionsMulticlass specific_options;
259  weightfile.getOptions(specific_options);
260 
261  if (specific_options.m_classes.empty()) {
262  B2FATAL("MVA::SpecificOptions of weightfile[" << idx <<
263  "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
264  }
265 
266  for (const auto& cls : specific_options.m_classes) {
267  m_classes.push_back(cls);
268  }
269  }
270  }
271 }
Belle2::GeneralCut::compile
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:114
Belle2::ChargedPidMVAMulticlassModule::beginRun
virtual void beginRun() override
Called once before a new run begins.
Definition: ChargedPidMVAMulticlassModule.cc:51
Belle2::ChargedPidMVAMulticlassModule::m_datasets
DatasetsList m_datasets
List of MVA::SingleDataset objects.
Definition: ChargedPidMVAMulticlassModule.h:157
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::ECLCluster
ECL cluster data.
Definition: ECLCluster.h:39
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::ChargedPidMVAMulticlassModule::m_variables
VariablesLists m_variables
List of lists of feature variables.
Definition: ChargedPidMVAMulticlassModule.h:163
Belle2::ECLCluster::getTheta
double getTheta() const
Return Corrected Theta of Shower (radian).
Definition: ECLCluster.h:326
Belle2::ChargedPidMVAMulticlassModule::m_spectators
VariablesLists m_spectators
List of lists of spectator variables.
Definition: ChargedPidMVAMulticlassModule.h:169
Belle2::ChargedPidMVAMulticlassModule::m_experts
ExpertsList m_experts
List of MVA::Expert objects.
Definition: ChargedPidMVAMulticlassModule.h:151
Belle2::ChargedPidMVAMulticlassModule::~ChargedPidMVAMulticlassModule
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
Belle2::ChargedPidMVAMulticlassModule::m_payload_name
std::string m_payload_name
The name of the database payload object with the MVA weights.
Definition: ChargedPidMVAMulticlassModule.h:127
Belle2::ChargedPidMVAMulticlassModule::event
virtual void event() override
Called once for each event.
Definition: ChargedPidMVAMulticlassModule.cc:61
Belle2::Module
Base class for Modules.
Definition: Module.h:74
Belle2::ChargedPidMVAMulticlassModule::initialize
virtual void initialize() override
Use this to initialize resources or memory your module needs.
Definition: ChargedPidMVAMulticlassModule.cc:41
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::StoreObjPtr
Type-safe access to single objects in the data store.
Definition: ParticleList.h:33
Belle2::ChargedPidMVAMulticlassModule::m_particle_lists
std::vector< std::string > m_particle_lists
The input list of ParticleList names.
Definition: ChargedPidMVAMulticlassModule.h:122
Belle2::MVA::TMVAOptionsMulticlass
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:126
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::ChargedPidMVAMulticlassModule::m_classes
std::vector< std::string > m_classes
List of MVA class names.
Definition: ChargedPidMVAMulticlassModule.h:174
Belle2::ChargedPidMVAMulticlassModule::m_event_metadata
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
Definition: ChargedPidMVAMulticlassModule.h:137
Belle2::ChargedPidMVAMulticlassModule::m_weightfiles_representation
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
Definition: ChargedPidMVAMulticlassModule.h:144
Belle2::ChargedPidMVAMulticlassModule::m_ecl_only
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
Definition: ChargedPidMVAMulticlassModule.h:132
Belle2::Particle
Class to store reconstructed particles.
Definition: Particle.h:77
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::ChargedPidMVAMulticlassModule::initializeMVA
void initializeMVA()
Definition: ChargedPidMVAMulticlassModule.cc:197
Belle2::Variable::Manager
Global list of available variables.
Definition: Manager.h:108
Belle2::ChargedPidMVAMulticlassModule
This module evaluates the response of a multi-class MVA trained for global charged particle identific...
Definition: ChargedPidMVAMulticlassModule.h:52
Belle2::Variable::Manager::Instance
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:27