Belle II Software  release-08-01-10
MVAFilter.icc.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 #pragma once
9 
10 #include <tracking/trackFindingCDC/filters/base/MVAFilter.dcl.h>
11 
12 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
13 
14 #include <tracking/trackFindingCDC/filters/base/FilterOnVarSet.icc.h>
15 
16 #include <framework/core/ModuleParamList.templateDetails.h>
17 #include <tracking/trackFindingCDC/utilities/StringManipulation.h>
18 
19 #include <tracking/trackFindingCDC/utilities/Named.h>
20 
21 #include <RtypesCore.h>
22 
23 #include <vector>
24 #include <string>
25 #include <memory>
26 #include <cmath>
27 
28 namespace Belle2 {
34  namespace TrackFindingCDC {
35 
36  template <class AFilter>
37  MVA<AFilter>::MVA(std::unique_ptr<AVarSet> varSet,
38  const std::string& identifier,
39  double defaultCut,
40  const std::string& dbObjectName)
41  : Super(std::move(varSet)), m_identifier(identifier), m_cutValue(defaultCut), m_DBPayloadName(dbObjectName)
42  {
43  }
44 
45  template <class AFilter>
46  MVA<AFilter>::~MVA() = default;
47 
48  template <class AFilter>
49  void MVA<AFilter>::exposeParameters(ModuleParamList* moduleParamList, const std::string& prefix)
50  {
51  Super::exposeParameters(moduleParamList, prefix);
52  moduleParamList->addParameter(prefixed(prefix, "cut"),
53  m_cutValue,
54  "The cut value of the mva output below which the object is rejected",
55  m_cutValue);
56 
57  moduleParamList->addParameter(prefixed(prefix, "identifier"),
58  m_identifier,
59  "Database identfier of the expert of weight file name",
60  m_identifier);
61 
62  moduleParamList->addParameter(prefixed(prefix, "DBPayloadName"),
63  m_DBPayloadName,
64  "Name of the DB payload containing weightfile name and the cut value. If a DB payload with both values is available and valid, it will override the values provided by parameters.",
65  m_DBPayloadName);
66  }
67 
68  template <class AFilter>
70  {
71  Super::initialize();
72 
73  DBObjPtr<TrackingMVAFilterParameters> mvaParameterPayload(m_DBPayloadName);
74  if (mvaParameterPayload.isValid()) {
75  m_identifier = mvaParameterPayload->getIdentifierName();
76  m_cutValue = mvaParameterPayload->getCutValue();
77  B2DEBUG(20, "MVAFilter: Using DBObject " << m_DBPayloadName << " with weightfile " << m_identifier << " and cut value " <<
78  m_cutValue << ".");
79  } else {
80  B2FATAL("MVAFilter: No valid MVAFilter payload with name " + m_DBPayloadName + " was found.");
81  }
82 
83  std::vector<Named<Float_t*>> namedVariables = Super::getVarSet().getNamedVariables();
84 
85  m_mvaExpert = std::make_unique<MVAExpert>(m_identifier, std::move(namedVariables));
86  m_mvaExpert->initialize();
87  }
88 
89  template <class AFilter>
91  {
92  Super::beginRun();
93  m_mvaExpert->beginRun();
94  }
95 
96  template <class AFilter>
97  Weight MVA<AFilter>::operator()(const Object& obj)
98  {
99  double prediction = predict(obj);
100  return prediction < m_cutValue ? NAN : prediction;
101  }
102 
103  template <class AFilter>
104  double MVA<AFilter>::predict(const Object& obj)
105  {
106  Weight extracted = Super::operator()(obj);
107  if (std::isnan(extracted)) {
108  return NAN;
109  } else {
110  return m_mvaExpert->predict();
111  }
112  }
113 
114  template <class AVarSet>
115  MVAFilter<AVarSet>::MVAFilter(const std::string& defaultTrainingName,
116  double defaultCut,
117  const std::string& defaultDBObjectName)
118  : Super(std::make_unique<AVarSet>(), defaultTrainingName, defaultCut, defaultDBObjectName)
119  {
120  }
121 
122  template <class AVarSet>
123  MVAFilter<AVarSet>::~MVAFilter() = default;
124  }
126 }
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
The Module parameter list class.
Generic class that generates some named float values from a given object.
Definition: BaseVarSet.h:33
~MVAFilter()
Default destructor.
MVAFilter(const std::string &defaultTrainingName="", double defaultCut=NAN, const std::string &defaultDBObjectName="")
Constructor of the filter.
Filter based on a mva method.
Definition: MVAFilter.dcl.h:36
virtual ~MVA()
Default destructor.
void initialize() override
Initialize the expert before event processing.
Definition: MVAFilter.icc.h:69
Weight operator()(const Object &obj) override
Function to object for its signalness.
Definition: MVAFilter.icc.h:97
void beginRun() override
Signal to load new run parameters.
Definition: MVAFilter.icc.h:90
MVA(std::unique_ptr< AVarSet > varSet, const std::string &identifier="", double defaultCut=NAN, const std::string &dbObjectName="")
Constructor of the filter.
Definition: MVAFilter.icc.h:37
virtual double predict(const Object &obj)
Evaluate the mva method.
void exposeParameters(ModuleParamList *moduleParamList, const std::string &prefix) override
Expose the set of parameters of the filter to the module parameter list.
Definition: MVAFilter.icc.h:49
AFilter Super
Type of the base class.
typename AFilter::Object Object
Type of pbject to be filtered.
void addParameter(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module list.
Abstract base class for different kinds of events.