9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
12 #include <mva/interface/Interface.h>
13 #include <mva/methods/TMVA.h>
14 #include <analysis/VariableManager/Utility.h>
15 #include <analysis/dataobjects/Particle.h>
18 #include <mdst/dataobjects/ECLCluster.h>
26 setDescription(
"This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
32 "The input list of decay strings, where the mother particle string should correspond to a full name of a particle list. One can select to run on daughters instead of mother particle, e.g. ['Lambda0 -> ^p+ ^pi-'].",
33 std::vector<std::string>());
36 "The name of the database payload object with the MVA weights.",
37 std::string(
"ChargedPidMVAWeights"));
40 "Specify whether to use a charge-independent training of the MVA.",
44 "Specify whether to use an ECL-only training of the MVA.",
78 decayDescriptor.init(decayString);
79 auto pl_name = decayDescriptor.getMother()->getFullName();
81 unsigned short m_nSelectedDaughters = decayDescriptor.getSelectionNames().size();
84 if (!pList) { B2FATAL(
"ParticleList: " << pl_name <<
" could not be found. Aborting..."); }
85 const auto nTargetParticles = (m_nSelectedDaughters == 0) ? pList->getListSize() : pList->getListSize() *
88 std::vector<int> pdgs;
89 if (m_nSelectedDaughters == 0)
90 pdgs.push_back(pList->getPDGCode());
92 pdgs = decayDescriptor.getSelectionPDGCodes();
93 for (
auto pdg : pdgs) {
96 B2FATAL(
"PDG: " << pdg <<
" of ParticleList: " << pl_name <<
97 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
100 std::vector<const Particle*> targetParticles;
101 if (m_nSelectedDaughters > 0) {
102 for (
unsigned int iPart(0); iPart < pList->getListSize(); ++iPart) {
103 auto* iParticle = pList->getParticle(iPart);
104 auto daughters = decayDescriptor.getSelectionParticles(iParticle);
105 for (
auto* iDaughter : daughters) {
106 targetParticles.push_back(iDaughter);
110 B2DEBUG(11,
"ParticleList: " << pList->getParticleListName() <<
" - N = " << pList->getListSize() <<
" particles.");
112 for (
unsigned int ipart(0); ipart < nTargetParticles; ++ipart) {
114 const Particle* particle = (m_nSelectedDaughters > 0) ? targetParticles[ipart] : pList->getParticle(ipart);
116 B2DEBUG(11,
"\tParticle [" << ipart <<
"]");
122 B2DEBUG(11,
"\t\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
128 auto clusterTheta = eclCluster->
getTheta();
129 auto p = particle->
getP();
132 int idx_theta, idx_p, idx_charge;
137 const auto cutstr = (!cuts->empty()) ? cuts->at(index) :
"";
139 B2DEBUG(11,
"\t\tclusterTheta = " << clusterTheta <<
" [rad]");
140 B2DEBUG(11,
"\t\tp = " << p <<
" [GeV/c]");
142 B2DEBUG(11,
"\t\tcharge = " << charge);
144 B2DEBUG(11,
"\t\tBrems corrected = " << particle->
hasExtraInfo(
"bremsCorrectedPhotonEnergy"));
145 B2DEBUG(11,
"\t\tWeightfile idx = " << index <<
" - (clusterTheta, p, charge) = (" << idx_theta <<
", " << idx_p <<
", " <<
147 if (!cutstr.empty()) {
148 B2DEBUG(11,
"\t\tCategory cut = " << cutstr);
153 B2DEBUG(11,
"\tMVA variables:");
156 for (
unsigned int ivar(0); ivar < nvars; ++ivar) {
161 auto var_result = varobj->function(particle);
162 if (std::holds_alternative<double>(var_result)) {
163 var = std::get<double>(var_result);
164 }
else if (std::holds_alternative<int>(var_result)) {
165 var = std::get<int>(var_result);
166 }
else if (std::holds_alternative<bool>(var_result)) {
167 var = std::get<bool>(var_result);
169 B2ERROR(
"Variable '" << varobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
173 var = (std::isnan(var)) ? -999.0 : var;
175 B2DEBUG(11,
"\t\tvar[" << ivar <<
"] : " << varobj->name <<
" = " << var);
181 B2DEBUG(12,
"\tMVA spectators:");
184 for (
unsigned int ispec(0); ispec < nspecs; ++ispec) {
188 double spec = std::numeric_limits<double>::quiet_NaN();
189 auto spec_result = specobj->function(particle);
190 if (std::holds_alternative<double>(spec_result)) {
191 spec = std::get<double>(spec_result);
192 }
else if (std::holds_alternative<int>(spec_result)) {
193 spec = std::get<int>(spec_result);
194 }
else if (std::holds_alternative<bool>(spec_result)) {
195 spec = std::get<bool>(spec_result);
197 B2ERROR(
"Variable '" << specobj->name <<
"' has wrong data type! It must be one of double, integer, or bool.");
200 B2DEBUG(12,
"\t\tspec[" << ispec <<
"] : " << specobj->name <<
" = " << spec);
202 m_datasets.at(index)->m_spectators[ispec] = spec;
207 if (!cutstr.empty()) {
211 if (!cut->check(particle)) {
212 B2DEBUG(11,
"\t\tParticle didn't pass MVA category cut, skip MVA application...");
220 B2DEBUG(11,
"\tMVA response:");
222 std::string score_varname(
"");
224 std::vector<float> scores =
m_experts.at(index)->applyMulticlass(*
m_datasets.at(index))[0];
226 for (
unsigned int classID(0); classID <
m_classes.size(); ++classID) {
228 const std::string className(
m_classes.at(classID));
230 score_varname =
"pidChargedBDTScore_" + className;
233 score_varname +=
"_" + std::to_string(Const::ECL);
240 B2DEBUG(11,
"\t\tclass[" << classID <<
"] = " << className <<
" - score = " << scores[classID]);
241 B2DEBUG(12,
"\t\tExtraInfo: " << score_varname);
258 ". Load supported MVA interfaces for multi-class charged particle identification...");
264 B2INFO(
"\tLoading weightfiles from the payload class.");
267 auto nfiles = serialized_weightfiles->size();
269 B2INFO(
"\tConstruct the MVA experts and datasets from N = " << nfiles <<
" weightfiles...");
278 for (
unsigned int idx(0); idx < nfiles; idx++) {
280 B2DEBUG(12,
"\t\tweightfile[" << idx <<
"]");
283 std::stringstream ss(serialized_weightfiles->at(idx));
287 weightfile.getOptions(general_options);
291 m_variables[idx] = manager.getVariables(general_options.m_variables);
292 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
294 B2DEBUG(12,
"\t\tRetrieved N = " << general_options.m_variables.size()
295 <<
" variables, N = " << general_options.m_spectators.size()
299 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
302 B2DEBUG(12,
"\t\tweightfile loaded successfully into expert[" << idx <<
"]!");
305 std::vector<float> v(general_options.m_variables.size(), 0.0);
306 std::vector<float> s(general_options.m_spectators.size(), 0.0);
307 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
309 B2DEBUG(12,
"\t\tdataset[" << idx <<
"] created successfully!");
317 weightfile.getOptions(specific_options);
319 if (specific_options.m_classes.empty()) {
320 B2FATAL(
"MVA::SpecificOptions of weightfile[" << idx <<
321 "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
325 for (
const auto& cls : specific_options.m_classes) {
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
std::vector< std::string > m_decayStrings
The input list of decay strings to which MVA weights will be applied.
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
StoreArray< Particle > m_particles
StoreArray of Particles.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
ChargedPidMVAMulticlassModule()
Constructor, for setting module description and parameters.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
VariablesLists m_spectators
List of lists of spectator variables.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
size_t size() const
Getter for number of detector IDs in this set.
static DetectorSet set()
Accessor function for the set of valid detectors.
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
double getTheta() const
Return Corrected Theta of Shower (radian).
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
General options which are shared by all MVA trainings.
Options for the TMVA Multiclass MVA method.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void setDescription(const std::string &description)
Sets the description of the module.
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Class to store reconstructed particles.
const ECLCluster * getECLCluster() const
Returns the pointer to the ECLCluster object that was used to create this Particle (if ParticleType =...
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
double getCharge(void) const
Returns particle charge.
double getP() const
Returns momentum magnitude (same as getMomentumMagnitude but with shorter name)
int getArrayIndex() const
Returns this object's array index (in StoreArray), or -1 if not found.
Type-safe access to single objects in the data store.
Global list of available variables.
static Manager & Instance()
get singleton instance.
REG_MODULE(B2BIIConvertBeamParams)
Register the module.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Abstract base class for different kinds of events.