Belle II Software light-2406-ragdoll
ChargedPidMVAModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8//THIS MODULE
9#include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAModule.h>
10
11//ANALYSIS
12#include <mva/interface/Interface.h>
13#include <analysis/VariableManager/Utility.h>
14#include <analysis/dataobjects/Particle.h>
15#include <analysis/dataobjects/ParticleList.h>
16
17// FRAMEWORK
18#include <framework/logging/LogConfig.h>
19#include <framework/logging/LogSystem.h>
20
21
22using namespace Belle2;
23
24REG_MODULE(ChargedPidMVA);
25
26ChargedPidMVAModule::ChargedPidMVAModule() : Module()
27{
28 setDescription("This module evaluates the response of an MVA trained for binary charged particle identification between two hypotheses, S and B. For a given input set of (S,B) mass hypotheses, it takes the Particle objects in the appropriate charged stable particle's ParticleLists, calculates the MVA score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
29
30 setPropertyFlags(c_ParallelProcessingCertified);
31
32 addParam("sigHypoPDGCode",
33 m_sig_pdg,
34 "The input signal mass hypothesis' pdgId.",
35 int(0));
36 addParam("bkgHypoPDGCode",
37 m_bkg_pdg,
38 "The input background mass hypothesis' pdgId.",
39 int(0));
40 addParam("particleLists",
41 m_decayStrings,
42 "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
43 std::vector<std::string>());
44 addParam("payloadName",
45 m_payload_name,
46 "The name of the database payload object with the MVA weights.",
47 std::string("ChargedPidMVAWeights"));
48 addParam("chargeIndependent",
49 m_charge_independent,
50 "Specify whether to use a charge-independent training of the MVA.",
51 bool(false));
52 addParam("useECLOnlyTraining",
53 m_ecl_only,
54 "Specify whether to use an ECL-only training of the MVA.",
55 bool(false));
56}
57
58
59ChargedPidMVAModule::~ChargedPidMVAModule() = default;
60
61
62void ChargedPidMVAModule::initialize()
63{
64 m_event_metadata.isRequired();
65
66 m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
67
68 if (!(*m_weightfiles_representation.get())->isValidPdg(m_sig_pdg)) {
69 B2FATAL("PDG: " << m_sig_pdg <<
70 " of the signal mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
71 }
72 if (!(*m_weightfiles_representation.get())->isValidPdg(m_bkg_pdg)) {
73 B2FATAL("PDG: " << m_bkg_pdg <<
74 " of the background mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
75 }
76
77 /* Initialize MVA if the payload has changed and now. */
78 (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
80
81 m_score_varname = "pidPairChargedBDTScore_" + std::to_string(m_sig_pdg) + "_VS_" + std::to_string(m_bkg_pdg);
82
83 if (m_ecl_only) {
84 m_score_varname += "_" + std::to_string(Const::ECL);
85 } else {
86 for (const Const::EDetector& det : Const::PIDDetectorSet::set()) {
87 m_score_varname += "_" + std::to_string(det);
88 }
89 }
90}
91
92
94{
95}
96
97
99{
100
101 // Debug strings per log level.
102 std::map<int, std::string> debugStr = {
103 {11, ""},
104 {12, ""}
105 };
106
107 B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
108
109 for (auto decayString : m_decayStrings) {
110
111 DecayDescriptor decayDescriptor;
112 decayDescriptor.init(decayString);
113 auto pListName = decayDescriptor.getMother()->getFullName();
114
115 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
116 StoreObjPtr<ParticleList> pList(pListName);
117
118 if (!pList) {
119 B2FATAL("ParticleList: " << pListName << " could not be found. Aborting...");
120 }
121
122 auto pListSize = pList->getListSize();
123
124 B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pListSize << " particles.");
125
126 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
127
128 // Need to get an absolute value in order to check if in Const::ChargedStable.
129 std::vector<int> pdgs;
130 if (m_nSelectedDaughters == 0) {
131 pdgs.push_back(pList->getPDGCode());
132 } else {
133 pdgs = decayDescriptor.getSelectionPDGCodes();
134 }
135 for (auto pdg : pdgs) {
136 // Check if this ParticleList is made up of legit Const::ChargedStable particles.
137 if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
138 B2FATAL("PDG: " << pdg << " of ParticleList: " << pListName <<
139 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
140 }
141 }
142 std::vector<const Particle*> targetParticles;
143 if (m_nSelectedDaughters > 0) {
144 for (unsigned int iPart(0); iPart < pListSize; ++iPart) {
145 auto* iParticle = pList->getParticle(iPart);
146 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
147 for (auto* iDaughter : daughters) {
148 targetParticles.push_back(iDaughter);
149 }
150 }
151 }
152
153 for (unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
154
155 const Particle* particle = (m_nSelectedDaughters == 0) ? pList->getParticle(ipart) : targetParticles[ipart];
156
157 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
158 // LEGACY TRAININGS: always require a track-cluster match.
159 const ECLCluster* eclCluster = particle->getECLCluster();
160 if (!eclCluster) {
161 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid Track-ECLCluster relation, skip MVA application...");
162 continue;
163 }
164 }
165
166 // Retrieve the index for the correct MVA expert and dataset,
167 // given the reconstructed (polar angle, p, charge)
168 auto thVarName = (*m_weightfiles_representation.get())->getThetaVarName();
169 auto theta = std::get<double>(Variable::Manager::Instance().getVariable(thVarName)->function(particle));
170 auto p = particle->getP();
171 // Set a dummy charge of zero to pick charge-independent payloads, if requested.
172 auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
173 if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
174 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid input variable, skip MVA application..." <<
175 " polar angle: " << theta << ", p: " << p << ", charge: " << charge);
176 continue;
177 }
178
179 int idx_theta, idx_p, idx_charge;
180 auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
181
182 auto* matchVar = Variable::Manager::Instance().getVariable("clusterTrackMatch");
183 auto hasMatch = std::isnormal(std::get<double>(matchVar->function(particle)));
184
185 debugStr[11] += "\n";
186 debugStr[11] += ("Particle [" + std::to_string(ipart) + "]\n");
187 debugStr[11] += ("Has ECL cluster match? " + std::to_string(hasMatch) + "\n");
188 debugStr[11] += ("polar angle: " + thVarName + " = " + std::to_string(theta) + " [rad]\n");
189 debugStr[11] += ("p = " + std::to_string(p) + " [GeV/c]\n");
191 debugStr[11] += ("charge = " + std::to_string(charge) + "\n");
192 }
193 debugStr[11] += ("Is brems corrected ? " + std::to_string(particle->hasExtraInfo("bremsCorrected")) + "\n");
194 debugStr[11] += ("Weightfile idx = " + std::to_string(index) + " - (polar angle, p, charge) = (" + std::to_string(
195 idx_theta) + ", " + std::to_string(idx_p) + ", " +
196 std::to_string(idx_charge) + ")\n");
197 if (m_cuts.at(index)) {
198 debugStr[11] += ("Category cut: " + m_cuts.at(index)->decompile() + "\n");
199 }
200
201 B2DEBUG(11, debugStr[11]);
202 debugStr[11].clear();
203
204 // Don't even bother if particle does not fulfil the category selection.
205 if (m_cuts.at(index)) {
206 if (!m_cuts.at(index)->check(particle)) {
207 B2DEBUG(11, "\nParticle [" << ipart << "] didn't pass MVA category cut, skip MVA application...");
208 continue;
209 }
210 }
211
212 // Fill the MVA::SingleDataset w/ variables and spectators.
213
214 debugStr[11] += "\n";
215 debugStr[11] += "MVA variables:\n";
216
217 auto nvars = m_variables.at(index).size();
218 for (unsigned int ivar(0); ivar < nvars; ++ivar) {
219
220 auto varobj = m_variables.at(index).at(ivar);
221
222 double var = std::numeric_limits<double>::quiet_NaN();
223 auto var_result = varobj->function(particle);
224 if (std::holds_alternative<double>(var_result)) {
225 var = std::get<double>(var_result);
226 } else if (std::holds_alternative<int>(var_result)) {
227 var = std::get<int>(var_result);
228 } else if (std::holds_alternative<bool>(var_result)) {
229 var = std::get<bool>(var_result);
230 } else {
231 B2ERROR("Variable '" << varobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
232 }
233
234 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
235 // LEGACY TRAININGS: manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
236 var = (std::isnan(var)) ? -999.0 : var;
237 }
238
239 debugStr[11] += ("\tvar[" + std::to_string(ivar) + "] : " + varobj->name + " = " + std::to_string(var) + "\n");
240
241 m_datasets.at(index)->m_input[ivar] = var;
242
243 }
244
245 B2DEBUG(11, debugStr[11]);
246 debugStr[11].clear();
247
248 // Check spectators only when in debug mode.
249 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 12)) {
250
251 debugStr[12] += "\n";
252 debugStr[12] += "MVA spectators:\n";
253
254 auto nspecs = m_spectators.at(index).size();
255 for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
256
257 auto specobj = m_spectators.at(index).at(ispec);
258
259 double spec = std::numeric_limits<double>::quiet_NaN();
260 auto spec_result = specobj->function(particle);
261 if (std::holds_alternative<double>(spec_result)) {
262 spec = std::get<double>(spec_result);
263 } else if (std::holds_alternative<int>(spec_result)) {
264 spec = std::get<int>(spec_result);
265 } else if (std::holds_alternative<bool>(spec_result)) {
266 spec = std::get<bool>(spec_result);
267 } else {
268 B2ERROR("Variable '" << specobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
269 }
270
271 debugStr[12] += ("\tspec[" + std::to_string(ispec) + "] : " + specobj->name + " = " + std::to_string(spec) + "\n");
272
273 m_datasets.at(index)->m_spectators[ispec] = spec;
274
275 }
276
277 B2DEBUG(12, debugStr[12]);
278 debugStr[12].clear();
279
280 }
281
282 // Compute MVA score.
283
284 debugStr[11] += "\n";
285 debugStr[12] += "\n";
286 debugStr[11] += "MVA response:\n";
287
288 float score = m_experts.at(index)->apply(*m_datasets.at(index))[0];
289
290 debugStr[11] += ("\tscore = " + std::to_string(score));
291 debugStr[12] += ("\textraInfo: " + m_score_varname + "\n");
292
293 // Store the MVA score as a new particle object property.
294 m_particles[particle->getArrayIndex()]->writeExtraInfo(m_score_varname, score);
295
296 B2DEBUG(11, debugStr[11]);
297 B2DEBUG(12, debugStr[12]);
298 debugStr[11].clear();
299 debugStr[12].clear();
300
301 }
302
303 }
304
305 // Clear the debug string map before next event.
306 debugStr.clear();
307
308}
309
310
312{
313
314 std::map<std::string, std::string> aliasesLegacy;
315
316 aliasesLegacy.insert(std::make_pair("__event__", "evtNum"));
317
319 it != Const::PIDDetectorSet::set().end(); ++it) {
320
321 auto detName = Const::parseDetectors(*it);
322
323 aliasesLegacy.insert(std::make_pair("missingLogL_" + detName, "pidMissingProbabilityExpert(" + detName + ")"));
324
325 for (auto& [pdgId, info] : m_stdChargedInfo) {
326
327 std::string alias = "deltaLogL_" + std::get<0>(info) + "_" + std::get<1>(info) + "_" + detName;
328 std::string var = "pidDeltaLogLikelihoodValueExpert(" + std::to_string(pdgId) + ", " + std::to_string(std::get<2>
329 (info)) + "," + detName + ")";
330
331 aliasesLegacy.insert(std::make_pair(alias, var));
332
333 if (it.getIndex() == 0) {
334 alias = "deltaLogL_" + std::get<0>(info) + "_" + std::get<1>(info) + "_ALL";
335 var = "pidDeltaLogLikelihoodValueExpert(" + std::to_string(pdgId) + ", " + std::to_string(std::get<2>(info)) + ", ALL)";
336 aliasesLegacy.insert(std::make_pair(alias, var));
337 }
338
339 }
340
341 }
342
343 B2INFO("Setting hard-coded aliases for the ChargedPidMVA algorithm.");
344
345 std::string debugStr("\n");
346 for (const auto& [alias, variable] : aliasesLegacy) {
347 debugStr += (alias + " --> " + variable + "\n");
348 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
349 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
350 }
351 }
352 B2DEBUG(10, debugStr);
353
354}
355
356
358{
359
360 auto aliases = (*m_weightfiles_representation.get())->getAliases();
361
362 if (!aliases->empty()) {
363
364 B2INFO("Setting aliases for the ChargedPidMVA algorithm read from the payload.");
365
366 std::string debugStr("\n");
367 for (const auto& [alias, variable] : *aliases) {
368 if (alias != variable) {
369 debugStr += (alias + " --> " + variable + "\n");
370 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
371 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
372 }
373 }
374 }
375 B2DEBUG(10, debugStr);
376
377 return;
378
379 }
380
381 // Manually set aliases - for bw compatibility
382 this->registerAliasesLegacy();
383
384}
385
386
388{
389
390 B2INFO("Run: " << m_event_metadata->getRun() << ". Load supported MVA interfaces for binary charged particle identification...");
391
392 // Set the necessary variable aliases from the payload.
393 this->registerAliases();
394
395 // The supported methods have to be initialized once (calling it more than once is safe).
397 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
398
399 B2INFO("\tLoading weightfiles from the payload class for SIGNAL particle hypothesis: " << m_sig_pdg);
400
401 auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeights(m_sig_pdg);
402 auto nfiles = serialized_weightfiles->size();
403
404 B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
405
406 // The size of the vectors must correspond
407 // to the number of available weightfiles for this pdgId.
408 m_experts.resize(nfiles);
409 m_datasets.resize(nfiles);
410 m_cuts.resize(nfiles);
411 m_variables.resize(nfiles);
412 m_spectators.resize(nfiles);
413
414 for (unsigned int idx(0); idx < nfiles; idx++) {
415
416 B2DEBUG(12, "\t\tweightfile[" << idx << "]");
417
418 // De-serialize the string into an MVA::Weightfile object.
419 std::stringstream ss(serialized_weightfiles->at(idx));
420 auto weightfile = MVA::Weightfile::loadFromStream(ss);
421
422 MVA::GeneralOptions general_options;
423 weightfile.getOptions(general_options);
424
425 // Store the list of pointers to the relevant variables for this xml file.
427 m_variables[idx] = manager.getVariables(general_options.m_variables);
428 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
429
430 B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
431 << " variables, N = " << general_options.m_spectators.size()
432 << " spectators");
433
434 // Store an MVA::Expert object.
435 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
436 m_experts.at(idx)->load(weightfile);
437
438 B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
439
440 // Store an MVA::SingleDataset object, in which we will save our features later...
441 std::vector<float> v(general_options.m_variables.size(), 0.0);
442 std::vector<float> s(general_options.m_spectators.size(), 0.0);
443 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
444
445 B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
446
447 // Compile cut for this category.
448 const auto cuts = (*m_weightfiles_representation.get())->getCuts(m_sig_pdg);
449 const auto cutstr = (!cuts->empty()) ? cuts->at(idx) : "";
450 m_cuts[idx] = (!cutstr.empty()) ? Variable::Cut::compile(cutstr) : nullptr;
451
452 B2DEBUG(12, "\t\tcut[" << idx << "] created successfully!");
453
454 }
455
456}
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
int m_bkg_pdg
The input background mass hypothesis' pdgId.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
std::string m_score_varname
The lookup name of the MVA score variable, given the input S, B mass hypotheses for the algorithm.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::map< int, std::tuple< std::string, std::string, int > > m_stdChargedInfo
Map with standard charged particles' info.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual void event() override
Called once for each event.
void registerAliases()
Set variable aliases needed by the MVA.
int m_sig_pdg
The input signal mass hypothesis' pdgId.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
CutsList m_cuts
List of Cut objects.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
Definition: UnitConst.cc:220
static DetectorSet set()
Accessor for the set of valid detector IDs.
Definition: Const.h:333
EDetector
Enum for identifying the detector components (detector and subdetector).
Definition: Const.h:42
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
Definition: UnitConst.cc:162
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
ECL cluster data.
Definition: ECLCluster.h:27
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:84
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
Base class for Modules.
Definition: Module.h:72
Class to store reconstructed particles.
Definition: Particle.h:75
const ECLCluster * getECLCluster() const
Returns the pointer to the ECLCluster object that was used to create this Particle (if ParticleType =...
Definition: Particle.cc:891
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
Definition: Particle.cc:1266
double getCharge(void) const
Returns particle charge.
Definition: Particle.cc:622
double getP() const
Returns momentum magnitude (same as getMomentumMagnitude but with shorter name)
Definition: Particle.h:578
int getArrayIndex() const
Returns this object's array index (in StoreArray), or -1 if not found.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
const Var * getVariable(std::string name)
Get the variable belonging to the given key.
Definition: Manager.cc:57
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
bool addAlias(const std::string &alias, const std::string &variable)
Add alias Return true if the alias was successfully added.
Definition: Manager.cc:95
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
double charge(int pdgCode)
Returns electric charge of a particle with given pdg code.
Definition: EvtPDLUtil.cc:44
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24