10 #include <tracking/trackFindingCDC/filters/base/MVAFilter.dcl.h>
12 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
14 #include <tracking/trackFindingCDC/filters/base/FilterOnVarSet.icc.h>
16 #include <framework/core/ModuleParamList.templateDetails.h>
17 #include <tracking/trackFindingCDC/utilities/StringManipulation.h>
19 #include <tracking/trackFindingCDC/utilities/Named.h>
21 #include <RtypesCore.h>
34 namespace TrackFindingCDC {
36 template <
class AFilter>
38 const std::string& identifier,
40 const std::string& dbObjectName)
41 :
Super(std::move(varSet)), m_identifier(identifier), m_cutValue(defaultCut), m_DBPayloadName(dbObjectName)
45 template <
class AFilter>
48 template <
class AFilter>
51 Super::exposeParameters(moduleParamList, prefix);
54 "The cut value of the mva output below which the object is rejected",
57 moduleParamList->
addParameter(prefixed(prefix,
"identifier"),
59 "Database identfier of the expert of weight file name",
62 moduleParamList->
addParameter(prefixed(prefix,
"DBPayloadName"),
64 "Name of the DB payload containing weightfile name and the cut value. If a DB payload with both values is available and valid, it will override the values provided by parameters.",
68 template <
class AFilter>
74 if (mvaParameterPayload.isValid()) {
75 m_identifier = mvaParameterPayload->getIdentifierName();
76 m_cutValue = mvaParameterPayload->getCutValue();
77 B2DEBUG(20,
"MVAFilter: Using DBObject " << m_DBPayloadName <<
" with weightfile " << m_identifier <<
" and cut value " <<
80 B2FATAL(
"MVAFilter: No valid MVAFilter payload with name " + m_DBPayloadName +
" was found.");
83 std::vector<Named<Float_t*>> namedVariables = Super::getVarSet().getNamedVariables();
85 m_mvaExpert = std::make_unique<MVAExpert>(m_identifier, std::move(namedVariables));
86 m_mvaExpert->initialize();
89 template <
class AFilter>
93 m_mvaExpert->beginRun();
96 template <
class AFilter>
99 double prediction = predict(obj);
100 return prediction < m_cutValue ? NAN : prediction;
103 template <
class AFilter>
106 Weight extracted = Super::operator()(obj);
107 if (std::isnan(extracted)) {
110 return m_mvaExpert->predict();
114 template <
class AVarSet>
117 const std::string& defaultDBObjectName)
118 :
Super(std::make_unique<
AVarSet>(), defaultTrainingName, defaultCut, defaultDBObjectName)
122 template <
class AVarSet>
Class for accessing objects in the database.
The Module parameter list class.
Generic class that generates some named float values from a given object.
~MVAFilter()
Default destructor.
MVAFilter(const std::string &defaultTrainingName="", double defaultCut=NAN, const std::string &defaultDBObjectName="")
Constructor of the filter.
Filter based on a mva method.
virtual ~MVA()
Default destructor.
void initialize() override
Initialize the expert before event processing.
Weight operator()(const Object &obj) override
Function to object for its signalness.
void beginRun() override
Signal to load new run parameters.
MVA(std::unique_ptr< AVarSet > varSet, const std::string &identifier="", double defaultCut=NAN, const std::string &dbObjectName="")
Constructor of the filter.
virtual double predict(const Object &obj)
Evaluate the mva method.
void exposeParameters(ModuleParamList *moduleParamList, const std::string &prefix) override
Expose the set of parameters of the filter to the module parameter list.
AFilter Super
Type of the base class.
typename AFilter::Object Object
Type of pbject to be filtered.
void addParameter(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module list.
Abstract base class for different kinds of events.