Belle II Software  release-05-01-25
ChargedPidMVAModule.cc
1 //THIS MODULE
2 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAModule.h>
3 
4 //ANALYSIS
5 #include <mva/interface/Interface.h>
6 #include <analysis/VariableManager/Utility.h>
7 #include <analysis/dataobjects/Particle.h>
8 #include <analysis/dataobjects/ParticleList.h>
9 
10 //MDST
11 #include <mdst/dataobjects/ECLCluster.h>
12 
13 using namespace Belle2;
14 
15 REG_MODULE(ChargedPidMVA)
16 
18 {
19  setDescription("This module evaluates the response of an MVA trained for binary charged particle identification between two hypotheses, S and B. For a given input set of (S,B) mass hypotheses, it takes the Particle objects in the appropriate charged stable particle's ParticleLists, calculates the MVA score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
20 
22 
23  addParam("sigHypoPDGCode",
24  m_sig_pdg,
25  "The input signal mass hypothesis' pdgId.",
26  int(0));
27  addParam("bkgHypoPDGCode",
28  m_bkg_pdg,
29  "The input background mass hypothesis' pdgId.",
30  int(0));
31  addParam("particleLists",
33  "The input list of ParticleList names.",
34  std::vector<std::string>());
35  addParam("payloadName",
37  "The name of the database payload object with the MVA weights.",
38  std::string("ChargedPidMVAWeights"));
39  addParam("useECLOnlyTraining",
40  m_ecl_only,
41  "Specify whether to use an ECL-only training of the MVA.",
42  bool(false));
43 }
44 
45 
47 
48 
50 {
51 
52  m_event_metadata.isRequired();
53 
54  m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
55 
56 }
57 
58 
60 {
61 
62  // Retrieve the payload from the DB.
63  (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
64  initializeMVA();
65 
66  if (!(*m_weightfiles_representation.get())->isValidPdg(m_sig_pdg)) {
67  B2FATAL("PDG: " << m_sig_pdg <<
68  " of the signal mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
69  }
70  if (!(*m_weightfiles_representation.get())->isValidPdg(m_bkg_pdg)) {
71  B2FATAL("PDG: " << m_bkg_pdg <<
72  " of the background mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
73  }
74 
75  m_score_varname = "pidPairChargedBDTScore_" + std::to_string(m_sig_pdg) + "_VS_" + std::to_string(m_bkg_pdg);
76 
77  if (m_ecl_only) {
78  m_score_varname += "_" + std::to_string(Const::ECL);
79  } else {
80  for (size_t iDet(0); iDet < Const::PIDDetectors::set().size(); ++iDet) {
81  m_score_varname += "_" + std::to_string(Const::PIDDetectors::set()[iDet]);
82  }
83  }
84 }
85 
86 
88 {
89 
90  B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
91 
92  for (const auto& name : m_particle_lists) {
93 
94  StoreObjPtr<ParticleList> pList(name);
95  if (!pList) { B2FATAL("ParticleList: " << name << " could not be found. Aborting..."); }
96 
97  // Need to get an absolute value in order to check if in Const::ChargedStable.
98  int pdg = abs(pList->getPDGCode());
99 
100  // Check if this ParticleList is made up of legit Const::ChargedStable particles.
101  if (!(*m_weightfiles_representation.get())->isValidPdg(pdg)) {
102  B2FATAL("PDG: " << pList->getPDGCode() << " of ParticleList: " << pList->getParticleListName() <<
103  " is not that of a valid particle in Const::chargedStableSet! Aborting...");
104  }
105 
106  // Skip if this ParticleList does not match any of the input (S, B) hypotheses.
107  if (pdg != m_sig_pdg && pdg != m_bkg_pdg) {
108  continue;
109  }
110 
111  B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pList->getListSize() << " particles.");
112 
113  for (unsigned int ipart(0); ipart < pList->getListSize(); ++ipart) {
114 
115  Particle* particle = pList->getParticle(ipart);
116 
117  B2DEBUG(11, "\tParticle [" << ipart << "]");
118 
119  // Check that the particle has a valid relation set between track and ECL cluster.
120  // Otherwise, skip to next.
121  const ECLCluster* eclCluster = particle->getECLCluster();
122  if (!eclCluster) {
123  B2WARNING("\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
124  continue;
125  }
126 
127  // Retrieve the index for the correct MVA expert and dataset,
128  // given reconstructed (clusterTheta, p)
129  auto theta = eclCluster->getTheta();
130  auto p = particle->getP();
131  int jth, ip;
132  auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, jth, ip);
133 
134  // Get the cut defining the MVA category under exam (this reflects the one used in the training).
135  const auto cuts = (*m_weightfiles_representation.get())->getCuts(m_sig_pdg);
136  const auto cutstr = (!cuts->empty()) ? cuts->at(index) : "";
137 
138  B2DEBUG(11, "\t\tcharge = " << particle->getCharge());
139  B2DEBUG(11, "\t\tclusterTheta = " << theta << " [rad]");
140  B2DEBUG(11, "\t\tp = " << p << " [GeV/c]");
141  B2DEBUG(11, "\t\tBrems corrected = " << particle->hasExtraInfo("bremsCorrectedPhotonEnergy"));
142  B2DEBUG(11, "\t\tWeightfile idx = " << index << " - (clusterTheta, p) = (" << jth << ", " << ip << ")");
143  if (!cutstr.empty()) {
144  B2DEBUG(11, "\tCategory cut: " << cutstr);
145  }
146 
147  // Fill the MVA::SingleDataset w/ variables and spectators.
148 
149  B2DEBUG(11, "\tMVA variables:");
150 
151  auto nvars = m_variables.at(index).size();
152  for (unsigned int ivar(0); ivar < nvars; ++ivar) {
153 
154  auto varobj = m_variables.at(index).at(ivar);
155 
156  auto var = varobj->function(particle);
157 
158  // Manual imputation value of -999 for NaN (undefined) variables.
159  var = (std::isnan(var)) ? -999.0 : var;
160 
161  B2DEBUG(11, "\t\tvar[" << ivar << "] : " << varobj->name << " = " << var);
162 
163  m_datasets.at(index)->m_input[ivar] = var;
164 
165  }
166 
167  B2DEBUG(12, "\tMVA spectators:");
168 
169  auto nspecs = m_spectators.at(index).size();
170  for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
171 
172  auto specobj = m_spectators.at(index).at(ispec);
173 
174  auto spec = specobj->function(particle);
175 
176  B2DEBUG(12, "\t\tspec[" << ispec << "] : " << specobj->name << " = " << spec);
177 
178  m_datasets.at(index)->m_spectators[ispec] = spec;
179 
180  }
181 
182  // Compute MVA score only if particle fulfils category selection.
183  if (!cutstr.empty()) {
184 
185  std::unique_ptr<Variable::Cut> cut = Variable::Cut::compile(cutstr);
186 
187  if (!cut->check(particle)) {
188  B2WARNING("\tParticle didn't pass MVA category cut, skip MVA application...");
189  continue;
190  }
191 
192  }
193 
194  float score = m_experts.at(index)->apply(*m_datasets.at(index))[0];
195 
196  B2DEBUG(11, "\tMVA score = " << score);
197  B2DEBUG(12, "\tExtraInfo: " << m_score_varname);
198 
199  // Store the MVA score as a new particle object property.
200  particle->writeExtraInfo(m_score_varname, score);
201 
202  }
203 
204  }
205 }
206 
207 
209 {
210 
211  B2INFO("Load supported MVA interfaces for charged particle identification...");
212 
213  // The supported methods have to be initialized once (calling it more than once is safe).
215  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
216 
217  B2INFO("\tLoading weightfiles from the payload class for SIGNAL particle hypothesis: " << m_sig_pdg);
218 
219  auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeights(m_sig_pdg);
220  auto nfiles = serialized_weightfiles->size();
221 
222  B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
223 
224  // The size of the vectors must correspond
225  // to the number of available weightfiles for this pdgId.
226  m_experts.resize(nfiles);
227  m_datasets.resize(nfiles);
228  m_variables.resize(nfiles);
229  m_spectators.resize(nfiles);
230 
231  for (unsigned int idx(0); idx < nfiles; idx++) {
232 
233  B2DEBUG(12, "\t\tweightfile[" << idx << "]");
234 
235  // De-serialize the string into an MVA::Weightfile object.
236  std::stringstream ss(serialized_weightfiles->at(idx));
237  auto weightfile = MVA::Weightfile::loadFromStream(ss);
238 
239  MVA::GeneralOptions general_options;
240  weightfile.getOptions(general_options);
241 
242  // Store the list of pointers to the relevant variables for this xml file.
244  m_variables[idx] = manager.getVariables(general_options.m_variables);
245  m_spectators[idx] = manager.getVariables(general_options.m_spectators);
246 
247  B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
248  << " variables, N = " << general_options.m_spectators.size()
249  << " spectators");
250 
251  // Store an MVA::Expert object.
252  m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
253  m_experts.at(idx)->load(weightfile);
254 
255  B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
256 
257  // Store an MVA::SingleDataset object, in which we will save our features later...
258  std::vector<float> v(general_options.m_variables.size(), 0.0);
259  std::vector<float> s(general_options.m_spectators.size(), 0.0);
260  m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
261 
262  B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
263 
264  }
265 
266 }
Belle2::GeneralCut::compile
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:114
Belle2::ChargedPidMVAModule::m_datasets
DatasetsList m_datasets
List of MVA::SingleDataset objects.
Definition: ChargedPidMVAModule.h:155
Belle2::ChargedPidMVAModule::m_weightfiles_representation
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
Definition: ChargedPidMVAModule.h:143
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::ChargedPidMVAModule::m_payload_name
std::string m_payload_name
The name of the database payload object with the MVA weights.
Definition: ChargedPidMVAModule.h:121
Belle2::Module::setDescription
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:216
Belle2::ECLCluster
ECL cluster data.
Definition: ECLCluster.h:39
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::Module::c_ParallelProcessingCertified
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:82
Belle2::ChargedPidMVAModule::~ChargedPidMVAModule
virtual ~ChargedPidMVAModule()
Destructor, use this to clean up anything you created in the constructor.
Belle2::ChargedPidMVAModule::ChargedPidMVAModule
ChargedPidMVAModule()
Constructor, for setting module description and parameters.
Belle2::ChargedPidMVAModule::initializeMVA
void initializeMVA()
Belle2::ChargedPidMVAModule::m_experts
ExpertsList m_experts
List of MVA::Expert objects.
Definition: ChargedPidMVAModule.h:149
Belle2::ECLCluster::getTheta
double getTheta() const
Return Corrected Theta of Shower (radian).
Definition: ECLCluster.h:326
Belle2::ChargedPidMVAModule::m_event_metadata
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
Definition: ChargedPidMVAModule.h:136
Belle2::ChargedPidMVAModule::m_variables
VariablesLists m_variables
List of lists of feature variables.
Definition: ChargedPidMVAModule.h:161
Belle2::ChargedPidMVAMulticlassModule::m_payload_name
std::string m_payload_name
The name of the database payload object with the MVA weights.
Definition: ChargedPidMVAMulticlassModule.h:127
Belle2::ChargedPidMVAModule::m_sig_pdg
int m_sig_pdg
The input signal mass hypothesis' pdgId.
Definition: ChargedPidMVAModule.h:106
Belle2::Module
Base class for Modules.
Definition: Module.h:74
Belle2::Module::setPropertyFlags
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:210
Belle2::ChargedPidMVAModule::m_particle_lists
std::vector< std::string > m_particle_lists
The input list of names of ParticleList objects to which MVA weights will be applied.
Definition: ChargedPidMVAModule.h:116
Belle2::ChargedPidMVAModule::m_score_varname
std::string m_score_varname
The lookup name of the MVA score variable, given the input S, B mass hypotheses for the algorithm.
Definition: ChargedPidMVAModule.h:131
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::StoreObjPtr
Type-safe access to single objects in the data store.
Definition: ParticleList.h:33
Belle2::ChargedPidMVAMulticlassModule::m_particle_lists
std::vector< std::string > m_particle_lists
The input list of ParticleList names.
Definition: ChargedPidMVAMulticlassModule.h:122
Belle2::ChargedPidMVAModule::m_ecl_only
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
Definition: ChargedPidMVAModule.h:126
Belle2::ChargedPidMVAModule::beginRun
virtual void beginRun() override
Called once before a new run begins.
Belle2::ChargedPidMVAModule::event
virtual void event() override
Called once for each event.
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::ChargedPidMVAMulticlassModule::m_ecl_only
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
Definition: ChargedPidMVAMulticlassModule.h:132
Belle2::Particle
Class to store reconstructed particles.
Definition: Particle.h:77
Belle2::Module::addParam
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:562
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::ChargedPidMVAModule::m_bkg_pdg
int m_bkg_pdg
The input background mass hypothesis' pdgId.
Definition: ChargedPidMVAModule.h:111
Belle2::ChargedPidMVAModule::m_spectators
VariablesLists m_spectators
List of lists of spectator variables.
Definition: ChargedPidMVAModule.h:167
Belle2::ChargedPidMVAModule::initialize
virtual void initialize() override
Use this to initialize resources or memory your module needs.
Belle2::Variable::Manager
Global list of available variables.
Definition: Manager.h:108
Belle2::Variable::Manager::Instance
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:27