Belle II Software  release-05-01-25
MVAFilter.icc.h
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Oliver Frost *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 #pragma once
11 
12 #include <tracking/trackFindingCDC/filters/base/MVAFilter.dcl.h>
13 
14 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
15 
16 #include <tracking/trackFindingCDC/filters/base/FilterOnVarSet.icc.h>
17 
18 #include <framework/core/ModuleParamList.templateDetails.h>
19 #include <tracking/trackFindingCDC/utilities/StringManipulation.h>
20 
21 #include <tracking/trackFindingCDC/utilities/Named.h>
22 
23 #include <RtypesCore.h>
24 
25 #include <vector>
26 #include <string>
27 #include <memory>
28 #include <cmath>
29 
30 namespace Belle2 {
36  namespace TrackFindingCDC {
37 
38  template <class AFilter>
39  MVA<AFilter>::MVA(std::unique_ptr<AVarSet> varSet,
40  const std::string& identifier,
41  double defaultCut)
42  : Super(std::move(varSet))
43  , m_param_cut(defaultCut)
44  , m_param_identifier(identifier)
45  {
46  }
47 
48  template <class AFilter>
49  MVA<AFilter>::~MVA() = default;
50 
51  template <class AFilter>
52  void MVA<AFilter>::exposeParameters(ModuleParamList* moduleParamList, const std::string& prefix)
53  {
54  Super::exposeParameters(moduleParamList, prefix);
55  moduleParamList->addParameter(prefixed(prefix, "cut"),
56  m_param_cut,
57  "The cut value of the mva output below which the object is rejected",
58  m_param_cut);
59 
60  moduleParamList->addParameter(prefixed(prefix, "identifier"),
61  m_param_identifier,
62  "Database identfier of the expert of weight file name",
63  m_param_identifier);
64  }
65 
66  template <class AFilter>
68  {
69  Super::initialize();
70  std::vector<Named<Float_t*>> namedVariables = Super::getVarSet().getNamedVariables();
71  m_mvaExpert = std::make_unique<MVAExpert>(m_param_identifier, std::move(namedVariables));
72  m_mvaExpert->initialize();
73  }
74 
75  template <class AFilter>
77  {
78  Super::beginRun();
79  m_mvaExpert->beginRun();
80  }
81 
82  template <class AFilter>
83  Weight MVA<AFilter>::operator()(const Object& obj)
84  {
85  double prediction = predict(obj);
86  return prediction < m_param_cut ? NAN : prediction;
87  }
88 
89  template <class AFilter>
90  double MVA<AFilter>::predict(const Object& obj)
91  {
92  Weight extracted = Super::operator()(obj);
93  if (std::isnan(extracted)) {
94  return NAN;
95  } else {
96  return m_mvaExpert->predict();
97  }
98  }
99 
100  template <class AVarSet>
101  MVAFilter<AVarSet>::MVAFilter(const std::string& defaultTrainingName,
102  double defaultCut)
103  : Super(std::make_unique<AVarSet>(), defaultTrainingName, defaultCut)
104  {
105  }
106 
107  template <class AVarSet>
108  MVAFilter<AVarSet>::~MVAFilter() = default;
109  }
111 }
Belle2::TrackFindingCDC::MVAFilter
Convience template to create a mva filter for a set of variables.
Definition: MVAFilter.dcl.h:95
Belle2::ModuleParamList::addParameter
void addParameter(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module list.
Definition: ModuleParamList.templateDetails.h:38
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::TrackFindingCDC::MVA
Filter based on a mva method.
Definition: MVAFilter.dcl.h:43
Belle2::TrackFindingCDC::MVA::MVA
MVA(std::unique_ptr< AVarSet > varSet, const std::string &identifier="", double defaultCut=NAN)
Constructor of the filter.
Definition: MVAFilter.icc.h:47
Belle2::ModuleParamList
The Module parameter list class.
Definition: ModuleParamList.h:46