10#include <tracking/trackFindingCDC/filters/base/MVAFilter.dcl.h>
12#include <tracking/trackFindingCDC/mva/MVAExpert.h>
14#include <tracking/trackFindingCDC/filters/base/FilterOnVarSet.icc.h>
16#include <framework/core/ModuleParamList.templateDetails.h>
17#include <tracking/trackFindingCDC/utilities/StringManipulation.h>
19#include <tracking/trackFindingCDC/utilities/Named.h>
20#include <mva/interface/Dataset.h>
21#include <RtypesCore.h>
34 namespace TrackFindingCDC {
36 template <
class AFilter>
38 const std::string& identifier,
40 const std::string& dbObjectName)
45 template <
class AFilter>
48 template <
class AFilter>
54 "The cut value of the mva output below which the object is rejected",
57 moduleParamList->
addParameter(prefixed(prefix,
"identifier"),
59 "Database identifier of the expert of weight file name",
62 moduleParamList->
addParameter(prefixed(prefix,
"DBPayloadName"),
64 "Name of the DB payload containing weightfile name and the cut value. If a DB payload with both values is available and valid, it will override the values provided by parameters.",
68 template <
class AFilter>
74 if (mvaParameterPayload.isValid()) {
76 m_cutValue = mvaParameterPayload->getCutValue();
80 B2FATAL(
"MVAFilter: No valid MVAFilter payload with name " +
m_DBPayloadName +
" was found.");
83 std::vector<Named<Float_t*>> namedVariables = Super::getVarSet().getNamedVariables();
85 m_mvaExpert = std::make_unique<MVAExpert>(m_identifier, std::move(namedVariables));
86 m_mvaExpert->initialize();
89 template <
class AFilter>
95 const auto& selectedVars =
m_mvaExpert->getVariableNames();
96 const std::vector<Named<Float_t*>>& namedVariables =
Super::getVarSet().getNamedVariables();
98 for (
const auto& name : selectedVars) {
100 auto itNamedVariable = std::find_if(namedVariables.begin(),
101 namedVariables.end(),
103 return namedVariable.getName() == name;
105 if (itNamedVariable == namedVariables.end()) {
106 B2ERROR(
"Variable name " << name <<
" mismatch for MVA filter. " <<
107 "Could not find expected variable '" << name <<
"'");
113 template <
class AFilter>
116 double prediction =
predict(obj);
117 return prediction <
m_cutValue ? NAN : prediction;
120 template <
class AFilter>
124 if (std::isnan(extracted)) {
131 template <
class AFilter>
135 const int nRows = objs.size();
136 auto allFeatures = std::unique_ptr<float[]>(
new float[nRows * nFeature]);
138 for (
const auto& obj : objs) {
140 for (
int iFeature = 0; iFeature < nFeature; iFeature += 1) {
146 return m_mvaExpert->predict(allFeatures.get(), nFeature, nRows);
149 template <
class AFilter>
153 for (
auto& res : out) {
159 template <
class AVarSet>
162 const std::string& defaultDBObjectName)
163 :
Super(
std::make_unique<
AVarSet>(), defaultTrainingName, defaultCut, defaultDBObjectName)
167 template <
class AVarSet>
Class for accessing objects in the database.
The Module parameter list class.
MVA< Filter< typename AVarSet::Object > > Super
Type of the super class.
~MVAFilter()
Default destructor.
MVAFilter(const std::string &defaultTrainingName="", double defaultCut=NAN, const std::string &defaultDBObjectName="")
Constructor of the filter.
virtual ~MVA()
Default destructor.
void initialize() override
Initialize the expert before event processing.
Weight operator()(const Object &obj) override
Function to object for its signalness.
OnVarSet< AFilter > Super
Type of the super class.
std::unique_ptr< MVAExpert > m_mvaExpert
MVA Expert to examine the object.
std::string m_DBPayloadName
Name of the DB payload.
void beginRun() override
Signal to load new run parameters.
MVA(std::unique_ptr< AVarSet > varSet, const std::string &identifier="", double defaultCut=NAN, const std::string &dbObjectName="")
Constructor of the filter.
typename AFilter::Object Object
Type of the object to be analysed.
BaseVarSet< Object > AVarSet
virtual double predict(const Object &obj)
Evaluate the mva method.
void exposeParameters(ModuleParamList *moduleParamList, const std::string &prefix) override
Expose the set of parameters of the filter to the module parameter list.
double m_cutValue
The cut on the MVA output.
std::vector< Named< Float_t * > > m_namedVariables
named variables, ordered as in the weightFile:
std::string m_identifier
Database identifier of the expert or weight file name.
A mixin class to attach a name to an object.
void initialize() override
No reassignment of variable set possible for now.
Weight operator()(const Object &obj) override
Function extracting the variables of the object into the variable set.
AVarSet & getVarSet() const
Getter for the set of variables.
virtual void exposeParameters(ModuleParamList *moduleParamList, const std::string &prefix) override
Forward prefixed parameters of this findlet to the module parameter list.
void addParameter(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module list.
Abstract base class for different kinds of events.