Belle II Software  release-06-00-14
ChargedPidMVAModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 //THIS MODULE
9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAModule.h>
10 
11 //ANALYSIS
12 #include <mva/interface/Interface.h>
13 #include <analysis/VariableManager/Utility.h>
14 #include <analysis/dataobjects/Particle.h>
15 #include <analysis/dataobjects/ParticleList.h>
16 
17 //MDST
18 #include <mdst/dataobjects/ECLCluster.h>
19 
20 using namespace Belle2;
21 
22 REG_MODULE(ChargedPidMVA)
23 
25 {
26  setDescription("This module evaluates the response of an MVA trained for binary charged particle identification between two hypotheses, S and B. For a given input set of (S,B) mass hypotheses, it takes the Particle objects in the appropriate charged stable particle's ParticleLists, calculates the MVA score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
27 
28  setPropertyFlags(c_ParallelProcessingCertified);
29 
30  addParam("sigHypoPDGCode",
31  m_sig_pdg,
32  "The input signal mass hypothesis' pdgId.",
33  int(0));
34  addParam("bkgHypoPDGCode",
35  m_bkg_pdg,
36  "The input background mass hypothesis' pdgId.",
37  int(0));
38  addParam("particleLists",
39  m_particle_lists,
40  "The input list of ParticleList names.",
41  std::vector<std::string>());
42  addParam("payloadName",
43  m_payload_name,
44  "The name of the database payload object with the MVA weights.",
45  std::string("ChargedPidMVAWeights"));
46  addParam("chargeIndependent",
47  m_charge_independent,
48  "Specify whether to use a charge-independent training of the MVA.",
49  bool(false));
50  addParam("useECLOnlyTraining",
51  m_ecl_only,
52  "Specify whether to use an ECL-only training of the MVA.",
53  bool(false));
54 }
55 
56 
57 ChargedPidMVAModule::~ChargedPidMVAModule() = default;
58 
59 
60 void ChargedPidMVAModule::initialize()
61 {
62 
63  m_event_metadata.isRequired();
64 
65  m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
66 
67 }
68 
69 
71 {
72 
73  // Retrieve the payload from the DB.
74  (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
75  initializeMVA();
76 
77  if (!(*m_weightfiles_representation.get())->isValidPdg(m_sig_pdg)) {
78  B2FATAL("PDG: " << m_sig_pdg <<
79  " of the signal mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
80  }
81  if (!(*m_weightfiles_representation.get())->isValidPdg(m_bkg_pdg)) {
82  B2FATAL("PDG: " << m_bkg_pdg <<
83  " of the background mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
84  }
85 
86  m_score_varname = "pidPairChargedBDTScore_" + std::to_string(m_sig_pdg) + "_VS_" + std::to_string(m_bkg_pdg);
87 
88  if (m_ecl_only) {
89  m_score_varname += "_" + std::to_string(Const::ECL);
90  } else {
91  for (size_t iDet(0); iDet < Const::PIDDetectors::set().size(); ++iDet) {
92  m_score_varname += "_" + std::to_string(Const::PIDDetectors::set()[iDet]);
93  }
94  }
95 }
96 
97 
99 {
100 
101  B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
102 
103  for (const auto& name : m_particle_lists) {
104 
105  StoreObjPtr<ParticleList> pList(name);
106  if (!pList) { B2FATAL("ParticleList: " << name << " could not be found. Aborting..."); }
107 
108  // Need to get an absolute value in order to check if in Const::ChargedStable.
109  int pdg = abs(pList->getPDGCode());
110 
111  // Check if this ParticleList is made up of legit Const::ChargedStable particles.
112  if (!(*m_weightfiles_representation.get())->isValidPdg(pdg)) {
113  B2FATAL("PDG: " << pList->getPDGCode() << " of ParticleList: " << pList->getParticleListName() <<
114  " is not that of a valid particle in Const::chargedStableSet! Aborting...");
115  }
116 
117  // Skip if this ParticleList does not match any of the input (S, B) hypotheses.
118  if (pdg != m_sig_pdg && pdg != m_bkg_pdg) {
119  continue;
120  }
121 
122  B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pList->getListSize() << " particles.");
123 
124  for (unsigned int ipart(0); ipart < pList->getListSize(); ++ipart) {
125 
126  Particle* particle = pList->getParticle(ipart);
127 
128  B2DEBUG(11, "\tParticle [" << ipart << "]");
129 
130  // Check that the particle has a valid relation set between track and ECL cluster.
131  // Otherwise, skip to next.
132  const ECLCluster* eclCluster = particle->getECLCluster();
133  if (!eclCluster) {
134  B2DEBUG(11, "\t\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
135  continue;
136  }
137 
138  // Retrieve the index for the correct MVA expert and dataset,
139  // given reconstructed (clusterTheta, p, charge)
140  auto clusterTheta = eclCluster->getTheta();
141  auto p = particle->getP();
142  // Set a dummy charge of zero to pick charge-independent payloads, if requested.
143  auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
144  int idx_theta, idx_p, idx_charge;
145  auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(clusterTheta, p, charge, idx_theta, idx_p, idx_charge);
146 
147  // Get the cut defining the MVA category under exam (this reflects the one used in the training).
148  const auto cuts = (*m_weightfiles_representation.get())->getCuts(m_sig_pdg);
149  const auto cutstr = (!cuts->empty()) ? cuts->at(index) : "";
150 
151  B2DEBUG(11, "\t\tclusterTheta = " << clusterTheta << " [rad]");
152  B2DEBUG(11, "\t\tp = " << p << " [GeV/c]");
153  if (!m_charge_independent) {
154  B2DEBUG(11, "\t\tcharge = " << charge);
155  }
156  B2DEBUG(11, "\t\tBrems corrected = " << particle->hasExtraInfo("bremsCorrectedPhotonEnergy"));
157  B2DEBUG(11, "\t\tWeightfile idx = " << index << " - (clusterTheta, p, charge) = (" << idx_theta << ", " << idx_p << ", " <<
158  idx_charge << ")");
159  if (!cutstr.empty()) {
160  B2DEBUG(11, "\tCategory cut: " << cutstr);
161  }
162 
163  // Fill the MVA::SingleDataset w/ variables and spectators.
164 
165  B2DEBUG(11, "\tMVA variables:");
166 
167  auto nvars = m_variables.at(index).size();
168  for (unsigned int ivar(0); ivar < nvars; ++ivar) {
169 
170  auto varobj = m_variables.at(index).at(ivar);
171 
172  auto var = varobj->function(particle);
173 
174  // Manual imputation value of -999 for NaN (undefined) variables.
175  var = (std::isnan(var)) ? -999.0 : var;
176 
177  B2DEBUG(11, "\t\tvar[" << ivar << "] : " << varobj->name << " = " << var);
178 
179  m_datasets.at(index)->m_input[ivar] = var;
180 
181  }
182 
183  B2DEBUG(12, "\tMVA spectators:");
184 
185  auto nspecs = m_spectators.at(index).size();
186  for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
187 
188  auto specobj = m_spectators.at(index).at(ispec);
189 
190  auto spec = specobj->function(particle);
191 
192  B2DEBUG(12, "\t\tspec[" << ispec << "] : " << specobj->name << " = " << spec);
193 
194  m_datasets.at(index)->m_spectators[ispec] = spec;
195 
196  }
197 
198  // Compute MVA score only if particle fulfils category selection.
199  if (!cutstr.empty()) {
200 
201  std::unique_ptr<Variable::Cut> cut = Variable::Cut::compile(cutstr);
202 
203  if (!cut->check(particle)) {
204  B2WARNING("\tParticle didn't pass MVA category cut, skip MVA application...");
205  continue;
206  }
207 
208  }
209 
210  float score = m_experts.at(index)->apply(*m_datasets.at(index))[0];
211 
212  B2DEBUG(11, "\tMVA score = " << score);
213  B2DEBUG(12, "\tExtraInfo: " << m_score_varname);
214 
215  // Store the MVA score as a new particle object property.
216  particle->writeExtraInfo(m_score_varname, score);
217 
218  }
219 
220  }
221 }
222 
223 
225 {
226 
227  B2INFO("Run: " << m_event_metadata->getRun() << ". Load supported MVA interfaces for binary charged particle identification...");
228 
229  // The supported methods have to be initialized once (calling it more than once is safe).
231  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
232 
233  B2INFO("\tLoading weightfiles from the payload class for SIGNAL particle hypothesis: " << m_sig_pdg);
234 
235  auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeights(m_sig_pdg);
236  auto nfiles = serialized_weightfiles->size();
237 
238  B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
239 
240  // The size of the vectors must correspond
241  // to the number of available weightfiles for this pdgId.
242  m_experts.resize(nfiles);
243  m_datasets.resize(nfiles);
244  m_variables.resize(nfiles);
245  m_spectators.resize(nfiles);
246 
247  for (unsigned int idx(0); idx < nfiles; idx++) {
248 
249  B2DEBUG(12, "\t\tweightfile[" << idx << "]");
250 
251  // De-serialize the string into an MVA::Weightfile object.
252  std::stringstream ss(serialized_weightfiles->at(idx));
253  auto weightfile = MVA::Weightfile::loadFromStream(ss);
254 
255  MVA::GeneralOptions general_options;
256  weightfile.getOptions(general_options);
257 
258  // Store the list of pointers to the relevant variables for this xml file.
260  m_variables[idx] = manager.getVariables(general_options.m_variables);
261  m_spectators[idx] = manager.getVariables(general_options.m_spectators);
262 
263  B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
264  << " variables, N = " << general_options.m_spectators.size()
265  << " spectators");
266 
267  // Store an MVA::Expert object.
268  m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
269  m_experts.at(idx)->load(weightfile);
270 
271  B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
272 
273  // Store an MVA::SingleDataset object, in which we will save our features later...
274  std::vector<float> v(general_options.m_variables.size(), 0.0);
275  std::vector<float> s(general_options.m_spectators.size(), 0.0);
276  m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
277 
278  B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
279 
280  }
281 
282 }
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
int m_bkg_pdg
The input background mass hypothesis' pdgId.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
std::string m_score_varname
The lookup name of the MVA score variable, given the input S, B mass hypotheses for the algorithm.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual void event() override
Called once for each event.
int m_sig_pdg
The input signal mass hypothesis' pdgId.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
ChargedPidMVAModule()
Constructor, for setting module description and parameters.
VariablesLists m_spectators
List of lists of spectator variables.
std::vector< std::string > m_particle_lists
The input list of names of ParticleList objects to which MVA weights will be applied.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
size_t size() const
Getter for number of detector IDs in this set.
Definition: UnitConst.cc:256
static DetectorSet set()
Accessor function for the set of valid detectors.
Definition: Const.h:255
ECL cluster data.
Definition: ECLCluster.h:27
double getTheta() const
Return Corrected Theta of Shower (radian).
Definition: ECLCluster.h:304
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:104
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
Base class for Modules.
Definition: Module.h:72
Class to store reconstructed particles.
Definition: Particle.h:74
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:98
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
double charge(int pdgCode)
Returns electric charge of a particle with given pdg code.
Definition: EvtPDLUtil.cc:44
Abstract base class for different kinds of events.