9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAModule.h>
12 #include <mva/interface/Interface.h>
13 #include <analysis/VariableManager/Utility.h>
14 #include <analysis/dataobjects/Particle.h>
15 #include <analysis/dataobjects/ParticleList.h>
18 #include <mdst/dataobjects/ECLCluster.h>
26 setDescription(
"This module evaluates the response of an MVA trained for binary charged particle identification between two hypotheses, S and B. For a given input set of (S,B) mass hypotheses, it takes the Particle objects in the appropriate charged stable particle's ParticleLists, calculates the MVA score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
28 setPropertyFlags(c_ParallelProcessingCertified);
30 addParam(
"sigHypoPDGCode",
32 "The input signal mass hypothesis' pdgId.",
34 addParam(
"bkgHypoPDGCode",
36 "The input background mass hypothesis' pdgId.",
38 addParam(
"particleLists",
40 "The input list of decay strings, where the mother particle string should correspond to a full name of a particle list. One can select to run on daughters instead of mother particle, e.g. ['Lambda0 -> ^p+ ^pi-'].",
41 std::vector<std::string>());
42 addParam(
"payloadName",
44 "The name of the database payload object with the MVA weights.",
45 std::string(
"ChargedPidMVAWeights"));
46 addParam(
"chargeIndependent",
48 "Specify whether to use a charge-independent training of the MVA.",
50 addParam(
"useECLOnlyTraining",
52 "Specify whether to use an ECL-only training of the MVA.",
57 ChargedPidMVAModule::~ChargedPidMVAModule() =
default;
60 void ChargedPidMVAModule::initialize()
78 " of the signal mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
82 " of the background mass hypothesis is not that of a valid particle in Const::chargedStableSet! Aborting...");
103 decayDescriptor.init(decayString);
104 auto pl_name = decayDescriptor.getMother()->getFullName();
105 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
107 if (!pList) { B2FATAL(
"ParticleList: " << pl_name <<
" could not be found. Aborting..."); }
110 std::vector<int> pdgs;
111 if (m_nSelectedDaughters == 0)
112 pdgs.push_back(pList->getPDGCode());
114 pdgs = decayDescriptor.getSelectionPDGCodes();
115 for (
auto pdg : pdgs) {
118 B2FATAL(
"PDG: " << pdg <<
" of ParticleList: " << pl_name <<
119 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
123 B2DEBUG(11,
"ParticleList: " << pList->getParticleListName() <<
" - N = " << pList->getListSize() <<
" particles.");
124 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pList->getListSize() : pList->getListSize() *
125 m_nSelectedDaughters;
126 std::vector<const Particle*> targetParticles;
127 if (m_nSelectedDaughters > 0) {
128 for (
unsigned int iPart(0); iPart < pList->getListSize(); ++iPart) {
129 auto* iParticle = pList->getParticle(iPart);
130 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
131 for (
auto* iDaughter : daughters) {
132 targetParticles.push_back(iDaughter);
136 for (
unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
138 const Particle* particle = (m_nSelectedDaughters == 0) ? pList->getParticle(ipart) : targetParticles[ipart];
139 B2DEBUG(11,
"\tParticle [" << ipart <<
"]");
145 B2DEBUG(11,
"\t\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
151 auto clusterTheta = eclCluster->
getTheta();
152 auto p = particle->
getP();
155 int idx_theta, idx_p, idx_charge;
160 const auto cutstr = (!cuts->empty()) ? cuts->at(index) :
"";
162 B2DEBUG(11,
"\t\tclusterTheta = " << clusterTheta <<
" [rad]");
163 B2DEBUG(11,
"\t\tp = " << p <<
" [GeV/c]");
165 B2DEBUG(11,
"\t\tcharge = " << charge);
167 B2DEBUG(11,
"\t\tBrems corrected = " << particle->
hasExtraInfo(
"bremsCorrectedPhotonEnergy"));
168 B2DEBUG(11,
"\t\tWeightfile idx = " << index <<
" - (clusterTheta, p, charge) = (" << idx_theta <<
", " << idx_p <<
", " <<
170 if (!cutstr.empty()) {
171 B2DEBUG(11,
"\tCategory cut: " << cutstr);
176 B2DEBUG(11,
"\tMVA variables:");
179 for (
unsigned int ivar(0); ivar < nvars; ++ivar) {
184 auto var_result = varobj->function(particle);
185 if (std::holds_alternative<double>(var_result)) {
186 var = std::get<double>(var_result);
187 }
else if (std::holds_alternative<int>(var_result)) {
188 var = std::get<int>(var_result);
189 }
else if (std::holds_alternative<bool>(var_result)) {
190 var = std::get<bool>(var_result);
192 B2ERROR(
"Variable '" << varobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
196 var = (std::isnan(var)) ? -999.0 : var;
198 B2DEBUG(11,
"\t\tvar[" << ivar <<
"] : " << varobj->name <<
" = " << var);
204 B2DEBUG(12,
"\tMVA spectators:");
207 for (
unsigned int ispec(0); ispec < nspecs; ++ispec) {
211 double spec = std::numeric_limits<double>::quiet_NaN();
212 auto spec_result = specobj->function(particle);
213 if (std::holds_alternative<double>(spec_result)) {
214 spec = std::get<double>(spec_result);
215 }
else if (std::holds_alternative<int>(spec_result)) {
216 spec = std::get<int>(spec_result);
217 }
else if (std::holds_alternative<bool>(spec_result)) {
218 spec = std::get<bool>(spec_result);
220 B2ERROR(
"Variable '" << specobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
223 B2DEBUG(12,
"\t\tspec[" << ispec <<
"] : " << specobj->name <<
" = " << spec);
225 m_datasets.at(index)->m_spectators[ispec] = spec;
230 if (!cutstr.empty()) {
234 if (!cut->check(particle)) {
235 B2DEBUG(11,
"\t\tParticle didn't pass MVA category cut, skip MVA application...");
243 B2DEBUG(11,
"\tMVA score = " << score);
258 B2INFO(
"Run: " <<
m_event_metadata->getRun() <<
". Load supported MVA interfaces for binary charged particle identification...");
264 B2INFO(
"\tLoading weightfiles from the payload class for SIGNAL particle hypothesis: " <<
m_sig_pdg);
267 auto nfiles = serialized_weightfiles->size();
269 B2INFO(
"\tConstruct the MVA experts and datasets from N = " << nfiles <<
" weightfiles...");
278 for (
unsigned int idx(0); idx < nfiles; idx++) {
280 B2DEBUG(12,
"\t\tweightfile[" << idx <<
"]");
283 std::stringstream ss(serialized_weightfiles->at(idx));
287 weightfile.getOptions(general_options);
291 m_variables[idx] = manager.getVariables(general_options.m_variables);
292 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
294 B2DEBUG(12,
"\t\tRetrieved N = " << general_options.m_variables.size()
295 <<
" variables, N = " << general_options.m_spectators.size()
299 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
302 B2DEBUG(12,
"\t\tweightfile loaded successfully into expert[" << idx <<
"]!");
305 std::vector<float> v(general_options.m_variables.size(), 0.0);
306 std::vector<float> s(general_options.m_spectators.size(), 0.0);
307 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
309 B2DEBUG(12,
"\t\tdataset[" << idx <<
"] created successfully!");
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
int m_bkg_pdg
The input background mass hypothesis' pdgId.
std::vector< std::string > m_decayStrings
The input list of decay strings to which MVA weights will be applied.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
std::string m_score_varname
The lookup name of the MVA score variable, given the input S, B mass hypotheses for the algorithm.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual void event() override
Called once for each event.
int m_sig_pdg
The input signal mass hypothesis' pdgId.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
ChargedPidMVAModule()
Constructor, for setting module description and parameters.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
size_t size() const
Getter for number of detector IDs in this set.
static DetectorSet set()
Accessor function for the set of valid detectors.
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
double getTheta() const
Return Corrected Theta of Shower (radian).
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
General options which are shared by all MVA trainings.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Class to store reconstructed particles.
const ECLCluster * getECLCluster() const
Returns the pointer to the ECLCluster object that was used to create this Particle (if ParticleType =...
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
double getCharge(void) const
Returns particle charge.
double getP() const
Returns momentum magnitude (same as getMomentumMagnitude but with shorter name)
int getArrayIndex() const
Returns this object's array index (in StoreArray), or -1 if not found.
Type-safe access to single objects in the data store.
Global list of available variables.
static Manager & Instance()
get singleton instance.
REG_MODULE(B2BIIConvertBeamParams)
Register the module.
double charge(int pdgCode)
Returns electric charge of a particle with given pdg code.
Abstract base class for different kinds of events.