9#include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAModule.h>
12#include <analysis/dataobjects/Particle.h>
13#include <analysis/dataobjects/ParticleList.h>
14#include <analysis/DecayDescriptor/DecayDescriptor.h>
15#include <analysis/VariableManager/Utility.h>
16#include <analysis/variables/ECLVariables.h>
19#include <framework/logging/LogConfig.h>
20#include <framework/logging/LogSystem.h>
21#include <mva/interface/Interface.h>
29 setDescription(
"This module evaluates the response of an MVA trained for binary charged particle identification between two hypotheses, S and B. For a given input set of (S,B) mass hypotheses, it takes the Particle objects in the appropriate charged stable particle's ParticleLists, calculates the MVA score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
35 "The input signal mass hypothesis' pdgId.",
39 "The input background mass hypothesis' pdgId.",
43 "The input list of DecayStrings, where each selected (^) daughter should correspond to a standard charged ParticleList, e.g. ['Lambda0:sig -> ^p+ ^pi-', 'J/psi:sig -> ^mu+ ^mu-']. One can also directly pass a list of standard charged ParticleLists, e.g. ['e+:my_electrons', 'pi+:my_pions']. Note that charge-conjugated ParticleLists will automatically be included.",
44 std::vector<std::string>());
47 "The name of the database payload object with the MVA weights.",
48 std::string(
"ChargedPidMVAWeights"));
51 "Specify whether to use a charge-independent training of the MVA.",
55 "Specify whether to use an ECL-only training of the MVA.",
65 m_event_metadata.isRequired();
67 m_weightfiles_representation = std::make_unique<DBObjPtr<ChargedPidMVAWeights>>(m_payload_name);
69 if (!(*m_weightfiles_representation.get())->isValidPdg(m_sig_pdg)) {
70 B2FATAL(
"PDG: " << m_sig_pdg <<
71 " of the signal mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
73 if (!(*m_weightfiles_representation.get())->isValidPdg(m_bkg_pdg)) {
74 B2FATAL(
"PDG: " << m_bkg_pdg <<
75 " of the background mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
79 (*m_weightfiles_representation.get()).addCallback([
this]() { initializeMVA(); });
82 m_score_varname =
"pidPairChargedBDTScore_" + std::to_string(m_sig_pdg) +
"_VS_" + std::to_string(m_bkg_pdg);
85 m_score_varname +=
"_" + std::to_string(Const::ECL);
88 m_score_varname +=
"_" + std::to_string(det);
103 std::map<int, std::string> debugStr = {
108 B2DEBUG(11,
"EVENT: " << m_event_metadata->getEvent());
110 for (
auto decayString : m_decayStrings) {
113 decayDescriptor.init(decayString);
114 auto pListName = decayDescriptor.getMother()->getFullName();
116 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
120 B2FATAL(
"ParticleList: " << pListName <<
" could not be found. Aborting...");
123 auto pListSize = pList->getListSize();
125 B2DEBUG(11,
"ParticleList: " << pList->getParticleListName() <<
" - N = " << pListSize <<
" particles.");
127 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pListSize : pListSize * m_nSelectedDaughters;
130 std::vector<int> pdgs;
131 if (m_nSelectedDaughters == 0) {
132 pdgs.push_back(pList->getPDGCode());
134 pdgs = decayDescriptor.getSelectionPDGCodes();
136 for (
auto pdg : pdgs) {
138 if (!(*m_weightfiles_representation.get())->isValidPdg(abs(pdg))) {
139 B2FATAL(
"PDG: " << pdg <<
" of ParticleList: " << pListName <<
140 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
143 std::vector<const Particle*> targetParticles;
144 if (m_nSelectedDaughters > 0) {
145 for (
unsigned int iPart(0); iPart < pListSize; ++iPart) {
146 auto* iParticle = pList->getParticle(iPart);
147 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
148 for (
auto* iDaughter : daughters) {
149 targetParticles.push_back(iDaughter);
154 for (
unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
156 const Particle* particle = (m_nSelectedDaughters == 0) ? pList->getParticle(ipart) : targetParticles[ipart];
158 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
160 const ECLCluster* eclCluster = particle->getECLCluster();
162 B2DEBUG(11,
"\nParticle [" << ipart <<
"] has invalid Track-ECLCluster relation, skip MVA application...");
169 auto thVarName = (*m_weightfiles_representation.get())->getThetaVarName();
171 auto p = particle->getP();
173 auto charge = (!m_charge_independent) ? particle->getCharge() : 0.0;
174 if (std::isnan(theta) or std::isnan(p) or std::isnan(charge)) {
175 B2DEBUG(11,
"\nParticle [" << ipart <<
"] has invalid input variable, skip MVA application..." <<
176 " polar angle: " << theta <<
", p: " << p <<
", charge: " << charge);
180 int idx_theta, idx_p, idx_charge;
181 auto index = (*m_weightfiles_representation.get())->getMVAWeightIdx(theta, p, charge, idx_theta, idx_p, idx_charge);
183 auto hasMatch = std::isnormal(Variable::eclClusterTrackMatched(particle));
185 debugStr[11] +=
"\n";
186 debugStr[11] += (
"Particle [" + std::to_string(ipart) +
"]\n");
187 debugStr[11] += (
"Has ECL cluster match? " + std::to_string(hasMatch) +
"\n");
188 debugStr[11] += (
"polar angle: " + thVarName +
" = " + std::to_string(theta) +
" [rad]\n");
189 debugStr[11] += (
"p = " + std::to_string(p) +
" [GeV/c]\n");
190 if (!m_charge_independent) {
191 debugStr[11] += (
"charge = " + std::to_string(charge) +
"\n");
193 debugStr[11] += (
"Is brems corrected ? " + std::to_string(particle->hasExtraInfo(
"bremsCorrected")) +
"\n");
194 debugStr[11] += (
"Weightfile idx = " + std::to_string(index) +
" - (polar angle, p, charge) = (" + std::to_string(
195 idx_theta) +
", " + std::to_string(idx_p) +
", " +
196 std::to_string(idx_charge) +
")\n");
197 if (m_cuts.at(index)) {
198 debugStr[11] += (
"Category cut: " + m_cuts.at(index)->decompile() +
"\n");
201 B2DEBUG(11, debugStr[11]);
202 debugStr[11].clear();
205 if (m_cuts.at(index)) {
206 if (!m_cuts.at(index)->check(particle)) {
207 B2DEBUG(11,
"\nParticle [" << ipart <<
"] didn't pass MVA category cut, skip MVA application...");
214 debugStr[11] +=
"\n";
215 debugStr[11] +=
"MVA variables:\n";
217 auto nvars = m_variables.at(index).size();
218 for (
unsigned int ivar(0); ivar < nvars; ++ivar) {
220 auto varobj = m_variables.at(index).at(ivar);
222 double var = std::numeric_limits<double>::quiet_NaN();
223 auto var_result = varobj->function(particle);
224 if (std::holds_alternative<double>(var_result)) {
225 var = std::get<double>(var_result);
226 }
else if (std::holds_alternative<int>(var_result)) {
227 var = std::get<int>(var_result);
228 }
else if (std::holds_alternative<bool>(var_result)) {
229 var = std::get<bool>(var_result);
231 B2ERROR(
"Variable '" << varobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
234 if (!(*m_weightfiles_representation.get())->hasImplicitNaNmasking()) {
236 var = (std::isnan(var)) ? -999.0 : var;
239 debugStr[11] += (
"\tvar[" + std::to_string(ivar) +
"] : " + varobj->name +
" = " + std::to_string(var) +
"\n");
241 m_datasets.at(index)->m_input[ivar] = var;
245 B2DEBUG(11, debugStr[11]);
246 debugStr[11].clear();
251 debugStr[12] +=
"\n";
252 debugStr[12] +=
"MVA spectators:\n";
254 auto nspecs = m_spectators.at(index).size();
255 for (
unsigned int ispec(0); ispec < nspecs; ++ispec) {
257 auto specobj = m_spectators.at(index).at(ispec);
259 double spec = std::numeric_limits<double>::quiet_NaN();
260 auto spec_result = specobj->function(particle);
261 if (std::holds_alternative<double>(spec_result)) {
262 spec = std::get<double>(spec_result);
263 }
else if (std::holds_alternative<int>(spec_result)) {
264 spec = std::get<int>(spec_result);
265 }
else if (std::holds_alternative<bool>(spec_result)) {
266 spec = std::get<bool>(spec_result);
268 B2ERROR(
"Variable '" << specobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
271 debugStr[12] += (
"\tspec[" + std::to_string(ispec) +
"] : " + specobj->name +
" = " + std::to_string(spec) +
"\n");
273 m_datasets.at(index)->m_spectators[ispec] = spec;
277 B2DEBUG(12, debugStr[12]);
278 debugStr[12].clear();
284 debugStr[11] +=
"\n";
285 debugStr[12] +=
"\n";
286 debugStr[11] +=
"MVA response:\n";
288 float score = m_experts.at(index)->apply(*m_datasets.at(index))[0];
290 debugStr[11] += (
"\tscore = " + std::to_string(score));
291 debugStr[12] += (
"\textraInfo: " + m_score_varname +
"\n");
294 m_particles[particle->getArrayIndex()]->writeExtraInfo(m_score_varname, score);
296 B2DEBUG(11, debugStr[11]);
297 B2DEBUG(12, debugStr[12]);
298 debugStr[11].clear();
299 debugStr[12].clear();
314 std::map<std::string, std::string> aliasesLegacy;
316 aliasesLegacy.insert(std::make_pair(
"__event__",
"evtNum"));
323 aliasesLegacy.insert(std::make_pair(
"missingLogL_" + detName,
"pidMissingProbabilityExpert(" + detName +
")"));
325 for (
auto& [pdgId, info] : m_stdChargedInfo) {
327 std::string alias =
"deltaLogL_" + std::get<0>(info) +
"_" + std::get<1>(info) +
"_" + detName;
328 std::string var =
"pidDeltaLogLikelihoodValueExpert(" + std::to_string(pdgId) +
", " + std::to_string(std::get<2>
329 (info)) +
"," + detName +
")";
331 aliasesLegacy.insert(std::make_pair(alias, var));
333 if (it.getIndex() == 0) {
334 alias =
"deltaLogL_" + std::get<0>(info) +
"_" + std::get<1>(info) +
"_ALL";
335 var =
"pidDeltaLogLikelihoodValueExpert(" + std::to_string(pdgId) +
", " + std::to_string(std::get<2>(info)) +
", ALL)";
336 aliasesLegacy.insert(std::make_pair(alias, var));
343 B2INFO(
"Setting hard-coded aliases for the ChargedPidMVA algorithm.");
345 std::string debugStr(
"\n");
346 for (
const auto& [alias, variable] : aliasesLegacy) {
347 debugStr += (alias +
" --> " + variable +
"\n");
349 B2ERROR(
"Something went wrong with setting alias: " << alias <<
" for variable: " << variable);
352 B2DEBUG(10, debugStr);
360 auto aliases = (*m_weightfiles_representation.get())->getAliases();
362 if (!aliases->empty()) {
364 B2INFO(
"Setting aliases for the ChargedPidMVA algorithm read from the payload.");
366 std::string debugStr(
"\n");
367 for (
const auto& [alias, variable] : *aliases) {
368 if (alias != variable) {
369 debugStr += (alias +
" --> " + variable +
"\n");
371 B2ERROR(
"Something went wrong with setting alias: " << alias <<
" for variable: " << variable);
375 B2DEBUG(10, debugStr);
382 this->registerAliasesLegacy();
390 B2INFO(
"Run: " << m_event_metadata->getRun() <<
". Load supported MVA interfaces for binary charged particle identification...");
393 this->registerAliases();
399 B2INFO(
"\tLoading weightfiles from the payload class for SIGNAL particle hypothesis: " << m_sig_pdg);
401 auto serialized_weightfiles = (*m_weightfiles_representation.get())->getMVAWeights(m_sig_pdg);
402 auto nfiles = serialized_weightfiles->size();
404 B2INFO(
"\tConstruct the MVA experts and datasets from N = " << nfiles <<
" weightfiles...");
408 m_experts.resize(nfiles);
409 m_datasets.resize(nfiles);
410 m_cuts.resize(nfiles);
411 m_variables.resize(nfiles);
412 m_spectators.resize(nfiles);
414 for (
unsigned int idx(0); idx < nfiles; idx++) {
416 B2DEBUG(12,
"\t\tweightfile[" << idx <<
"]");
419 std::stringstream ss(serialized_weightfiles->at(idx));
423 weightfile.getOptions(general_options);
427 m_variables[idx] = manager.getVariables(general_options.m_variables);
428 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
430 B2DEBUG(12,
"\t\tRetrieved N = " << general_options.m_variables.size()
431 <<
" variables, N = " << general_options.m_spectators.size()
435 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
436 m_experts.at(idx)->load(weightfile);
438 B2DEBUG(12,
"\t\tweightfile loaded successfully into expert[" << idx <<
"]!");
441 std::vector<float> v(general_options.m_variables.size(), 0.0);
442 std::vector<float> s(general_options.m_spectators.size(), 0.0);
443 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
445 B2DEBUG(12,
"\t\tdataset[" << idx <<
"] created successfully!");
448 const auto cuts = (*m_weightfiles_representation.get())->getCuts(m_sig_pdg);
449 const auto cutstr = (!cuts->empty()) ? cuts->at(idx) :
"";
452 B2DEBUG(12,
"\t\tcut[" << idx <<
"] created successfully!");
int m_bkg_pdg
The input background mass hypothesis' pdgId.
std::vector< std::string > m_decayStrings
The input list of DecayStrings, where each selected (^) daughter should correspond to a standard char...
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
virtual ~ChargedPidMVAModule()
Destructor, use this to clean up anything you created in the constructor.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual void beginRun() override
Called once before a new run begins.
void registerAliases()
Set variable aliases needed by the MVA.
int m_sig_pdg
The input signal mass hypothesis' pdgId.
void registerAliasesLegacy()
Set variable aliases needed by the MVA.
ChargedPidMVAModule()
Constructor, for setting module description and parameters.
std::string m_payload_name
The name of the database payload object with the MVA weights.
Iterator end() const
Ending iterator.
static DetectorSet set()
Accessor for the set of valid detector IDs.
EDetector
Enum for identifying the detector components (detector and subdetector).
static std::string parseDetectors(EDetector det)
Converts Const::EDetector object to string.
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
@ c_Debug
Debug: for code development.
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
General options which are shared by all MVA trainings.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void setDescription(const std::string &description)
Sets the description of the module.
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Class to store reconstructed particles.
Type-safe access to single objects in the data store.
Global list of available variables.
static Manager & Instance()
get singleton instance.
bool addAlias(const std::string &alias, const std::string &variable)
Add alias Return true if the alias was successfully added.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Abstract base class for different kinds of events.