Belle II Software development
MVAFilter.icc.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#pragma once
9
10#include <tracking/trackFindingCDC/filters/base/MVAFilter.dcl.h>
11
12#include <tracking/trackFindingCDC/mva/MVAExpert.h>
13
14#include <tracking/trackFindingCDC/filters/base/FilterOnVarSet.icc.h>
15
16#include <framework/core/ModuleParamList.templateDetails.h>
17#include <tracking/trackFindingCDC/utilities/StringManipulation.h>
18
19#include <tracking/trackFindingCDC/utilities/Named.h>
20#include <mva/interface/Dataset.h>
21#include <RtypesCore.h>
22
23#include <vector>
24#include <string>
25#include <memory>
26#include <cmath>
27
28namespace Belle2 {
34 namespace TrackFindingCDC {
35
36 template <class AFilter>
37 MVA<AFilter>::MVA(std::unique_ptr<AVarSet> varSet,
38 const std::string& identifier,
39 double defaultCut,
40 const std::string& dbObjectName)
41 : Super(std::move(varSet)), m_identifier(identifier), m_cutValue(defaultCut), m_DBPayloadName(dbObjectName)
42 {
43 }
44
45 template <class AFilter>
46 MVA<AFilter>::~MVA() = default;
47
48 template <class AFilter>
49 void MVA<AFilter>::exposeParameters(ModuleParamList* moduleParamList, const std::string& prefix)
50 {
51 Super::exposeParameters(moduleParamList, prefix);
52 moduleParamList->addParameter(prefixed(prefix, "cut"),
53 m_cutValue,
54 "The cut value of the mva output below which the object is rejected",
55 m_cutValue);
56
57 moduleParamList->addParameter(prefixed(prefix, "identifier"),
58 m_identifier,
59 "Database identifier of the expert of weight file name",
60 m_identifier);
62 moduleParamList->addParameter(prefixed(prefix, "DBPayloadName"),
63 m_DBPayloadName,
64 "Name of the DB payload containing weightfile name and the cut value. If a DB payload with both values is available and valid, it will override the values provided by parameters.",
65 m_DBPayloadName);
66 }
68 template <class AFilter>
70 {
71 Super::initialize();
72
73 DBObjPtr<TrackingMVAFilterParameters> mvaParameterPayload(m_DBPayloadName);
74 if (mvaParameterPayload.isValid()) {
75 m_identifier = mvaParameterPayload->getIdentifierName();
76 m_cutValue = mvaParameterPayload->getCutValue();
77 B2DEBUG(20, "MVAFilter: Using DBObject " << m_DBPayloadName << " with weightfile " << m_identifier << " and cut value " <<
78 m_cutValue << ".");
79 } else {
80 B2FATAL("MVAFilter: No valid MVAFilter payload with name " + m_DBPayloadName + " was found.");
81 }
82
83 std::vector<Named<Float_t*>> namedVariables = Super::getVarSet().getNamedVariables();
84
85 m_mvaExpert = std::make_unique<MVAExpert>(m_identifier, std::move(namedVariables));
86 m_mvaExpert->initialize();
87 }
88
89 template <class AFilter>
91 {
92 Super::beginRun();
93 m_mvaExpert->beginRun();
95 const auto& selectedVars = m_mvaExpert->getVariableNames();
96 const std::vector<Named<Float_t*>>& namedVariables = Super::getVarSet().getNamedVariables();
97 m_namedVariables.clear();
98 for (const auto& name : selectedVars) {
99
100 auto itNamedVariable = std::find_if(namedVariables.begin(),
101 namedVariables.end(),
102 [name](const Named<Float_t*>& namedVariable) {
103 return namedVariable.getName() == name;
104 });
105 if (itNamedVariable == namedVariables.end()) {
106 B2ERROR("Variable name " << name << " mismatch for MVA filter. " <<
107 "Could not find expected variable '" << name << "'");
108 }
109 m_namedVariables.push_back(*itNamedVariable);
110 }
112
113 template <class AFilter>
115 {
116 double prediction = predict(obj);
117 return prediction < m_cutValue ? NAN : prediction;
118 }
119
120 template <class AFilter>
121 double MVA<AFilter>::predict(const Object& obj)
122 {
123 Weight extracted = Super::operator()(obj);
124 if (std::isnan(extracted)) {
125 return NAN;
126 } else {
127 return m_mvaExpert->predict();
128 }
129 }
130
131 template <class AFilter>
132 std::vector<float> MVA<AFilter>::predict(const std::vector<Object*>& objs)
133 {
134 const int nFeature = m_namedVariables.size();
135 const int nRows = objs.size();
136 auto allFeatures = std::unique_ptr<float[]>(new float[nRows * nFeature]);
137 size_t iRow = 0;
138 for (const auto& obj : objs) {
139 if (Super::getVarSet().extract(obj)) {
140 for (int iFeature = 0; iFeature < nFeature; iFeature += 1) {
141 allFeatures[nFeature * iRow + iFeature] = *m_namedVariables[iFeature];
142 }
143 iRow += 1;
144 }
145 }
146 return m_mvaExpert->predict(allFeatures.get(), nFeature, nRows);
147 }
148
149 template <class AFilter>
150 std::vector<float> MVA<AFilter>::operator()(const std::vector<Object*>& objs)
151 {
152 auto out = predict(objs);
153 for (auto& res : out) {
154 res = res < m_cutValue ? NAN : res;
155 }
156 return out;
157 }
158
159 template <class AVarSet>
160 MVAFilter<AVarSet>::MVAFilter(const std::string& defaultTrainingName,
161 double defaultCut,
162 const std::string& defaultDBObjectName)
163 : Super(std::make_unique<AVarSet>(), defaultTrainingName, defaultCut, defaultDBObjectName)
164 {
165 }
166
167 template <class AVarSet>
169 }
171}
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
The Module parameter list class.
~MVAFilter()
Default destructor.
MVAFilter(const std::string &defaultTrainingName="", double defaultCut=NAN, const std::string &defaultDBObjectName="")
Constructor of the filter.
Filter based on a mva method.
Definition: MVAFilter.dcl.h:36
virtual ~MVA()
Default destructor.
void initialize() override
Initialize the expert before event processing.
Definition: MVAFilter.icc.h:69
Weight operator()(const Object &obj) override
Function to object for its signalness.
void beginRun() override
Signal to load new run parameters.
Definition: MVAFilter.icc.h:90
MVA(std::unique_ptr< AVarSet > varSet, const std::string &identifier="", double defaultCut=NAN, const std::string &dbObjectName="")
Constructor of the filter.
Definition: MVAFilter.icc.h:37
typename AFilter::Object Object
Type of the object to be analysed.
Definition: MVAFilter.dcl.h:44
virtual double predict(const Object &obj)
Evaluate the mva method.
void exposeParameters(ModuleParamList *moduleParamList, const std::string &prefix) override
Expose the set of parameters of the filter to the module parameter list.
Definition: MVAFilter.icc.h:49
A mixin class to attach a name to an object.
Definition: Named.h:23
Filter adapter to make a filter work on a set of variables.
void addParameter(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module list.
Abstract base class for different kinds of events.
STL namespace.