Belle II Software development
ChargedPidMVAModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8//THIS MODULE
9#include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAModule.h>
10
11//ANALYSIS
12#include <mva/interface/Interface.h>
13#include <analysis/VariableManager/Utility.h>
14#include <analysis/dataobjects/Particle.h>
15#include <analysis/dataobjects/ParticleList.h>
16#include <analysis/variables/ECLVariables.h>
17
18// FRAMEWORK
19#include <framework/logging/LogConfig.h>
20#include <framework/logging/LogSystem.h>
21
22
23using namespace Belle2;
24
25REG_MODULE(ChargedPidMVA);
26
27ChargedPidMVAModule::ChargedPidMVAModule() : Module()
28{
29 setDescription("This module evaluates the response of an MVA trained for binary charged particle identification between two hypotheses, S and B. For a given input set of (S,B) mass hypotheses, it takes the Particle objects in the appropriate charged stable particle's ParticleLists, calculates the MVA score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
30
31 setPropertyFlags(c_ParallelProcessingCertified);
32
33 addParam("sigHypoPDGCode",
34 m_sig_pdg,
35 "The input signal mass hypothesis' pdgId.",
36 int(0));
37 addParam("bkgHypoPDGCode",
38 m_bkg_pdg,
39 "The input background mass hypothesis' pdgId.",
40 int(0));
41 addParam("particleLists",
42 m_decayStrings,
43 "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
44 std::vector<std::string>());
45 addParam("payloadName",
46 m_payload_name,
47 "The name of the database payload object with the MVA weights.",
48 std::string("ChargedPidMVAWeights"));
49 addParam("chargeIndependent",
50 m_charge_independent,
51 "Specify whether to use a charge-independent training of the MVA.",
52 bool(false));
53 addParam("useECLOnlyTraining",
54 m_ecl_only,
55 "Specify whether to use an ECL-only training of the MVA.",
56 bool(false));
57}
58
59
60ChargedPidMVAModule::~ChargedPidMVAModule() = default;
61
62
63void ChargedPidMVAModule::initialize()
64{
65 m_event_metadata.isRequired();
66
67 m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
68
69 if (!(*m_weightfiles_representation.get())->isValidPdg(m_sig_pdg)) {
70 B2FATAL("PDG: " << m_sig_pdg <<
71 " of the signal mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
72 }
73 if (!(*m_weightfiles_representation.get())->isValidPdg(m_bkg_pdg)) {
74 B2FATAL("PDG: " << m_bkg_pdg <<
75 " of the background mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
76 }
77
78 /* Initialize MVA if the payload has changed and now. */
79 (*m_weightfiles_representation.get()).addCallback([this]() { initializeMVA(); });
81
82 m_score_varname = "pidPairChargedBDTScore_" + std::to_string(m_sig_pdg) + "_VS_" + std::to_string(m_bkg_pdg);
83
84 if (m_ecl_only) {
85 m_score_varname += "_" + std::to_string(Const::ECL);
86 } else {
87 for (const Const::EDetector& det : Const::PIDDetectorSet::set()) {
88 m_score_varname += "_" + std::to_string(det);
89 }
90 }
91}
92
93
95{
96}
97
98
100{
101
102 // Debug strings per log level.
103 std::map<int, std::string> debugStr = {
104 {11, ""},
105 {12, ""}
106 };
107
108 B2DEBUG(11, "EVENT: " << m_event_metadata->getEvent());
109
110 for (auto decayString : m_decayStrings) {
111
112 DecayDescriptor decayDescriptor;
113 decayDescriptor.init(decayString);
114 auto pListName = decayDescriptor.getMother()->getFullName();
115
116 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
117 StoreObjPtr<ParticleList> pList(pListName);
118
119 if (!pList) {
120 B2FATAL("ParticleList: " << pListName << " could not be found. Aborting...");
121 }
122
123 auto pListSize = pList->getListSize();
124
125 B2DEBUG(11, "ParticleList: " << pList->getParticleListName() << " - N = " << pListSize << " particles.");
126
127 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
128
129 // Need to get an absolute value in order to check if in Const::ChargedStable.
130 std::vector<int> pdgs;
131 if (m_nSelectedDaughters == 0) {
132 pdgs.push_back(pList->getPDGCode());
133 } else {
134 pdgs = decayDescriptor.getSelectionPDGCodes();
135 }
136 for (auto pdg : pdgs) {
137 // Check if this ParticleList is made up of legit Const::ChargedStable particles.
138 if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
139 B2FATAL("PDG: " << pdg << " of ParticleList: " << pListName <<
140 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
141 }
142 }
143 std::vector<const Particle*> targetParticles;
144 if (m_nSelectedDaughters > 0) {
145 for (unsigned int iPart(0); iPart < pListSize; ++iPart) {
146 auto* iParticle = pList->getParticle(iPart);
147 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
148 for (auto* iDaughter : daughters) {
149 targetParticles.push_back(iDaughter);
150 }
151 }
152 }
153
154 for (unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
155
156 const Particle* particle = (m_nSelectedDaughters == 0) ? pList->getParticle(ipart) : targetParticles[ipart];
157
158 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
159 // LEGACY TRAININGS: always require a track-cluster match.
160 const ECLCluster* eclCluster = particle->getECLCluster();
161 if (!eclCluster) {
162 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid Track-ECLCluster relation, skip MVA application...");
163 continue;
164 }
165 }
166
167 // Retrieve the index for the correct MVA expert and dataset,
168 // given the reconstructed (polar angle, p, charge)
169 auto thVarName = (*m_weightfiles_representation.get())->getThetaVarName();
170 auto theta = std::get<double>(Variable::Manager::Instance().getVariable(thVarName)->function(particle));
171 auto p = particle->getP();
172 // Set a dummy charge of zero to pick charge-independent payloads, if requested.
173 auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
174 if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
175 B2DEBUG(11, "\nParticle [" << ipart << "] has invalid input variable, skip MVA application..." <<
176 " polar angle: " << theta << ", p: " << p << ", charge: " << charge);
177 continue;
178 }
179
180 int idx_theta, idx_p, idx_charge;
181 auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
182
183 auto hasMatch = std::isnormal(Variable::eclClusterTrackMatched(particle));
184
185 debugStr[11] += "\n";
186 debugStr[11] += ("Particle [" + std::to_string(ipart) + "]\n");
187 debugStr[11] += ("Has ECL cluster match? " + std::to_string(hasMatch) + "\n");
188 debugStr[11] += ("polar angle: " + thVarName + " = " + std::to_string(theta) + " [rad]\n");
189 debugStr[11] += ("p = " + std::to_string(p) + " [GeV/c]\n");
191 debugStr[11] += ("charge = " + std::to_string(charge) + "\n");
192 }
193 debugStr[11] += ("Is brems corrected ? " + std::to_string(particle->hasExtraInfo("bremsCorrected")) + "\n");
194 debugStr[11] += ("Weightfile idx = " + std::to_string(index) + " - (polar angle, p, charge) = (" + std::to_string(
195 idx_theta) + ", " + std::to_string(idx_p) + ", " +
196 std::to_string(idx_charge) + ")\n");
197 if (m_cuts.at(index)) {
198 debugStr[11] += ("Category cut: " + m_cuts.at(index)->decompile() + "\n");
199 }
200
201 B2DEBUG(11, debugStr[11]);
202 debugStr[11].clear();
203
204 // Don't even bother if particle does not fulfil the category selection.
205 if (m_cuts.at(index)) {
206 if (!m_cuts.at(index)->check(particle)) {
207 B2DEBUG(11, "\nParticle [" << ipart << "] didn't pass MVA category cut, skip MVA application...");
208 continue;
209 }
210 }
211
212 // Fill the MVA::SingleDataset w/ variables and spectators.
213
214 debugStr[11] += "\n";
215 debugStr[11] += "MVA variables:\n";
216
217 auto nvars = m_variables.at(index).size();
218 for (unsigned int ivar(0); ivar < nvars; ++ivar) {
219
220 auto varobj = m_variables.at(index).at(ivar);
221
222 double var = std::numeric_limits<double>::quiet_NaN();
223 auto var_result = varobj->function(particle);
224 if (std::holds_alternative<double>(var_result)) {
225 var = std::get<double>(var_result);
226 } else if (std::holds_alternative<int>(var_result)) {
227 var = std::get<int>(var_result);
228 } else if (std::holds_alternative<bool>(var_result)) {
229 var = std::get<bool>(var_result);
230 } else {
231 B2ERROR("Variable '" << varobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
232 }
233
234 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
235 // LEGACY TRAININGS: manual imputation value of -999 for NaN (undefined) variables. Needed by TMVA.
236 var = (std::isnan(var)) ? -999.0 : var;
237 }
238
239 debugStr[11] += ("\tvar[" + std::to_string(ivar) + "] : " + varobj->name + " = " + std::to_string(var) + "\n");
240
241 m_datasets.at(index)->m_input[ivar] = var;
242
243 }
244
245 B2DEBUG(11, debugStr[11]);
246 debugStr[11].clear();
247
248 // Check spectators only when in debug mode.
249 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 12)) {
250
251 debugStr[12] += "\n";
252 debugStr[12] += "MVA spectators:\n";
253
254 auto nspecs = m_spectators.at(index).size();
255 for (unsigned int ispec(0); ispec < nspecs; ++ispec) {
256
257 auto specobj = m_spectators.at(index).at(ispec);
258
259 double spec = std::numeric_limits<double>::quiet_NaN();
260 auto spec_result = specobj->function(particle);
261 if (std::holds_alternative<double>(spec_result)) {
262 spec = std::get<double>(spec_result);
263 } else if (std::holds_alternative<int>(spec_result)) {
264 spec = std::get<int>(spec_result);
265 } else if (std::holds_alternative<bool>(spec_result)) {
266 spec = std::get<bool>(spec_result);
267 } else {
268 B2ERROR("Variable '" << specobj->name << "' has wrong data type! It must be one of double, integer, or bool.");
269 }
270
271 debugStr[12] += ("\tspec[" + std::to_string(ispec) + "] : " + specobj->name + " = " + std::to_string(spec) + "\n");
272
273 m_datasets.at(index)->m_spectators[ispec] = spec;
274
275 }
276
277 B2DEBUG(12, debugStr[12]);
278 debugStr[12].clear();
279
280 }
281
282 // Compute MVA score.
283
284 debugStr[11] += "\n";
285 debugStr[12] += "\n";
286 debugStr[11] += "MVA response:\n";
287
288 float score = m_experts.at(index)->apply(*m_datasets.at(index))[0];
289
290 debugStr[11] += ("\tscore = " + std::to_string(score));
291 debugStr[12] += ("\textraInfo: " + m_score_varname + "\n");
292
293 // Store the MVA score as a new particle object property.
294 m_particles[particle->getArrayIndex()]->writeExtraInfo(m_score_varname, score);
295
296 B2DEBUG(11, debugStr[11]);
297 B2DEBUG(12, debugStr[12]);
298 debugStr[11].clear();
299 debugStr[12].clear();
300
301 }
302
303 }
304
305 // Clear the debug string map before next event.
306 debugStr.clear();
307
308}
309
310
312{
313
314 std::map<std::string, std::string> aliasesLegacy;
315
316 aliasesLegacy.insert(std::make_pair("__event__", "evtNum"));
317
319 it != Const::PIDDetectorSet::set().end(); ++it) {
320
321 auto detName = Const::parseDetectors(*it);
322
323 aliasesLegacy.insert(std::make_pair("missingLogL_" + detName, "pidMissingProbabilityExpert(" + detName + ")"));
324
325 for (auto& [pdgId, info] : m_stdChargedInfo) {
326
327 std::string alias = "deltaLogL_" + std::get<0>(info) + "_" + std::get<1>(info) + "_" + detName;
328 std::string var = "pidDeltaLogLikelihoodValueExpert(" + std::to_string(pdgId) + ", " + std::to_string(std::get<2>
329 (info)) + "," + detName + ")";
330
331 aliasesLegacy.insert(std::make_pair(alias, var));
332
333 if (it.getIndex() == 0) {
334 alias = "deltaLogL_" + std::get<0>(info) + "_" + std::get<1>(info) + "_ALL";
335 var = "pidDeltaLogLikelihoodValueExpert(" + std::to_string(pdgId) + ", " + std::to_string(std::get<2>(info)) + ", ALL)";
336 aliasesLegacy.insert(std::make_pair(alias, var));
337 }
338
339 }
340
341 }
342
343 B2INFO("Setting hard-coded aliases for the ChargedPidMVA algorithm.");
344
345 std::string debugStr("\n");
346 for (const auto& [alias, variable] : aliasesLegacy) {
347 debugStr += (alias + " --> " + variable + "\n");
348 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
349 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
350 }
351 }
352 B2DEBUG(10, debugStr);
353
354}
355
356
358{
359
360 auto aliases = (*m_weightfiles_representation.get())->getAliases();
361
362 if (!aliases->empty()) {
363
364 B2INFO("Setting aliases for the ChargedPidMVA algorithm read from the payload.");
365
366 std::string debugStr("\n");
367 for (const auto& [alias, variable] : *aliases) {
368 if (alias != variable) {
369 debugStr += (alias + " --> " + variable + "\n");
370 if (!Variable::Manager::Instance().addAlias(alias, variable)) {
371 B2ERROR("Something went wrong with setting alias: " << alias << " for variable: " << variable);
372 }
373 }
374 }
375 B2DEBUG(10, debugStr);
376
377 return;
378
379 }
380
381 // Manually set aliases - for bw compatibility
382 this->registerAliasesLegacy();
383
384}
385
386
388{
389
390 B2INFO("Run: " << m_event_metadata->getRun() << ". Load supported MVA interfaces for binary charged particle identification...");
391
392 // Set the necessary variable aliases from the payload.
393 this->registerAliases();
394
395 // The supported methods have to be initialized once (calling it more than once is safe).
397 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
398
399 B2INFO("\tLoading weightfiles from the payload class for SIGNAL particle hypothesis: " << m_sig_pdg);
400
401 auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeights(m_sig_pdg);
402 auto nfiles = serialized_weightfiles->size();
403
404 B2INFO("\tConstruct the MVA experts and datasets from N = " << nfiles << " weightfiles...");
405
406 // The size of the vectors must correspond
407 // to the number of available weightfiles for this pdgId.
408 m_experts.resize(nfiles);
409 m_datasets.resize(nfiles);
410 m_cuts.resize(nfiles);
411 m_variables.resize(nfiles);
412 m_spectators.resize(nfiles);
413
414 for (unsigned int idx(0); idx < nfiles; idx++) {
415
416 B2DEBUG(12, "\t\tweightfile[" << idx << "]");
417
418 // De-serialize the string into an MVA::Weightfile object.
419 std::stringstream ss(serialized_weightfiles->at(idx));
420 auto weightfile = MVA::Weightfile::loadFromStream(ss);
421
422 MVA::GeneralOptions general_options;
423 weightfile.getOptions(general_options);
424
425 // Store the list of pointers to the relevant variables for this xml file.
427 m_variables[idx] = manager.getVariables(general_options.m_variables);
428 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
429
430 B2DEBUG(12, "\t\tRetrieved N = " << general_options.m_variables.size()
431 << " variables, N = " << general_options.m_spectators.size()
432 << " spectators");
433
434 // Store an MVA::Expert object.
435 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
436 m_experts.at(idx)->load(weightfile);
437
438 B2DEBUG(12, "\t\tweightfile loaded successfully into expert[" << idx << "]!");
439
440 // Store an MVA::SingleDataset object, in which we will save our features later...
441 std::vector<float> v(general_options.m_variables.size(), 0.0);
442 std::vector<float> s(general_options.m_spectators.size(), 0.0);
443 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
444
445 B2DEBUG(12, "\t\tdataset[" << idx << "] created successfully!");
446
447 // Compile cut for this category.
448 const auto cuts = (*m_weightfiles_representation.get())->getCuts(m_sig_pdg);
449 const auto cutstr = (!cuts->empty()) ? cuts->at(idx) : "";
450 m_cuts[idx] = (!cutstr.empty()) ? Variable::Cut::compile(cutstr) : nullptr;
451
452 B2DEBUG(12, "\t\tcut[" << idx << "] created successfully!");
453
454 }
455
456}
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
int m_bkg_pdg
The input background mass hypothesis' pdgId.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
std::string m_score_varname
The lookup name of the MVA score variable, given the input S, B mass hypotheses for the algorithm.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::map< int, std::tuple< std::string, std::string, int > > m_stdChargedInfo
Map with standard charged particles' info.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual void event() override
Called once for each event.
void registerAliases()
Set variable aliases needed by the MVA.
int m_sig_pdg
The input signal mass hypothesis' pdgId.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
CutsList m_cuts
List of Cut objects.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
Definition: UnitConst.cc:220
static DetectorSet set()
Accessor for the set of valid detector IDs.
Definition: Const.h:333
EDetector
Enum for identifying the detector components (detector and subdetector).
Definition: Const.h:42
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
Definition: UnitConst.cc:162
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
ECL cluster data.
Definition: ECLCluster.h:27
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
Definition: GeneralCut.h:84
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
Base class for Modules.
Definition: Module.h:72
Class to store reconstructed particles.
Definition: Particle.h:75
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
bool addAlias(const std::string &alias, const std::string &variable)
Add alias Return true if the alias was successfully added.
Definition: Manager.cc:95
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
double charge(int pdgCode)
Returns electric charge of a particle with given pdg code.
Definition: EvtPDLUtil.cc:44
Abstract base class for different kinds of events.