Belle II Software  release-08-01-10
PyEstimator.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 #pragma once
9 
10 #include <tracking/trackFindingCDC/varsets/NamedFloatTuple.h>
11 #include <boost/python/object.hpp>
12 
13 namespace Belle2 {
18  namespace TrackFindingCDC {
19 
20  class NamedFloatTuple;
21 
23  class PyEstimator {
24  public:
28  explicit PyEstimator(const std::string& pickleFileName);
29 
31  double predict(const std::vector<double>& inputVariables);
32 
34  double predict(const std::vector<NamedFloatTuple*>& floatTuples);
35 
37  double predict(boost::python::object& array);
38 
39  private:
41  void unpickleEstimator(const std::string& pickleFileName);
42 
44  void expand(size_t nVars);
45 
46  private:
48  std::string m_pickleFileName;
49 
51  boost::python::object m_estimator;
52 
54  boost::python::object m_predict;
55 
57  boost::python::object m_array;
58 
60  size_t m_nCurrent = 0;
61 
64 
65  };
66 
67  }
69 }
Class to invoke a pretrained python estimator that follows the sklearn interface.
Definition: PyEstimator.h:23
void expand(size_t nVars)
Reserves space for at least n variable in the input array.
Definition: PyEstimator.cc:124
void unpickleEstimator(const std::string &pickleFileName)
Load the estimator object from the pickled file.
Definition: PyEstimator.cc:92
std::string m_pickleFileName
File name of the pickle file that contains the trained estimator.
Definition: PyEstimator.h:48
PyEstimator(const std::string &pickleFileName)
Construct the Estimator.
Definition: PyEstimator.cc:22
double predict(const std::vector< double > &inputVariables)
Call the predict method of the estimator.
Definition: PyEstimator.cc:42
boost::python::object m_estimator
Retrained python estimator object.
Definition: PyEstimator.h:51
boost::python::object m_predict
Python bound prediction method - cached to avoid repeated lookup.
Definition: PyEstimator.h:54
size_t m_nCurrent
Cache for the current length of the input array.
Definition: PyEstimator.h:60
bool m_is_binary_classification
Internal flag to keep track whether a binary classification with predict_proba is evaluated.
Definition: PyEstimator.h:63
boost::python::object m_array
Array to be served to the estimator.
Definition: PyEstimator.h:57
Abstract base class for different kinds of events.