10#include <tracking/trackFindingCDC/filters/base/MVAFilter.dcl.h>
12#include <tracking/trackFindingCDC/mva/MVAExpert.h>
14#include <tracking/trackFindingCDC/filters/base/FilterOnVarSet.icc.h>
16#include <framework/core/ModuleParamList.templateDetails.h>
17#include <tracking/trackFindingCDC/utilities/StringManipulation.h>
19#include <tracking/trackFindingCDC/utilities/Named.h>
20#include <mva/interface/Dataset.h>
21#include <RtypesCore.h>
34 namespace TrackFindingCDC {
36 template <
class AFilter>
38 const std::string& identifier,
40 const std::string& dbObjectName)
41 :
Super(
std::move(varSet)), m_identifier(identifier), m_cutValue(defaultCut), m_DBPayloadName(dbObjectName)
45 template <
class AFilter>
48 template <
class AFilter>
51 Super::exposeParameters(moduleParamList, prefix);
54 "The cut value of the mva output below which the object is rejected",
57 moduleParamList->
addParameter(prefixed(prefix,
"identifier"),
59 "Database identifier of the expert of weight file name",
62 moduleParamList->
addParameter(prefixed(prefix,
"DBPayloadName"),
64 "Name of the DB payload containing weightfile name and the cut value. If a DB payload with both values is available and valid, it will override the values provided by parameters.",
68 template <
class AFilter>
74 if (mvaParameterPayload.isValid()) {
75 m_identifier = mvaParameterPayload->getIdentifierName();
76 m_cutValue = mvaParameterPayload->getCutValue();
77 B2DEBUG(20,
"MVAFilter: Using DBObject " << m_DBPayloadName <<
" with weightfile " << m_identifier <<
" and cut value " <<
80 B2FATAL(
"MVAFilter: No valid MVAFilter payload with name " + m_DBPayloadName +
" was found.");
83 std::vector<Named<Float_t*>> namedVariables = Super::getVarSet().getNamedVariables();
85 m_mvaExpert = std::make_unique<MVAExpert>(m_identifier, std::move(namedVariables));
86 m_mvaExpert->initialize();
89 template <
class AFilter>
93 m_mvaExpert->beginRun();
95 const auto& selectedVars = m_mvaExpert->getVariableNames();
96 const std::vector<Named<Float_t*>>& namedVariables = Super::getVarSet().getNamedVariables();
97 m_namedVariables.clear();
98 for (
const auto& name : selectedVars) {
100 auto itNamedVariable = std::find_if(namedVariables.begin(),
101 namedVariables.end(),
103 return namedVariable.getName() == name;
105 if (itNamedVariable == namedVariables.end()) {
106 B2ERROR(
"Variable name " << name <<
" mismatch for MVA filter. " <<
107 "Could not find expected variable '" << name <<
"'");
109 m_namedVariables.push_back(*itNamedVariable);
113 template <
class AFilter>
116 double prediction = predict(obj);
117 return prediction < m_cutValue ? NAN : prediction;
120 template <
class AFilter>
123 Weight extracted = Super::operator()(obj);
124 if (std::isnan(extracted)) {
127 return m_mvaExpert->predict();
131 template <
class AFilter>
134 const int nFeature = m_namedVariables.size();
135 const int nRows = objs.size();
136 auto allFeatures = std::unique_ptr<float[]>(
new float[nRows * nFeature]);
138 for (
const auto& obj : objs) {
139 if (Super::getVarSet().extract(obj)) {
140 for (
int iFeature = 0; iFeature < nFeature; iFeature += 1) {
141 allFeatures[nFeature * iRow + iFeature] = *m_namedVariables[iFeature];
146 return m_mvaExpert->predict(allFeatures.get(), nFeature, nRows);
149 template <
class AFilter>
152 auto out = predict(objs);
153 for (
auto& res : out) {
154 res = res < m_cutValue ? NAN : res;
159 template <
class AVarSet>
162 const std::string& defaultDBObjectName)
163 :
Super(
std::make_unique<
AVarSet>(), defaultTrainingName, defaultCut, defaultDBObjectName)
167 template <
class AVarSet>
Class for accessing objects in the database.
The Module parameter list class.
~MVAFilter()
Default destructor.
MVAFilter(const std::string &defaultTrainingName="", double defaultCut=NAN, const std::string &defaultDBObjectName="")
Constructor of the filter.
Filter based on a mva method.
virtual ~MVA()
Default destructor.
void initialize() override
Initialize the expert before event processing.
Weight operator()(const Object &obj) override
Function to object for its signalness.
void beginRun() override
Signal to load new run parameters.
MVA(std::unique_ptr< AVarSet > varSet, const std::string &identifier="", double defaultCut=NAN, const std::string &dbObjectName="")
Constructor of the filter.
typename AFilter::Object Object
Type of the object to be analysed.
virtual double predict(const Object &obj)
Evaluate the mva method.
void exposeParameters(ModuleParamList *moduleParamList, const std::string &prefix) override
Expose the set of parameters of the filter to the module parameter list.
A mixin class to attach a name to an object.
Filter adapter to make a filter work on a set of variables.
void addParameter(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module list.
Abstract base class for different kinds of events.