Belle II Software  release-05-02-19
PyEstimator.h
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2015 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Nils Braun *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 #pragma once
11 
12 #include <tracking/trackFindingCDC/varsets/NamedFloatTuple.h>
13 #include <boost/python/object.hpp>
14 
15 namespace Belle2 {
20  namespace TrackFindingCDC {
21 
22  class NamedFloatTuple;
23 
25  class PyEstimator {
26  public:
30  explicit PyEstimator(const std::string& pickleFileName);
31 
33  double predict(const std::vector<double>& inputVariables);
34 
36  double predict(const std::vector<NamedFloatTuple*>& floatTuples);
37 
39  double predict(boost::python::object& array);
40 
41  private:
43  void unpickleEstimator(const std::string& pickleFileName);
44 
46  void expand(size_t nVars);
47 
48  private:
50  std::string m_pickleFileName;
51 
53  boost::python::object m_estimator;
54 
56  boost::python::object m_predict;
57 
59  boost::python::object m_array;
60 
62  size_t m_nCurrent = 0;
63 
65  bool m_is_binary_classification = false;
66 
67  };
68 
69  }
71 }
Belle2::TrackFindingCDC::PyEstimator::m_nCurrent
size_t m_nCurrent
Cache for the current length of the input array.
Definition: PyEstimator.h:70
Belle2::TrackFindingCDC::PyEstimator::m_pickleFileName
std::string m_pickleFileName
File name of the pickle file that contains the trained estimator.
Definition: PyEstimator.h:58
Belle2::TrackFindingCDC::PyEstimator::m_array
boost::python::object m_array
Array to be served to the estimator.
Definition: PyEstimator.h:67
Belle2::TrackFindingCDC::PyEstimator::expand
void expand(size_t nVars)
Reserves space for at least n variable in the input array.
Definition: PyEstimator.cc:125
Belle2::TrackFindingCDC::PyEstimator::m_is_binary_classification
bool m_is_binary_classification
Internal flag to keep track whether a binary classification with predict_proba is evaluated.
Definition: PyEstimator.h:73
Belle2::TrackFindingCDC::PyEstimator::m_estimator
boost::python::object m_estimator
Retrained python estimator object.
Definition: PyEstimator.h:61
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::TrackFindingCDC::PyEstimator::PyEstimator
PyEstimator(const std::string &pickleFileName)
Construct the Estimator.
Definition: PyEstimator.cc:24
Belle2::TrackFindingCDC::PyEstimator::unpickleEstimator
void unpickleEstimator(const std::string &pickleFileName)
Load the estimator object from the pickled file.
Definition: PyEstimator.cc:93
Belle2::TrackFindingCDC::PyEstimator::m_predict
boost::python::object m_predict
Python bound prediction method - cached to avoid repeated lookup.
Definition: PyEstimator.h:64
Belle2::TrackFindingCDC::PyEstimator::predict
double predict(const std::vector< double > &inputVariables)
Call the predict method of the estimator.
Definition: PyEstimator.cc:43