9 #include <analysis/modules/ChargedParticleIdentificator/ChargedPidMVAMulticlassModule.h>
12 #include <mva/interface/Interface.h>
13 #include <mva/methods/TMVA.h>
14 #include <analysis/VariableManager/Utility.h>
15 #include <analysis/dataobjects/Particle.h>
18 #include <mdst/dataobjects/ECLCluster.h>
26 setDescription(
"This module evaluates the response of a multi-class MVA trained for global charged particle identification.. It takes the Particle objects in the input charged stable particles' ParticleLists, calculates the MVA per-class score using the appropriate xml weight file, and adds it as ExtraInfo to the Particle objects.");
28 setPropertyFlags(c_ParallelProcessingCertified);
30 addParam(
"particleLists",
32 "The input list of ParticleList names.",
33 std::vector<std::string>());
34 addParam(
"payloadName",
36 "The name of the database payload object with the MVA weights.",
37 std::string(
"ChargedPidMVAWeights"));
38 addParam(
"chargeIndependent",
40 "Specify whether to use a charge-independent training of the MVA.",
42 addParam(
"useECLOnlyTraining",
44 "Specify whether to use an ECL-only training of the MVA.",
80 if (!pList) { B2FATAL(
"ParticleList: " << name <<
" could not be found. Aborting..."); }
83 int pdg = abs(pList->getPDGCode());
87 B2FATAL(
"PDG: " << pList->getPDGCode() <<
" of ParticleList: " << pList->getParticleListName() <<
88 " is not that of a valid particle in Const::chargedStableSet! Aborting...");
91 B2DEBUG(11,
"ParticleList: " << pList->getParticleListName() <<
" - N = " << pList->getListSize() <<
" particles.");
93 for (
unsigned int ipart(0); ipart < pList->getListSize(); ++ipart) {
95 Particle* particle = pList->getParticle(ipart);
97 B2DEBUG(11,
"\tParticle [" << ipart <<
"]");
101 const ECLCluster* eclCluster = particle->getECLCluster();
103 B2DEBUG(11,
"\t\tParticle has invalid Track-ECLCluster relation, skip MVA application...");
109 auto clusterTheta = eclCluster->
getTheta();
110 auto p = particle->getP();
113 int idx_theta, idx_p, idx_charge;
118 const auto cutstr = (!cuts->empty()) ? cuts->at(index) :
"";
120 B2DEBUG(11,
"\t\tclusterTheta = " << clusterTheta <<
" [rad]");
121 B2DEBUG(11,
"\t\tp = " << p <<
" [GeV/c]");
123 B2DEBUG(11,
"\t\tcharge = " << charge);
125 B2DEBUG(11,
"\t\tBrems corrected = " << particle->hasExtraInfo(
"bremsCorrectedPhotonEnergy"));
126 B2DEBUG(11,
"\t\tWeightfile idx = " << index <<
" - (clusterTheta, p, charge) = (" << idx_theta <<
", " << idx_p <<
", " <<
128 if (!cutstr.empty()) {
129 B2DEBUG(11,
"\t\tCategory cut = " << cutstr);
134 B2DEBUG(11,
"\tMVA variables:");
137 for (
unsigned int ivar(0); ivar < nvars; ++ivar) {
141 auto var = varobj->function(particle);
144 var = (std::isnan(var)) ? -999.0 : var;
146 B2DEBUG(11,
"\t\tvar[" << ivar <<
"] : " << varobj->name <<
" = " << var);
152 B2DEBUG(12,
"\tMVA spectators:");
155 for (
unsigned int ispec(0); ispec < nspecs; ++ispec) {
159 auto spec = specobj->function(particle);
161 B2DEBUG(12,
"\t\tspec[" << ispec <<
"] : " << specobj->name <<
" = " << spec);
163 m_datasets.at(index)->m_spectators[ispec] = spec;
168 if (!cutstr.empty()) {
172 if (!cut->check(particle)) {
173 B2WARNING(
"\tParticle didn't pass MVA category cut, skip MVA application...");
181 B2DEBUG(11,
"\tMVA response:");
183 std::string score_varname(
"");
185 std::vector<float> scores =
m_experts.at(index)->applyMulticlass(*
m_datasets.at(index))[0];
187 for (
unsigned int classID(0); classID <
m_classes.size(); ++classID) {
189 const std::string className(
m_classes.at(classID));
191 score_varname =
"pidChargedBDTScore_" + className;
194 score_varname +=
"_" + std::to_string(Const::ECL);
201 B2DEBUG(11,
"\t\tclass[" << classID <<
"] = " << className <<
" - score = " << scores[classID]);
202 B2DEBUG(12,
"\t\tExtraInfo: " << score_varname);
205 particle->writeExtraInfo(score_varname, scores[classID]);
219 ". Load supported MVA interfaces for multi-class charged particle identification...");
225 B2INFO(
"\tLoading weightfiles from the payload class.");
228 auto nfiles = serialized_weightfiles->size();
230 B2INFO(
"\tConstruct the MVA experts and datasets from N = " << nfiles <<
" weightfiles...");
239 for (
unsigned int idx(0); idx < nfiles; idx++) {
241 B2DEBUG(12,
"\t\tweightfile[" << idx <<
"]");
244 std::stringstream ss(serialized_weightfiles->at(idx));
248 weightfile.getOptions(general_options);
252 m_variables[idx] = manager.getVariables(general_options.m_variables);
253 m_spectators[idx] = manager.getVariables(general_options.m_spectators);
255 B2DEBUG(12,
"\t\tRetrieved N = " << general_options.m_variables.size()
256 <<
" variables, N = " << general_options.m_spectators.size()
260 m_experts[idx] = supported_interfaces[general_options.m_method]->getExpert();
263 B2DEBUG(12,
"\t\tweightfile loaded successfully into expert[" << idx <<
"]!");
266 std::vector<float> v(general_options.m_variables.size(), 0.0);
267 std::vector<float> s(general_options.m_spectators.size(), 0.0);
268 m_datasets[idx] = std::make_unique<MVA::SingleDataset>(general_options, v, 1.0, s);
270 B2DEBUG(12,
"\t\tdataset[" << idx <<
"] created successfully!");
278 weightfile.getOptions(specific_options);
280 if (specific_options.m_classes.empty()) {
281 B2FATAL(
"MVA::SpecificOptions of weightfile[" << idx <<
282 "] has no registered MVA classes! This shouldn't happen in multi-class mode. Aborting...");
286 for (
const auto& cls : specific_options.m_classes) {
This module evaluates the response of a multi-class MVA trained for global charged particle identific...
StoreObjPtr< EventMetaData > m_event_metadata
The event information.
virtual void initialize() override
Use this to initialize resources or memory your module needs.
bool m_ecl_only
Flag to specify if we use an ECL-only based training.
virtual void event() override
Called once for each event.
std::unique_ptr< DBObjPtr< ChargedPidMVAWeights > > m_weightfiles_representation
Interface to get the database payload with the MVA weight files.
DatasetsList m_datasets
List of MVA::SingleDataset objects.
std::vector< std::string > m_classes
List of MVA class names.
bool m_charge_independent
Flag to specify if we use a charge-independent training.
virtual ~ChargedPidMVAMulticlassModule()
Destructor, use this to clean up anything you created in the constructor.
virtual void beginRun() override
Called once before a new run begins.
VariablesLists m_variables
List of lists of feature variables.
VariablesLists m_spectators
List of lists of spectator variables.
std::vector< std::string > m_particle_lists
The input list of ParticleList names.
ExpertsList m_experts
List of MVA::Expert objects.
std::string m_payload_name
The name of the database payload object with the MVA weights.
size_t size() const
Getter for number of detector IDs in this set.
static DetectorSet set()
Accessor function for the set of valid detectors.
double getTheta() const
Return Corrected Theta of Shower (radian).
static std::unique_ptr< GeneralCut > compile(const std::string &cut)
Creates an instance of a cut and returns a unique_ptr to it, if you need a copy-able object instead y...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
General options which are shared by all MVA trainings.
Options for the TMVA Multiclass MVA method.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Class to store reconstructed particles.
Type-safe access to single objects in the data store.
Global list of available variables.
static Manager & Instance()
get singleton instance.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Abstract base class for different kinds of events.